Uh oh!
There was an error while loading.Please reload this page.
- Notifications
You must be signed in to change notification settings - Fork5k
-
Hi, I want to extract attention map from pretrained vision transformer for specific image. |
BetaWas this translation helpful?Give feedback.
All reactions
Replies: 12 comments 28 replies
-
This is toy examples to visualize whole attention map and attention map only for class token. (seehere for more information) importnumpyasnpfromPILimportImageimportmatplotlib.pyplotaspltfromtimm.modelsimportcreate_modelimporttorch.nn.functionalasFfromtorchvision.transformsimportCompose,Resize,CenterCrop,Normalize,ToTensordefto_tensor(img):transform_fn=Compose([Resize(249,3),CenterCrop(224),ToTensor(),Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])returntransform_fn(img)defshow_img(img):img=np.asarray(img)plt.figure(figsize=(10,10))plt.imshow(img)plt.axis('off')plt.show()defshow_img2(img1,img2,alpha=0.8):img1=np.asarray(img1)img2=np.asarray(img2)plt.figure(figsize=(10,10))plt.imshow(img1)plt.imshow(img2,alpha=alpha)plt.axis('off')plt.show()defmy_forward_wrapper(attn_obj):defmy_forward(x):B,N,C=x.shapeqkv=attn_obj.qkv(x).reshape(B,N,3,attn_obj.num_heads,C//attn_obj.num_heads).permute(2,0,3,1,4)q,k,v=qkv.unbind(0)# make torchscript happy (cannot use tensor as tuple)attn= (q @k.transpose(-2,-1))*attn_obj.scaleattn=attn.softmax(dim=-1)attn=attn_obj.attn_drop(attn)attn_obj.attn_map=attnattn_obj.cls_attn_map=attn[:, :,0,2:]x= (attn @v).transpose(1,2).reshape(B,N,C)x=attn_obj.proj(x)x=attn_obj.proj_drop(x)returnxreturnmy_forwardimg=Image.open('n02102480_Sussex_spaniel.JPEG')x=to_tensor(img)model=create_model('deit_small_distilled_patch16_224',pretrained=True)model.blocks[-1].attn.forward=my_forward_wrapper(model.blocks[-1].attn)y=model(x.unsqueeze(0))attn_map=model.blocks[-1].attn.attn_map.mean(dim=1).squeeze(0).detach()cls_weight=model.blocks[-1].attn.cls_attn_map.mean(dim=1).view(14,14).detach()img_resized=x.permute(1,2,0)*0.5+0.5cls_resized=F.interpolate(cls_weight.view(1,1,14,14), (224,224),mode='bilinear').view(224,224,1)show_img(img)show_img(attn_map)show_img(cls_weight)show_img(img_resized)show_img2(img_resized,cls_resized,alpha=0.8) attention map for last layer (198 x 198 (=196(img) + 1(cls) + 1(distill))) |
BetaWas this translation helpful?Give feedback.
All reactions
👍 20❤️ 9
-
The attention scores are a bit scattered here, usually the cls token focuses on certain patches and they are consistent. Are you sure taking the mean is a good idea? Also it might be better to rollout the cls token over multiple blocks. |
BetaWas this translation helpful?Give feedback.
All reactions
-
So I tried taking the product of the cls token weights over other blocks like this but it shows the error RuntimeError Traceback (most recent call last)Cell In[18], line 13 11# Forward pass through all blocks 12forblockin model.blocks:---> 13 x, attn_map = block.attn.forward(x) 14 outputs.append(x) 15 attn_maps.append(attn_map)Cell In[8], line 4,in my_forward_wrapper.<locals>.my_forward(x) 2 def my_forward(x): 3 B, N, C = x.shape----> 4 qkv = attn_obj.qkv(x).reshape(B, N, 3, attn_obj.num_heads, C // attn_obj.num_heads).permute(2, 0, 3, 1, 4) 5 q, k, v = qkv.unbind(0)# make torchscript happy (cannot use tensor as tuple) 7 attn = (q @ k.transpose(-2, -1))* attn_obj.scaleFile /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518,in Module._wrapped_call_impl(self,*args,**kwargs) 1516return self._compiled_call_impl(*args,**kwargs)# type: ignore[misc] 1517 else:-> 1518return self._call_impl(*args,**kwargs)File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527,in Module._call_impl(self,*args,**kwargs) 1522# If we don't have any hooks, we want to skip the rest of the logic in 1523# this function, and just call forward. 1524if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks):-> 1527return forward_call(*args,**kwargs) 1529 try: 1530 result = NoneFile /opt/conda/lib/python3.10/site-packages/torch/nn/modules/linear.py:114,in Linear.forward(self, input) 113 def forward(self, input: Tensor) -> Tensor:--> 114return F.linear(input, self.weight, self.bias)RuntimeError: mat1 and mat2 shapes cannot be multiplied (672x224 and 384x1152) Code: model=create_model('deit_small_distilled_patch16_224',pretrained=True)# Replace forward function in all blocksforblockinmodel.blocks:block.attn.forward=my_forward_wrapper(block.attn)# Forward pass through the modeloutputs= []attn_maps= []cls_weights= []# Forward pass through all blocksforblockinmodel.blocks:x,attn_map=block.attn.forward(x)outputs.append(x)attn_maps.append(attn_map)cls_weights.append(block.attn.cls_attn_map.min(dim=1).values.view(14,14).detach())# Combine class scores of all blockscls_weight_combined=torch.prod(torch.stack(cls_weights),dim=0)# Resize input image and class weightsimg_resized=x.permute(0,2,3,1)*0.5+0.5cls_resized=F.interpolate(cls_weight_combined.view(1,1,14,14), (224,224),mode='bilinear').view(224,224,1)# Visualizeshow_img(image)show_img(attn_maps[-1])# Attention map from the last blockshow_img(cls_weight_combined)# Combined class weightsshow_img(img_resized.squeeze())# Squeeze the batch dimensionshow_img2(img_resized.squeeze(),cls_resized,alpha=0.8)# Squeeze the batch dimension |
BetaWas this translation helpful?Give feedback.
All reactions
-
Hi @arnavs04 That's a good question. You're right. There may be a better way to merge different attention maps. The code I give contains very basic ways to merge different attention maps. The error you posted looks like a dimension mismatch error:
I suggest applying a patch embedding layer to the input image before passing it into the attention block. You should also ensure that other layers (e.g., normalization, mlp) are applied appropriately. Thank you. Hankyul. |
BetaWas this translation helpful?Give feedback.
All reactions
👍 1
-
Ohh, thank you for the help!! |
BetaWas this translation helpful?Give feedback.
All reactions
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
-
@kiashann |
BetaWas this translation helpful?Give feedback.
All reactions
-
Hi@mae338 I am happy to help you.
Thank you. Hankyul |
BetaWas this translation helpful?Give feedback.
All reactions
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
-
Thank you so much, Hankyul. I appreciate your effort to help me. I'd be pleased if you could help me understand this as well: |
BetaWas this translation helpful?Give feedback.
All reactions
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
-
Hi@mae338 I hope this could help you.
Thank you. Hankyul |
BetaWas this translation helpful?Give feedback.
All reactions
❤️ 1
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
-
Hi@hankyul2, I’d like to view the outputs of a pretrained vit model (vit-base-batch-16 224) , especially the input to the mlp head. I tried the same function but it didn’t work. Any suggestions? |
BetaWas this translation helpful?Give feedback.
All reactions
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
-
Hi@mae338 I hope this can help you. You can extract the ultimate features of a pre-trained ViT by # dependency!wgethttps://user-images.githubusercontent.com/31476895/167238573-b0cc3a6d-d3ee-462b-8630-a8f253e69bb2.png!pipinstall-Uqfastaitimm==0.6.13huggingface_hub########################################### codeimportnumpyasnpfromPILimportImagefromtimm.modelsimportcreate_modelfromtorchimportnnfromtorchvision.transformsimportCompose,Resize,CenterCrop,Normalize,ToTensordefto_tensor(img):transform_fn=Compose([Resize(249),CenterCrop(224),ToTensor(),Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])returntransform_fn(img)img=Image.open('167238573-b0cc3a6d-d3ee-462b-8630-a8f253e69bb2.png').convert('RGB')x=to_tensor(img)model=create_model('deit_small_distilled_patch16_224',pretrained=True)x=model.forward_features(x.unsqueeze(0))y=model.forward_head(x,pre_logits=True)print(y.shape)########################################### outputtorch.Size([1,384])########################################## Thank you. Hankyul |
BetaWas this translation helpful?Give feedback.
All reactions
❤️ 1
-
I would like to apply this code to the 'vit_small_patch16_384' model from timm. How should I modify the code for this purpose?
|
BetaWas this translation helpful?Give feedback.
All reactions
-
attn_map=model.blocks[-1].attn.attn_map.mean(dim=1).squeeze(0).detach()cls_weight=model.blocks[-1].attn.cls_attn_map.mean(dim=1).view(24,24).detach()# (14->24)img_resized=x.permute(1,2,0)*0.5+0.5cls_resized=F.interpolate(cls_weight.view(1,1,24,24), (384,384),mode='bilinear').view(384,384,1)# (14->24), (224->384)
Thank you. Hankyul |
BetaWas this translation helpful?Give feedback.
All reactions
👍 1
-
Well understood. Thank you for reply! |
BetaWas this translation helpful?Give feedback.
All reactions
-
@hankyul2
2.If I want to apply this model to a regression problem and estimate multiple parameters with a single model, is it possible to check the attention map for each parameter? In a multi-class classification problem, we can check the attention map when a certain class is classified to see which areas are being focused on to recognize that class. However, in the case of a regression problem, how should this be done? |
BetaWas this translation helpful?Give feedback.
All reactions
-
So, this doesn't include the visualization helpers yet, but have added a simpler extraction helper to get the attention activations via one of two methods, fx or hooks. WIP but can be seenhttps://github.com/huggingface/pytorch-image-models/pull/2168/files#diff-358e0d5feb2c109ff53d21bc4fa8a6af94566be622b0f1167316216b0036b8b3
|
BetaWas this translation helpful?Give feedback.
All reactions
👍 4
-
Hello@rwightman, thank you for showing us how to extract the attention layers and maintaining your wonderful timm library. I would like to ask, I am using fast.ai alongside with timm models. I've trained a ViT for a classification task in my own dataset. What would be the best way to load the weights of my ViT and visualize the attention activations over my input image using the timm visualization helpers?, thank you |
BetaWas this translation helpful?Give feedback.
All reactions
-
I have a bit of a trivial question here and slightly off-topic to what@rwightman discussed, but related to the attention map that@hankyul2 posted in their initial visualization: I'm assuming that the ViT in OPs question is being trained for a classification task. Is it necessary that the attention map must contain high activations along the diagonal, similar to the attention maps generated while training seq2seq models? This seems quite unintuitive to me, as I wouldn't expect a classification ViT to attend to self-attend to patches in such a manner. I would instead expect every patch to attend to certain "interesting" parts of the image, almost like having certain columns in the attention map with high activations. Is my understanding incorrect, or should we expect the attention map to have high activations along the diagonal? Thank you in advance! |
BetaWas this translation helpful?Give feedback.
All reactions
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
-
@SarthakJShetty You can see the attention maps here after every block for DeiT, in the initial layers you can see the clear diagonal and how the attention maps then change. Please check out mynotebook. It contains visualizations for attention rollout but also without. I hope it can be of help. My assumption, or rather intuition, behind these diagonal patterns is centered around the class token, which I believe is the reason we don't see information flow immediately, as the class token takes in information from the tokens without much change. However, this is what I think for shallow layers; I only believe that in the later and deeper layers, the attention starts to play out. |
BetaWas this translation helpful?Give feedback.
All reactions
👍 1
-
Thank you for the clarification @arnavs04! This helps. Now that I think about it, the visualizations that I'm observing are almost always like Attention Map 5 and beyond. Quick question: when you say "the attention starts to play out.", you mean when the queries actually start attending to relevant parts (and not just predominately self-attending in a diagonal fashion) of the image correct? i.e when the attention actually starts looking non-diagonal and like the attention maps 5-12? |
BetaWas this translation helpful?Give feedback.
All reactions
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
-
@SarthakJShetty In the diagram, you can see how information moves gradually from local to global, resembling the behaviour of a CNN (hence the diagonal structure slowly transitioning to a more globally attended pattern). This example might not be fully representative, as there is usually some global information shared even in the early stages of a vision transformer. Here is a paper that maybe clear your doubts:Do Vision Transformers See Like Convolutional Neural Networks. I haven't gone through it completely as I had kept it on my reading list. I am currently working on Vision Transformers. If you have any questions or would like to discuss them with me, let me know! I'll be glad to help! |
BetaWas this translation helpful?Give feedback.
All reactions
-
FYI there's a fix on main for the node/module matching so that outputs will remain in order of traversal (usually matches order of forward pass, at least for timm models) regardless of how many matching names/wildcards are specified. |
BetaWas this translation helpful?Give feedback.
All reactions
-
Hi!@hankyul2 Thanks for your excellent explanation above. I understood most of it but was still confused about Why Thanks! |
BetaWas this translation helpful?Give feedback.
All reactions
-
This is deit architecture which has both a class token and a distillation token for prediction. |
BetaWas this translation helpful?Give feedback.
All reactions
👍 2
-
Thanks for your response @arnavs04! The similarity scores you mentioned are the dot product of two vectors from the query and key matrices. Is that right? |
BetaWas this translation helpful?Give feedback.
All reactions
-
@zichunxx Yup exactly! The (n+2) x (n+2) attention matrix. |
BetaWas this translation helpful?Give feedback.
All reactions
👍 1
-
Thanks for your generous help! @arnavs04 I have read your notebook which is very thorough and helpful. I have noticed that some vision transformers are implemented as an encoder without the cls token. In this situation, how do we plot the overlaid image to illustrate which patch is watched with a higher weight? Thanks! |
BetaWas this translation helpful?Give feedback.
All reactions
👍 1
-
I apologize for the delay, I didn't see the reply. The attribution map is resized with bilinear interpolation to fit the H x W resolution as the original map. This heatmap is now taken and with 0.5 x heat_map + 0.5 x original_image we get our saliency map. Obviously you can tweak the values instead of 0.5 and 0.5 respectively. This is the goal of post-hoc model agnostic explainability methods for vision transformers. |
BetaWas this translation helpful?Give feedback.
All reactions
👍 2
-
https://github.com/facebookresearch/dino/blob/main/visualize_attention.py this might be helpful |
BetaWas this translation helpful?Give feedback.
All reactions
-
BetaWas this translation helpful?Give feedback.
All reactions
👍 2
-
Awesome |
BetaWas this translation helpful?Give feedback.
All reactions
-
Thank you for your letter, I will reply ASAP.BWTianwen Zhou |
BetaWas this translation helpful?Give feedback.
All reactions
-
Thank you for your letter, I will reply ASAP.BWTianwen Zhou |
BetaWas this translation helpful?Give feedback.
All reactions
-
BetaWas this translation helpful?Give feedback.
All reactions
-
Thank you for your letter, I will reply ASAP.BWTianwen Zhou |
BetaWas this translation helpful?Give feedback.