- Notifications
You must be signed in to change notification settings - Fork3
[CVPR 2024] The official pytorch implementation of "A General and Efficient Training for Transformer via Token Expansion".
License
Osilly/TokenExpansion
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
The official pytorch implementation of Token Expansion (ToE).
2024.4.17
: Add the example of ToE applying for YOLOS on object detection.2023.11.26
: Release code of ToE.
🔥 : We find that ToE has a significant effect for PEFT methods. For example, ToE can achieve2x training speed and30% GPU memory saving for the training ofVisual Prompt Tuning (VPT), while improving accuracy by 2%-6% on VTAB-1k benchmark. We will release the code soon. Stay tuned!
The ''initialization-expansion-merging'' pipeline of proposed ToE. We take the 1st training stage (
torch>=1.12.0torchvision>=0.13.0timm==0.9.2
We provide the main codetoken_select.py
. It can be seamlessly integrated into the training of ViTs.
You can find the examples of applying ToE to popular ViTs (e.g., DeiT inToE/deit
and LV-ViT inToE/lvvit
) and existing efficient training frameworks (e.g., EfficientTrain inToE/EfficientTrain
).
It is simple to change the existing codes, and the codes for the changes we make to the original model codes are wrapped in two# ---------------------#
.
For example (ToE/deit/main.py
):
model = create_model( args.model, pretrained=False, num_classes=args.nb_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=None, img_size=args.input_size, ) # ---------------------# model.token_select = TokenSelect( expansion_step=args.expansion_step, keep_rate=args.keep_rate, initialization_keep_rate=args.initialization_keep_rate, expansion_multiple_stage=args.expansion_multiple_stage, distance=args.distance, ) # ---------------------#
Download and extract ImageNet images fromhttp://image-net.org/. The directory structure should be:
│ILSVRC2012/├──train/│ ├── n01440764│ │ ├── n01440764_10026.JPEG│ │ ├── n01440764_10027.JPEG│ │ ├── ......│ ├── ......├──val/│ ├── n01440764│ │ ├── ILSVRC2012_val_00000293.JPEG│ │ ├── ILSVRC2012_val_00002138.JPEG│ │ ├── ......│ ├── ......
Download and extract COCO 2017 train and val images with annotations fromhttp://cocodataset.org. We expect the directory structure to be the following:
path/to/coco/ annotations/ # annotation json files train2017/ # train images val2017/ # val images
Followingoriginal LV-ViT repo to prepare.
We provide NFNet-F6 generated dense label map inGoogle Drive andBaiDu Yun (password: y6j2). As NFNet-F6 are based on pure ImageNet data, no extra training data is involved.
Followingoriginal YOLOS repo to prepare.
[FB] Pretrained model for YOLOS-T
[FB] Pretrained model for YOLOS-S
The detailed training scripts are presents in the specific code paths (e.g., DeiT in inToE/deit/run/
). You should prepare the environments and datasets.
We take a few simple training examples:
We train the DeiT-small on four GPUs, the ImageNet-1K dataset is required.
cd ToE/deitbash run/imagenet_small_run1.sh
run/imagenet_small_run1.sh:
result_dir=[your_result_path]dataset_dir=[your_dataset_path]device=0,1,2,3master_port=6666CUDA_VISIBLE_DEVICES=$device torchrun --nproc_per_node=4 --master_port=$master_port main.py \--patch-size 16 \--model deit_small_patch16_224 \--batch-size 256 \--data-path $dataset_dir \--output_dir $result_dir \--num_workers 8 \--seed 3407 \--expansion-step 0 100 200 \--keep-rate 0.5 0.75 1.0 \--initialization-keep-rate 0.25 \--expansion-multiple-stage 2 \
We train the LV-ViT-S on four GPUs, the ImageNet-1K dataset and the label data (seeoriginal LV-ViT repo) are required.
cd ToE/lvvitbash run/imagenet_small_run1.sh
run/imagenet_small_run1.sh:
#!/bin/bashresult_dir=[your_result_path]dataset_dir=[your_dataset_path]label_dir=[your_token_labeling_dataset_path]device=0,1,2,3master_port=6666shiftCUDA_VISIBLE_DEVICES=$device python3 -m torch.distributed.launch --nproc_per_node=4 main.py "$@" $dataset_dir \--output $result_dir \--model lvvit_s \-b 256 \--img-size 224 \--drop-path 0.1 \--token-label \--token-label-data $label_dir \--token-label-size 14 \--model-ema \--apex-amp \--expansion-step 0 100 200 \--keep-rate 0.4 0.7 1.0 \--initialization-keep-rate 0.2 \--expansion-multiple-stage 2 \
We train the EfficientTrain (DeiT-small) on eight GPUs, the ImageNet-1K dataset is required.
cd ToE/EfficientTrainbash run/imagenet_small_run1.sh
run/imagenet_small_run1.sh:
result_dir=[your_result_path]dataset_dir=[your_dataset_path]CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python ET_training.py \--data_path $dataset_dir \--output_dir $result_dir \--model deit_small_patch16_224 \--final_bs 256 --epochs 300 \--num_gpus 8 --num_workers 8 \
The codes of speedup factors of ToE are presents inToE/EfficientTrain/ET_training.py
. For example:
"deit_tiny_patch16_224": " --use_amp true --clip_grad 5.0 \--expansion-step 0 100 200 --keep-rate 0.6 0.8 1.0 \--initialization-keep-rate 0.3 --expansion-multiple-stage 2 "
We train the YOLOS-S on eight GPUs, the COCO dataset and the pertained model (seeoriginal YOLOS repo) are required.
cd ToE/YOLOSbash run/coco_small_run1.sh
run/coco_small_run1.sh:
result_dir=[your_result_path]dataset_dir=[your_dataset_path]pretrain_path=[your_pretrain_model_path]device=0,1,2,3,4,5,6,7master_port=6666CUDA_VISIBLE_DEVICES=$device python -m torch.distributed.launch --nproc_per_node=8 --master_port=$master_port --use_env main.py \ --coco_path $dataset_dir \ --batch_size 1 \ --lr 2.5e-5 \ --epochs 150 \ --backbone_name small \ --pre_trained $pretrain_path \ --eval_size 800 \ --init_pe_size 512 864 \ --mid_pe_size 512 864 \ --output_dir $result_dir \ --num_workers 8 \ --expansion-step 5 50 100 \ --keep-rate 0.5 0.75 1.0 \ --initialization-keep-rate 0.25 \ --expansion-multiple-stage 2 \
This project is based onDeiT,LV-ViT,EfficientTrain,YOLOS andtimm. Thanks for their wonderful works.
If you find this work useful in your research, please consider citing:
@article{huang2024general, title={A General and Efficient Training for Transformer via Token Expansion}, author={Huang, Wenxuan and Shen, Yunhang and Xie, Jiao and Zhang, Baochang and He, Gaoqi and Li, Ke and Sun, Xing and Lin, Shaohui}, journal={CVPR}, year={2024}}