Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commite370fe5

Browse files
committed
Add a config for training with infos including complex prompts
1 parent5bf944c commite370fe5

File tree

2 files changed

+226
-2
lines changed

2 files changed

+226
-2
lines changed

‎configs/grounding/mv-grounding_8xb12_embodiedscan-vg-9dof-full.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@
142142
dataset=dict(type=dataset_type,
143143
data_root=data_root,
144144
ann_file='embodiedscan_infos_train.pkl',
145-
vg_file='embodiedscan_train_full_vg.json',
145+
vg_file='embodiedscan_train_vg.json',
146146
metainfo=metainfo,
147147
pipeline=train_pipeline,
148148
test_mode=False,
@@ -157,7 +157,7 @@
157157
dataset=dict(type=dataset_type,
158158
data_root=data_root,
159159
ann_file='embodiedscan_infos_val.pkl',
160-
vg_file='embodiedscan_val_full_vg.json',
160+
vg_file='embodiedscan_val_vg.json',
161161
metainfo=metainfo,
162162
pipeline=test_pipeline,
163163
test_mode=True,
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
_base_= ['../default_runtime.py']
2+
n_points=100000
3+
4+
backend_args=None
5+
# Uncomment the following if use ceph or other file clients.
6+
# See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient
7+
# for more details.
8+
# file_client_args = dict(
9+
# backend='petrel',
10+
# path_mapping=dict({
11+
# './data/scannet/':
12+
# 's3://openmmlab/datasets/detection3d/scannet_processed/',
13+
# 'data/scannet/':
14+
# 's3://openmmlab/datasets/detection3d/scannet_processed/'
15+
# }))
16+
17+
metainfo=dict(classes='all')
18+
19+
model=dict(
20+
type='SparseFeatureFusion3DGrounder',
21+
num_queries=256,
22+
voxel_size=0.01,
23+
data_preprocessor=dict(type='Det3DDataPreprocessor',
24+
mean=[123.675,116.28,103.53],
25+
std=[58.395,57.12,57.375],
26+
bgr_to_rgb=True,
27+
pad_size_divisor=32),
28+
backbone=dict(
29+
type='mmdet.ResNet',
30+
depth=50,
31+
base_channels=16,# to make it consistent with mink resnet
32+
num_stages=4,
33+
out_indices=(0,1,2,3),
34+
frozen_stages=1,
35+
norm_cfg=dict(type='BN',requires_grad=False),
36+
norm_eval=True,
37+
init_cfg=dict(type='Pretrained',checkpoint='torchvision://resnet50'),
38+
style='pytorch'),
39+
backbone_lidar=dict(type='MinkResNet',in_channels=3,depth=34),
40+
use_xyz_feat=True,
41+
# change due to no img feature fusion
42+
neck_3d=dict(type='MinkNeck',
43+
num_classes=1,
44+
in_channels=[128,256,512,1024],
45+
out_channels=256,
46+
voxel_size=0.01,
47+
pts_prune_threshold=1000),
48+
decoder=dict(
49+
num_layers=6,
50+
return_intermediate=True,
51+
layer_cfg=dict(
52+
# query self attention layer
53+
self_attn_cfg=dict(embed_dims=256,num_heads=8,dropout=0.0),
54+
# cross attention layer query to text
55+
cross_attn_text_cfg=dict(embed_dims=256,num_heads=8,dropout=0.0),
56+
# cross attention layer query to image
57+
cross_attn_cfg=dict(embed_dims=256,num_heads=8,dropout=0.0),
58+
ffn_cfg=dict(embed_dims=256,
59+
feedforward_channels=2048,
60+
ffn_drop=0.0)),
61+
post_norm_cfg=None),
62+
bbox_head=dict(type='GroundingHead',
63+
num_classes=256,
64+
sync_cls_avg_factor=True,
65+
decouple_bbox_loss=True,
66+
decouple_groups=4,
67+
share_pred_layer=True,
68+
decouple_weights=[0.2,0.2,0.2,0.4],
69+
contrastive_cfg=dict(max_text_len=256,
70+
log_scale='auto',
71+
bias=True),
72+
loss_cls=dict(type='mmdet.FocalLoss',
73+
use_sigmoid=True,
74+
gamma=2.0,
75+
alpha=0.25,
76+
loss_weight=1.0),
77+
loss_bbox=dict(type='BBoxCDLoss',
78+
mode='l1',
79+
loss_weight=1.0,
80+
group='g8')),
81+
coord_type='DEPTH',
82+
# training and testing settings
83+
train_cfg=dict(assigner=dict(type='HungarianAssigner3D',
84+
match_costs=[
85+
dict(type='BinaryFocalLossCost',
86+
weight=1.0),
87+
dict(type='BBox3DL1Cost',weight=2.0),
88+
dict(type='IoU3DCost',weight=2.0)
89+
]), ),
90+
test_cfg=None)
91+
92+
dataset_type='MultiView3DGroundingDataset'
93+
data_root='data'
94+
95+
train_pipeline= [
96+
dict(type='LoadAnnotations3D'),
97+
dict(type='MultiViewPipeline',
98+
n_images=20,
99+
transforms=[
100+
dict(type='LoadImageFromFile',backend_args=backend_args),
101+
dict(type='LoadDepthFromFile',backend_args=backend_args),
102+
dict(type='ConvertRGBDToPoints',coord_type='CAMERA'),
103+
dict(type='PointSample',num_points=n_points//10),
104+
dict(type='Resize',scale=(480,480),keep_ratio=False)
105+
]),
106+
dict(type='AggregateMultiViewPoints',coord_type='DEPTH'),
107+
dict(type='PointSample',num_points=n_points),
108+
dict(type='GlobalRotScaleTrans',
109+
rot_range=[-0.087266,0.087266],
110+
scale_ratio_range=[.9,1.1],
111+
translation_std=[.1,.1,.1],
112+
shift_height=False),
113+
dict(type='Pack3DDetInputs',
114+
keys=['img','points','gt_bboxes_3d','gt_labels_3d'])
115+
]
116+
test_pipeline= [
117+
dict(type='LoadAnnotations3D'),
118+
dict(type='MultiViewPipeline',
119+
n_images=50,
120+
ordered=True,
121+
transforms=[
122+
dict(type='LoadImageFromFile',backend_args=backend_args),
123+
dict(type='LoadDepthFromFile',backend_args=backend_args),
124+
dict(type='ConvertRGBDToPoints',coord_type='CAMERA'),
125+
dict(type='PointSample',num_points=n_points//10),
126+
dict(type='Resize',scale=(480,480),keep_ratio=False)
127+
]),
128+
dict(type='AggregateMultiViewPoints',coord_type='DEPTH'),
129+
dict(type='PointSample',num_points=n_points),
130+
dict(type='Pack3DDetInputs',
131+
keys=['img','points','gt_bboxes_3d','gt_labels_3d'])
132+
]
133+
134+
# TODO: to determine a reasonable batch size
135+
train_dataloader=dict(
136+
batch_size=12,
137+
num_workers=12,
138+
persistent_workers=True,
139+
sampler=dict(type='DefaultSampler',shuffle=True),
140+
dataset=dict(type='RepeatDataset',
141+
times=1,
142+
dataset=dict(type=dataset_type,
143+
data_root=data_root,
144+
ann_file='embodiedscan_infos_train.pkl',
145+
vg_file='embodiedscan_train_vg_all.json',
146+
metainfo=metainfo,
147+
pipeline=train_pipeline,
148+
test_mode=False,
149+
filter_empty_gt=True,
150+
box_type_3d='Euler-Depth')))
151+
152+
val_dataloader=dict(batch_size=12,
153+
num_workers=12,
154+
persistent_workers=True,
155+
drop_last=False,
156+
sampler=dict(type='DefaultSampler',shuffle=False),
157+
dataset=dict(type=dataset_type,
158+
data_root=data_root,
159+
ann_file='embodiedscan_infos_val.pkl',
160+
vg_file='embodiedscan_val_vg_all.json',
161+
metainfo=metainfo,
162+
pipeline=test_pipeline,
163+
test_mode=True,
164+
filter_empty_gt=True,
165+
box_type_3d='Euler-Depth'))
166+
167+
test_dataloader=dict(batch_size=12,
168+
num_workers=12,
169+
persistent_workers=True,
170+
drop_last=False,
171+
sampler=dict(type='DefaultSampler',shuffle=False),
172+
dataset=dict(type=dataset_type,
173+
data_root=data_root,
174+
ann_file='embodiedscan_infos_test.pkl',
175+
vg_file='embodiedscan_test_vg_all.json',
176+
metainfo=metainfo,
177+
pipeline=test_pipeline,
178+
test_mode=True,
179+
filter_empty_gt=True,
180+
box_type_3d='Euler-Depth'))
181+
182+
val_evaluator=dict(type='GroundingMetric')
183+
test_evaluator=dict(type='GroundingMetric',format_only=True)
184+
185+
# training schedule for 1x
186+
train_cfg=dict(type='EpochBasedTrainLoop',max_epochs=12,val_interval=12)
187+
val_cfg=dict(type='ValLoop')
188+
test_cfg=dict(type='TestLoop')
189+
190+
# optimizer
191+
lr=5e-4
192+
optim_wrapper=dict(type='OptimWrapper',
193+
optimizer=dict(type='AdamW',lr=lr,weight_decay=0.0005),
194+
paramwise_cfg=dict(
195+
custom_keys={
196+
'text_encoder':dict(lr_mult=0.0),
197+
'decoder':dict(lr_mult=0.1,decay_mult=1.0)
198+
}),
199+
clip_grad=dict(max_norm=10,norm_type=2))
200+
201+
# learning rate
202+
param_scheduler=dict(type='MultiStepLR',
203+
begin=0,
204+
end=12,
205+
by_epoch=True,
206+
milestones=[8,11],
207+
gamma=0.1)
208+
209+
custom_hooks= [dict(type='EmptyCacheHook',after_iter=True)]
210+
211+
# hooks
212+
default_hooks=dict(
213+
checkpoint=dict(type='CheckpointHook',interval=1,max_keep_ckpts=3))
214+
215+
# vis_backends = [
216+
# dict(type='TensorboardVisBackend'),
217+
# dict(type='LocalVisBackend')
218+
# ]
219+
# visualizer = dict(
220+
# type='Det3DLocalVisualizer',
221+
# vis_backends=vis_backends, name='visualizer')
222+
223+
find_unused_parameters=True
224+
load_from='/mnt/petrelfs/wangtai/EmbodiedScan/work_dirs/mv-3ddet-challenge/epoch_12.pth'# noqa

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp