Uh oh!
There was an error while loading.Please reload this page.
- Notifications
You must be signed in to change notification settings - Fork5k
Description
This feature request is related to challenges in Masked Image Modeling (MIM) pre-training using vision transformer models intimm
. Currently, embedding and feature extraction are tightly coupled withinforward_features
, making it difficult to inject mask operations after initial embedding and positional encoding but before the transformer stages, which is a common MIM requirement. Researchers need to access embedded tokens for masking before passing them through subsequent transformer layers.
Describe the solution you'd like
I propose a refactoring of all vision transformer models(e.g. vit, swin_transformer, etc.) to uniformly expose two distinct interfaces:
embed(self, x)
: This method should take the input tensor x (e.g., image patches) and return the embedded vectors, including positional encodings if applicable. The output should be ready for the transformer encoder stages.forward_stages(self, x)
: This method should take the output from embed(x) (i.e., the embedded and position-encoded tokens) and pass them through the transformer encoder layers.
This separation would make the existingforward_features(x)
effectively equivalent toforward_stages(embed(x))
. This allows researchers to easily perform mask operations on the embedded tokens returned byembed(x)
before passing them toforward_stages(x)
, enabling flexible MIM pre-training experiments.
Describe alternatives you've considered
I have considered alternative solutions, such as adding a mask parameter directly to theforward_features
andforward
methods, similar tovision_transformer
's implementation.
pytorch-image-models/timm/models/vision_transformer.py
Lines 933 to 935 ina7c5368
defforward_features(self,x:torch.Tensor,attn_mask:Optional[torch.Tensor]=None)->torch.Tensor: | |
"""Forward pass through feature layers (embeddings, transformer blocks, post-transformer norm).""" |
While seemingly straightforward, this approach presents drawbacks:
- Function Signature: It still alters the method's interface, even if default parameter values mitigate direct breakage.
- Mask Value Flexibility: More importantly, it would limit flexibility in how masked-out positions are handled, restricting whether masked token values can be learned (e.g., a learnable mask token) or simply set to zero. Separating
embed
andforward_stages
provides full control.
Additional context
If this refactoring aligns with the library's design, I would gladly contribute a Pull Request to implement these changes across relevant vision transformer models.