@@ -1281,19 +1281,28 @@ def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scal
12811281
12821282pred = self .net .forward_with_cond_scale (image_embed ,time_cond ,self_cond = self_cond ,cond_scale = cond_scale ,** text_cond )
12831283
1284+ # derive x0
1285+
12841286if self .predict_x_start :
12851287x_start = pred
1286- pred_noise = self .noise_scheduler .predict_noise_from_start (image_embed ,t = time_cond ,x0 = pred )
12871288else :
1288- x_start = self .noise_scheduler .predict_start_from_noise (image_embed ,t = time_cond ,noise = pred )
1289- pred_noise = pred
1289+ x_start = self .noise_scheduler .predict_start_from_noise (image_embed ,t = time_cond ,noise = pred_noise )
1290+
1291+ # clip x0 before maybe predicting noise
12901292
12911293if not self .predict_x_start :
12921294x_start .clamp_ (- 1. ,1. )
12931295
12941296if self .predict_x_start and self .sampling_clamp_l2norm :
12951297x_start = self .l2norm_clamp_embed (x_start )
12961298
1299+ # predict noise
1300+
1301+ if self .predict_x_start :
1302+ pred_noise = self .noise_scheduler .predict_noise_from_start (image_embed ,t = time_cond ,x0 = x_start )
1303+ else :
1304+ pred_noise = pred
1305+
12971306if time_next < 0 :
12981307image_embed = x_start
12991308continue
@@ -2897,16 +2906,25 @@ def p_sample_loop_ddim(
28972906
28982907pred ,_ = self .parse_unet_output (learned_variance ,unet_output )
28992908
2909+ # predict x0
2910+
29002911if predict_x_start :
29012912x_start = pred
2902- pred_noise = noise_scheduler .predict_noise_from_start (img ,t = time_cond ,x0 = pred )
29032913else :
29042914x_start = noise_scheduler .predict_start_from_noise (img ,t = time_cond ,noise = pred )
2905- pred_noise = pred
2915+
2916+ # maybe clip x0
29062917
29072918if clip_denoised :
29082919x_start = self .dynamic_threshold (x_start )
29092920
2921+ # predict noise
2922+
2923+ if predict_x_start :
2924+ pred_noise = noise_scheduler .predict_noise_from_start (img ,t = time_cond ,x0 = pred )
2925+ else :
2926+ pred_noise = pred
2927+
29102928c1 = eta * ((1 - alpha / alpha_next )* (1 - alpha_next )/ (1 - alpha )).sqrt ()
29112929c2 = ((1 - alpha_next )- torch .square (c1 )).sqrt ()
29122930noise = torch .randn_like (img )if not is_last_timestep else 0.