Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import PIL | |
| from PIL import Image | |
| import torch | |
| import torchvision | |
| from torchvision import datasets, transforms | |
| import vision_transformer as vits | |
| arch = "vit_small" | |
| mode = "simpool" | |
| gamma = None | |
| patch_size = 16 | |
| num_classes = 0 | |
| checkpoint = "checkpoints/vits_dino_simpool_no_gamma_ep100.pth" | |
| checkpoint_key = "teacher" | |
| cm = plt.get_cmap('viridis') | |
| attn_map_size = 224 | |
| width_display = 290 | |
| height_display = 290 | |
| example_dir = "examples/" | |
| example_list = [[example_dir + example] for example in os.listdir(example_dir)] | |
| #example_list = "n03017168_54500.JPEG" | |
| # Load model | |
| model = vits.__dict__[arch]( | |
| mode=mode, | |
| gamma=gamma, | |
| patch_size=patch_size, | |
| num_classes=num_classes, | |
| ) | |
| state_dict = torch.load(checkpoint) | |
| state_dict = state_dict[checkpoint_key] | |
| state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} | |
| state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} | |
| state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict()} | |
| msg = model.load_state_dict(state_dict, strict=True) | |
| model.eval() | |
| def get_attention_map(img, resolution=32): | |
| input_size = resolution * 14 | |
| data_transforms = transforms.Compose([ | |
| transforms.Resize((input_size, input_size), interpolation=3), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
| ]) | |
| x = data_transforms(img) | |
| attn = model.get_simpool_attention(x[None, :, :, :]) | |
| attn = attn.reshape(1, 1, input_size//patch_size, input_size//patch_size) | |
| attn = attn/attn.sum() | |
| attn = attn.squeeze() | |
| attn = (attn-(attn).min())/((attn).max()-(attn).min()) | |
| attn = torch.threshold(attn, 0.1, 0) | |
| attn_img = Image.fromarray(np.uint8(cm(attn.detach().numpy())*255)).convert('RGB') | |
| attn_img = attn_img.resize((attn_map_size, attn_map_size), resample=Image.NEAREST) | |
| return attn_img | |
| attention_interface = gr.Interface( | |
| fn=get_attention_map, | |
| inputs=[ | |
| gr.Image(type="pil", label="Input Image"), | |
| gr.Dropdown(choices=[16, 32, 64, 128], | |
| label="Attention Map Resolution", | |
| value=32) | |
| ], | |
| outputs=gr.Image(type="pil", label="SimPool Attention Map", width=width_display, height=height_display), | |
| examples=example_list, | |
| title="Explore the Attention Maps of SimPool🔍", | |
| description="Upload or use one of the selected images to explore the intricate focus areas of a ViT-S model with SimPool, trained on ImageNet-1k, under supervision." | |
| ) | |
| demo = gr.TabbedInterface([attention_interface], | |
| ["Visualize Attention Maps"], title="SimPool Attention Map Visualizer 🌌") | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |