- Notifications
You must be signed in to change notification settings - Fork28
Official code repository for NeurIPS 2022 paper "SatMAE: Pretraining Transformers for Temporal and Multi-Spectral Satellite Imagery"
License
sustainlab-group/SatMAE
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
This is the official repository for the NeurIPS 2022 paper"SatMAE: Pre-training Transformers for Temporal and Multi-Spectral Satellite Imagery".
Authors:Yezhen Cong1,Samar Khanna1,Chenlin Meng,Patrick Liu,Erik Rozi,Yutong He,Marshall Burke,David B. Lobell,Stefano Ermon.
1 Equal contribution, order determined via coin flip.
Pre-training and finetuning on fMoW-Temporal are MEMORY-HEAVY.Please make sure you have enough memory.For context, we ran our experiments on 8 NVIDIA V100 GPUs.
You can download the fMoW datasethere. Then follow this piece ofcode to preprocess it. We will also upload the pre-processed dataset soon. The metadata files arehere.
After you download the dataset and metadata files, your directory should look like:
<PATH_TO_DATASET_ROOT_FOLDER>--- train_62classes.csv--- val_62classes.csv--- fmow------- train---------- airport---------- ...------- val---------- airport---------- ...For pretraining, this is the default command:
python -m torch.distributed.launch --nproc_per_node=8 \ --nnodes=1 --master_port=1234 main_pretrain.py \ --batch_size 8 --accum_iter 16 \ --norm_pix_loss --epochs 100 \ --blr 1.5e-4 --mask_ratio 0.75 \ --input_size 224 --patch_size 16 \ --model mae_vit_large_patch16 \ --model_type temporal \ --dataset_type temporal \ --train_path<PATH_TO_DATASET_ROOT_FOLDER>/train_62classes.csv --output_dir<PATH_TO_YOUR_OUTPUT_FOLDER> \ --log_dir<PATH_TO_YOUR_OUTPUT_FOLDER> \ --num_workers 8
Note that if you want to use wandb, login to wandb after activating condaand before running the code by doingwandb login in the shell,and add--wandb <YOUR_WANDB_PROJECT_NAME> to the command above.This applies to all following commands.You will also have to edit the entity name inmain_pretrain.py andmain_finetune.py.
To finetune, the basic command is:
python -m torch.distributed.launch --nproc_per_node=8 \ --nnodes=1 --master_port=1234 main_finetune.py \ --output_dir<PATH_TO_YOUR_OUTPUT_FOLDER> \ --log_dir<PATH_TO_YOUR_OUTPUT_FOLDER> \ --batch_size 4 --accum_iter 4 \ --model vit_large_patch16 --epochs 50 --blr 1e-3 --layer_decay 0.75 \ --weight_decay 0.05 --drop_path 0.2 --reprob 0.25 \ --mixup 0.8 --cutmix 1.0 --model_type temporal \ --finetune<PATH_TO_YOUR_PRETRAIN_CHECKPOINT> \ --dist_eval --num_workers 8 --dataset temporal \ --train_path<PATH_TO_DATASET_ROOT_FOLDER>/train_62classes.csv \ --test_path<PATH_TO_DATASET_ROOT_FOLDER>/val_62classes.csv
Note: If you are using our provided checkpoint, please add--nb_classes 1000.This is a legacy issue which won't affect the model performance since the actual number of classes is less than 1000.To resume a finetuning job, you can replace the--finetune <PATH_TO_YOUR_PRETRAIN_CHECKPOINT> to--resume <PATH_TO_YOUR_PRETRAIN_CHECKPOINT> in the command above.To finetune from scratch, simply omit the--finetune argument.
To evaluate, the basic command is:
python -m torch.distributed.launch --nproc_per_node=8 \ --nnodes=1 --master_port=1234 main_finetune.py \ --output_dir<PATH_TO_YOUR_OUTPUT_FOLDER> \ --log_dir<PATH_TO_YOUR_OUTPUT_FOLDER> \ --batch_size 16 \ --model vit_large_patch16 \ --model_type temporal \ --resume<PATH_TO_YOUR_FINEtune_CHECKPOINT> \ --dist_eval --eval --num_workers 8 --dataset fmow_temporal \ --train_path<PATH_TO_DATASET_ROOT_FOLDER>/train_62classes.csv \ --test_path<PATH_TO_DATASET_ROOT_FOLDER>/val_62classes.csv
Similarly, if you are using our provided checkpoint, please add--nb_classes 1000.
You can download model weights pre-trained on fMoW-temporal and weights finetuned on fMoW-temporalhere.
| Model | Top 1 Accuracy | Pretrain | Finetune |
|---|---|---|---|
| ViT-Large | 77.78% | download, 800 epochs | download, 29 epochs |
| Model | Top 1 Accuracy | Pretrain | Finetune |
|---|---|---|---|
| ViT-Large | 79.99% | download, 50 epochs | download, 24 epochs |
The accuracy of SatMAE on fMoW-Temporal (reported above) is achieved without using test-time augmentation (see paper).
Training multi-spectral SatMAE is similar to trainingtemporal SatMAE.
You can access and download the fMoW-Sentinel dataset we collectedhere.Try thislink if the previous one doesn't display correctly.
Note that when loading thetrain.csv orval.csv files, you may have to preprocess a columncalledimage_path. Theimage_path for any row can be constructed like this:
fmow-sentinel/<split>/<category>/<category>_<location_id>/<category>_<location_id>_<image_id>.tifFor pretraining, this is the default command:
python -m torch.distributed.launch --nproc_per_node=8 main_pretrain.py \--wandb satmae_pretrain \--batch_size 16 --accum_iter 32 --blr 0.0001 \--epochs 200 --warmup_epochs 20 --num_workers 16 \--input_size 96 --patch_size 8 \--mask_ratio 0.75 \--model_type group_c \--dataset_type sentinel --dropped_bands 0 9 10 \--grouped_bands 0 1 2 6 --grouped_bands 3 4 5 7 --grouped_bands 8 9 \--train_path /home/fmow-sentinel-filtered-csv/train.csv \--output_dir /home/experiments/pretrain \--log_dir /home/experiments/pretrain
You can use the--spatial_mask argument to toggle on consistent spatial masking(rather than independent masking). See paper for details (independent masking performs better).
To resume a pretraining job, you can use--resume PATH/TO/CKPT.PTH(eg:--resume /home/experiments/pretrain/checkpoint-175.pth).
To finetune, the basic command is:
python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \--wandb satmae_finetune \--batch_size 8 --accum_iter 16 --blr 0.0002 \--epochs 30 --num_workers 16 \--input_size 96 --patch_size 8 \--weight_decay 0.05 --drop_path 0.2 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \--model_type group_c \--dataset_type sentinel --dropped_bands 0 9 10 \--train_path /home/fmow-sentinel-filtered-csv/train.csv \--test_path /home/fmow-sentinel-filtered-csv/val.csv \--output_dir /home/experiments/finetune \--log_dir /home/experiments/finetune \--finetune /home/experiments/pretain/checkpoint-199.pth
To finetune from scratch, simply omit the--finetune argument.To resume a finetuning job, you can replace--finetune path/to/pretrained_checkpoint.pthwith--resume path/to/finetune_checkpoint.pth in the command above.
We will be uploading model checkpointshere.The pretrained checkpoints have been trained for 200 epochs,so the accuracy numbers might be higher than in the paper(where the models were pretrained for 50 epochs).
The Top 1 accuracy is measured on the validation set offMoW-Sentinel.
| Model | Top 1 Accuracy | Pretrain | Finetune |
|---|---|---|---|
| ViT-Base (200 epochs) | 62.65% | download | download |
| ViT-Large (200 epochs) | 63.84% | download | download |
If you are interested, please check out some of our group's follow-on work on foundation models for remote-sensing data:
- DiffusionSat: A Generative Foundation Model for Satellite Imagery, in ICLR 2024.
- ExPLoRA: Parameter-Efficient Extended Pre-training to Adapt Vision Transformers under Domain Shifts, in ICML 2025.
In particular, ExPLoRA is a superset of SatMAE and contains code to do full or parameter-efficient pre-training in either MAE or DinoV2 flavors.
Code from this repository is inspired from the Masked Autoencoders (MAE)repository.
If you found our project helpful, please cite our paper:
@inproceedings{ satmae2022, title={Sat{MAE}: Pre-training Transformers for Temporal and Multi-Spectral Satellite Imagery}, author={Yezhen Cong and Samar Khanna and Chenlin Meng and Patrick Liu and Erik Rozi and Yutong He and Marshall Burke and David B. Lobell and Stefano Ermon}, booktitle={Advances in Neural Information Processing Systems}, editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho}, year={2022}, url={https://openreview.net/forum?id=WBhqzpF6KYH}}About
Official code repository for NeurIPS 2022 paper "SatMAE: Pretraining Transformers for Temporal and Multi-Spectral Satellite Imagery"
Resources
License
Uh oh!
There was an error while loading.Please reload this page.