Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit536f3af

Browse files
committed
feat: add lcm sampler support
This referenced an issue discussion of the stable-diffusion-webui atAUTOMATIC1111/stable-diffusion-webui#13952, whichmay not be too perfect.
1 parent3bf1665 commit536f3af

File tree

5 files changed

+384
-4
lines changed

5 files changed

+384
-4
lines changed

‎README.md‎

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
2626
-`DPM++ 2M`
2727
-[`DPM++ 2M v2`](https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457)
2828
-`DPM++ 2S a`
29+
-[`LCM`](https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13952)
2930
- Cross-platform reproducibility (`--rng cuda`, consistent with the`stable-diffusion-webui GPU RNG`)
3031
- Embedds generation parameters into png output as webui-compatible text string
3132
- Supported platforms
@@ -80,6 +81,7 @@ git submodule update
8081
```shell
8182
cd models
8283
pip install -r requirements.txt
84+
# (optional) python convert_diffusers_to_original_stable_diffusion.py --model_path [path to diffusers weights] --checkpoint_path [path to weights]
8385
python convert.py [path to weights] --out_type [output precision]
8486
# For example, python convert.py sd-v1-4.ckpt --out_type f16
8587
```
@@ -132,7 +134,7 @@ arguments:
132134
1.0 corresponds to full destruction of information in init image
133135
-H, --height H image height, in pixel space (default: 512)
134136
-W, --width W image width, in pixel space (default: 512)
135-
--sampling-method {euler, euler_a, heun, dpm++2m, dpm++2mv2}
137+
--sampling-method {euler, euler_a, heun, dpm++2m, dpm++2mv2, lcm}
136138
sampling method (default: "euler_a")
137139
--steps STEPS number of sample steps (default: 20)
138140
--rng {std_default, cuda} RNG (default: cuda)
@@ -196,3 +198,4 @@ docker run -v /path/to/models:/models -v /path/to/output/:/output sd [args...]
196198
-[stable-diffusion-stability-ai](https://github.com/Stability-AI/stablediffusion)
197199
-[stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
198200
-[k-diffusion](https://github.com/crowsonkb/k-diffusion)
201+
-[latent-consistency-model](https://github.com/luosiallen/latent-consistency-model)

‎examples/main.cpp‎

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,16 @@ const char* sample_method_str[] = {
8080
"dpm2",
8181
"dpm++2s_a",
8282
"dpm++2m",
83-
"dpm++2mv2"};
83+
"dpm++2mv2",
84+
"lcm",
85+
};
8486

8587
// Names of the sigma schedule overrides, same order as Schedule in stable-diffusion.h
8688
constchar* schedule_str[] = {
8789
"default",
8890
"discrete",
89-
"karras"};
91+
"karras"
92+
};
9093

9194
structOption {
9295
int n_threads = -1;
@@ -146,7 +149,7 @@ void print_usage(int argc, const char* argv[]) {
146149
printf(" 1.0 corresponds to full destruction of information in init image\n");
147150
printf(" -H, --height H image height, in pixel space (default: 512)\n");
148151
printf(" -W, --width W image width, in pixel space (default: 512)\n");
149-
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2}\n");
152+
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, lcm}\n");
150153
printf(" sampling method (default:\"euler_a\")\n");
151154
printf(" --steps STEPS number of sample steps (default: 20)\n");
152155
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
# Copy from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
2+
# LICENSE: https://github.com/huggingface/diffusers/blob/main/LICENSE
3+
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
4+
# *Only* converts the UNet, VAE, and Text Encoder.
5+
# Does not convert optimizer state or any other thing.
6+
7+
importargparse
8+
importos.pathasosp
9+
importre
10+
11+
importtorch
12+
fromsafetensors.torchimportload_file,save_file
13+
14+
15+
# =================#
16+
# UNet Conversion #
17+
# =================#
18+
19+
unet_conversion_map= [
20+
# (stable-diffusion, HF Diffusers)
21+
("time_embed.0.weight","time_embedding.linear_1.weight"),
22+
("time_embed.0.bias","time_embedding.linear_1.bias"),
23+
("time_embed.2.weight","time_embedding.linear_2.weight"),
24+
("time_embed.2.bias","time_embedding.linear_2.bias"),
25+
("input_blocks.0.0.weight","conv_in.weight"),
26+
("input_blocks.0.0.bias","conv_in.bias"),
27+
("out.0.weight","conv_norm_out.weight"),
28+
("out.0.bias","conv_norm_out.bias"),
29+
("out.2.weight","conv_out.weight"),
30+
("out.2.bias","conv_out.bias"),
31+
]
32+
33+
unet_conversion_map_resnet= [
34+
# (stable-diffusion, HF Diffusers)
35+
("in_layers.0","norm1"),
36+
("in_layers.2","conv1"),
37+
("out_layers.0","norm2"),
38+
("out_layers.3","conv2"),
39+
("emb_layers.1","time_emb_proj"),
40+
("skip_connection","conv_shortcut"),
41+
]
42+
43+
unet_conversion_map_layer= []
44+
# hardcoded number of downblocks and resnets/attentions...
45+
# would need smarter logic for other networks.
46+
foriinrange(4):
47+
# loop over downblocks/upblocks
48+
49+
forjinrange(2):
50+
# loop over resnets/attentions for downblocks
51+
hf_down_res_prefix=f"down_blocks.{i}.resnets.{j}."
52+
sd_down_res_prefix=f"input_blocks.{3*i+j+1}.0."
53+
unet_conversion_map_layer.append((sd_down_res_prefix,hf_down_res_prefix))
54+
55+
ifi<3:
56+
# no attention layers in down_blocks.3
57+
hf_down_atn_prefix=f"down_blocks.{i}.attentions.{j}."
58+
sd_down_atn_prefix=f"input_blocks.{3*i+j+1}.1."
59+
unet_conversion_map_layer.append((sd_down_atn_prefix,hf_down_atn_prefix))
60+
61+
forjinrange(3):
62+
# loop over resnets/attentions for upblocks
63+
hf_up_res_prefix=f"up_blocks.{i}.resnets.{j}."
64+
sd_up_res_prefix=f"output_blocks.{3*i+j}.0."
65+
unet_conversion_map_layer.append((sd_up_res_prefix,hf_up_res_prefix))
66+
67+
ifi>0:
68+
# no attention layers in up_blocks.0
69+
hf_up_atn_prefix=f"up_blocks.{i}.attentions.{j}."
70+
sd_up_atn_prefix=f"output_blocks.{3*i+j}.1."
71+
unet_conversion_map_layer.append((sd_up_atn_prefix,hf_up_atn_prefix))
72+
73+
ifi<3:
74+
# no downsample in down_blocks.3
75+
hf_downsample_prefix=f"down_blocks.{i}.downsamplers.0.conv."
76+
sd_downsample_prefix=f"input_blocks.{3*(i+1)}.0.op."
77+
unet_conversion_map_layer.append((sd_downsample_prefix,hf_downsample_prefix))
78+
79+
# no upsample in up_blocks.3
80+
hf_upsample_prefix=f"up_blocks.{i}.upsamplers.0."
81+
sd_upsample_prefix=f"output_blocks.{3*i+2}.{1ifi==0else2}."
82+
unet_conversion_map_layer.append((sd_upsample_prefix,hf_upsample_prefix))
83+
84+
hf_mid_atn_prefix="mid_block.attentions.0."
85+
sd_mid_atn_prefix="middle_block.1."
86+
unet_conversion_map_layer.append((sd_mid_atn_prefix,hf_mid_atn_prefix))
87+
88+
forjinrange(2):
89+
hf_mid_res_prefix=f"mid_block.resnets.{j}."
90+
sd_mid_res_prefix=f"middle_block.{2*j}."
91+
unet_conversion_map_layer.append((sd_mid_res_prefix,hf_mid_res_prefix))
92+
93+
94+
defconvert_unet_state_dict(unet_state_dict):
95+
# buyer beware: this is a *brittle* function,
96+
# and correct output requires that all of these pieces interact in
97+
# the exact order in which I have arranged them.
98+
mapping= {k:kforkinunet_state_dict.keys()}
99+
forsd_name,hf_nameinunet_conversion_map:
100+
mapping[hf_name]=sd_name
101+
fork,vinmapping.items():
102+
if"resnets"ink:
103+
forsd_part,hf_partinunet_conversion_map_resnet:
104+
v=v.replace(hf_part,sd_part)
105+
mapping[k]=v
106+
fork,vinmapping.items():
107+
forsd_part,hf_partinunet_conversion_map_layer:
108+
v=v.replace(hf_part,sd_part)
109+
mapping[k]=v
110+
new_state_dict= {v:unet_state_dict[k]fork,vinmapping.items()}
111+
returnnew_state_dict
112+
113+
114+
# ================#
115+
# VAE Conversion #
116+
# ================#
117+
118+
vae_conversion_map= [
119+
# (stable-diffusion, HF Diffusers)
120+
("nin_shortcut","conv_shortcut"),
121+
("norm_out","conv_norm_out"),
122+
("mid.attn_1.","mid_block.attentions.0."),
123+
]
124+
125+
foriinrange(4):
126+
# down_blocks have two resnets
127+
forjinrange(2):
128+
hf_down_prefix=f"encoder.down_blocks.{i}.resnets.{j}."
129+
sd_down_prefix=f"encoder.down.{i}.block.{j}."
130+
vae_conversion_map.append((sd_down_prefix,hf_down_prefix))
131+
132+
ifi<3:
133+
hf_downsample_prefix=f"down_blocks.{i}.downsamplers.0."
134+
sd_downsample_prefix=f"down.{i}.downsample."
135+
vae_conversion_map.append((sd_downsample_prefix,hf_downsample_prefix))
136+
137+
hf_upsample_prefix=f"up_blocks.{i}.upsamplers.0."
138+
sd_upsample_prefix=f"up.{3-i}.upsample."
139+
vae_conversion_map.append((sd_upsample_prefix,hf_upsample_prefix))
140+
141+
# up_blocks have three resnets
142+
# also, up blocks in hf are numbered in reverse from sd
143+
forjinrange(3):
144+
hf_up_prefix=f"decoder.up_blocks.{i}.resnets.{j}."
145+
sd_up_prefix=f"decoder.up.{3-i}.block.{j}."
146+
vae_conversion_map.append((sd_up_prefix,hf_up_prefix))
147+
148+
# this part accounts for mid blocks in both the encoder and the decoder
149+
foriinrange(2):
150+
hf_mid_res_prefix=f"mid_block.resnets.{i}."
151+
sd_mid_res_prefix=f"mid.block_{i+1}."
152+
vae_conversion_map.append((sd_mid_res_prefix,hf_mid_res_prefix))
153+
154+
155+
vae_conversion_map_attn= [
156+
# (stable-diffusion, HF Diffusers)
157+
("norm.","group_norm."),
158+
("q.","query."),
159+
("k.","key."),
160+
("v.","value."),
161+
("proj_out.","proj_attn."),
162+
]
163+
164+
165+
defreshape_weight_for_sd(w):
166+
# convert HF linear weights to SD conv2d weights
167+
returnw.reshape(*w.shape,1,1)
168+
169+
170+
defconvert_vae_state_dict(vae_state_dict):
171+
mapping= {k:kforkinvae_state_dict.keys()}
172+
fork,vinmapping.items():
173+
forsd_part,hf_partinvae_conversion_map:
174+
v=v.replace(hf_part,sd_part)
175+
mapping[k]=v
176+
fork,vinmapping.items():
177+
if"attentions"ink:
178+
forsd_part,hf_partinvae_conversion_map_attn:
179+
v=v.replace(hf_part,sd_part)
180+
mapping[k]=v
181+
new_state_dict= {v:vae_state_dict[k]fork,vinmapping.items()}
182+
weights_to_convert= ["q","k","v","proj_out"]
183+
fork,vinnew_state_dict.items():
184+
forweight_nameinweights_to_convert:
185+
iff"mid.attn_1.{weight_name}.weight"ink:
186+
print(f"Reshaping{k} for SD format")
187+
new_state_dict[k]=reshape_weight_for_sd(v)
188+
returnnew_state_dict
189+
190+
191+
# =========================#
192+
# Text Encoder Conversion #
193+
# =========================#
194+
195+
196+
textenc_conversion_lst= [
197+
# (stable-diffusion, HF Diffusers)
198+
("resblocks.","text_model.encoder.layers."),
199+
("ln_1","layer_norm1"),
200+
("ln_2","layer_norm2"),
201+
(".c_fc.",".fc1."),
202+
(".c_proj.",".fc2."),
203+
(".attn",".self_attn"),
204+
("ln_final.","transformer.text_model.final_layer_norm."),
205+
("token_embedding.weight","transformer.text_model.embeddings.token_embedding.weight"),
206+
("positional_embedding","transformer.text_model.embeddings.position_embedding.weight"),
207+
]
208+
protected= {re.escape(x[1]):x[0]forxintextenc_conversion_lst}
209+
textenc_pattern=re.compile("|".join(protected.keys()))
210+
211+
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
212+
code2idx= {"q":0,"k":1,"v":2}
213+
214+
215+
defconvert_text_enc_state_dict_v20(text_enc_dict):
216+
new_state_dict= {}
217+
capture_qkv_weight= {}
218+
capture_qkv_bias= {}
219+
fork,vintext_enc_dict.items():
220+
if (
221+
k.endswith(".self_attn.q_proj.weight")
222+
ork.endswith(".self_attn.k_proj.weight")
223+
ork.endswith(".self_attn.v_proj.weight")
224+
):
225+
k_pre=k[:-len(".q_proj.weight")]
226+
k_code=k[-len("q_proj.weight")]
227+
ifk_prenotincapture_qkv_weight:
228+
capture_qkv_weight[k_pre]= [None,None,None]
229+
capture_qkv_weight[k_pre][code2idx[k_code]]=v
230+
continue
231+
232+
if (
233+
k.endswith(".self_attn.q_proj.bias")
234+
ork.endswith(".self_attn.k_proj.bias")
235+
ork.endswith(".self_attn.v_proj.bias")
236+
):
237+
k_pre=k[:-len(".q_proj.bias")]
238+
k_code=k[-len("q_proj.bias")]
239+
ifk_prenotincapture_qkv_bias:
240+
capture_qkv_bias[k_pre]= [None,None,None]
241+
capture_qkv_bias[k_pre][code2idx[k_code]]=v
242+
continue
243+
244+
relabelled_key=textenc_pattern.sub(lambdam:protected[re.escape(m.group(0))],k)
245+
new_state_dict[relabelled_key]=v
246+
247+
fork_pre,tensorsincapture_qkv_weight.items():
248+
ifNoneintensors:
249+
raiseException("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
250+
relabelled_key=textenc_pattern.sub(lambdam:protected[re.escape(m.group(0))],k_pre)
251+
new_state_dict[relabelled_key+".in_proj_weight"]=torch.cat(tensors)
252+
253+
fork_pre,tensorsincapture_qkv_bias.items():
254+
ifNoneintensors:
255+
raiseException("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
256+
relabelled_key=textenc_pattern.sub(lambdam:protected[re.escape(m.group(0))],k_pre)
257+
new_state_dict[relabelled_key+".in_proj_bias"]=torch.cat(tensors)
258+
259+
returnnew_state_dict
260+
261+
262+
defconvert_text_enc_state_dict(text_enc_dict):
263+
returntext_enc_dict
264+
265+
266+
if__name__=="__main__":
267+
parser=argparse.ArgumentParser()
268+
269+
parser.add_argument("--model_path",default=None,type=str,required=True,help="Path to the model to convert.")
270+
parser.add_argument("--checkpoint_path",default=None,type=str,required=True,help="Path to the output model.")
271+
parser.add_argument("--half",action="store_true",help="Save weights in half precision.")
272+
parser.add_argument(
273+
"--use_safetensors",action="store_true",help="Save weights use safetensors, default is ckpt."
274+
)
275+
276+
args=parser.parse_args()
277+
278+
assertargs.model_pathisnotNone,"Must provide a model path!"
279+
280+
assertargs.checkpoint_pathisnotNone,"Must provide a checkpoint path!"
281+
282+
# Path for safetensors
283+
unet_path=osp.join(args.model_path,"unet","diffusion_pytorch_model.safetensors")
284+
vae_path=osp.join(args.model_path,"vae","diffusion_pytorch_model.safetensors")
285+
text_enc_path=osp.join(args.model_path,"text_encoder","model.safetensors")
286+
287+
# Load models from safetensors if it exists, if it doesn't pytorch
288+
ifosp.exists(unet_path):
289+
unet_state_dict=load_file(unet_path,device="cpu")
290+
else:
291+
unet_path=osp.join(args.model_path,"unet","diffusion_pytorch_model.bin")
292+
unet_state_dict=torch.load(unet_path,map_location="cpu")
293+
294+
ifosp.exists(vae_path):
295+
vae_state_dict=load_file(vae_path,device="cpu")
296+
else:
297+
vae_path=osp.join(args.model_path,"vae","diffusion_pytorch_model.bin")
298+
vae_state_dict=torch.load(vae_path,map_location="cpu")
299+
300+
ifosp.exists(text_enc_path):
301+
text_enc_dict=load_file(text_enc_path,device="cpu")
302+
else:
303+
text_enc_path=osp.join(args.model_path,"text_encoder","pytorch_model.bin")
304+
text_enc_dict=torch.load(text_enc_path,map_location="cpu")
305+
306+
# Convert the UNet model
307+
unet_state_dict=convert_unet_state_dict(unet_state_dict)
308+
unet_state_dict= {"model.diffusion_model."+k:vfork,vinunet_state_dict.items()}
309+
310+
# Convert the VAE model
311+
vae_state_dict=convert_vae_state_dict(vae_state_dict)
312+
vae_state_dict= {"first_stage_model."+k:vfork,vinvae_state_dict.items()}
313+
314+
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
315+
is_v20_model="text_model.encoder.layers.22.layer_norm2.bias"intext_enc_dict
316+
317+
ifis_v20_model:
318+
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
319+
text_enc_dict= {"transformer."+k:vfork,vintext_enc_dict.items()}
320+
text_enc_dict=convert_text_enc_state_dict_v20(text_enc_dict)
321+
text_enc_dict= {"cond_stage_model.model."+k:vfork,vintext_enc_dict.items()}
322+
else:
323+
text_enc_dict=convert_text_enc_state_dict(text_enc_dict)
324+
text_enc_dict= {"cond_stage_model.transformer."+k:vfork,vintext_enc_dict.items()}
325+
326+
# Put together new checkpoint
327+
state_dict= {**unet_state_dict,**vae_state_dict,**text_enc_dict}
328+
ifargs.half:
329+
state_dict= {k:v.half()fork,vinstate_dict.items()}
330+
331+
ifargs.use_safetensors:
332+
save_file(state_dict,args.checkpoint_path)
333+
else:
334+
state_dict= {"state_dict":state_dict}
335+
torch.save(state_dict,args.checkpoint_path)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp