Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit968fbf0

Browse files
authored
feat: add option to switch the sigma schedule (leejet#51)
Concretely, this allows switching to the "Karras" schedule from theKarras et al 2022 paper, equivalent to the samplers marked as "Karras"in the AUTOMATIC1111 WebUI. This choice is in principle orthogonal tothe sampler choice and can be given independently.
1 parentb6899e8 commit968fbf0

File tree

3 files changed

+117
-37
lines changed

3 files changed

+117
-37
lines changed

‎examples/main.cpp‎

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ const char* sample_method_str[] = {
8080
"dpm++2m",
8181
"dpm++2mv2"};
8282

83+
// Names of the sigma schedule overrides, same order as Schedule in stable-diffusion.h
84+
constchar* schedule_str[] = {
85+
"default",
86+
"discrete",
87+
"karras"};
88+
8389
structOption {
8490
int n_threads = -1;
8591
std::string mode = TXT2IMG;
@@ -92,6 +98,7 @@ struct Option {
9298
int w =512;
9399
int h =512;
94100
SampleMethod sample_method = EULER_A;
101+
Schedule schedule = DEFAULT;
95102
int sample_steps =20;
96103
float strength =0.75f;
97104
RNGType rng_type = CUDA_RNG;
@@ -111,6 +118,7 @@ struct Option {
111118
printf(" width: %d\n", w);
112119
printf(" height: %d\n", h);
113120
printf(" sample_method: %s\n", sample_method_str[sample_method]);
121+
printf(" schedule: %s\n", schedule_str[schedule]);
114122
printf(" sample_steps: %d\n", sample_steps);
115123
printf(" strength: %.2f\n", strength);
116124
printf(" rng: %s\n", rng_type_to_str[rng_type]);
@@ -141,6 +149,7 @@ void print_usage(int argc, const char* argv[]) {
141149
printf(" --steps STEPS number of sample steps (default: 20)\n");
142150
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
143151
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
152+
printf(" --schedule {discrete, karras} Denoiser sigma schedule (default: discrete)\n");
144153
printf(" -v, --verbose print extra info\n");
145154
}
146155

@@ -237,6 +246,23 @@ void parse_args(int argc, const char* argv[], Option* opt) {
237246
invalid_arg =true;
238247
break;
239248
}
249+
}elseif (arg =="--schedule") {
250+
if (++i >= argc) {
251+
invalid_arg =true;
252+
break;
253+
}
254+
constchar* schedule_selected = argv[i];
255+
int schedule_found = -1;
256+
for (int d =0; d < N_SCHEDULES; d++) {
257+
if (!strcmp(schedule_selected, schedule_str[d])) {
258+
schedule_found = d;
259+
}
260+
}
261+
if (schedule_found == -1) {
262+
invalid_arg =true;
263+
break;
264+
}
265+
opt->schedule = (Schedule)schedule_found;
240266
}elseif (arg =="-s" || arg =="--seed") {
241267
if (++i >= argc) {
242268
invalid_arg =true;
@@ -377,7 +403,7 @@ int main(int argc, const char* argv[]) {
377403
}
378404

379405
StableDiffusionsd(opt.n_threads, vae_decode_only,true, opt.rng_type);
380-
if (!sd.load_from_file(opt.model_path)) {
406+
if (!sd.load_from_file(opt.model_path, opt.schedule)) {
381407
return1;
382408
}
383409

@@ -413,4 +439,4 @@ int main(int argc, const char* argv[]) {
413439
printf("save result image to '%s'\n", opt.output_path.c_str());
414440

415441
return0;
416-
}
442+
}

‎stable-diffusion.cpp‎

Lines changed: 81 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2654,32 +2654,12 @@ struct AutoEncoderKL {
26542654

26552655
// Ref: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/external.py
26562656

2657-
structDiscreteSchedule {
2657+
structSigmaSchedule {
26582658
float alphas_cumprod[TIMESTEPS];
26592659
float sigmas[TIMESTEPS];
26602660
float log_sigmas[TIMESTEPS];
26612661

2662-
std::vector<float>get_sigmas(uint32_t n) {
2663-
std::vector<float> result;
2664-
2665-
int t_max = TIMESTEPS -1;
2666-
2667-
if (n ==0) {
2668-
return result;
2669-
}elseif (n ==1) {
2670-
result.push_back(t_to_sigma(t_max));
2671-
result.push_back(0);
2672-
return result;
2673-
}
2674-
2675-
float step =static_cast<float>(t_max) /static_cast<float>(n -1);
2676-
for (int i =0; i < n; ++i) {
2677-
float t = t_max - step * i;
2678-
result.push_back(t_to_sigma(t));
2679-
}
2680-
result.push_back(0);
2681-
return result;
2682-
}
2662+
virtual std::vector<float>get_sigmas(uint32_t n) = 0;
26832663

26842664
floatsigma_to_t(float sigma) {
26852665
float log_sigma =std::log(sigma);
@@ -2714,11 +2694,59 @@ struct DiscreteSchedule {
27142694
float log_sigma = (1.0f - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx];
27152695
returnstd::exp(log_sigma);
27162696
}
2697+
};
27172698

2699+
structDiscreteSchedule : SigmaSchedule {
2700+
std::vector<float>get_sigmas(uint32_t n) {
2701+
std::vector<float> result;
2702+
2703+
int t_max = TIMESTEPS -1;
2704+
2705+
if (n ==0) {
2706+
return result;
2707+
}elseif (n ==1) {
2708+
result.push_back(t_to_sigma(t_max));
2709+
result.push_back(0);
2710+
return result;
2711+
}
2712+
2713+
float step =static_cast<float>(t_max) /static_cast<float>(n -1);
2714+
for (int i =0; i < n; ++i) {
2715+
float t = t_max - step * i;
2716+
result.push_back(t_to_sigma(t));
2717+
}
2718+
result.push_back(0);
2719+
return result;
2720+
}
2721+
};
2722+
2723+
structKarrasSchedule : SigmaSchedule {
2724+
std::vector<float>get_sigmas(uint32_t n) {
2725+
// These *COULD* be function arguments here,
2726+
// but does anybody ever bother to touch them?
2727+
float sigma_min =0.1;
2728+
float sigma_max =10.;
2729+
float rho =7.;
2730+
2731+
std::vector<float>result(n +1);
2732+
2733+
float min_inv_rho =pow(sigma_min, (1. / rho));
2734+
float max_inv_rho =pow(sigma_max, (1. / rho));
2735+
for (int i =0; i < n; i++) {
2736+
// Eq. (5) from Karras et al 2022
2737+
result[i] =pow(max_inv_rho + (float)i / ((float)n -1.) * (min_inv_rho - max_inv_rho), rho);
2738+
}
2739+
result[n] =0.;
2740+
return result;
2741+
}
2742+
};
2743+
2744+
structDenoiser {
2745+
std::shared_ptr<SigmaSchedule> schedule = std::make_shared<DiscreteSchedule>();
27182746
virtual std::vector<float>get_scalings(float sigma) = 0;
27192747
};
27202748

2721-
structCompVisDenoiser :publicDiscreteSchedule {
2749+
structCompVisDenoiser :publicDenoiser {
27222750
float sigma_data =1.0f;
27232751

27242752
std::vector<float>get_scalings(float sigma) {
@@ -2728,7 +2756,7 @@ struct CompVisDenoiser : public DiscreteSchedule {
27282756
}
27292757
};
27302758

2731-
structCompVisVDenoiser :publicDiscreteSchedule {
2759+
structCompVisVDenoiser :publicDenoiser {
27322760
float sigma_data =1.0f;
27332761

27342762
std::vector<float>get_scalings(float sigma) {
@@ -2764,7 +2792,7 @@ class StableDiffusionGGML {
27642792
UNetModel diffusion_model;
27652793
AutoEncoderKL first_stage_model;
27662794

2767-
std::shared_ptr<DiscreteSchedule> denoiser = std::make_shared<CompVisDenoiser>();
2795+
std::shared_ptr<Denoiser> denoiser = std::make_shared<CompVisDenoiser>();
27682796

27692797
StableDiffusionGGML() =default;
27702798

@@ -2798,7 +2826,7 @@ class StableDiffusionGGML {
27982826
}
27992827
}
28002828

2801-
boolload_from_file(const std::string& file_path) {
2829+
boolload_from_file(const std::string& file_path, Schedule schedule) {
28022830
LOG_INFO("loading model from '%s'", file_path.c_str());
28032831

28042832
std::ifstreamfile(file_path, std::ios::binary);
@@ -3093,10 +3121,29 @@ class StableDiffusionGGML {
30933121
LOG_INFO("running in eps-prediction mode");
30943122
}
30953123

3124+
if (schedule != DEFAULT) {
3125+
switch (schedule) {
3126+
case DISCRETE:
3127+
LOG_INFO("running with discrete schedule");
3128+
denoiser->schedule = std::make_shared<DiscreteSchedule>();
3129+
break;
3130+
case KARRAS:
3131+
LOG_INFO("running with Karras schedule");
3132+
denoiser->schedule = std::make_shared<KarrasSchedule>();
3133+
break;
3134+
case DEFAULT:
3135+
// Don't touch anything.
3136+
break;
3137+
default:
3138+
LOG_ERROR("Unknown schedule %i", schedule);
3139+
abort();
3140+
}
3141+
}
3142+
30963143
for (int i =0; i < TIMESTEPS; i++) {
3097-
denoiser->alphas_cumprod[i] = alphas_cumprod[i];
3098-
denoiser->sigmas[i] =std::sqrt((1 - denoiser->alphas_cumprod[i]) / denoiser->alphas_cumprod[i]);
3099-
denoiser->log_sigmas[i] =std::log(denoiser->sigmas[i]);
3144+
denoiser->schedule->alphas_cumprod[i] = alphas_cumprod[i];
3145+
denoiser->schedule->sigmas[i] =std::sqrt((1 - denoiser->schedule->alphas_cumprod[i]) / denoiser->schedule->alphas_cumprod[i]);
3146+
denoiser->schedule->log_sigmas[i] =std::log(denoiser->schedule->sigmas[i]);
31003147
}
31013148

31023149
returntrue;
@@ -3445,7 +3492,7 @@ class StableDiffusionGGML {
34453492
c_in = scaling[1];
34463493
}
34473494

3448-
float t = denoiser->sigma_to_t(sigma);
3495+
float t = denoiser->schedule->sigma_to_t(sigma);
34493496
ggml_set_f32(timesteps, t);
34503497
set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels);
34513498

@@ -4010,8 +4057,8 @@ StableDiffusion::StableDiffusion(int n_threads,
40104057
rng_type);
40114058
}
40124059

4013-
boolStableDiffusion::load_from_file(const std::string& file_path) {
4014-
return sd->load_from_file(file_path);
4060+
boolStableDiffusion::load_from_file(const std::string& file_path, Schedule s) {
4061+
return sd->load_from_file(file_path, s);
40154062
}
40164063

40174064
std::vector<uint8_t>StableDiffusion::txt2img(const std::string& prompt,
@@ -4061,7 +4108,7 @@ std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
40614108
structggml_tensor*x_t =ggml_new_tensor_4d(ctx, GGML_TYPE_F32, W, H, C,1);
40624109
ggml_tensor_set_f32_randn(x_t, sd->rng);
40634110

4064-
std::vector<float> sigmas = sd->denoiser->get_sigmas(sample_steps);
4111+
std::vector<float> sigmas = sd->denoiser->schedule->get_sigmas(sample_steps);
40654112

40664113
LOG_INFO("start sampling");
40674114
structggml_tensor* x_0 = sd->sample(ctx,x_t, c, uc, cfg_scale, sample_method, sigmas);
@@ -4117,7 +4164,7 @@ std::vector<uint8_t> StableDiffusion::img2img(const std::vector<uint8_t>& init_i
41174164
}
41184165
LOG_INFO("img2img %dx%d", width, height);
41194166

4120-
std::vector<float> sigmas = sd->denoiser->get_sigmas(sample_steps);
4167+
std::vector<float> sigmas = sd->denoiser->schedule->get_sigmas(sample_steps);
41214168
size_t t_enc =static_cast<size_t>(sample_steps * strength);
41224169
LOG_INFO("target t_enc is %zu steps", t_enc);
41234170
std::vector<float> sigma_sched;

‎stable-diffusion.h‎

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ enum SampleMethod {
2525
N_SAMPLE_METHODS
2626
};
2727

28+
enum Schedule {
29+
DEFAULT,
30+
DISCRETE,
31+
KARRAS,
32+
N_SCHEDULES
33+
};
34+
2835
classStableDiffusionGGML;
2936

3037
classStableDiffusion {
@@ -36,7 +43,7 @@ class StableDiffusion {
3643
bool vae_decode_only =false,
3744
bool free_params_immediately =false,
3845
RNGType rng_type = STD_DEFAULT_RNG);
39-
boolload_from_file(const std::string& file_path);
46+
boolload_from_file(const std::string& file_path, Schedule d = DEFAULT);
4047
std::vector<uint8_t>txt2img(
4148
const std::string& prompt,
4249
const std::string& negative_prompt,

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp