- Notifications
You must be signed in to change notification settings - Fork0
[ICLR 2025] The official pytorch implement of "Dynamic-LLaVA: Efficient Multimodal Large Language Models via Dynamic Vision-language Context Sparsification".
License
Osilly/dynamic_llava
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
The official pytorch implement of "Dynamic-LLaVA: Efficient Multimodal Large Language Models via Dynamic Vision-language Context Sparsification".
Keep your workspace path is in the code, and then:
conda create -n llava python=3.10 -yconda activate dynamic_llavapip install --upgrade pip # enable PEP 660 supportpip install -e .pip install -e ".[train]"pip install flash-attn --no-build-isolation
Dynamic-LLaVA is trained on 8 A100 GPUs with 80GB memory. To train on fewer GPUs, you can reduce theper_device_train_batch_size
and increase thegradient_accumulation_steps
accordingly. Always keep the global batch size the same:per_device_train_batch_size
xgradient_accumulation_steps
xnum_gpus
.
Please download the annotation of the final mixture our instruction tuning datallava_v1_5_mix665k.json, and download the images from constituting datasets:
- COCO:train2017
- GQA:images
- OCR-VQA:download script,we save all files as
.jpg
- TextVQA:train_val_images
- VisualGenome:part1,part2
After downloading all of them, organize the data as follows in./playground/data
,
├── coco│ └── train2017├── gqa│ └── images├── ocr_vqa│ └── images├── textvqa│ └── train_images└── vg ├── VG_100K └── VG_100K_2
To Dynamic-LLaVA, you can get the base checkpoints in[LLaVA-1.5-7B] and[LLaVA-1.5-13B] for training Dynamic-LLaVA-7B and Dynamic-LLaVA-13B, respectively.
We provide the training scripts for Dynamic-LLaVA-7B and Dynamic-LLaVA-13B, while you can find inrun
.
For training Dynamic-LLaVA-7B, you can directly conduct the shellrun/train_dynamic_llava_7b.sh
, the detailed command is as follows:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash run/train_dynamic_llava_7b.sh
The details ofrun/train_dynamic_llava_7b.sh
are as follows:
#!/bin/bashdeepspeed llava/train/train_sparse.py \ --deepspeed ./scripts/zero3.json \ --model_name_or_path [llava-v1.5-7b]\# your open-resource checkpoint path --version v1 \ --data_path [./playground/data/llava_v1_5_mix665k.json]\# your instruct-following dataset --image_folder [./playground/data]\# your instruct-following dataset --vision_tower openai/clip-vit-large-patch14-336 \ --mm_projector_type mlp2x_gelu \ --mm_vision_select_layer -2 \ --mm_use_im_start_end False \ --mm_use_im_patch_token False \ --image_aspect_ratio pad \ --group_by_modality_length False \ --requires_image True \ --bf16 True \ --output_dir ./results/dynamic-llava-7b \ --num_train_epochs 1.0 \ --per_device_train_batch_size 8 \ --per_device_eval_batch_size 4 \ --gradient_accumulation_steps 1 \ --evaluation_strategy"no" \ --save_strategy"steps" \ --save_steps 40000 \ --save_total_limit 1 \ --learning_rate 5e-6 \ --weight_decay 0. \ --predictor_lr 2e-4 \ --predictor_weight_decay 0. \ --warmup_ratio 0.03 \ --lr_scheduler_type"cosine" \ --logging_steps 1 \ --tf32 True \ --model_max_length 2048 \ --gradient_checkpointing True \ --dataloader_num_workers 4 \ --lazy_preprocess True \ --report_to wandb \ --mask_loss_weight 100.0 \ --gumbel_start_tau 1.0 \ --gumbel_end_tau 0.1 \ --use_vision_predictor True \ --use_text_predictor True \ --use_output_text_predictor True \ --use_instruct_predictor False \ --vision_keep_rate 0.2 \ --output_text_keep_rate 0.5 \ --output_text_len_for_training 50 \
For training Dynamic-LLaVA-13B, you can directly conduct the shellrun/train_dynamic_llava_13b.sh
, the detailed command is as follows:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash run/train_dynamic_llava_13b.sh
The details ofrun/train_dynamic_llava_13b.sh
are as follows:
#!/bin/bashdeepspeed llava/train/train_sparse.py \ --deepspeed ./scripts/zero3.json \ --model_name_or_path [llava-v1.5-13b]\# your open-resource checkpoint path --version v1 \ --data_path [./playground/data/llava_v1_5_mix665k.json]\# your instruct-following dataset --image_folder [./playground/data]\# your instruct-following dataset --vision_tower openai/clip-vit-large-patch14-336 \ --mm_projector_type mlp2x_gelu \ --mm_vision_select_layer -2 \ --mm_use_im_start_end False \ --mm_use_im_patch_token False \ --image_aspect_ratio pad \ --group_by_modality_length False \ --requires_image True \ --bf16 True \ --output_dir ./results/dynamic-llava-13b \ --num_train_epochs 1.0 \ --per_device_train_batch_size 8 \ --per_device_eval_batch_size 4 \ --gradient_accumulation_steps 1 \ --evaluation_strategy"no" \ --save_strategy"steps" \ --save_steps 40000 \ --save_total_limit 1 \ --learning_rate 5e-6 \ --weight_decay 0. \ --predictor_lr 2e-4 \ --predictor_weight_decay 0. \ --warmup_ratio 0.03 \ --lr_scheduler_type"cosine" \ --logging_steps 1 \ --tf32 True \ --model_max_length 2048 \ --gradient_checkpointing True \ --dataloader_num_workers 4 \ --lazy_preprocess True \ --report_to wandb \ --mask_loss_weight 100.0 \ --gumbel_start_tau 1.0 \ --gumbel_end_tau 0.1 \ --use_vision_predictor True \ --use_text_predictor True \ --use_output_text_predictor True \ --use_instruct_predictor False \ --vision_keep_rate 0.2 \ --output_text_keep_rate 0.5 \ --output_text_len_for_training 50 \
We provide the evaluation scripts to evaluate the benchmarks.
For evaluate Dynamic-LLaVA-7B in VQAv2 benchmark, you can directly conduct the shellrun/dynamic_eval/eval_for_vqav2.sh
, the detailed command is as follows:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash run/dynamic_eval/eval_for_vqav2.sh
The details ofrun/dynamic_eval/eval_for_vqav2.sh
are as follows:
#!/bin/bashgpu_list="${CUDA_VISIBLE_DEVICES:-0}"IFS=','read -ra GPULIST<<<"$gpu_list"CHUNKS=${#GPULIST[@]}CKPT="dynamic-llava-7b"SPLIT="llava_vqav2_mscoco_test-dev2015"forIDXin$(seq 0$((CHUNKS-1)));do CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava.dynamic_eval.model_vqa_loader \ --model-path [./results/dynamic-llava-7b]\# your Dynamic-LLaVA checkpoint path --question-file [./playground/data/eval/vqav2/$SPLIT.jsonl]\# your benchmark path --image-folder [./playground/data/eval/vqav2/test2015]\# your benchmark path --answers-file ./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl \ --num-chunks$CHUNKS \ --chunk-idx$IDX \ --temperature 0 \ --conv-mode vicuna_v1&donewaitoutput_file=./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/merge.jsonl# Clear out the output file if it exists.>"$output_file"# Loop through the indices and concatenate each file.forIDXin$(seq 0$((CHUNKS-1)));do cat ./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl>>"$output_file"donepython scripts/convert_vqav2_for_submission.py --split$SPLIT --ckpt$CKPT \--test_dir"./playground/data/eval/vqav2" \--result_dir"./playground/data/eval/vqav2"
And then, submit the results to theevaluation server:./playground/data/eval/vqav2/answers_upload
.
For evaluate Dynamic-LLaVA-7B in GQA benchmark, you can directly conduct the shellrun/dynamic_eval/eval_for_gqa.sh
, the detailed command is as follows:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash run/dynamic_eval/eval_for_gqa.sh
The details ofrun/dynamic_eval/eval_for_gqa.sh
are as follows:
#!/bin/bashgpu_list="${CUDA_VISIBLE_DEVICES:-0}"IFS=','read -ra GPULIST<<<"$gpu_list"CHUNKS=${#GPULIST[@]}CKPT="dynamic-llava-7b"SPLIT="llava_gqa_testdev_balanced"GQADIR="[./playground/data/eval/gqa/data]"# your benchmark pathforIDXin$(seq 0$((CHUNKS-1)));do CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava.dynamic_eval.model_vqa_loader \ --model-path [./results/dynamic-llava-7b]\# your Dynamic-LLaVA checkpoint path --question-file [./playground/data/eval/gqa/$SPLIT.jsonl]\# your benchmark path --image-folder [./playground/data/eval/gqa/data/images]\# your benchmark path --answers-file ./playground/data/eval/gqa/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl \ --num-chunks$CHUNKS \ --chunk-idx$IDX \ --temperature 0 \ --conv-mode vicuna_v1&donewaitoutput_file=./playground/data/eval/gqa/answers/$SPLIT/$CKPT/merge.jsonl# Clear out the output file if it exists.>"$output_file"# Loop through the indices and concatenate each file.forIDXin$(seq 0$((CHUNKS-1)));do cat ./playground/data/eval/gqa/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl>>"$output_file"donepython scripts/convert_gqa_for_eval.py --src$output_file --dst$GQADIR/testdev_balanced_predictions.jsoncd$GQADIRpython eval/eval.py --tier testdev_balanced
This project is based onLLaVA. Thanks for their wonderful works.