Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Visualize attention map for vision transformer#1232

Unanswered
kiashann asked this question inQ&A
Discussion options

Hi, I want to extract attention map from pretrained vision transformer for specific image.
How I can do that?

You must be logged in to vote

Replies: 12 comments 28 replies

Comment options

Hi@kiashann

This is toy examples to visualize whole attention map and attention map only for class token. (seehere for more information)

importnumpyasnpfromPILimportImageimportmatplotlib.pyplotaspltfromtimm.modelsimportcreate_modelimporttorch.nn.functionalasFfromtorchvision.transformsimportCompose,Resize,CenterCrop,Normalize,ToTensordefto_tensor(img):transform_fn=Compose([Resize(249,3),CenterCrop(224),ToTensor(),Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])returntransform_fn(img)defshow_img(img):img=np.asarray(img)plt.figure(figsize=(10,10))plt.imshow(img)plt.axis('off')plt.show()defshow_img2(img1,img2,alpha=0.8):img1=np.asarray(img1)img2=np.asarray(img2)plt.figure(figsize=(10,10))plt.imshow(img1)plt.imshow(img2,alpha=alpha)plt.axis('off')plt.show()defmy_forward_wrapper(attn_obj):defmy_forward(x):B,N,C=x.shapeqkv=attn_obj.qkv(x).reshape(B,N,3,attn_obj.num_heads,C//attn_obj.num_heads).permute(2,0,3,1,4)q,k,v=qkv.unbind(0)# make torchscript happy (cannot use tensor as tuple)attn= (q @k.transpose(-2,-1))*attn_obj.scaleattn=attn.softmax(dim=-1)attn=attn_obj.attn_drop(attn)attn_obj.attn_map=attnattn_obj.cls_attn_map=attn[:, :,0,2:]x= (attn @v).transpose(1,2).reshape(B,N,C)x=attn_obj.proj(x)x=attn_obj.proj_drop(x)returnxreturnmy_forwardimg=Image.open('n02102480_Sussex_spaniel.JPEG')x=to_tensor(img)model=create_model('deit_small_distilled_patch16_224',pretrained=True)model.blocks[-1].attn.forward=my_forward_wrapper(model.blocks[-1].attn)y=model(x.unsqueeze(0))attn_map=model.blocks[-1].attn.attn_map.mean(dim=1).squeeze(0).detach()cls_weight=model.blocks[-1].attn.cls_attn_map.mean(dim=1).view(14,14).detach()img_resized=x.permute(1,2,0)*0.5+0.5cls_resized=F.interpolate(cls_weight.view(1,1,14,14), (224,224),mode='bilinear').view(224,224,1)show_img(img)show_img(attn_map)show_img(cls_weight)show_img(img_resized)show_img2(img_resized,cls_resized,alpha=0.8)

original image
org

attention map for last layer (198 x 198 (=196(img) + 1(cls) + 1(distill)))
attn

class attention map for last layer (14 x 14)
cls_attn

class attention map over image
cls_img

You must be logged in to vote
4 replies
@arnavsm
Comment options

The attention scores are a bit scattered here, usually the cls token focuses on certain patches and they are consistent. Are you sure taking the mean is a good idea? Also it might be better to rollout the cls token over multiple blocks.

@arnavsm
Comment options

So I tried taking the product of the cls token weights over other blocks like this but it shows the error

RuntimeError                              Traceback (most recent call last)Cell In[18], line 13     11# Forward pass through all blocks     12forblockin model.blocks:---> 13     x, attn_map = block.attn.forward(x)     14     outputs.append(x)     15     attn_maps.append(attn_map)Cell In[8], line 4,in my_forward_wrapper.<locals>.my_forward(x)      2 def my_forward(x):      3     B, N, C = x.shape----> 4     qkv = attn_obj.qkv(x).reshape(B, N, 3, attn_obj.num_heads, C // attn_obj.num_heads).permute(2, 0, 3, 1, 4)      5     q, k, v = qkv.unbind(0)# make torchscript happy (cannot use tensor as tuple)      7     attn = (q @ k.transpose(-2, -1))* attn_obj.scaleFile /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518,in Module._wrapped_call_impl(self,*args,**kwargs)   1516return self._compiled_call_impl(*args,**kwargs)# type: ignore[misc]   1517 else:-> 1518return self._call_impl(*args,**kwargs)File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527,in Module._call_impl(self,*args,**kwargs)   1522# If we don't have any hooks, we want to skip the rest of the logic in   1523# this function, and just call forward.   1524if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks   1525         or _global_backward_pre_hooks or _global_backward_hooks   1526         or _global_forward_hooks or _global_forward_pre_hooks):-> 1527return forward_call(*args,**kwargs)   1529 try:   1530     result = NoneFile /opt/conda/lib/python3.10/site-packages/torch/nn/modules/linear.py:114,in Linear.forward(self, input)    113 def forward(self, input: Tensor) -> Tensor:--> 114return F.linear(input, self.weight, self.bias)RuntimeError: mat1 and mat2 shapes cannot be multiplied (672x224 and 384x1152)

Code:

model=create_model('deit_small_distilled_patch16_224',pretrained=True)# Replace forward function in all blocksforblockinmodel.blocks:block.attn.forward=my_forward_wrapper(block.attn)# Forward pass through the modeloutputs= []attn_maps= []cls_weights= []# Forward pass through all blocksforblockinmodel.blocks:x,attn_map=block.attn.forward(x)outputs.append(x)attn_maps.append(attn_map)cls_weights.append(block.attn.cls_attn_map.min(dim=1).values.view(14,14).detach())# Combine class scores of all blockscls_weight_combined=torch.prod(torch.stack(cls_weights),dim=0)# Resize input image and class weightsimg_resized=x.permute(0,2,3,1)*0.5+0.5cls_resized=F.interpolate(cls_weight_combined.view(1,1,14,14), (224,224),mode='bilinear').view(224,224,1)# Visualizeshow_img(image)show_img(attn_maps[-1])# Attention map from the last blockshow_img(cls_weight_combined)# Combined class weightsshow_img(img_resized.squeeze())# Squeeze the batch dimensionshow_img2(img_resized.squeeze(),cls_resized,alpha=0.8)# Squeeze the batch dimension
@hankyul2
Comment options

Hi @arnavs04

That's a good question. You're right. There may be a better way to merge different attention maps. The code I give contains very basic ways to merge different attention maps.

The error you posted looks like a dimension mismatch error:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (672x224 and 384x1152)

I suggest applying a patch embedding layer to the input image before passing it into the attention block. You should also ensure that other layers (e.g., normalization, mlp) are applied appropriately.

Thank you.

Hankyul.

@arnavsm
Comment options

Ohh, thank you for the help!!

Comment options

@kiashann
Thank you for your valuable code. the whole code is working fine but I just need to understand how these lines work : model.blocks[-1].attn.forward=my_forward_wrapper(model.blocks[-1].attn) and attn_map = model.blocks[-1].attn.attn_map.mean(dim=1).squeeze(0).detach() so I ran these lines of code in the console : attn_obj=model.blocks[-1].attn & qkv = attn_obj.qkv(x) but got this error (RuntimeError: mat1 and mat2 shapes cannot be multiplied (672x224 and 384x1152)). I'd like to know why. I need to make sure whether x is the transformed image or some other variable – When I debugged the code I found out that B,N,C from x.shape are 1 &384 &198, which are different from the dimensions of the transformed image

You must be logged in to vote
5 replies
@hankyul2
Comment options

Hi@mae338

I am happy to help you.

  • Iny=model(x.unsqueeze(0)),x.unsqueeze(0) is a transformed image that has shapes as(1, 3, 224, 224).
  • Inmodel.blocks[-1].attn.forward=my_forward_wrapper(model.blocks[-1].attn), we replace the originalforward to ourmy_forward_wrapper to save attention map as an instance variable ofmodel.blocks[-1].attn.
  • Inmodel.blocks[-1].attn.attn_map.mean(dim=1).squeeze(0).detach(), we average attention map in head-dimension.mean(dim=1) so that we can display whole attention.
  • To understand better, we recommend you print out every shape of the tensor, which helps you to see the overall workflow of ViT.

Thank you.

Hankyul

@mae338
Comment options

Thank you so much, Hankyul. I appreciate your effort to help me. I'd be pleased if you could help me understand this as well:
What about def my_forward(x):
B, N, C = x.shape
qkv = attn_obj.qkv(x)?
more especifically, the x parameter? what does it refer to?
and why did you create the function my_forward?
I also found out that the function my_forward must return x or else it doesn't work and the whole program won't work. I wonder why?

@hankyul2
Comment options

Hi@mae338

I hope this could help you.

  • We replace the original Attentionforward() method tomy_forward() method that only inserts two extra codes for saving attention map (attn_obj.attn_map = attn,attn_obj.cls_attn_map = attn[:, :, 0, 2:]). Thus, everything (function, signature, and return value) should be the same as the original function except for additional code.
  • x is image tokens (Batch x N x D) input to attention. Image tokens are tokenized by a linear layer, and their shape depends on the model's size. In our case, DeiT-small/16 splits the whole image into 196 (N=198, with 2 extra class, distill tokens) patches and has 384 (D=384) channel dimensions.
  • Since the attention block passes the output to the next block such as the MLP block,my_forward() should also pass the return value. If you skip passing the return value to the next block, the next blocks get aNone value, which is an unexpected situation for them, thereby generating an error.

Thank you.

Hankyul

@mae338
Comment options

Hi@hankyul2,

I’d like to view the outputs of a pretrained vit model (vit-base-batch-16 224) , especially the input to the mlp head. I tried the same function but it didn’t work. Any suggestions?
Thank you
mae

@hankyul2
Comment options

Hi@mae338

I hope this can help you.

You can extract the ultimate features of a pre-trained ViT byy = model.forward_head(x, pre_logits=True) and visualize them for your purpose. Since it was a long time ago when I upload initial code in first comment, I copied and modified them as:

# dependency!wgethttps://user-images.githubusercontent.com/31476895/167238573-b0cc3a6d-d3ee-462b-8630-a8f253e69bb2.png!pipinstall-Uqfastaitimm==0.6.13huggingface_hub########################################### codeimportnumpyasnpfromPILimportImagefromtimm.modelsimportcreate_modelfromtorchimportnnfromtorchvision.transformsimportCompose,Resize,CenterCrop,Normalize,ToTensordefto_tensor(img):transform_fn=Compose([Resize(249),CenterCrop(224),ToTensor(),Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])returntransform_fn(img)img=Image.open('167238573-b0cc3a6d-d3ee-462b-8630-a8f253e69bb2.png').convert('RGB')x=to_tensor(img)model=create_model('deit_small_distilled_patch16_224',pretrained=True)x=model.forward_features(x.unsqueeze(0))y=model.forward_head(x,pre_logits=True)print(y.shape)########################################### outputtorch.Size([1,384])##########################################

Thank you.

Hankyul

Comment options

@hankyul2

I would like to apply this code to the 'vit_small_patch16_384' model from timm. How should I modify the code for this purpose?
(I understand that '224' in the given code refers to the image size, but how is '14' determined?)
I apologize if this is due to my lack of knowledge.

attn_map = model.blocks[-1].attn.attn_map.mean(dim=1).squeeze(0).detach()cls_weight = model.blocks[-1].attn.cls_attn_map.mean(dim=1).view(14, 14).detach()img_resized = x.permute(1, 2, 0) * 0.5 + 0.5cls_resized = F.interpolate(cls_weight.view(1, 1, 14, 14), (224, 224), mode='bilinear').view(224, 224, 1)
You must be logged in to vote
3 replies
@hankyul2
Comment options

Hi@tomos7231

  1. 384 means input image resolution to ViT model. If you want to extract an attention map using the code above, you should change the input resolution (224 ->384) and the dimension of patches (14x14 ->24x24) like below code.
attn_map=model.blocks[-1].attn.attn_map.mean(dim=1).squeeze(0).detach()cls_weight=model.blocks[-1].attn.cls_attn_map.mean(dim=1).view(24,24).detach()# (14->24)img_resized=x.permute(1,2,0)*0.5+0.5cls_resized=F.interpolate(cls_weight.view(1,1,24,24), (384,384),mode='bilinear').view(384,384,1)# (14->24), (224->384)
  1. 14 means the number of patches in each spatial dimension, e.g., H, W of the feature map. This value (14) is determined by patch size (16) because each image is divided by 196 patches, each size is16x16.
  2. If you just want to extract an attention map, I would recommend@rwightman's solution, which is more convenient.

Thank you.

Hankyul

@tomos7231
Comment options

Hi@hankyul2

Well understood. Thank you for reply!

@hibiki-iwanaga
Comment options

@hankyul2
Thank you for such an insightful explanation above.

  1. I also tried to apply this code to the 'vit_small_patch16_384' model as shown above, but I encountered the following error.
RuntimeError                              Traceback (most recent call last)Cell In[26], line 24     20 # y = model(image)     23 attn_map = model.blocks[-1].attn.attn_map.mean(dim=1).squeeze(0).detach()---> 24 cls_weight = model.blocks[-1].attn.cls_attn_map.mean(dim=1).view(24, 24).detach()     26 img_resized = image2.permute(0, 2, 3, 1) * 0.5 + 0.5     27 cls_resized = F.interpolate(cls_weight.view(1, 1, 24, 24), (384, 384), mode='bilinear').view(384, 384, 1)RuntimeError: shape '[24, 24]' is invalid for input of size 575

2.If I want to apply this model to a regression problem and estimate multiple parameters with a single model, is it possible to check the attention map for each parameter? In a multi-class classification problem, we can check the attention map when a certain class is classified to see which areas are being focused on to recognize that class. However, in the case of a regression problem, how should this be done?

Comment options

So, this doesn't include the visualization helpers yet, but have added a simpler extraction helper to get the attention activations via one of two methods, fx or hooks.

WIP but can be seenhttps://github.com/huggingface/pytorch-image-models/pull/2168/files#diff-358e0d5feb2c109ff53d21bc4fa8a6af94566be622b0f1167316216b0036b8b3

import timmimport torchfrom timm.utils import AttentionExtracttimm.layers.set_fused_attn(False)mm = timm.create_model('vit_base_patch16_224')input = torch.randn(2,3,224,224)ee = AttentionExtract(mm, method='fx')oo = ee(input)
for n, t in oo.items():    print(n, t.shape)blocks.0.attn.softmax torch.Size([2, 12, 197, 197])blocks.1.attn.softmax torch.Size([2, 12, 197, 197])blocks.2.attn.softmax torch.Size([2, 12, 197, 197])blocks.3.attn.softmax torch.Size([2, 12, 197, 197])blocks.4.attn.softmax torch.Size([2, 12, 197, 197])blocks.5.attn.softmax torch.Size([2, 12, 197, 197])blocks.6.attn.softmax torch.Size([2, 12, 197, 197])blocks.7.attn.softmax torch.Size([2, 12, 197, 197])blocks.8.attn.softmax torch.Size([2, 12, 197, 197])blocks.9.attn.softmax torch.Size([2, 12, 197, 197])blocks.10.attn.softmax torch.Size([2, 12, 197, 197])blocks.11.attn.softmax torch.Size([2, 12, 197, 197])
You must be logged in to vote
10 replies
@Astroboy-01
Comment options

Hello@rwightman, thank you for showing us how to extract the attention layers and maintaining your wonderful timm library. I would like to ask, I am using fast.ai alongside with timm models. I've trained a ViT for a classification task in my own dataset. What would be the best way to load the weights of my ViT and visualize the attention activations over my input image using the timm visualization helpers?, thank you

@SarthakJShetty
Comment options

I have a bit of a trivial question here and slightly off-topic to what@rwightman discussed, but related to the attention map that@hankyul2 posted in their initial visualization:

I'm assuming that the ViT in OPs question is being trained for a classification task. Is it necessary that the attention map must contain high activations along the diagonal, similar to the attention maps generated while training seq2seq models? This seems quite unintuitive to me, as I wouldn't expect a classification ViT to attend to self-attend to patches in such a manner. I would instead expect every patch to attend to certain "interesting" parts of the image, almost like having certain columns in the attention map with high activations.

Is my understanding incorrect, or should we expect the attention map to have high activations along the diagonal?

Thank you in advance!

@arnavsm
Comment options

@SarthakJShetty You can see the attention maps here after every block for DeiT, in the initial layers you can see the clear diagonal and how the attention maps then change. Please check out mynotebook. It contains visualizations for attention rollout but also without. I hope it can be of help.

My assumption, or rather intuition, behind these diagonal patterns is centered around the class token, which I believe is the reason we don't see information flow immediately, as the class token takes in information from the tokens without much change. However, this is what I think for shallow layers; I only believe that in the later and deeper layers, the attention starts to play out.

attn-maps-all-layers

@SarthakJShetty
Comment options

Thank you for the clarification @arnavs04! This helps. Now that I think about it, the visualizations that I'm observing are almost always like Attention Map 5 and beyond.

Quick question: when you say "the attention starts to play out.", you mean when the queries actually start attending to relevant parts (and not just predominately self-attending in a diagonal fashion) of the image correct? i.e when the attention actually starts looking non-diagonal and like the attention maps 5-12?

@arnavsm
Comment options

@SarthakJShetty
One of the major reasons why vision transformers (ViTs) are thought to be “better” than CNNs is their ability to share global contextual information right from the beginning (i.e., in the shallow layers). Unlike CNNs, which progress from local to global information, ViTs can, in theory, access global information from the outset. However, attention maps show that not much global information is actually shared in the early stages. This has led to many works proposing modifications to optimize the self-attention mechanism in vision transformers, effectively “tuning” them similarly to CNNs.

In the diagram, you can see how information moves gradually from local to global, resembling the behaviour of a CNN (hence the diagonal structure slowly transitioning to a more globally attended pattern). This example might not be fully representative, as there is usually some global information shared even in the early stages of a vision transformer.

Here is a paper that maybe clear your doubts:Do Vision Transformers See Like Convolutional Neural Networks. I haven't gone through it completely as I had kept it on my reading list.

I am currently working on Vision Transformers. If you have any questions or would like to discuss them with me, let me know! I'll be glad to help!

Comment options

FYI there's a fix on main for the node/module matching so that outputs will remain in order of traversal (usually matches order of forward pass, at least for timm models) regardless of how many matching names/wildcards are specified.

You must be logged in to vote
0 replies
Comment options

Hi!@hankyul2 Thanks for your excellent explanation above.

I understood most of it but was still confused aboutattn_obj.cls_attn_map = attn[:, :, 0, 2:].

Whycls_attn_map is extracted based on the dimension index ofattn[:, :, 0, 2:]?

Thanks!

You must be logged in to vote
5 replies
@arnavsm
Comment options

This is deit architecture which has both a class token and a distillation token for prediction.
When we do index0 we are looking at the similarity scores of all other tokens wrt to class token. And now when we do2: it means, that we're only looking at the similarity tokens of all the patch tokens wrt class token (excluding the class token and distillation token itself)

@zichunxx
Comment options

Thanks for your response @arnavs04! The similarity scores you mentioned are the dot product of two vectors from the query and key matrices. Is that right?

@arnavsm
Comment options

@zichunxx Yup exactly! The (n+2) x (n+2) attention matrix.

@zichunxx
Comment options

Thanks for your generous help! @arnavs04 I have read your notebook which is very thorough and helpful.

I have noticed that some vision transformers are implemented as an encoder without the cls token. In this situation, how do we plot the overlaid image to illustrate which patch is watched with a higher weight? Thanks!

@arnavsm
Comment options

I apologize for the delay, I didn't see the reply.

The attribution map is resized with bilinear interpolation to fit the H x W resolution as the original map. This heatmap is now taken and with 0.5 x heat_map + 0.5 x original_image we get our saliency map. Obviously you can tweak the values instead of 0.5 and 0.5 respectively. This is the goal of post-hoc model agnostic explainability methods for vision transformers.

Comment options

https://github.com/facebookresearch/dino/blob/main/visualize_attention.py

this might be helpful

You must be logged in to vote
0 replies
Comment options

https://huggingface.co/spaces/timm/timmAttentionViz

You must be logged in to vote
1 reply
@hankyul2
Comment options

Awesome

Comment options

Thank you for your letter, I will reply ASAP.BWTianwen Zhou
You must be logged in to vote
0 replies
Comment options

Thank you for your letter, I will reply ASAP.BWTianwen Zhou
You must be logged in to vote
0 replies
Comment options

image

Does directly resizing a 14×14 attention map to 224×224 make sense? I’ve seen this approach used frequently in Visualizing attention rollouts for ViTs, but I’m trying to understand what the resized attention values actually represent. Since the original 14×14 map corresponds to attention over 16×16 patches of a 224×224 image, does interpolating the attention values introduce artifacts or distort their meaning? Would it make more sense to directly map each value of 14x14 attention rollout matrix to a fixed 16×16 region instead, hence getting a 224x224 mask shape? Additionally, I couldn’t find a definitive source explaining why interpolating attention maps is standard practice—is this done purely for visualization, or is there a theoretical justification behind it? Any insights would be greatly appreciated. Thanks in advance :))

You must be logged in to vote
0 replies
Comment options

Thank you for your letter, I will reply ASAP.BWTianwen Zhou
You must be logged in to vote
0 replies
Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment
Category
Q&A
Labels
None yet
14 participants
@kiashann@rwightman@SarthakJShetty@hankyul2@KishoreP1@Astroboy-01@zichunxx@tomos7231@TianwenZhou@arnavsm@anika81199@mae338@hibiki-iwanaga@buenyamink

[8]ページ先頭

©2009-2025 Movatter.jp