Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitafea457

Browse files
bssrdfbssrdfleejet
authored
fix: support more SDXL LoRA names (leejet#216)
* apply pmid lora only once for multiple txt2img calls* add better support for SDXL LoRA* fix for some sdxl lora, like lcm-lora-xl---------Co-authored-by: bssrdf <bssrdf@gmail.com>Co-authored-by: leejet <leejet714@gmail.com>
1 parent646e776 commitafea457

File tree

4 files changed

+33
-20
lines changed

4 files changed

+33
-20
lines changed

‎examples/cli/main.cpp‎

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -686,27 +686,26 @@ int main(int argc, const char* argv[]) {
686686
// Resize input image ...
687687
if (params.height %64 !=0 || params.width %64 !=0) {
688688
int resized_height = params.height + (64 - params.height %64);
689-
int resized_width = params.width + (64 - params.width %64);
689+
int resized_width= params.width + (64 - params.width %64);
690690

691-
uint8_t *resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width *3);
691+
uint8_t*resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width *3);
692692
if (resized_image_buffer ==NULL) {
693693
fprintf(stderr,"error: allocate memory for resize input image\n");
694694
free(input_image_buffer);
695695
return1;
696696
}
697-
stbir_resize(input_image_buffer, params.width, params.height,0,
698-
resized_image_buffer, resized_width, resized_height,0, STBIR_TYPE_UINT8,
699-
3/*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE,0,
700-
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
701-
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
702-
STBIR_COLORSPACE_SRGB,nullptr
703-
);
697+
stbir_resize(input_image_buffer, params.width, params.height,0,
698+
resized_image_buffer, resized_width, resized_height,0, STBIR_TYPE_UINT8,
699+
3/*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE,0,
700+
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
701+
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
702+
STBIR_COLORSPACE_SRGB,nullptr);
704703

705704
// Save resized result
706705
free(input_image_buffer);
707706
input_image_buffer = resized_image_buffer;
708-
params.height = resized_height;
709-
params.width = resized_width;
707+
params.height= resized_height;
708+
params.width= resized_width;
710709
}
711710
}
712711

‎lora.hpp‎

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ struct LoraModel : public GGMLModule {
1111
std::string file_path;
1212
ModelLoader model_loader;
1313
bool load_failed =false;
14-
bool applied =false;
14+
bool applied=false;
1515

1616
LoraModel(ggml_backend_t backend,
1717
ggml_type wtype,
@@ -91,10 +91,15 @@ struct LoraModel : public GGMLModule {
9191
k_tensor = k_tensor.substr(0, k_pos);
9292
replace_all_chars(k_tensor,'.','_');
9393
// LOG_DEBUG("k_tensor %s", k_tensor.c_str());
94-
if (k_tensor =="model_diffusion_model_output_blocks_2_2_conv") {// fix for SDXL
95-
k_tensor ="model_diffusion_model_output_blocks_2_1_conv";
94+
std::string lora_up_name ="lora." + k_tensor +".lora_up.weight";
95+
if (lora_tensors.find(lora_up_name) == lora_tensors.end()) {
96+
if (k_tensor =="model_diffusion_model_output_blocks_2_2_conv") {
97+
// fix for some sdxl lora, like lcm-lora-xl
98+
k_tensor ="model_diffusion_model_output_blocks_2_1_conv";
99+
lora_up_name ="lora." + k_tensor +".lora_up.weight";
100+
}
96101
}
97-
std::string lora_up_name ="lora." + k_tensor +".lora_up.weight";
102+
98103
std::string lora_down_name ="lora." + k_tensor +".lora_down.weight";
99104
std::string alpha_name ="lora." + k_tensor +".alpha";
100105
std::string scale_name ="lora." + k_tensor +".scale";

‎model.cpp‎

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ std::string convert_sdxl_lora_name(std::string tensor_name) {
211211
{"unet","model_diffusion_model"},
212212
{"te2","cond_stage_model_1_transformer"},
213213
{"te1","cond_stage_model_transformer"},
214+
{"text_encoder_2","cond_stage_model_1_transformer"},
215+
{"text_encoder","cond_stage_model_transformer"},
214216
};
215217
for (auto& pair_i : sdxl_lora_name_lookup) {
216218
if (tensor_name.compare(0, pair_i.first.length(), pair_i.first) ==0) {
@@ -446,18 +448,25 @@ std::string convert_tensor_name(const std::string& name) {
446448
}else {
447449
new_name = name;
448450
}
449-
}elseif (contains(name,"lora_up") ||contains(name,"lora_down") ||contains(name,"lora.up") ||contains(name,"lora.down")) {
451+
}elseif (contains(name,"lora_up") ||contains(name,"lora_down") ||
452+
contains(name,"lora.up") ||contains(name,"lora.down") ||
453+
contains(name,"lora_linear")) {
450454
size_t pos = new_name.find(".processor");
451455
if (pos != std::string::npos) {
452456
new_name.replace(pos,strlen(".processor"),"");
453457
}
454-
pos = new_name.find_last_of('_');
458+
pos = new_name.rfind("lora");
455459
if (pos != std::string::npos) {
456-
std::string name_without_network_parts = new_name.substr(0, pos);
457-
std::string network_part = new_name.substr(pos +1);
460+
std::string name_without_network_parts = new_name.substr(0, pos -1);
461+
std::string network_part = new_name.substr(pos);
458462
// LOG_DEBUG("%s %s", name_without_network_parts.c_str(), network_part.c_str());
459463
std::string new_key =convert_diffusers_name_to_compvis(name_without_network_parts,'.');
464+
new_key =convert_sdxl_lora_name(new_key);
460465
replace_all_chars(new_key,'.','_');
466+
size_t npos = network_part.rfind("_linear_layer");
467+
if (npos != std::string::npos) {
468+
network_part.replace(npos,strlen("_linear_layer"),"");
469+
}
461470
if (starts_with(network_part,"lora.")) {
462471
network_part ="lora_" + network_part.substr(5);
463472
}

‎stable-diffusion.cpp‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1610,7 +1610,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
16101610
if (sd_ctx->sd->stacked_id && !sd_ctx->sd->pmid_lora->applied) {
16111611
t0 =ggml_time_ms();
16121612
sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->n_threads);
1613-
t1 =ggml_time_ms();
1613+
t1=ggml_time_ms();
16141614
sd_ctx->sd->pmid_lora->applied =true;
16151615
LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) *1.0f /1000);
16161616
if (sd_ctx->sd->free_params_immediately) {

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp