Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit08eb9a4

Browse files
[Doc] for Visualization feature map using wandb backend in dev-1.x (#2557)
## MotivationDocs for Visualization featusre map using wandb backend.## ModificationAdd a new markdown file and result demo of wandb.---------Co-authored-by: MeowZheng <meowzheng@outlook.com>
1 parent916ed2b commit08eb9a4

File tree

2 files changed

+202
-0
lines changed

2 files changed

+202
-0
lines changed

‎docs/zh_cn/user_guides/index.rst‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@
1818
visualization.md
1919
useful_tools.md
2020
deployment.md
21+
visualization_feature_map.md
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
#wandb记录特征图可视化
2+
3+
MMSegmentation 1.x 提供了 Weights & Biases 的后端支持,方便对项目代码结果的可视化和管理。
4+
5+
##Wandb的配置
6+
7+
安装 Weights & Biases 的过程可以参考[官方安装指南](https://docs.wandb.ai/quickstart),具体的步骤如下:
8+
9+
```shell
10+
pip install wandb
11+
wandb login
12+
```
13+
14+
`vis_backend` 中添加`WandbVisBackend`
15+
16+
```python
17+
vis_backends=[dict(type='LocalVisBackend'),
18+
dict(type='TensorboardVisBackend'),
19+
dict(type='WandbVisBackend')]
20+
```
21+
22+
##测试数据和结果及特征图的可视化
23+
24+
`SegLocalVisualizer` 是继承自 MMEngine 中`Visualizer` 类的子类,适用于 MMSegmentation 可视化,有关`Visualizer` 的详细信息请参考在 MMEngine 中的[可视化教程](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/visualization.html)
25+
26+
以下是一个关于`SegLocalVisualizer` 的示例,首先你可以使用下面的命令下载这个案例中的数据:
27+
28+
<divalign=center>
29+
<imgsrc="https://user-images.githubusercontent.com/24582831/189833109-eddad58f-f777-4fc0-b98a-6bd429143b06.png"width="70%"/>
30+
</div>
31+
32+
```shell
33+
wget https://user-images.githubusercontent.com/24582831/189833109-eddad58f-f777-4fc0-b98a-6bd429143b06.png --output-document aachen_000000_000019_leftImg8bit.png
34+
wget https://user-images.githubusercontent.com/24582831/189833143-15f60f8a-4d1e-4cbb-a6e7-5e2233869fac.png --output-document aachen_000000_000019_gtFine_labelTrainIds.png
35+
36+
wget https://download.openmmlab.com/mmsegmentation/v0.5/ann/ann_r50-d8_512x1024_40k_cityscapes/ann_r50-d8_512x1024_40k_cityscapes_20200605_095211-049fc292.pth
37+
38+
```
39+
40+
```python
41+
# Copyright (c) OpenMMLab. All rights reserved.
42+
from argparseimport ArgumentParser
43+
from typingimport Type
44+
45+
import mmcv
46+
import torch
47+
import torch.nnas nn
48+
49+
from mmengine.modelimport revert_sync_batchnorm
50+
from mmengine.structuresimport PixelData
51+
from mmseg.apisimport inference_model, init_model
52+
from mmseg.structuresimport SegDataSample
53+
from mmseg.utilsimport register_all_modules
54+
from mmseg.visualizationimport SegLocalVisualizer
55+
56+
57+
classRecorder:
58+
"""record the forward output feature map and save to data_buffer."""
59+
60+
def__init__(self) ->None:
61+
self.data_buffer=list()
62+
63+
def__enter__(self, ):
64+
self._data_buffer=list()
65+
66+
defrecord_data_hook(self,model: nn.Module,input: Type,output: Type):
67+
self.data_buffer.append(output)
68+
69+
def__exit__(self,*args,**kwargs):
70+
pass
71+
72+
73+
defvisualize(args,model,recorder,result):
74+
seg_visualizer= SegLocalVisualizer(
75+
vis_backends=[dict(type='WandbVisBackend')],
76+
save_dir='temp_dir',
77+
alpha=0.5)
78+
seg_visualizer.dataset_meta=dict(
79+
classes=model.dataset_meta['classes'],
80+
palette=model.dataset_meta['palette'])
81+
82+
image= mmcv.imread(args.img,'color')
83+
84+
seg_visualizer.add_datasample(
85+
name='predict',
86+
image=image,
87+
data_sample=result,
88+
draw_gt=False,
89+
draw_pred=True,
90+
wait_time=0,
91+
out_file=None,
92+
show=False)
93+
94+
# add feature map to wandb visualizer
95+
for iinrange(len(recorder.data_buffer)):
96+
feature= recorder.data_buffer[i][0]# remove the batch
97+
drawn_img= seg_visualizer.draw_featmap(
98+
feature, image,channel_reduction='select_max')
99+
seg_visualizer.add_image(f'feature_map{i}', drawn_img)
100+
101+
if args.gt_mask:
102+
sem_seg= mmcv.imread(args.gt_mask,'unchanged')
103+
sem_seg= torch.from_numpy(sem_seg)
104+
gt_mask=dict(data=sem_seg)
105+
gt_mask= PixelData(**gt_mask)
106+
data_sample= SegDataSample()
107+
data_sample.gt_sem_seg= gt_mask
108+
109+
seg_visualizer.add_datasample(
110+
name='gt_mask',
111+
image=image,
112+
data_sample=data_sample,
113+
draw_gt=True,
114+
draw_pred=False,
115+
wait_time=0,
116+
out_file=None,
117+
show=False)
118+
119+
seg_visualizer.add_image('image', image)
120+
121+
122+
defmain():
123+
parser= ArgumentParser(
124+
description='Draw the Feature Map During Inference')
125+
parser.add_argument('img',help='Image file')
126+
parser.add_argument('config',help='Config file')
127+
parser.add_argument('checkpoint',help='Checkpoint file')
128+
parser.add_argument('--gt_mask',default=None,help='Path of gt mask file')
129+
parser.add_argument('--out-file',default=None,help='Path to output file')
130+
parser.add_argument(
131+
'--device',default='cuda:0',help='Device used for inference')
132+
parser.add_argument(
133+
'--opacity',
134+
type=float,
135+
default=0.5,
136+
help='Opacity of painted segmentation map. In (0, 1] range.')
137+
parser.add_argument(
138+
'--title',default='result',help='The image identifier.')
139+
args= parser.parse_args()
140+
141+
register_all_modules()
142+
143+
# build the model from a config file and a checkpoint file
144+
model= init_model(args.config, args.checkpoint,device=args.device)
145+
if args.device=='cpu':
146+
model= revert_sync_batchnorm(model)
147+
148+
# show all named module in the model and use it in source list below
149+
for name, modulein model.named_modules():
150+
print(name)
151+
152+
source= [
153+
'decode_head.fusion.stages.0.query_project.activate',
154+
'decode_head.context.stages.0.key_project.activate',
155+
'decode_head.context.bottleneck.activate'
156+
]
157+
source=dict.fromkeys(source)
158+
159+
count=0
160+
recorder= Recorder()
161+
# registry the forward hook
162+
for name, modulein model.named_modules():
163+
if namein source:
164+
count+=1
165+
module.register_forward_hook(recorder.record_data_hook)
166+
if count==len(source):
167+
break
168+
169+
with recorder:
170+
# test a single image, and record feature map to data_buffer
171+
result= inference_model(model, args.img)
172+
173+
visualize(args, model, recorder, result)
174+
175+
176+
if__name__=='__main__':
177+
main()
178+
179+
```
180+
181+
将上述代码保存为 feature_map_visual.py,在终端执行如下代码
182+
183+
```shell
184+
python feature_map_visual.py${图像}${配置文件}${检查点文件} [可选参数]
185+
```
186+
187+
样例
188+
189+
```shell
190+
python feature_map_visual.py \
191+
aachen_000000_000019_leftImg8bit.png \
192+
configs/ann/ann_r50-d8_4xb2-40k_cityscapes-512x1024.py \
193+
ann_r50-d8_512x1024_40k_cityscapes_20200605_095211-049fc292.pth \
194+
--gt_mask aachen_000000_000019_gtFine_labelTrainIds.png
195+
```
196+
197+
可视化后的图像结果和它的对应的 feature map图像会出现在wandb账户中
198+
199+
<divalign=center>
200+
<imgsrc="https://user-images.githubusercontent.com/76149310/217520321-647f5bf9-eef2-446d-a9e8-5ca7b621d500.png">
201+
</div>

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp