Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit7e5cdaa

Browse files
committed
SD3 lora support
1 parentb2453d2 commit7e5cdaa

File tree

6 files changed

+106
-24
lines changed

6 files changed

+106
-24
lines changed

‎extensions-builtin/Lora/network.py‎

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
importtorch.nn.functionalasF
88

99
frommodulesimportsd_models,cache,errors,hashes,shared
10+
importmodules.models.sd3.mmdit
1011

1112
NetworkWeights=namedtuple('NetworkWeights', ['network_key','sd_key','w','sd_module'])
1213

@@ -114,7 +115,10 @@ def __init__(self, net: Network, weights: NetworkWeights):
114115
self.sd_key=weights.sd_key
115116
self.sd_module=weights.sd_module
116117

117-
ifhasattr(self.sd_module,'weight'):
118+
ifisinstance(self.sd_module,modules.models.sd3.mmdit.QkvLinear):
119+
s=self.sd_module.weight.shape
120+
self.shape= (s[0]//3,s[1])
121+
elifhasattr(self.sd_module,'weight'):
118122
self.shape=self.sd_module.weight.shape
119123
elifisinstance(self.sd_module,nn.MultiheadAttention):
120124
# For now, only self-attn use Pytorch's MHA

‎extensions-builtin/Lora/network_lora.py‎

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
importtorch
22

33
importlyco_helpers
4+
importmodules.models.sd3.mmdit
45
importnetwork
56
frommodulesimportdevices
67

@@ -10,6 +11,13 @@ def create_module(self, net: network.Network, weights: network.NetworkWeights):
1011
ifall(xinweights.wforxin ["lora_up.weight","lora_down.weight"]):
1112
returnNetworkModuleLora(net,weights)
1213

14+
ifall(xinweights.wforxin ["lora_A.weight","lora_B.weight"]):
15+
w=weights.w.copy()
16+
weights.w.clear()
17+
weights.w.update({"lora_up.weight":w["lora_B.weight"],"lora_down.weight":w["lora_A.weight"]})
18+
19+
returnNetworkModuleLora(net,weights)
20+
1321
returnNone
1422

1523

@@ -29,7 +37,7 @@ def create_module(self, weights, key, none_ok=False):
2937
ifweightisNoneandnone_ok:
3038
returnNone
3139

32-
is_linear=type(self.sd_module)in [torch.nn.Linear,torch.nn.modules.linear.NonDynamicallyQuantizableLinear,torch.nn.MultiheadAttention]
40+
is_linear=type(self.sd_module)in [torch.nn.Linear,torch.nn.modules.linear.NonDynamicallyQuantizableLinear,torch.nn.MultiheadAttention,modules.models.sd3.mmdit.QkvLinear]
3341
is_conv=type(self.sd_module)in [torch.nn.Conv2d]
3442

3543
ifis_linear:

‎extensions-builtin/Lora/networks.py‎

Lines changed: 75 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
frommodulesimportshared,devices,sd_models,errors,scripts,sd_hijack
2222
importmodules.textual_inversion.textual_inversionastextual_inversion
23+
importmodules.models.sd3.mmdit
2324

2425
fromlora_loggerimportlogger
2526

@@ -166,12 +167,26 @@ def load_network(name, network_on_disk):
166167

167168
keys_failed_to_match= {}
168169
is_sd2='model_transformer_resblocks'inshared.sd_model.network_layer_mapping
170+
ifhasattr(shared.sd_model,'diffusers_weight_map'):
171+
diffusers_weight_map=shared.sd_model.diffusers_weight_map
172+
elifhasattr(shared.sd_model,'diffusers_weight_mapping'):
173+
diffusers_weight_map= {}
174+
fork,vinshared.sd_model.diffusers_weight_mapping():
175+
diffusers_weight_map[k]=v
176+
shared.sd_model.diffusers_weight_map=diffusers_weight_map
177+
else:
178+
diffusers_weight_map=None
169179

170180
matched_networks= {}
171181
bundle_embeddings= {}
172182

173183
forkey_network,weightinsd.items():
174-
key_network_without_network_parts,_,network_part=key_network.partition(".")
184+
185+
ifdiffusers_weight_map:
186+
key_network_without_network_parts,network_name,network_weight=key_network.rsplit(".",2)
187+
network_part=network_name+'.'+network_weight
188+
else:
189+
key_network_without_network_parts,_,network_part=key_network.partition(".")
175190

176191
ifkey_network_without_network_parts=="bundle_emb":
177192
emb_name,vec_name=network_part.split(".",1)
@@ -183,7 +198,11 @@ def load_network(name, network_on_disk):
183198
emb_dict[vec_name]=weight
184199
bundle_embeddings[emb_name]=emb_dict
185200

186-
key=convert_diffusers_name_to_compvis(key_network_without_network_parts,is_sd2)
201+
ifdiffusers_weight_map:
202+
key=diffusers_weight_map.get(key_network_without_network_parts,key_network_without_network_parts)
203+
else:
204+
key=convert_diffusers_name_to_compvis(key_network_without_network_parts,is_sd2)
205+
187206
sd_module=shared.sd_model.network_layer_mapping.get(key,None)
188207

189208
ifsd_moduleisNone:
@@ -347,6 +366,28 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
347366
purge_networks_from_memory()
348367

349368

369+
defallowed_layer_without_weight(layer):
370+
ifisinstance(layer,torch.nn.LayerNorm)andnotlayer.elementwise_affine:
371+
returnTrue
372+
373+
returnFalse
374+
375+
376+
defstore_weights_backup(weight):
377+
ifweightisNone:
378+
returnNone
379+
380+
returnweight.to(devices.cpu,copy=True)
381+
382+
383+
defrestore_weights_backup(obj,field,weight):
384+
ifweightisNone:
385+
setattr(obj,field,None)
386+
return
387+
388+
getattr(obj,field).copy_(weight)
389+
390+
350391
defnetwork_restore_weights_from_backup(self:Union[torch.nn.Conv2d,torch.nn.Linear,torch.nn.GroupNorm,torch.nn.LayerNorm,torch.nn.MultiheadAttention]):
351392
weights_backup=getattr(self,"network_weights_backup",None)
352393
bias_backup=getattr(self,"network_bias_backup",None)
@@ -356,21 +397,15 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
356397

357398
ifweights_backupisnotNone:
358399
ifisinstance(self,torch.nn.MultiheadAttention):
359-
self.in_proj_weight.copy_(weights_backup[0])
360-
self.out_proj.weight.copy_(weights_backup[1])
400+
restore_weights_backup(self,'in_proj_weight',weights_backup[0])
401+
restore_weights_backup(self.out_proj,'weight',weights_backup[0])
361402
else:
362-
self.weight.copy_(weights_backup)
403+
restore_weights_backup(self,'weight',weights_backup)
363404

364-
ifbias_backupisnotNone:
365-
ifisinstance(self,torch.nn.MultiheadAttention):
366-
self.out_proj.bias.copy_(bias_backup)
367-
else:
368-
self.bias.copy_(bias_backup)
405+
ifisinstance(self,torch.nn.MultiheadAttention):
406+
restore_weights_backup(self.out_proj,'bias',bias_backup)
369407
else:
370-
ifisinstance(self,torch.nn.MultiheadAttention):
371-
self.out_proj.bias=None
372-
else:
373-
self.bias=None
408+
restore_weights_backup(self,'bias',bias_backup)
374409

375410

376411
defnetwork_apply_weights(self:Union[torch.nn.Conv2d,torch.nn.Linear,torch.nn.GroupNorm,torch.nn.LayerNorm,torch.nn.MultiheadAttention]):
@@ -389,37 +424,38 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
389424

390425
weights_backup=getattr(self,"network_weights_backup",None)
391426
ifweights_backupisNoneandwanted_names!= ():
392-
ifcurrent_names!= ():
393-
raiseRuntimeError("no backup weights found and current weights are not unchanged")
427+
ifcurrent_names!= ()andnotallowed_layer_without_weight(self):
428+
raiseRuntimeError(f"{network_layer_name} -no backup weights found and current weights are not unchanged")
394429

395430
ifisinstance(self,torch.nn.MultiheadAttention):
396-
weights_backup= (self.in_proj_weight.to(devices.cpu,copy=True),self.out_proj.weight.to(devices.cpu,copy=True))
431+
weights_backup= (store_weights_backup(self.in_proj_weight),store_weights_backup(self.out_proj.weight))
397432
else:
398-
weights_backup=self.weight.to(devices.cpu,copy=True)
433+
weights_backup=store_weights_backup(self.weight)
399434

400435
self.network_weights_backup=weights_backup
401436

402437
bias_backup=getattr(self,"network_bias_backup",None)
403438
ifbias_backupisNoneandwanted_names!= ():
404439
ifisinstance(self,torch.nn.MultiheadAttention)andself.out_proj.biasisnotNone:
405-
bias_backup=self.out_proj.bias.to(devices.cpu,copy=True)
440+
bias_backup=store_weights_backup(self.out_proj)
406441
elifgetattr(self,'bias',None)isnotNone:
407-
bias_backup=self.bias.to(devices.cpu,copy=True)
442+
bias_backup=store_weights_backup(self.bias)
408443
else:
409444
bias_backup=None
410445

411446
# Unlike weight which always has value, some modules don't have bias.
412447
# Only report if bias is not None and current bias are not unchanged.
413448
ifbias_backupisnotNoneandcurrent_names!= ():
414449
raiseRuntimeError("no backup bias found and current bias are not unchanged")
450+
415451
self.network_bias_backup=bias_backup
416452

417453
ifcurrent_names!=wanted_names:
418454
network_restore_weights_from_backup(self)
419455

420456
fornetinloaded_networks:
421457
module=net.modules.get(network_layer_name,None)
422-
ifmoduleisnotNoneandhasattr(self,'weight'):
458+
ifmoduleisnotNoneandhasattr(self,'weight')andnotisinstance(module,modules.models.sd3.mmdit.QkvLinear):
423459
try:
424460
withtorch.no_grad():
425461
ifgetattr(self,'fp16_weight',None)isNone:
@@ -479,6 +515,24 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
479515

480516
continue
481517

518+
ifisinstance(self,modules.models.sd3.mmdit.QkvLinear)andmodule_qandmodule_kandmodule_v:
519+
try:
520+
withtorch.no_grad():
521+
# Send "real" orig_weight into MHA's lora module
522+
qw,kw,vw=self.weight.chunk(3,0)
523+
updown_q,_=module_q.calc_updown(qw)
524+
updown_k,_=module_k.calc_updown(kw)
525+
updown_v,_=module_v.calc_updown(vw)
526+
delqw,kw,vw
527+
updown_qkv=torch.vstack([updown_q,updown_k,updown_v])
528+
self.weight+=updown_qkv
529+
530+
exceptRuntimeErrorase:
531+
logging.debug(f"Network{net.name} layer{network_layer_name}:{e}")
532+
extra_network_lora.errors[net.name]=extra_network_lora.errors.get(net.name,0)+1
533+
534+
continue
535+
482536
ifmoduleisNone:
483537
continue
484538

‎modules/models/sd3/mmdit.py‎

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
175175
#################################################################################
176176

177177

178+
classQkvLinear(torch.nn.Linear):
179+
pass
180+
178181
defsplit_qkv(qkv,head_dim):
179182
qkv=qkv.reshape(qkv.shape[0],qkv.shape[1],3,-1,head_dim).movedim(2,0)
180183
returnqkv[0],qkv[1],qkv[2]
@@ -202,7 +205,7 @@ def __init__(
202205
self.num_heads=num_heads
203206
self.head_dim=dim//num_heads
204207

205-
self.qkv=nn.Linear(dim,dim*3,bias=qkv_bias,dtype=dtype,device=device)
208+
self.qkv=QkvLinear(dim,dim*3,bias=qkv_bias,dtype=dtype,device=device)
206209
ifnotpre_only:
207210
self.proj=nn.Linear(dim,dim,dtype=dtype,device=device)
208211
assertattn_modeinself.ATTENTION_MODES

‎modules/models/sd3/sd3_impls.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(self, shift=1.0, device=None, dtype=torch.float32, state_dict=None,
6767
}
6868
self.diffusion_model=MMDiT(input_size=None,pos_embed_scaling_factor=None,pos_embed_offset=None,pos_embed_max_size=pos_embed_max_size,patch_size=patch_size,in_channels=16,depth=depth,num_patches=num_patches,adm_in_channels=adm_in_channels,context_embedder_config=context_embedder_config,device=device,dtype=dtype)
6969
self.model_sampling=ModelSamplingDiscreteFlow(shift=shift)
70+
self.depth=depth
7071

7172
defapply_model(self,x,sigma,c_crossattn=None,y=None):
7273
dtype=self.get_dtype()

‎modules/models/sd3/sd3_model.py‎

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,15 @@ def add_noise_to_latent(self, x, noise, amount):
8282

8383
deffix_dimensions(self,width,height):
8484
returnwidth//16*16,height//16*16
85+
86+
defdiffusers_weight_mapping(self):
87+
foriinrange(self.model.depth):
88+
yieldf"transformer.transformer_blocks.{i}.attn.to_q",f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_q_proj"
89+
yieldf"transformer.transformer_blocks.{i}.attn.to_k",f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_k_proj"
90+
yieldf"transformer.transformer_blocks.{i}.attn.to_v",f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_v_proj"
91+
yieldf"transformer.transformer_blocks.{i}.attn.to_out.0",f"diffusion_model_joint_blocks_{i}_x_block_attn_proj"
92+
93+
yieldf"transformer.transformer_blocks.{i}.attn.add_q_proj",f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_q_proj"
94+
yieldf"transformer.transformer_blocks.{i}.attn.add_k_proj",f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_k_proj"
95+
yieldf"transformer.transformer_blocks.{i}.attn.add_v_proj",f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_v_proj"
96+
yieldf"transformer.transformer_blocks.{i}.attn.add_out_proj.0",f"diffusion_model_joint_blocks_{i}_context_block_attn_proj"

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp