2020
2121from modules import shared ,devices ,sd_models ,errors ,scripts ,sd_hijack
2222import modules .textual_inversion .textual_inversion as textual_inversion
23+ import modules .models .sd3 .mmdit
2324
2425from lora_logger import logger
2526
@@ -166,12 +167,26 @@ def load_network(name, network_on_disk):
166167
167168keys_failed_to_match = {}
168169is_sd2 = 'model_transformer_resblocks' in shared .sd_model .network_layer_mapping
170+ if hasattr (shared .sd_model ,'diffusers_weight_map' ):
171+ diffusers_weight_map = shared .sd_model .diffusers_weight_map
172+ elif hasattr (shared .sd_model ,'diffusers_weight_mapping' ):
173+ diffusers_weight_map = {}
174+ for k ,v in shared .sd_model .diffusers_weight_mapping ():
175+ diffusers_weight_map [k ]= v
176+ shared .sd_model .diffusers_weight_map = diffusers_weight_map
177+ else :
178+ diffusers_weight_map = None
169179
170180matched_networks = {}
171181bundle_embeddings = {}
172182
173183for key_network ,weight in sd .items ():
174- key_network_without_network_parts ,_ ,network_part = key_network .partition ("." )
184+
185+ if diffusers_weight_map :
186+ key_network_without_network_parts ,network_name ,network_weight = key_network .rsplit ("." ,2 )
187+ network_part = network_name + '.' + network_weight
188+ else :
189+ key_network_without_network_parts ,_ ,network_part = key_network .partition ("." )
175190
176191if key_network_without_network_parts == "bundle_emb" :
177192emb_name ,vec_name = network_part .split ("." ,1 )
@@ -183,7 +198,11 @@ def load_network(name, network_on_disk):
183198emb_dict [vec_name ]= weight
184199bundle_embeddings [emb_name ]= emb_dict
185200
186- key = convert_diffusers_name_to_compvis (key_network_without_network_parts ,is_sd2 )
201+ if diffusers_weight_map :
202+ key = diffusers_weight_map .get (key_network_without_network_parts ,key_network_without_network_parts )
203+ else :
204+ key = convert_diffusers_name_to_compvis (key_network_without_network_parts ,is_sd2 )
205+
187206sd_module = shared .sd_model .network_layer_mapping .get (key ,None )
188207
189208if sd_module is None :
@@ -347,6 +366,28 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
347366purge_networks_from_memory ()
348367
349368
369+ def allowed_layer_without_weight (layer ):
370+ if isinstance (layer ,torch .nn .LayerNorm )and not layer .elementwise_affine :
371+ return True
372+
373+ return False
374+
375+
376+ def store_weights_backup (weight ):
377+ if weight is None :
378+ return None
379+
380+ return weight .to (devices .cpu ,copy = True )
381+
382+
383+ def restore_weights_backup (obj ,field ,weight ):
384+ if weight is None :
385+ setattr (obj ,field ,None )
386+ return
387+
388+ getattr (obj ,field ).copy_ (weight )
389+
390+
350391def network_restore_weights_from_backup (self :Union [torch .nn .Conv2d ,torch .nn .Linear ,torch .nn .GroupNorm ,torch .nn .LayerNorm ,torch .nn .MultiheadAttention ]):
351392weights_backup = getattr (self ,"network_weights_backup" ,None )
352393bias_backup = getattr (self ,"network_bias_backup" ,None )
@@ -356,21 +397,15 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
356397
357398if weights_backup is not None :
358399if isinstance (self ,torch .nn .MultiheadAttention ):
359- self . in_proj_weight . copy_ ( weights_backup [0 ])
360- self .out_proj . weight . copy_ ( weights_backup [1 ])
400+ restore_weights_backup ( self , ' in_proj_weight' , weights_backup [0 ])
401+ restore_weights_backup ( self .out_proj , ' weight' , weights_backup [0 ])
361402else :
362- self . weight . copy_ ( weights_backup )
403+ restore_weights_backup ( self , ' weight' , weights_backup )
363404
364- if bias_backup is not None :
365- if isinstance (self ,torch .nn .MultiheadAttention ):
366- self .out_proj .bias .copy_ (bias_backup )
367- else :
368- self .bias .copy_ (bias_backup )
405+ if isinstance (self ,torch .nn .MultiheadAttention ):
406+ restore_weights_backup (self .out_proj ,'bias' ,bias_backup )
369407else :
370- if isinstance (self ,torch .nn .MultiheadAttention ):
371- self .out_proj .bias = None
372- else :
373- self .bias = None
408+ restore_weights_backup (self ,'bias' ,bias_backup )
374409
375410
376411def network_apply_weights (self :Union [torch .nn .Conv2d ,torch .nn .Linear ,torch .nn .GroupNorm ,torch .nn .LayerNorm ,torch .nn .MultiheadAttention ]):
@@ -389,37 +424,38 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
389424
390425weights_backup = getattr (self ,"network_weights_backup" ,None )
391426if weights_backup is None and wanted_names != ():
392- if current_names != ():
393- raise RuntimeError (" no backup weights found and current weights are not unchanged" )
427+ if current_names != ()and not allowed_layer_without_weight ( self ) :
428+ raise RuntimeError (f" { network_layer_name } - no backup weights found and current weights are not unchanged" )
394429
395430if isinstance (self ,torch .nn .MultiheadAttention ):
396- weights_backup = (self .in_proj_weight . to ( devices . cpu , copy = True ),self .out_proj .weight . to ( devices . cpu , copy = True ))
431+ weights_backup = (store_weights_backup ( self .in_proj_weight ),store_weights_backup ( self .out_proj .weight ))
397432else :
398- weights_backup = self .weight . to ( devices . cpu , copy = True )
433+ weights_backup = store_weights_backup ( self .weight )
399434
400435self .network_weights_backup = weights_backup
401436
402437bias_backup = getattr (self ,"network_bias_backup" ,None )
403438if bias_backup is None and wanted_names != ():
404439if isinstance (self ,torch .nn .MultiheadAttention )and self .out_proj .bias is not None :
405- bias_backup = self .out_proj . bias . to ( devices . cpu , copy = True )
440+ bias_backup = store_weights_backup ( self .out_proj )
406441elif getattr (self ,'bias' ,None )is not None :
407- bias_backup = self .bias . to ( devices . cpu , copy = True )
442+ bias_backup = store_weights_backup ( self .bias )
408443else :
409444bias_backup = None
410445
411446# Unlike weight which always has value, some modules don't have bias.
412447# Only report if bias is not None and current bias are not unchanged.
413448if bias_backup is not None and current_names != ():
414449raise RuntimeError ("no backup bias found and current bias are not unchanged" )
450+
415451self .network_bias_backup = bias_backup
416452
417453if current_names != wanted_names :
418454network_restore_weights_from_backup (self )
419455
420456for net in loaded_networks :
421457module = net .modules .get (network_layer_name ,None )
422- if module is not None and hasattr (self ,'weight' ):
458+ if module is not None and hasattr (self ,'weight' )and not isinstance ( module , modules . models . sd3 . mmdit . QkvLinear ) :
423459try :
424460with torch .no_grad ():
425461if getattr (self ,'fp16_weight' ,None )is None :
@@ -479,6 +515,24 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
479515
480516continue
481517
518+ if isinstance (self ,modules .models .sd3 .mmdit .QkvLinear )and module_q and module_k and module_v :
519+ try :
520+ with torch .no_grad ():
521+ # Send "real" orig_weight into MHA's lora module
522+ qw ,kw ,vw = self .weight .chunk (3 ,0 )
523+ updown_q ,_ = module_q .calc_updown (qw )
524+ updown_k ,_ = module_k .calc_updown (kw )
525+ updown_v ,_ = module_v .calc_updown (vw )
526+ del qw ,kw ,vw
527+ updown_qkv = torch .vstack ([updown_q ,updown_k ,updown_v ])
528+ self .weight += updown_qkv
529+
530+ except RuntimeError as e :
531+ logging .debug (f"Network{ net .name } layer{ network_layer_name } :{ e } " )
532+ extra_network_lora .errors [net .name ]= extra_network_lora .errors .get (net .name ,0 )+ 1
533+
534+ continue
535+
482536if module is None :
483537continue
484538