Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Add T2T_ViT#2426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Draft
brianhou0208 wants to merge2 commits intohuggingface:main
base:main
Choose a base branch
Loading
frombrianhou0208:t2t_vit
Draft

Conversation

brianhou0208
Copy link
Contributor

@brianhou0208brianhou0208 commentedJan 22, 2025
edited
Loading

Hi@rwightman this PRresolved#2364 , please check.

Result

test T2T-ViT model and weight on ImageNet val dataset

ModelAcc@1Acc@5FLOPs#GMACs#GParams#M
t2t_vit_771.676090.88602.02610.97554.2557
t2t_vit_1075.150092.80602.64761.28545.8347
t2t_vit_1276.480093.48403.06201.4926.8874
t2t_vit_1481.500095.66608.75264.33421.4658
t2t_vit_1981.932095.744015.66637.786839.0851
t2t_vit_2482.276095.886025.454312.675964.0010
t2t_vit_t_1481.688095.85208.68814.33421.4654
t2t_vit_t_1982.442096.082015.60187.786839.0847
t2t_vit_t_2482.554096.064025.389812.675964.0006
test code
fromtqdmimporttqdmimporttorchfromtorch.utils.dataimportDataLoaderimporttorchvision.datasetsasdatasetsimporttorchvision.transformsastransformsimporttimmfromtimm.dataimportIMAGENET_DEFAULT_MEAN,IMAGENET_DEFAULT_STDfromtimm.utils.metricsimportAverageMeter,accuracydevice=torch.device('cuda:0')if__name__=="__main__":val_dataset=datasets.ImageFolder('./data/val',transforms.Compose([transforms.Resize(int(224/0.9),interpolation=3),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(IMAGENET_DEFAULT_MEAN,IMAGENET_DEFAULT_STD)])    )val_loader=DataLoader(val_dataset,batch_size=256,shuffle=False,num_workers=16,pin_memory=True)fornameintimm.list_models('t2t_vit*'):model=timm.create_model(name,pretrained=True).eval()model.to(device)top1=AverageMeter()top5=AverageMeter()withtorch.no_grad():forimages,targetintqdm(val_loader):images=images.to(device)target=target.to(device)output=model(images)acc1,acc5=accuracy(output,target,topk=(1,5))top1.update(acc1,images.size(0))top5.update(acc5,images.size(0))print(f"Model{name} ACC@1{top1.avg:.4f} ACC@5{top5.avg:.4f}")
output log
100%|██████████████████████████████████████████████| 196/196 [00:39<00:00,  4.92it/s]Model t2t_vit_7 ACC@1 71.6760 ACC@5 90.8860FLOPs 2.0261 GFLOPS / MACs 975.534 MMACs / Params 4.2557 M100%|██████████████████████████████████████████████| 196/196 [00:39<00:00,  4.96it/s]Model t2t_vit_10 ACC@1 75.1500 ACC@5 92.8060FLOPs 2.6476 GFLOPS / MACs 1.2854 GMACs / Params 5.8347 M100%|██████████████████████████████████████████████| 196/196 [00:40<00:00,  4.88it/s]Model t2t_vit_12 ACC@1 76.4800 ACC@5 93.4840FLOPs 3.062 GFLOPS / MACs 1.492 GMACs / Params 6.8874 M100%|██████████████████████████████████████████████| 196/196 [01:08<00:00,  2.87it/s]Model t2t_vit_14 ACC@1 81.5000 ACC@5 95.6660FLOPs 8.7526 GFLOPS / MACs 4.334 GMACs / Params 21.4658 M100%|██████████████████████████████████████████████| 196/196 [01:45<00:00,  1.86it/s]Model t2t_vit_19 ACC@1 81.9320 ACC@5 95.7440FLOPs 15.6663 GFLOPS / MACs 7.7868 GMACs / Params 39.0851 M100%|██████████████████████████████████████████████| 196/196 [02:31<00:00,  1.30it/s]Model t2t_vit_24 ACC@1 82.2760 ACC@5 95.8860FLOPs 25.4543 GFLOPS / MACs 12.6759 GMACs / Params 64.001 M100%|██████████████████████████████████████████████| 196/196 [01:28<00:00,  2.20it/s]Model t2t_vit_t_14 ACC@1 81.6880 ACC@5 95.8520FLOPs 8.6881 GFLOPS / MACs 4.334 GMACs / Params 21.4654 M100%|██████████████████████████████████████████████| 196/196 [02:04<00:00,  1.57it/s]Model t2t_vit_t_19 ACC@1 82.4420 ACC@5 96.0820FLOPs 15.6018 GFLOPS / MACs 7.7868 GMACs / Params 39.0847 M100%|██████████████████████████████████████████████| 196/196 [02:51<00:00,  1.15it/s]Model t2t_vit_t_24 ACC@1 82.5540 ACC@5 96.0640FLOPs 25.3898 GFLOPS / MACs 12.6759 GMACs / Params 64.0006 M
calculate FLOPs/MACs/Params tool

report fromcalflops

fromcalflopsimportcalculate_flopsdefflops_param(model):flops,macs,params=calculate_flops(model=model,input_shape=(1,3,224,224),output_as_string=True,output_precision=4,print_detailed=False,print_results=False    )print(f"FLOPs{flops} / MACs{macs} / Params{params}")

Reference

paper:https://arxiv.org/pdf/2101.11986
code:https://github.com/yitu-opensource/T2T-ViT

@brianhou0208brianhou0208 marked this pull request as draftJanuary 23, 2025 16:06
@rwightman
Copy link
Collaborator

@brianhou0208 thanks for the work, and looks like a good job getting it in shape. I took a closer look using your code but I have some doubts about this model

  1. it requires a workaround w/ AMP + float16 to avoid NaN (see next post)
  2. compared to simpler models it's really not performing better givent the speed, especially comparing thesehttps://huggingface.co/collections/timm/searching-for-better-vit-baselines-663eb74f64f847d2f35a9c19 they are faster and better accuracy at a fraction of the param count and they have fewer macs/activations. Even comparing some models that have been there longer like deit3 (e.g. deit3_medium_patch16_224) they are faster/simpler/smaller than these.

For speed comparisons I disabled F.sdpa in existing vit to be fair. Simpler vits with higher acccuracy (imagenet-1k pretrain also to be fair) are often 30-40% faster.

So not convinced this is worth the add. Was there a particular reason you had interest in the model?

@rwightman
Copy link
Collaborator

defsingle_attn(self,x:torch.Tensor)->torch.Tensor:k,q,v=torch.split(self.kqv(x),self.emb,dim=-1)ifnottorch.jit.is_scripting():withtorch.autocast(device_type=v.device.type,enabled=False):y=self._attn_impl(k,q,v)else:y=self._attn_impl(k,q,v)# skip connectiony=v+self.dp(self.proj(y))# same as token_transformer in T2T layer, use v as skip connectionreturnydef_attn_impl(self,k,q,v):kp,qp=self.prm_exp(k),self.prm_exp(q)# (B, T, m), (B, T, m)D=torch.einsum('bti,bi->bt',qp,kp.sum(dim=1)).unsqueeze(dim=2)# (B, T, m) * (B, m) -> (B, T, 1)kptv=torch.einsum('bin,bim->bnm',v.float(),kp)# (B, emb, m)y=torch.einsum('bti,bni->btn',qp,kptv)/ (D.repeat(1,1,self.emb)+self.epsilon)# (B, T, emb)/Diagreturny

@brianhou0208
Copy link
ContributorAuthor

Hi@rwightman, I agree with your observation. The T2T-ViT model does not have advantages over other models. The only advantage might be that it does not use anynn.Conv2d at all, relying instead on thenn.Unfold method to extract patches.
Most ViT-based models require some form of convolution for input processing, but the T2T-ViT architecture can completely bypass convolution, maybe this architecture can be further explored...

Another issue occurs when using pre-trained weights and testing whether the structure of first_conv is adaptive to the number of input (C, H, W). Iffirst_conv is set to None, thetest_model_default_cfgs_non_std test will fail.

first_conv=cfg['first_conv']
ifisinstance(first_conv,str):
first_conv= (first_conv,)
assertisinstance(first_conv, (tuple,list))
forfcinfirst_conv:
assertfc+".weight"instate_dict.keys(),f'{fc} not in model params'

Intest_model_load_pretrained , iffirst_convd is like T2T-ViT without Conv, passing this parameter tonn.Linear instead ofnn.Conv2d will also report an error.
input_convs=pretrained_cfg.get('first_conv',None)
ifinput_convsisnotNoneandin_chans!=3:
ifisinstance(input_convs,str):
input_convs= (input_convs,)
forinput_conv_nameininput_convs:
weight_name=input_conv_name+'.weight'
try:
state_dict[weight_name]=adapt_input_conv(in_chans,state_dict[weight_name])
_logger.info(
f'Converted input conv{input_conv_name} pretrained weights from 3 to{in_chans} channel(s)')
exceptNotImplementedErrorase:
delstate_dict[weight_name]
strict=False
_logger.warning(
f'Unable to convert pretrained{input_conv_name} weights, using random init for this layer.')

Since this involves modifyingtest_models, and adding T2T-ViT is not worth the effort, I should probably close this PR.

@rwightman
Copy link
Collaborator

rwightman commentedJan 24, 2025
edited
Loading

@brianhou0208 I don't know if not having the input conv is a 'feature', my very first vit impl here, before the official JAX code was released that used the Conv2D trick was this:

classPatchEmbed(nn.Module):
""" Image to Patch Embedding
Unfold image into fixed size patches, flatten into seq, project to embedding dim.
"""
def__init__(self,img_size=224,patch_size=16,in_chans=3,embed_dim=768,flatten_channels_last=False):
super().__init__()
img_size=to_2tuple(img_size)
patch_size=to_2tuple(patch_size)
assertimg_size[0]%patch_size[0]==0,'image height must be divisible by the patch height'
assertimg_size[1]%patch_size[1]==0,'image width must be divisible by the patch width'
num_patches= (img_size[1]//patch_size[1])* (img_size[0]//patch_size[0])
patch_dim=in_chans*patch_size[0]*patch_size[1]
self.img_size=img_size
self.patch_size=patch_size
self.flatten_channels_last=flatten_channels_last
self.num_patches=num_patches
self.proj=nn.Linear(patch_dim,embed_dim)
defforward(self,x):
B,C,H,W=x.shape
Ph,Pw=self.patch_size
assertH==self.img_size[0]andW==self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
ifself.flatten_channels_last:
# flatten patches with channels last like the paper (likely using TF)
x=x.unfold(2,Ph,Ph).unfold(3,Pw,Pw).permute(0,2,3,4,5,1).reshape(B,-1,Ph*Pw*C)
else:
x=x.permute(0,2,3,1).unfold(1,Ph,Ph).unfold(2,Pw,Pw).reshape(B,-1,C*Ph*Pw)
x=self.proj(x)
returnx

The conv approach was faster since it was an optimized kernel and not a chain of API calls, I suppose torch.compile would rectify most of that but still don't see the downside to the conv.

Also the packed vit I started working on (have yet to pick it back up) has to push patchification further into the data pipeline,https://github.com/huggingface/pytorch-image-models/blob/379780bb6ca3304d63bf8ca789d5bbce5949d0b5/timm/models/vision_transformer_packed.py

Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment
Reviewers
No reviews
Assignees
No one assigned
Labels
None yet
Projects
None yet
Milestone
No milestone
Development

Successfully merging this pull request may close these issues.

[FEATURE] add t2t_vit
2 participants
@brianhou0208@rwightman

[8]ページ先頭

©2009-2025 Movatter.jp