Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit8f05f5b

Browse files
rmatifleejet
andauthored
feat: add support for custom scheduler (leejet#694)
---------Co-authored-by: leejet <leejet714@gmail.com>
1 parent15d0f82 commit8f05f5b

File tree

6 files changed

+101
-9
lines changed

6 files changed

+101
-9
lines changed

‎examples/cli/README.md‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ Generation Options:
121121
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise
122122
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm],
123123
default: discrete
124+
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
124125
--skip-layers layers to skip for SLG steps (default: [7,8,9])
125126
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
126127
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)

‎examples/cli/main.cpp‎

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,15 @@ std::string get_image_params(const SDCliParams& cli_params, const SDContextParam
258258
parameter_string +="Sampler RNG:" +std::string(sd_rng_type_name(ctx_params.sampler_rng_type)) +",";
259259
}
260260
parameter_string +="Sampler:" +std::string(sd_sample_method_name(gen_params.sample_params.sample_method));
261-
if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) {
261+
if (!gen_params.custom_sigmas.empty()) {
262+
parameter_string +=", Custom Sigmas: [";
263+
for (size_t i =0; i < gen_params.custom_sigmas.size(); ++i) {
264+
std::ostringstream oss;
265+
oss << std::fixed <<std::setprecision(4) << gen_params.custom_sigmas[i];
266+
parameter_string += oss.str() + (i == gen_params.custom_sigmas.size() -1 ?"" :",");
267+
}
268+
parameter_string +="]";
269+
}elseif (gen_params.sample_params.scheduler != SCHEDULER_COUNT) {// Only show schedule if not using custom sigmas
262270
parameter_string +="" +std::string(sd_scheduler_name(gen_params.sample_params.scheduler));
263271
}
264272
parameter_string +=",";
@@ -806,4 +814,4 @@ int main(int argc, const char* argv[]) {
806814
release_all_resources();
807815

808816
return0;
809-
}
817+
}

‎examples/common/common.hpp‎

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,8 @@ struct SDGenerationParams {
883883
std::vector<int> high_noise_skip_layers = {7,8,9};
884884
sd_sample_params_t high_noise_sample_params;
885885

886+
std::vector<float> custom_sigmas;
887+
886888
std::string easycache_option;
887889
sd_easycache_params_t easycache_params;
888890

@@ -1201,6 +1203,43 @@ struct SDGenerationParams {
12011203
return1;
12021204
};
12031205

1206+
auto on_sigmas_arg = [&](int argc,constchar** argv,int index) {
1207+
if (++index >= argc) {
1208+
return -1;
1209+
}
1210+
std::string sigmas_str = argv[index];
1211+
if (!sigmas_str.empty() && sigmas_str.front() =='[') {
1212+
sigmas_str.erase(0,1);
1213+
}
1214+
if (!sigmas_str.empty() && sigmas_str.back() ==']') {
1215+
sigmas_str.pop_back();
1216+
}
1217+
1218+
std::stringstreamss(sigmas_str);
1219+
std::string item;
1220+
while (std::getline(ss, item,',')) {
1221+
item.erase(0, item.find_first_not_of("\t\n\r\f\v"));
1222+
item.erase(item.find_last_not_of("\t\n\r\f\v") +1);
1223+
if (!item.empty()) {
1224+
try {
1225+
custom_sigmas.push_back(std::stof(item));
1226+
}catch (const std::invalid_argument& e) {
1227+
fprintf(stderr,"error: invalid float value '%s' in --sigmas\n", item.c_str());
1228+
return -1;
1229+
}catch (const std::out_of_range& e) {
1230+
fprintf(stderr,"error: float value '%s' out of range in --sigmas\n", item.c_str());
1231+
return -1;
1232+
}
1233+
}
1234+
}
1235+
1236+
if (custom_sigmas.empty() && !sigmas_str.empty()) {
1237+
fprintf(stderr,"error: could not parse any sigma values from '%s'\n", argv[index]);
1238+
return -1;
1239+
}
1240+
return1;
1241+
};
1242+
12041243
auto on_ref_image_arg = [&](int argc,constchar** argv,int index) {
12051244
if (++index >= argc) {
12061245
return -1;
@@ -1260,6 +1299,10 @@ struct SDGenerationParams {
12601299
"--scheduler",
12611300
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], default: discrete",
12621301
on_scheduler_arg},
1302+
{"",
1303+
"--sigmas",
1304+
"custom sigma values for the sampler, comma-separated (e.g.,\"14.61,7.8,3.5,0.0\").",
1305+
on_sigmas_arg},
12631306
{"",
12641307
"--skip-layers",
12651308
"layers to skip for SLG steps (default: [7,8,9])",
@@ -1512,6 +1555,8 @@ struct SDGenerationParams {
15121555

15131556
sample_params.guidance.slg.layers = skip_layers.data();
15141557
sample_params.guidance.slg.layer_count = skip_layers.size();
1558+
sample_params.custom_sigmas = custom_sigmas.data();
1559+
sample_params.custom_sigmas_count =static_cast<int>(custom_sigmas.size());
15151560
high_noise_sample_params.guidance.slg.layers = high_noise_skip_layers.data();
15161561
high_noise_sample_params.guidance.slg.layer_count = high_noise_skip_layers.size();
15171562

@@ -1606,6 +1651,7 @@ struct SDGenerationParams {
16061651
<<" sample_params:" << sample_params_str <<",\n"
16071652
<<" high_noise_skip_layers:" <<vec_to_string(high_noise_skip_layers) <<",\n"
16081653
<<" high_noise_sample_params:" << high_noise_sample_params_str <<",\n"
1654+
<<" custom_sigmas:" <<vec_to_string(custom_sigmas) <<",\n"
16091655
<<" easycache_option:\"" << easycache_option <<"\",\n"
16101656
<<" easycache:"
16111657
<< (easycache_params.enabled ?"enabled" :"disabled")

‎examples/server/README.md‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ Default Generation Options:
115115
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise
116116
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm],
117117
default: discrete
118+
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
118119
--skip-layers layers to skip for SLG steps (default: [7,8,9])
119120
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
120121
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)

‎stable-diffusion.cpp‎

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2600,6 +2600,8 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) {
26002600
sample_params->scheduler = SCHEDULER_COUNT;
26012601
sample_params->sample_method = SAMPLE_METHOD_COUNT;
26022602
sample_params->sample_steps =20;
2603+
sample_params->custom_sigmas =nullptr;
2604+
sample_params->custom_sigmas_count =0;
26032605
}
26042606

26052607
char*sd_sample_params_to_str(constsd_sample_params_t* sample_params) {
@@ -3194,11 +3196,21 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
31943196
}
31953197
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
31963198

3197-
int sample_steps = sd_img_gen_params->sample_params.sample_steps;
3198-
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps,
3199-
sd_ctx->sd->get_image_seq_len(height, width),
3200-
sd_img_gen_params->sample_params.scheduler,
3201-
sd_ctx->sd->version);
3199+
int sample_steps = sd_img_gen_params->sample_params.sample_steps;
3200+
std::vector<float> sigmas;
3201+
if (sd_img_gen_params->sample_params.custom_sigmas_count >0) {
3202+
sigmas = std::vector<float>(sd_img_gen_params->sample_params.custom_sigmas,
3203+
sd_img_gen_params->sample_params.custom_sigmas + sd_img_gen_params->sample_params.custom_sigmas_count);
3204+
if (sample_steps != sigmas.size() -1) {
3205+
sample_steps =static_cast<int>(sigmas.size()) -1;
3206+
LOG_WARN("sample_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps);
3207+
}
3208+
}else {
3209+
sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps,
3210+
sd_ctx->sd->get_image_seq_len(height, width),
3211+
sd_img_gen_params->sample_params.scheduler,
3212+
sd_ctx->sd->version);
3213+
}
32023214

32033215
ggml_tensor* init_latent =nullptr;
32043216
ggml_tensor* concat_latent =nullptr;
@@ -3461,7 +3473,29 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
34613473
if (high_noise_sample_steps >0) {
34623474
total_steps += high_noise_sample_steps;
34633475
}
3464-
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps,0, sd_vid_gen_params->sample_params.scheduler, sd_ctx->sd->version);
3476+
3477+
std::vector<float> sigmas;
3478+
if (sd_vid_gen_params->sample_params.custom_sigmas_count >0) {
3479+
sigmas = std::vector<float>(sd_vid_gen_params->sample_params.custom_sigmas,
3480+
sd_vid_gen_params->sample_params.custom_sigmas + sd_vid_gen_params->sample_params.custom_sigmas_count);
3481+
if (total_steps != sigmas.size() -1) {
3482+
total_steps =static_cast<int>(sigmas.size()) -1;
3483+
LOG_WARN("total_steps != custom_sigmas_count - 1, set total_steps to %d", total_steps);
3484+
if (sample_steps >= total_steps) {
3485+
sample_steps = total_steps;
3486+
LOG_WARN("total_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps);
3487+
}
3488+
if (high_noise_sample_steps >0) {
3489+
high_noise_sample_steps = total_steps - sample_steps;
3490+
LOG_WARN("total_steps != custom_sigmas_count - 1, set high_noise_sample_steps to %d", high_noise_sample_steps);
3491+
}
3492+
}
3493+
}else {
3494+
sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps,
3495+
0,
3496+
sd_vid_gen_params->sample_params.scheduler,
3497+
sd_ctx->sd->version);
3498+
}
34653499

34663500
if (high_noise_sample_steps <0) {
34673501
// timesteps ∝ sigmas for Flow models (like wan2.2 a14b)
@@ -3841,4 +3875,4 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
38413875
LOG_INFO("generate_video completed in %.2fs", (t5 - t0) *1.0f /1000);
38423876

38433877
return result_images;
3844-
}
3878+
}

‎stable-diffusion.h‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ typedef struct {
225225
intsample_steps;
226226
floateta;
227227
intshifted_timestep;
228+
float*custom_sigmas;
229+
intcustom_sigmas_count;
228230
}sd_sample_params_t;
229231

230232
typedefstruct {

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp