Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit94e13b7

Browse files
authored
Merge pull request#2529 from huggingface/rope_vit
Adding Naver rope-vit compatibility to EVA ViT
2 parents8d41071 +cec7290 commit94e13b7

16 files changed

+873
-167
lines changed

‎README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ All model architecture families include variants with pretrained weights. There
508508
* Res2Net -https://arxiv.org/abs/1904.01169
509509
* ResNeSt -https://arxiv.org/abs/2004.08955
510510
* ReXNet -https://arxiv.org/abs/2007.00992
511+
* ROPE-ViT -https://arxiv.org/abs/2403.13298
511512
* SelecSLS -https://arxiv.org/abs/1907.00837
512513
* Selective Kernel Networks -https://arxiv.org/abs/1903.06586
513514
* Sequencer2D -https://arxiv.org/abs/2205.01972

‎timm/layers/__init__.py

Lines changed: 86 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,40 @@
1+
from ._fximport (
2+
create_feature_extractor,
3+
get_graph_node_names,
4+
register_notrace_function,
5+
register_notrace_module,
6+
is_notrace_module,
7+
is_notrace_function,
8+
get_notrace_modules,
9+
get_notrace_functions,
10+
)
111
from .activationsimport*
2-
from .adaptive_avgmax_poolimport \
3-
adaptive_avgmax_pool2d,select_adaptive_pool2d,AdaptiveAvgMaxPool2d,SelectAdaptivePool2d
12+
from .adaptive_avgmax_poolimport (
13+
adaptive_avgmax_pool2d,
14+
select_adaptive_pool2d,
15+
AdaptiveAvgMaxPool2d,
16+
SelectAdaptivePool2d,
17+
)
418
from .attentionimportAttention,AttentionRope,maybe_add_mask
519
from .attention2dimportMultiQueryAttention2d,Attention2d,MultiQueryAttentionV2
620
from .attention_poolimportAttentionPoolLatent
721
from .attention_pool2dimportAttentionPool2d,RotAttentionPool2d,RotaryEmbedding
822
from .blur_poolimportBlurPool2d,create_aa
923
from .classifierimportcreate_classifier,ClassifierHead,NormMlpClassifierHead,ClNormMlpClassifierHead
1024
from .cond_conv2dimportCondConv2d,get_condconv_initializer
11-
from .configimportis_exportable,is_scriptable,is_no_jit,use_fused_attn, \
12-
set_exportable,set_scriptable,set_no_jit,set_layer_config,set_fused_attn, \
13-
set_reentrant_ckpt,use_reentrant_ckpt
25+
from .configimport (
26+
is_exportable,
27+
is_scriptable,
28+
is_no_jit,
29+
use_fused_attn,
30+
set_exportable,
31+
set_scriptable,
32+
set_no_jit,
33+
set_layer_config,
34+
set_fused_attn,
35+
set_reentrant_ckpt,
36+
use_reentrant_ckpt,
37+
)
1438
from .conv2d_sameimportConv2dSame,conv2d_same
1539
from .conv_bn_actimportConvNormAct,ConvNormActAa,ConvBnAct
1640
from .create_actimportcreate_act_layer,get_act_layer,get_act_fn
@@ -20,8 +44,17 @@
2044
from .create_norm_actimportget_norm_act_layer,create_norm_act_layer,get_norm_act_layer
2145
from .dropimportDropBlock2d,DropPath,drop_block_2d,drop_path
2246
from .ecaimportEcaModule,CecaModule,EfficientChannelAttn,CircularEfficientChannelAttn
23-
from .evo_normimportEvoNorm2dB0,EvoNorm2dB1,EvoNorm2dB2,\
24-
EvoNorm2dS0,EvoNorm2dS0a,EvoNorm2dS1,EvoNorm2dS1a,EvoNorm2dS2,EvoNorm2dS2a
47+
from .evo_normimport (
48+
EvoNorm2dB0,
49+
EvoNorm2dB1,
50+
EvoNorm2dB2,
51+
EvoNorm2dS0,
52+
EvoNorm2dS0a,
53+
EvoNorm2dS1,
54+
EvoNorm2dS1a,
55+
EvoNorm2dS2,
56+
EvoNorm2dS2a,
57+
)
2558
from .fast_normimportis_fast_norm,set_fast_norm,fast_group_norm,fast_layer_norm
2659
from .filter_response_normimportFilterResponseNormTlu2d,FilterResponseNormAct2d
2760
from .formatimportFormat,get_channel_dim,get_spatial_dim,nchw_to,nhwc_to
@@ -37,19 +70,50 @@
3770
from .mlpimportMlp,GluMlp,GatedMlp,SwiGLU,SwiGLUPacked,ConvMlp,GlobalResponseNormMlp
3871
from .non_local_attnimportNonLocalAttn,BatNonLocalAttn
3972
from .normimportGroupNorm,GroupNorm1,LayerNorm,LayerNorm2d,RmsNorm,RmsNorm2d,SimpleNorm,SimpleNorm2d
40-
from .norm_actimportBatchNormAct2d,GroupNormAct,GroupNorm1Act,LayerNormAct,LayerNormAct2d,\
41-
SyncBatchNormAct,convert_sync_batchnorm,FrozenBatchNormAct2d,freeze_batch_norm_2d,unfreeze_batch_norm_2d
73+
from .norm_actimport (
74+
BatchNormAct2d,
75+
GroupNormAct,
76+
GroupNorm1Act,
77+
LayerNormAct,
78+
LayerNormAct2d,
79+
SyncBatchNormAct,
80+
convert_sync_batchnorm,
81+
FrozenBatchNormAct2d,
82+
freeze_batch_norm_2d,
83+
unfreeze_batch_norm_2d,
84+
)
4285
from .paddingimportget_padding,get_same_padding,pad_same
4386
from .patch_dropoutimportPatchDropout
4487
from .patch_embedimportPatchEmbed,PatchEmbedWithSize,PatchEmbedInterpolator,resample_patch_embed
4588
from .pool1dimportglobal_pool_nlc
4689
from .pool2d_sameimportAvgPool2dSame,create_pool2d
4790
from .pos_embedimportresample_abs_pos_embed,resample_abs_pos_embed_nhwc
48-
from .pos_embed_relimportRelPosMlp,RelPosBias,RelPosBiasTf,gen_relative_position_index,gen_relative_log_coords, \
49-
resize_rel_pos_bias_table,resize_rel_pos_bias_table_simple,resize_rel_pos_bias_table_levit
50-
from .pos_embed_sincosimportpixel_freq_bands,freq_bands,build_sincos2d_pos_embed,build_fourier_pos_embed, \
51-
build_rotary_pos_embed,apply_rot_embed,apply_rot_embed_cat,apply_rot_embed_list,apply_keep_indices_nlc, \
52-
FourierEmbed,RotaryEmbedding,RotaryEmbeddingCat
91+
from .pos_embed_relimport (
92+
RelPosMlp,
93+
RelPosBias,
94+
RelPosBiasTf,
95+
gen_relative_position_index,
96+
gen_relative_log_coords,
97+
resize_rel_pos_bias_table,
98+
resize_rel_pos_bias_table_simple,
99+
resize_rel_pos_bias_table_levit,
100+
)
101+
from .pos_embed_sincosimport (
102+
pixel_freq_bands,
103+
freq_bands,
104+
build_sincos2d_pos_embed,
105+
build_fourier_pos_embed,
106+
build_rotary_pos_embed,
107+
apply_rot_embed,
108+
apply_rot_embed_cat,
109+
apply_rot_embed_list,
110+
apply_keep_indices_nlc,
111+
FourierEmbed,
112+
RotaryEmbedding,
113+
RotaryEmbeddingCat,
114+
RotaryEmbeddingMixed,
115+
get_mixed_freqs,
116+
)
53117
from .squeeze_exciteimportSEModule,SqueezeExcite,EffectiveSEModule,EffectiveSqueezeExcite
54118
from .selective_kernelimportSelectiveKernel
55119
from .separable_convimportSeparableConv2d,SeparableConvNormAct
@@ -60,5 +124,11 @@
60124
from .test_time_poolimportTestTimePoolHead,apply_test_time_pool
61125
from .trace_utilsimport_assert,_float_to_int
62126
from .typingimportLayerType,PadType
63-
from .weight_initimporttrunc_normal_,trunc_normal_tf_,variance_scaling_,lecun_normal_, \
64-
init_weight_jax,init_weight_vit
127+
from .weight_initimport (
128+
trunc_normal_,
129+
trunc_normal_tf_,
130+
variance_scaling_,
131+
lecun_normal_,
132+
init_weight_jax,
133+
init_weight_vit,
134+
)

‎timm/layers/_fx.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
fromtypingimportCallable,Dict,List,Optional,Union,Tuple,Type
2+
3+
importtorch
4+
fromtorchimportnn
5+
6+
try:
7+
# NOTE we wrap torchvision fns to use timm leaf / no trace definitions
8+
fromtorchvision.models.feature_extractionimportcreate_feature_extractoras_create_feature_extractor
9+
fromtorchvision.models.feature_extractionimportget_graph_node_namesas_get_graph_node_names
10+
has_fx_feature_extraction=True
11+
exceptImportError:
12+
has_fx_feature_extraction=False
13+
14+
15+
__all__= [
16+
'register_notrace_module',
17+
'is_notrace_module',
18+
'get_notrace_modules',
19+
'register_notrace_function',
20+
'is_notrace_function',
21+
'get_notrace_functions',
22+
'create_feature_extractor',
23+
'get_graph_node_names',
24+
]
25+
26+
# modules to treat as leafs when tracing
27+
_leaf_modules=set()
28+
29+
30+
defregister_notrace_module(module:Type[nn.Module]):
31+
"""
32+
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
33+
"""
34+
_leaf_modules.add(module)
35+
returnmodule
36+
37+
38+
defis_notrace_module(module:Type[nn.Module]):
39+
returnmodulein_leaf_modules
40+
41+
42+
defget_notrace_modules():
43+
returnlist(_leaf_modules)
44+
45+
46+
# Functions we want to autowrap (treat them as leaves)
47+
_autowrap_functions=set()
48+
49+
50+
defregister_notrace_function(name_or_fn):
51+
_autowrap_functions.add(name_or_fn)
52+
returnname_or_fn
53+
54+
55+
defis_notrace_function(func:Callable):
56+
returnfuncin_autowrap_functions
57+
58+
59+
defget_notrace_functions():
60+
returnlist(_autowrap_functions)
61+
62+
63+
defget_graph_node_names(model:nn.Module)->Tuple[List[str],List[str]]:
64+
return_get_graph_node_names(
65+
model,
66+
tracer_kwargs={
67+
'leaf_modules':list(_leaf_modules),
68+
'autowrap_functions':list(_autowrap_functions)
69+
}
70+
)
71+
72+
73+
defcreate_feature_extractor(model:nn.Module,return_nodes:Union[Dict[str,str],List[str]]):
74+
asserthas_fx_feature_extraction,'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
75+
return_create_feature_extractor(
76+
model,return_nodes,
77+
tracer_kwargs={
78+
'leaf_modules':list(_leaf_modules),
79+
'autowrap_functions':list(_autowrap_functions)
80+
}
81+
)

‎timm/layers/attention.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
fromtorchimportnnasnn
55
fromtorch.nnimportfunctionalasF
66

7+
from ._fximportregister_notrace_function
78
from .configimportuse_fused_attn
89
from .pos_embed_sincosimportapply_rot_embed_cat
910

1011

12+
@torch.fx.wrap
13+
@register_notrace_function
1114
defmaybe_add_mask(scores:torch.Tensor,attn_mask:Optional[torch.Tensor]=None):
1215
returnscoresifattn_maskisNoneelsescores+attn_mask
1316

‎timm/layers/attention_pool2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
importtorch
1313
importtorch.nnasnn
1414

15-
from.configimportuse_fused_attn
15+
from .configimportuse_fused_attn
1616
from .helpersimportto_2tuple
1717
from .pos_embedimportresample_abs_pos_embed
1818
from .pos_embed_sincosimportapply_rot_embed,RotaryEmbedding

‎timm/layers/cond_conv2d.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
fromtorchimportnnasnn
1313
fromtorch.nnimportfunctionalasF
1414

15+
from ._fximportregister_notrace_module
1516
from .helpersimportto_2tuple
1617
from .conv2d_sameimportconv2d_same
1718
from .paddingimportget_padding_value
@@ -30,6 +31,7 @@ def condconv_initializer(weight):
3031
returncondconv_initializer
3132

3233

34+
@register_notrace_module
3335
classCondConv2d(nn.Module):
3436
""" Conditionally Parameterized Convolution
3537
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py

‎timm/layers/conv2d_same.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
importtorch.nn.functionalasF
88
fromtypingimportTuple,Optional
99

10+
from ._fximportregister_notrace_module
1011
from .configimportis_exportable,is_scriptable
1112
from .paddingimportpad_same,pad_same_arg,get_padding_value
1213

@@ -27,6 +28,7 @@ def conv2d_same(
2728
returnF.conv2d(x,weight,bias,stride, (0,0),dilation,groups)
2829

2930

31+
@register_notrace_module
3032
classConv2dSame(nn.Conv2d):
3133
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
3234
"""

‎timm/layers/inplace_abn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ def inplace_abn(x, weight, bias, running_mean, running_var,
1515
definplace_abn_sync(**kwargs):
1616
inplace_abn(**kwargs)
1717

18+
from ._fximportregister_notrace_module
1819

20+
21+
@register_notrace_module
1922
classInplaceAbn(nn.Module):
2023
"""Activated Batch Normalization
2124

‎timm/layers/non_local_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
fromtorchimportnn
99
fromtorch.nnimportfunctionalasF
1010

11+
from ._fximportregister_notrace_module
1112
from .conv_bn_actimportConvNormAct
1213
from .helpersimportmake_divisible
1314
from .trace_utilsimport_assert
@@ -69,6 +70,7 @@ def reset_parameters(self):
6970
nn.init.constant_(m.bias,0)
7071

7172

73+
@register_notrace_module
7274
classBilinearAttnTransform(nn.Module):
7375

7476
def__init__(self,in_channels,block_size,groups,act_layer=nn.ReLU,norm_layer=nn.BatchNorm2d):

‎timm/layers/norm_act.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
fromtorch.nnimportfunctionalasF
2020
fromtorchvision.ops.miscimportFrozenBatchNorm2d
2121

22+
from ._fximportregister_notrace_module
2223
from .create_actimportcreate_act_layer
2324
from .fast_normimportis_fast_norm,fast_group_norm,fast_layer_norm,fast_rms_norm,rms_norm2d,fast_rms_norm2d
2425
from .normimportRmsNorm,RmsNorm2d
@@ -39,6 +40,7 @@ def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True):
3940
returnnn.Identity()ifactisNoneelseact
4041

4142

43+
@register_notrace_module
4244
classBatchNormAct2d(nn.BatchNorm2d):
4345
"""BatchNorm + Activation
4446
@@ -134,6 +136,7 @@ def forward(self, x):
134136
returnx
135137

136138

139+
@register_notrace_module
137140
classSyncBatchNormAct(nn.SyncBatchNorm):
138141
# Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)
139142
# This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers
@@ -191,6 +194,7 @@ def convert_sync_batchnorm(module, process_group=None):
191194
returnmodule_output
192195

193196

197+
@register_notrace_module
194198
classFrozenBatchNormAct2d(torch.nn.Module):
195199
"""
196200
BatchNormAct2d where the batch statistics and the affine parameters are fixed

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp