Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit80f8efe

Browse files
committed
release the training and evaluation codes of Seer-Large, which achieves Avg.Len. of 4.3 on CALVIN ABC-D
1 parent4828149 commit80f8efe

27 files changed

+377
-114
lines changed

‎.gitignore‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# workspace
2+
calvin
3+
checkpoints
4+
eval_logs
5+
evaluate
6+
17
# Byte-compiled / optimized / DLL files
28
__pycache__/
39
*.py[cod]

‎README.md‎

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
<h3align="center">
77
<ahref="https://arxiv.org/pdf/2412.15109">Arxiv</a> |
8-
<a>Video</a> |
98
<ahref="https://nimolty.github.io/Seer/">Webpage</a>
109
</h3>
1110

@@ -59,13 +58,13 @@ This section details the pre-training process of Seer in real-world experiments,
5958
Relevant checkpoints are available on the[website](https://drive.google.com/drive/folders/1F3IE95z2THAQ_lt3DKUFdRGc86Thsnc7?usp=sharing).
6059
|Model|Checkpoint|
6160
|:------:|:------:|
62-
|CALVIN ABC-D|[Seer](https://drive.google.com/drive/folders/17Gv9snGCkViuhHmzN3eTWlI0tMfGSGT3?usp=sharing) /[Seer Large](https://drive.google.com/drive/folders/1AFabqfDEi69oMo0FTGhEiH2QSRLYBR9r?usp=drive_link)|
61+
|CALVIN ABC-D|[Seer](https://drive.google.com/drive/folders/17Gv9snGCkViuhHmzN3eTWlI0tMfGSGT3?usp=sharing)(Avg.Len. : 3.98)/[Seer Large](https://drive.google.com/drive/folders/1AFabqfDEi69oMo0FTGhEiH2QSRLYBR9r?usp=drive_link) (Avg.Len. : 4.30)|
6362
|Real-World|[Seer (Droid Pre-trained)](https://drive.google.com/drive/folders/1rT8JKLhJGIo97jfYUm2JiFUrogOq-dgJ?usp=drive_link)|
6463

6564
##📆 TODO <aname="todos"></a>
6665
-[x] Release real-world expriment code.
6766
-[x] Release CALVIN ABC-D experiment code (Seer).
68-
-[] Release CALVIN ABC-D experiment code (Seer-Large).
67+
-[x] Release CALVIN ABC-D experiment code (Seer-Large).
6968
-[ ] Release LIBERO-LONG experiment code.
7069

7170
##License <aname="license"></a>

‎docs/CALVIN_ABC-D_INSTALL.md‎

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#Installation
22

3-
**(1) Env**
3+
**(1)CondaEnv**
44
```python
55
conda create-n seer python=3.10
66
conda activate seer
@@ -28,3 +28,9 @@ cd ${YOUR_PATH_TO_SEER}
2828
pip install-r requirements.txt
2929
pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0--index-url https://download.pytorch.org/whl/cu121
3030
```
31+
32+
**(5) Create a soft link to CALVIN**
33+
```python
34+
cd${YOUR_PATH_TO_SEER}
35+
ln-s$CALVIN_ROOT calvin
36+
```

‎docs/CALVIN_ABC-D_RUN.md‎

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,44 @@ For convenience, some checkpoints, such as the MAE-pretrained ViT-B model, are p
55
*:exclamation:**pretrain.sh, finetune.sh, scratch, eval.sh:**
66
Please update the following:
77
***calvin_dataset_path** to the directory where you have stored the CALVIN ABC-D data.
8-
***checkpoint_path** to the parent directory where your experiment checkpoints are saved.
8+
***save_checkpoint_path** to the parent directory where your experiment checkpoints are saved. Recommend to create a```checkpoints``` folder in the project root directory.
99
***finetune_from_pretrained_ckpt** to the location of your pre-trained checkpoint.
1010
***resume_from_checkpoint** to the location of your fine-tuned checkpoint.
11-
***vit_ckpt_path** to the location of your ViT checkpoint (downloaded from the[website](https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing)).
11+
***vit_checkpoint_path** to the location of your ViT checkpoint (downloaded from the[website](https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing)). Recommend to be stored in```checkpoints/vit_mae/mae_pretrain_vit_base.pth```.
1212

1313
*:exclamation:**networkx:**
1414
Due to compatibility issues between the networkx library in CALVIN and Python 3.10, we provide a compatible version of networkx.zip on the[website](https://drive.google.com/file/d/1z-d1SaI0rXfBtBicw1zPSsP-wE-26oLq/view?usp=sharing). Download and unzip it, then replace the existing networkx library in the following path:
1515

1616
##Seer
1717
###Pre-train
1818
```bash
19+
# Pre-train Seer on Calvin ABC-D dataset
1920
bash scripts/CALVIN_ABC_D/Seer/pretrain.sh
21+
# Pre-train Seer-Large on Calvin ABC-D dataset
22+
bash scripts/CALVIN_ABC_D/Seer-Large/pretrain.sh
2023
```
24+
2125
###Fine-tune
2226
```bash
27+
# Fine-tune Seer on Calvin ABC-D dataset
2328
bash scripts/CALVIN_ABC_D/Seer/finetune.sh
29+
# Fine-tune Seer-Large on Calvin ABC-D dataset
30+
bash scripts/CALVIN_ABC_D/Seer-Large/finetune.sh
2431
```
25-
###Eval
32+
33+
###Train from Scratch
2634
```bash
27-
bash scripts/CALVIN_ABC_D/Seer/eval.sh
35+
# Train Seer on Calvin ABC-D dataset from scratch
36+
bash scripts/CALVIN_ABC_D/Seer/scratch.sh
37+
# Train Seer-Large on Calvin ABC-D dataset from scratch
38+
bash scripts/CALVIN_ABC_D/Seer-Large/scratch.sh
2839
```
29-
###Scratch
40+
41+
###Eval
3042
```bash
31-
bash scripts/CALVIN_ABC_D/Seer/scratch.sh
43+
# Evaluate Seer on Calvin ABC-D benchmark
44+
bash scripts/CALVIN_ABC_D/Seer/eval.sh
45+
# Evaluate Seer-Large on Calvin ABC-D benchmark
46+
bash scripts/CALVIN_ABC_D/Seer-Large/eval.sh
3247
```
48+

‎docs/REAL-WORLD_PRETRAIN.md‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#Pre-train
22
##Notice
3-
We provide code for pre-training on both the DROID and OXE datasets. Users should update thecheckpoint_path to the directory where you want to save the training checkpoints, and modify the root_dir to the location where the preprocessed real data is stored. Additionally, users should configure the SLURM information in the provided scripts.
3+
We provide code for pre-training on both the DROID and OXE datasets. Users should update thesave_checkpoint_path to the directory where you want to save the training checkpoints, and modify the root_dir to the location where the preprocessed real data is stored. Additionally, users should configure the SLURM information in the provided scripts.
44

55
Preparation
66
```python

‎eval_calvin.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def main():
3535
model=SeerAgent(
3636
finetune_type=args.finetune_type,
3737
clip_device=device_id,
38-
checkpoint_path=args.vit_ckpt_path,
38+
vit_checkpoint_path=args.vit_checkpoint_path,
3939
sequence_length=args.sequence_length,
4040
num_resampler_query=args.num_resampler_query,
4141
num_obs_token_per_image=args.num_obs_token_per_image,
@@ -44,6 +44,7 @@ def main():
4444
action_pred_steps=args.action_pred_steps,
4545
obs_pred=args.obs_pred,
4646
atten_only_obs=args.atten_only_obs,
47+
attn_robot_proprio_state=args.attn_robot_proprio_state,
4748
atten_goal=args.atten_goal,
4849
atten_goal_state=args.atten_goal_state,
4950
mask_l_obs_ratio=args.mask_l_obs_ratio,

‎models/seer_model.py‎

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
2+
importos
13
importrandom
24
fromfunctoolsimportpartial
35
fromcopyimportdeepcopy
@@ -18,6 +20,7 @@
1820

1921
defgenerate_attention_mask(K,num_A,num_B,atten_goal,atten_goal_state,
2022
atten_only_obs,
23+
attn_robot_proprio_state,
2124
mask_l_obs_ratio,
2225
num_obs_token,action_pred_steps):
2326
# num_A: 1+1+self.NUM_RESAMPLER_QUERY*2+1*2
@@ -43,6 +46,8 @@ def generate_attention_mask(K, num_A, num_B, atten_goal, atten_goal_state,
4346
attention_mask[start_index+num_A+num_obs_token:start_index+num_A+num_obs_token+action_pred_steps]=-float('inf')
4447
attention_mask[start_index+num_A+num_obs_token:start_index+num_A+num_obs_token+action_pred_steps,start_index+2:start_index+num_A]=0.0
4548
attention_mask[start_index+num_A+num_obs_token:start_index+num_A+num_obs_token+action_pred_steps,start_index+num_A:start_index+num_A+num_obs_token]=0.0
49+
ifattn_robot_proprio_state:
50+
attention_mask[start_index+num_A+num_obs_token:start_index+num_A+num_obs_token+action_pred_steps,start_index+1:start_index+2]=0.0
4651
ifmask_l_obs_ratio>0:
4752
count=int(mask_l_obs_ratio* (num_obs_token))
4853
selected_numbers=np.random.choice(range(num_obs_token),size=count,replace=False)
@@ -112,12 +117,13 @@ def __init__(
112117
self,
113118
finetune_type,
114119
clip_device,
115-
checkpoint_path,
120+
vit_checkpoint_path,
116121
sequence_length=10,
117122
num_resampler_query=9,
118123
num_obs_token_per_image=10,
119124
obs_pred=False,
120125
atten_only_obs=False,
126+
attn_robot_proprio_state=False,
121127
atten_goal=False,
122128
atten_goal_state=False,
123129
mask_l_obs_ratio=0.0,
@@ -142,19 +148,20 @@ def __init__(
142148
self.atten_goal=atten_goal
143149
self.atten_goal_state=atten_goal_state
144150
self.atten_only_obs=atten_only_obs
151+
self.attn_robot_proprio_state=attn_robot_proprio_state
145152
self.mask_l_obs_ratio=mask_l_obs_ratio
146153
self.hidden_dim=hidden_dim
147154
self.phase=phase
148155
assertself.phasein ["pretrain","finetune","evaluate"]
149156
self.gripper_width=gripper_width
150-
self.checkpoint_path=checkpoint_path
157+
self.vit_checkpoint_path=vit_checkpoint_path
151158

152159
# text projector
153160
self.text_projector=nn.Linear(512,self.hidden_dim)
154161

155162
# state encoder
156-
ARM_STATE_FEATURE_DIM=384
157-
GRIPPER_STATE_FEATURE_DIM=384
163+
ARM_STATE_FEATURE_DIM=self.hidden_dim
164+
GRIPPER_STATE_FEATURE_DIM=self.hidden_dim
158165
self.arm_state_encoder=nn.Linear(6,ARM_STATE_FEATURE_DIM)
159166
self.gripper_state_encoder=nn.Linear(2,GRIPPER_STATE_FEATURE_DIM)
160167
self.state_projector=nn.Linear(ARM_STATE_FEATURE_DIM+GRIPPER_STATE_FEATURE_DIM,self.hidden_dim)
@@ -204,6 +211,7 @@ def __init__(
204211
atten_goal=self.atten_goal,
205212
atten_goal_state=self.atten_goal_state,
206213
atten_only_obs=self.atten_only_obs,
214+
attn_robot_proprio_state=self.attn_robot_proprio_state,
207215
mask_l_obs_ratio=self.mask_l_obs_ratio,
208216
num_obs_token=this_num_obs_token,
209217
action_pred_steps=self.action_pred_steps),
@@ -218,21 +226,22 @@ def __init__(
218226
self.transformer_backbone=GPT2Model(config)
219227

220228
# action decoder
229+
MLP_hidden_dim=self.hidden_dim//2
221230
self.action_decoder=nn.Sequential(
222-
nn.Linear(self.hidden_dim,192),
231+
nn.Linear(self.hidden_dim,MLP_hidden_dim),
223232
nn.ReLU(),
224-
nn.Linear(192,192),
233+
nn.Linear(MLP_hidden_dim,MLP_hidden_dim),
225234
nn.ReLU(),
226235
)
227236
self.arm_action_decoder=nn.Sequential(
228-
nn.Linear(192,6),
237+
nn.Linear(MLP_hidden_dim,6),
229238
torch.nn.Tanh(),
230239
)
231240
self.gripper_action_decoder=nn.Sequential(
232-
nn.Linear(192,1),
241+
nn.Linear(MLP_hidden_dim,1),
233242
torch.nn.Sigmoid(),
234243
)
235-
self.IMAGE_DECODER_hidden_dim=384
244+
self.IMAGE_DECODER_hidden_dim=self.hidden_dim
236245
self.NUM_MASK_TOKEN=int(calvin_input_image_size**2/patch_size/patch_size)# i.e. num_patch
237246
self.PATCH_SIZE=patch_size
238247
self.mask_token=nn.Parameter(torch.zeros(1,1,self.IMAGE_DECODER_hidden_dim))
@@ -249,11 +258,15 @@ def __init__(
249258
self.initialize_weights()
250259

251260
# freeze vision encoder
252-
checkpoint=torch.load(checkpoint_path,map_location='cpu')
253-
msg=self.vision_encoder.load_state_dict(checkpoint['model'],strict=False)
261+
print(self.vit_checkpoint_path)
262+
vit_checkpoint=torch.load(self.vit_checkpoint_path,map_location='cpu')
263+
self.vision_encoder.load_state_dict(vit_checkpoint['model'],strict=False)
254264

255265
# # freeze text encoder
256-
self.clip_model,self.image_processor=clip.load("ViT-B/32",device=clip_device)
266+
ifos.path.exists("checkpoints/clip/ViT-B-32.pt"):
267+
self.clip_model,self.image_processor=clip.load("checkpoints/clip/ViT-B-32.pt",device=clip_device)
268+
else:
269+
self.clip_model,self.image_processor=clip.load("ViT-B/32",device=clip_device)
257270

258271
definitialize_weights(self):
259272
# initialization
@@ -298,6 +311,7 @@ def forward(self, image_primary, image_wrist, state, text_token, action=None):
298311
atten_goal=self.atten_goal,
299312
atten_goal_state=self.atten_goal_state,
300313
atten_only_obs=self.atten_only_obs,
314+
attn_robot_proprio_state=self.attn_robot_proprio_state,
301315
mask_l_obs_ratio=self.mask_l_obs_ratio,
302316
num_obs_token=this_num_obs_token,
303317
action_pred_steps=self.action_pred_steps).to(self.device),

‎real_controller/controller.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def setup_model(self):
9292
self.model=SeerAgent(
9393
finetune_type=self.args.finetune_type,
9494
clip_device=self.device_id,
95-
checkpoint_path=self.args.vit_ckpt_path,
95+
save_checkpoint_path=self.args.vit_checkpoint_path,
9696
sequence_length=self.args.sequence_length,
9797
num_resampler_query=self.args.num_resampler_query,
9898
num_obs_token_per_image=self.args.num_obs_token_per_image,
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#!/bin/bash
2+
export GIT_PYTHON_REFRESH=quiet
3+
calvin_dataset_path="calvin/dataset/task_ABC_D"
4+
calvin_conf_path="calvin/calvin_models/conf"
5+
vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth"# downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing
6+
### NEED TO CHANGE the checkpoint path ###
7+
resume_from_checkpoint="checkpoints/CALVIN_ABC_D/Seer_Large/12.pth"# checkpoint path to be evaluated
8+
IFS='/'read -ra path_parts<<<"$resume_from_checkpoint"
9+
run_name="${path_parts[-2]}"
10+
log_name="${path_parts[-1]}"
11+
log_folder="eval_logs/$run_name"
12+
mkdir -p"$log_folder"
13+
log_file="eval_logs/$run_name/evaluate_$log_name.log"
14+
node=1
15+
node_num=8
16+
17+
torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10211 eval_calvin.py\
18+
--traj_cons \
19+
--rgb_pad 10 \
20+
--gripper_pad 4 \
21+
--gradient_accumulation_steps 1 \
22+
--bf16_module"vision_encoder" \
23+
--vit_checkpoint_path${vit_checkpoint_path} \
24+
--calvin_dataset${calvin_dataset_path} \
25+
--calvin_conf_path${calvin_conf_path} \
26+
--workers 16 \
27+
--lr_scheduler cosine \
28+
--save_every_iter 50000 \
29+
--num_epochs 20 \
30+
--seed 42 \
31+
--batch_size 64 \
32+
--precision fp32 \
33+
--weight_decay 1e-4 \
34+
--num_resampler_query 16 \
35+
--num_obs_token_per_image 16 \
36+
--run_name${run_name} \
37+
--transformer_layers 24 \
38+
--hidden_dim 1024 \
39+
--transformer_heads 16 \
40+
--phase"evaluate" \
41+
--finetune_type"calvin" \
42+
--action_pred_steps 3 \
43+
--sequence_length 10 \
44+
--future_steps 3 \
45+
--window_size 13 \
46+
--obs_pred \
47+
--resume_from_checkpoint${resume_from_checkpoint}| tee${log_file} \
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#!/bin/bash
2+
### need to change to your path ###
3+
calvin_dataset_path="calvin/dataset/task_ABC_D"
4+
save_checkpoint_path="checkpoints/"
5+
finetune_from_pretrained_ckpt="checkpoints/pretrain_Seer_ptbs512_24layers_16heads_hd1024-Large_calvin_abc_d/9.pth"
6+
vit_checkpoint_path="checkpoints/vit_mae/mae_pretrain_vit_base.pth"# downloaded from https://drive.google.com/file/d/1bSsvRI4mDM3Gg51C6xO0l9CbojYw3OEt/view?usp=sharing
7+
node=8
8+
node_num=8
9+
torchrun --nnodes=${node} --nproc_per_node=${node_num} --master_port=10211 train.py \
10+
--traj_cons \
11+
--rgb_pad 10 \
12+
--gripper_pad 4 \
13+
--gradient_accumulation_steps 1 \
14+
--bf16_module"vision_encoder" \
15+
--vit_checkpoint_path${vit_checkpoint_path} \
16+
--calvin_dataset${calvin_dataset_path} \
17+
--workers 8 \
18+
--lr_scheduler cosine \
19+
--save_every_iter 100000 \
20+
--num_epochs 20 \
21+
--seed 42 \
22+
--batch_size 8 \
23+
--precision fp32 \
24+
--learning_rate 1e-3 \
25+
--warmup_epochs 3 \
26+
--finetune_type"calvin" \
27+
--wandb_project seer \
28+
--weight_decay 1e-4 \
29+
--num_resampler_query 16 \
30+
--num_obs_token_per_image 16 \
31+
--run_name finetune_Seer-Large_calvin_abc_d \
32+
--save_checkpoint_path${save_checkpoint_path} \
33+
--transformer_layers 24 \
34+
--hidden_dim 1024 \
35+
--transformer_heads 16 \
36+
--phase"finetune" \
37+
--action_pred_steps 3 \
38+
--sequence_length 10 \
39+
--future_steps 3 \
40+
--window_size 13 \
41+
--obs_pred \
42+
--loss_image \
43+
--loss_action \
44+
--save_checkpoint \
45+
--report_to_wandb \
46+
--offline \
47+
--finetune_from_pretrained_ckpt${finetune_from_pretrained_ckpt} \
48+
49+

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp