@@ -619,14 +619,20 @@ def q_posterior(self, x_start, x_t, t):
619619posterior_log_variance_clipped = extract (self .posterior_log_variance_clipped ,t ,x_t .shape )
620620return posterior_mean ,posterior_variance ,posterior_log_variance_clipped
621621
622- def q_sample (self ,x_start ,t ,noise = None ):
622+ def q_sample (self ,x_start ,t ,noise = None ):
623623noise = default (noise ,lambda :torch .randn_like (x_start ))
624624
625625return (
626626extract (self .sqrt_alphas_cumprod ,t ,x_start .shape )* x_start +
627627extract (self .sqrt_one_minus_alphas_cumprod ,t ,x_start .shape )* noise
628628 )
629629
630+ def calculate_v (self ,x_start ,t ,noise = None ):
631+ return (
632+ extract (self .sqrt_alphas_cumprod ,t ,x_start .shape )* noise -
633+ extract (self .sqrt_one_minus_alphas_cumprod ,t ,x_start .shape )* x_start
634+ )
635+
630636def q_sample_from_to (self ,x_from ,from_t ,to_t ,noise = None ):
631637shape = x_from .shape
632638noise = default (noise ,lambda :torch .randn_like (x_from ))
@@ -638,6 +644,12 @@ def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
638644
639645return x_from * (alpha_next / alpha )+ noise * (sigma_next * alpha - sigma * alpha_next )/ alpha
640646
647+ def predict_start_from_v (self ,x_t ,t ,v ):
648+ return (
649+ extract (self .sqrt_alphas_cumprod ,t ,x_t .shape )* x_t -
650+ extract (self .sqrt_one_minus_alphas_cumprod ,t ,x_t .shape )* v
651+ )
652+
641653def predict_start_from_noise (self ,x_t ,t ,noise ):
642654return (
643655extract (self .sqrt_recip_alphas_cumprod ,t ,x_t .shape )* x_t -
@@ -1146,6 +1158,7 @@ def __init__(
11461158image_cond_drop_prob = None ,
11471159loss_type = "l2" ,
11481160predict_x_start = True ,
1161+ predict_v = False ,
11491162beta_schedule = "cosine" ,
11501163condition_on_text_encodings = True ,# the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
11511164sampling_clamp_l2norm = False ,# whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
@@ -1197,6 +1210,7 @@ def __init__(
11971210# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
11981211
11991212self .predict_x_start = predict_x_start
1213+ self .predict_v = predict_v # takes precedence over predict_x_start
12001214
12011215# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
12021216
@@ -1226,7 +1240,9 @@ def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = Fal
12261240
12271241pred = self .net .forward_with_cond_scale (x ,t ,cond_scale = cond_scale ,self_cond = self_cond ,** text_cond )
12281242
1229- if self .predict_x_start :
1243+ if self .predict_v :
1244+ x_start = self .noise_scheduler .predict_start_from_v (x ,t = t ,v = pred )
1245+ elif self .predict_x_start :
12301246x_start = pred
12311247else :
12321248x_start = self .noise_scheduler .predict_start_from_noise (x ,t = t ,noise = pred )
@@ -1299,7 +1315,9 @@ def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scal
12991315
13001316# derive x0
13011317
1302- if self .predict_x_start :
1318+ if self .predict_v :
1319+ x_start = self .noise_scheduler .predict_start_from_v (image_embed ,t = time_cond ,v = pred )
1320+ elif self .predict_x_start :
13031321x_start = pred
13041322else :
13051323x_start = self .noise_scheduler .predict_start_from_noise (image_embed ,t = time_cond ,noise = pred_noise )
@@ -1314,7 +1332,7 @@ def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scal
13141332
13151333# predict noise
13161334
1317- if self .predict_x_start :
1335+ if self .predict_x_start or self . predict_v :
13181336pred_noise = self .noise_scheduler .predict_noise_from_start (image_embed ,t = time_cond ,x0 = x_start )
13191337else :
13201338pred_noise = pred
@@ -1372,7 +1390,12 @@ def p_losses(self, image_embed, times, text_cond, noise = None):
13721390if self .predict_x_start and self .training_clamp_l2norm :
13731391pred = self .l2norm_clamp_embed (pred )
13741392
1375- target = noise if not self .predict_x_start else image_embed
1393+ if self .predict_v :
1394+ target = self .noise_scheduler .calculate_v (image_embed ,times ,noise )
1395+ elif self .predict_x_start :
1396+ target = image_embed
1397+ else :
1398+ target = noise
13761399
13771400loss = self .noise_scheduler .loss_fn (pred ,target )
13781401return loss
@@ -2448,6 +2471,7 @@ def __init__(
24482471loss_type = 'l2' ,
24492472beta_schedule = None ,
24502473predict_x_start = False ,
2474+ predict_v = False ,
24512475predict_x_start_for_latent_diffusion = False ,
24522476image_sizes = None ,# for cascading ddpm, image size at each stage
24532477random_crop_sizes = None ,# whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
@@ -2620,6 +2644,10 @@ def __init__(
26202644
26212645self .predict_x_start = cast_tuple (predict_x_start ,len (unets ))if not predict_x_start_for_latent_diffusion else tuple (map (lambda t :isinstance (t ,VQGanVAE ),self .vaes ))
26222646
2647+ # predict v
2648+
2649+ self .predict_v = cast_tuple (predict_v ,len (unets ))
2650+
26232651# input image range
26242652
26252653self .input_image_range = (- 1. if not auto_normalize_img else 0. ,1. )
@@ -2731,14 +2759,16 @@ def dynamic_threshold(self, x):
27312759x = x .clamp (- s ,s )/ s
27322760return x
27332761
2734- def p_mean_variance (self ,unet ,x ,t ,image_embed ,noise_scheduler ,text_encodings = None ,lowres_cond_img = None ,self_cond = None ,clip_denoised = True ,predict_x_start = False ,learned_variance = False ,cond_scale = 1. ,model_output = None ,lowres_noise_level = None ):
2762+ def p_mean_variance (self ,unet ,x ,t ,image_embed ,noise_scheduler ,text_encodings = None ,lowres_cond_img = None ,self_cond = None ,clip_denoised = True ,predict_x_start = False ,predict_v = False , learned_variance = False ,cond_scale = 1. ,model_output = None ,lowres_noise_level = None ):
27352763assert not (cond_scale != 1. and not self .can_classifier_guidance ),'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
27362764
27372765model_output = default (model_output ,lambda :unet .forward_with_cond_scale (x ,t ,image_embed = image_embed ,text_encodings = text_encodings ,cond_scale = cond_scale ,lowres_cond_img = lowres_cond_img ,self_cond = self_cond ,lowres_noise_level = lowres_noise_level ))
27382766
27392767pred ,var_interp_frac_unnormalized = self .parse_unet_output (learned_variance ,model_output )
27402768
2741- if predict_x_start :
2769+ if predict_v :
2770+ x_start = noise_scheduler .predict_start_from_v (x ,t = t ,v = pred )
2771+ elif predict_x_start :
27422772x_start = pred
27432773else :
27442774x_start = noise_scheduler .predict_start_from_noise (x ,t = t ,noise = pred )
@@ -2765,9 +2795,9 @@ def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodin
27652795return model_mean ,posterior_variance ,posterior_log_variance ,x_start
27662796
27672797@torch .no_grad ()
2768- def p_sample (self ,unet ,x ,t ,image_embed ,noise_scheduler ,text_encodings = None ,cond_scale = 1. ,lowres_cond_img = None ,self_cond = None ,predict_x_start = False ,learned_variance = False ,clip_denoised = True ,lowres_noise_level = None ):
2798+ def p_sample (self ,unet ,x ,t ,image_embed ,noise_scheduler ,text_encodings = None ,cond_scale = 1. ,lowres_cond_img = None ,self_cond = None ,predict_x_start = False ,predict_v = False , learned_variance = False ,clip_denoised = True ,lowres_noise_level = None ):
27692799b ,* _ ,device = * x .shape ,x .device
2770- model_mean ,_ ,model_log_variance ,x_start = self .p_mean_variance (unet ,x = x ,t = t ,image_embed = image_embed ,text_encodings = text_encodings ,cond_scale = cond_scale ,lowres_cond_img = lowres_cond_img ,self_cond = self_cond ,clip_denoised = clip_denoised ,predict_x_start = predict_x_start ,noise_scheduler = noise_scheduler ,learned_variance = learned_variance ,lowres_noise_level = lowres_noise_level )
2800+ model_mean ,_ ,model_log_variance ,x_start = self .p_mean_variance (unet ,x = x ,t = t ,image_embed = image_embed ,text_encodings = text_encodings ,cond_scale = cond_scale ,lowres_cond_img = lowres_cond_img ,self_cond = self_cond ,clip_denoised = clip_denoised ,predict_x_start = predict_x_start ,predict_v = predict_v , noise_scheduler = noise_scheduler ,learned_variance = learned_variance ,lowres_noise_level = lowres_noise_level )
27712801noise = torch .randn_like (x )
27722802# no noise when t == 0
27732803nonzero_mask = (1 - (t == 0 ).float ()).reshape (b ,* ((1 ,)* (len (x .shape )- 1 )))
@@ -2782,6 +2812,7 @@ def p_sample_loop_ddpm(
27822812image_embed ,
27832813noise_scheduler ,
27842814predict_x_start = False ,
2815+ predict_v = False ,
27852816learned_variance = False ,
27862817clip_denoised = True ,
27872818lowres_cond_img = None ,
@@ -2840,6 +2871,7 @@ def p_sample_loop_ddpm(
28402871lowres_cond_img = lowres_cond_img ,
28412872lowres_noise_level = lowres_noise_level ,
28422873predict_x_start = predict_x_start ,
2874+ predict_v = predict_v ,
28432875noise_scheduler = noise_scheduler ,
28442876learned_variance = learned_variance ,
28452877clip_denoised = clip_denoised
@@ -2865,6 +2897,7 @@ def p_sample_loop_ddim(
28652897timesteps ,
28662898eta = 1. ,
28672899predict_x_start = False ,
2900+ predict_v = False ,
28682901learned_variance = False ,
28692902clip_denoised = True ,
28702903lowres_cond_img = None ,
@@ -2926,7 +2959,9 @@ def p_sample_loop_ddim(
29262959
29272960# predict x0
29282961
2929- if predict_x_start :
2962+ if predict_v :
2963+ x_start = noise_scheduler .predict_start_from_v (img ,t = time_cond ,v = pred )
2964+ elif predict_x_start :
29302965x_start = pred
29312966else :
29322967x_start = noise_scheduler .predict_start_from_noise (img ,t = time_cond ,noise = pred )
@@ -2938,8 +2973,8 @@ def p_sample_loop_ddim(
29382973
29392974# predict noise
29402975
2941- if predict_x_start :
2942- pred_noise = noise_scheduler .predict_noise_from_start (img ,t = time_cond ,x0 = pred )
2976+ if predict_x_start or predict_v :
2977+ pred_noise = noise_scheduler .predict_noise_from_start (img ,t = time_cond ,x0 = x_start )
29432978else :
29442979pred_noise = pred
29452980
@@ -2975,7 +3010,7 @@ def p_sample_loop(self, *args, noise_scheduler, timesteps = None, **kwargs):
29753010
29763011return self .p_sample_loop_ddim (* args ,noise_scheduler = noise_scheduler ,timesteps = timesteps ,** kwargs )
29773012
2978- def p_losses (self ,unet ,x_start ,times ,* ,image_embed ,noise_scheduler ,lowres_cond_img = None ,text_encodings = None ,predict_x_start = False ,noise = None ,learned_variance = False ,clip_denoised = False ,is_latent_diffusion = False ,lowres_noise_level = None ):
3013+ def p_losses (self ,unet ,x_start ,times ,* ,image_embed ,noise_scheduler ,lowres_cond_img = None ,text_encodings = None ,predict_x_start = False ,predict_v = False , noise = None ,learned_variance = False ,clip_denoised = False ,is_latent_diffusion = False ,lowres_noise_level = None ):
29793014noise = default (noise ,lambda :torch .randn_like (x_start ))
29803015
29813016# normalize to [-1, 1]
@@ -3020,7 +3055,12 @@ def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres
30203055
30213056pred ,_ = self .parse_unet_output (learned_variance ,unet_output )
30223057
3023- target = noise if not predict_x_start else x_start
3058+ if predict_v :
3059+ target = noise_scheduler .calculate_v (x_start ,times ,noise )
3060+ elif predict_x_start :
3061+ target = x_start
3062+ else :
3063+ target = noise
30243064
30253065loss = noise_scheduler .loss_fn (pred ,target ,reduction = 'none' )
30263066loss = reduce (loss ,'b ... -> b (...)' ,'mean' )
@@ -3106,7 +3146,7 @@ def sample(
31063146num_unets = self .num_unets
31073147cond_scale = cast_tuple (cond_scale ,num_unets )
31083148
3109- for unet_number ,unet ,vae ,channel ,image_size ,predict_x_start ,learned_variance ,noise_scheduler ,lowres_cond ,sample_timesteps ,unet_cond_scale in tqdm (zip (range (1 ,num_unets + 1 ),self .unets ,self .vaes ,self .sample_channels ,self .image_sizes ,self .predict_x_start ,self .learned_variance ,self .noise_schedulers ,self .lowres_conds ,self .sample_timesteps ,cond_scale )):
3149+ for unet_number ,unet ,vae ,channel ,image_size ,predict_x_start ,predict_v , learned_variance ,noise_scheduler ,lowres_cond ,sample_timesteps ,unet_cond_scale in tqdm (zip (range (1 ,num_unets + 1 ),self .unets ,self .vaes ,self .sample_channels ,self .image_sizes ,self .predict_x_start , self . predict_v ,self .learned_variance ,self .noise_schedulers ,self .lowres_conds ,self .sample_timesteps ,cond_scale )):
31103150if unet_number < start_at_unet_number :
31113151continue # It's the easiest way to do it
31123152
@@ -3142,6 +3182,7 @@ def sample(
31423182text_encodings = text_encodings ,
31433183cond_scale = unet_cond_scale ,
31443184predict_x_start = predict_x_start ,
3185+ predict_v = predict_v ,
31453186learned_variance = learned_variance ,
31463187clip_denoised = not is_latent_diffusion ,
31473188lowres_cond_img = lowres_cond_img ,
@@ -3181,6 +3222,7 @@ def forward(
31813222lowres_conditioner = self .lowres_conds [unet_index ]
31823223target_image_size = self .image_sizes [unet_index ]
31833224predict_x_start = self .predict_x_start [unet_index ]
3225+ predict_v = self .predict_v [unet_index ]
31843226random_crop_size = self .random_crop_sizes [unet_index ]
31853227learned_variance = self .learned_variance [unet_index ]
31863228b ,c ,h ,w ,device ,= * image .shape ,image .device
@@ -3219,7 +3261,7 @@ def forward(
32193261image = vae .encode (image )
32203262lowres_cond_img = maybe (vae .encode )(lowres_cond_img )
32213263
3222- losses = self .p_losses (unet ,image ,times ,image_embed = image_embed ,text_encodings = text_encodings ,lowres_cond_img = lowres_cond_img ,predict_x_start = predict_x_start ,learned_variance = learned_variance ,is_latent_diffusion = is_latent_diffusion ,noise_scheduler = noise_scheduler ,lowres_noise_level = lowres_noise_level )
3264+ losses = self .p_losses (unet ,image ,times ,image_embed = image_embed ,text_encodings = text_encodings ,lowres_cond_img = lowres_cond_img ,predict_x_start = predict_x_start ,predict_v = predict_v , learned_variance = learned_variance ,is_latent_diffusion = is_latent_diffusion ,noise_scheduler = noise_scheduler ,lowres_noise_level = lowres_noise_level )
32233265
32243266if not return_lowres_cond_image :
32253267return losses