Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitfbba0f9

Browse files
committed
bring in prediction of v objective, combining the findings from progressive distillation paper and imagen-video to the eventual extension of dalle2 to make-a-video
1 parent9f37705 commitfbba0f9

File tree

3 files changed

+69
-17
lines changed

3 files changed

+69
-17
lines changed

‎README.md‎

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,4 +1298,14 @@ For detailed information on training the diffusion prior, please refer to the [d
12981298
}
12991299
```
13001300

1301+
```bibtex
1302+
@article{Salimans2022ProgressiveDF,
1303+
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
1304+
author = {Tim Salimans and Jonathan Ho},
1305+
journal = {ArXiv},
1306+
year = {2022},
1307+
volume = {abs/2202.00512}
1308+
}
1309+
```
1310+
13011311
*Creating noise from data is easy; creating data from noise is generative modeling.* - <ahref="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>

‎dalle2_pytorch/dalle2_pytorch.py‎

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -619,14 +619,20 @@ def q_posterior(self, x_start, x_t, t):
619619
posterior_log_variance_clipped=extract(self.posterior_log_variance_clipped,t,x_t.shape)
620620
returnposterior_mean,posterior_variance,posterior_log_variance_clipped
621621

622-
defq_sample(self,x_start,t,noise=None):
622+
defq_sample(self,x_start,t,noise=None):
623623
noise=default(noise,lambda:torch.randn_like(x_start))
624624

625625
return (
626626
extract(self.sqrt_alphas_cumprod,t,x_start.shape)*x_start+
627627
extract(self.sqrt_one_minus_alphas_cumprod,t,x_start.shape)*noise
628628
)
629629

630+
defcalculate_v(self,x_start,t,noise=None):
631+
return (
632+
extract(self.sqrt_alphas_cumprod,t,x_start.shape)*noise-
633+
extract(self.sqrt_one_minus_alphas_cumprod,t,x_start.shape)*x_start
634+
)
635+
630636
defq_sample_from_to(self,x_from,from_t,to_t,noise=None):
631637
shape=x_from.shape
632638
noise=default(noise,lambda:torch.randn_like(x_from))
@@ -638,6 +644,12 @@ def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
638644

639645
returnx_from* (alpha_next/alpha)+noise* (sigma_next*alpha-sigma*alpha_next)/alpha
640646

647+
defpredict_start_from_v(self,x_t,t,v):
648+
return (
649+
extract(self.sqrt_alphas_cumprod,t,x_t.shape)*x_t-
650+
extract(self.sqrt_one_minus_alphas_cumprod,t,x_t.shape)*v
651+
)
652+
641653
defpredict_start_from_noise(self,x_t,t,noise):
642654
return (
643655
extract(self.sqrt_recip_alphas_cumprod,t,x_t.shape)*x_t-
@@ -1146,6 +1158,7 @@ def __init__(
11461158
image_cond_drop_prob=None,
11471159
loss_type="l2",
11481160
predict_x_start=True,
1161+
predict_v=False,
11491162
beta_schedule="cosine",
11501163
condition_on_text_encodings=True,# the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
11511164
sampling_clamp_l2norm=False,# whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
@@ -1197,6 +1210,7 @@ def __init__(
11971210
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
11981211

11991212
self.predict_x_start=predict_x_start
1213+
self.predict_v=predict_v# takes precedence over predict_x_start
12001214

12011215
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
12021216

@@ -1226,7 +1240,9 @@ def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = Fal
12261240

12271241
pred=self.net.forward_with_cond_scale(x,t,cond_scale=cond_scale,self_cond=self_cond,**text_cond)
12281242

1229-
ifself.predict_x_start:
1243+
ifself.predict_v:
1244+
x_start=self.noise_scheduler.predict_start_from_v(x,t=t,v=pred)
1245+
elifself.predict_x_start:
12301246
x_start=pred
12311247
else:
12321248
x_start=self.noise_scheduler.predict_start_from_noise(x,t=t,noise=pred)
@@ -1299,7 +1315,9 @@ def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scal
12991315

13001316
# derive x0
13011317

1302-
ifself.predict_x_start:
1318+
ifself.predict_v:
1319+
x_start=self.noise_scheduler.predict_start_from_v(image_embed,t=time_cond,v=pred)
1320+
elifself.predict_x_start:
13031321
x_start=pred
13041322
else:
13051323
x_start=self.noise_scheduler.predict_start_from_noise(image_embed,t=time_cond,noise=pred_noise)
@@ -1314,7 +1332,7 @@ def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scal
13141332

13151333
# predict noise
13161334

1317-
ifself.predict_x_start:
1335+
ifself.predict_x_startorself.predict_v:
13181336
pred_noise=self.noise_scheduler.predict_noise_from_start(image_embed,t=time_cond,x0=x_start)
13191337
else:
13201338
pred_noise=pred
@@ -1372,7 +1390,12 @@ def p_losses(self, image_embed, times, text_cond, noise = None):
13721390
ifself.predict_x_startandself.training_clamp_l2norm:
13731391
pred=self.l2norm_clamp_embed(pred)
13741392

1375-
target=noiseifnotself.predict_x_startelseimage_embed
1393+
ifself.predict_v:
1394+
target=self.noise_scheduler.calculate_v(image_embed,times,noise)
1395+
elifself.predict_x_start:
1396+
target=image_embed
1397+
else:
1398+
target=noise
13761399

13771400
loss=self.noise_scheduler.loss_fn(pred,target)
13781401
returnloss
@@ -2448,6 +2471,7 @@ def __init__(
24482471
loss_type='l2',
24492472
beta_schedule=None,
24502473
predict_x_start=False,
2474+
predict_v=False,
24512475
predict_x_start_for_latent_diffusion=False,
24522476
image_sizes=None,# for cascading ddpm, image size at each stage
24532477
random_crop_sizes=None,# whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
@@ -2620,6 +2644,10 @@ def __init__(
26202644

26212645
self.predict_x_start=cast_tuple(predict_x_start,len(unets))ifnotpredict_x_start_for_latent_diffusionelsetuple(map(lambdat:isinstance(t,VQGanVAE),self.vaes))
26222646

2647+
# predict v
2648+
2649+
self.predict_v=cast_tuple(predict_v,len(unets))
2650+
26232651
# input image range
26242652

26252653
self.input_image_range= (-1.ifnotauto_normalize_imgelse0.,1.)
@@ -2731,14 +2759,16 @@ def dynamic_threshold(self, x):
27312759
x=x.clamp(-s,s)/s
27322760
returnx
27332761

2734-
defp_mean_variance(self,unet,x,t,image_embed,noise_scheduler,text_encodings=None,lowres_cond_img=None,self_cond=None,clip_denoised=True,predict_x_start=False,learned_variance=False,cond_scale=1.,model_output=None,lowres_noise_level=None):
2762+
defp_mean_variance(self,unet,x,t,image_embed,noise_scheduler,text_encodings=None,lowres_cond_img=None,self_cond=None,clip_denoised=True,predict_x_start=False,predict_v=False,learned_variance=False,cond_scale=1.,model_output=None,lowres_noise_level=None):
27352763
assertnot (cond_scale!=1.andnotself.can_classifier_guidance),'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
27362764

27372765
model_output=default(model_output,lambda:unet.forward_with_cond_scale(x,t,image_embed=image_embed,text_encodings=text_encodings,cond_scale=cond_scale,lowres_cond_img=lowres_cond_img,self_cond=self_cond,lowres_noise_level=lowres_noise_level))
27382766

27392767
pred,var_interp_frac_unnormalized=self.parse_unet_output(learned_variance,model_output)
27402768

2741-
ifpredict_x_start:
2769+
ifpredict_v:
2770+
x_start=noise_scheduler.predict_start_from_v(x,t=t,v=pred)
2771+
elifpredict_x_start:
27422772
x_start=pred
27432773
else:
27442774
x_start=noise_scheduler.predict_start_from_noise(x,t=t,noise=pred)
@@ -2765,9 +2795,9 @@ def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodin
27652795
returnmodel_mean,posterior_variance,posterior_log_variance,x_start
27662796

27672797
@torch.no_grad()
2768-
defp_sample(self,unet,x,t,image_embed,noise_scheduler,text_encodings=None,cond_scale=1.,lowres_cond_img=None,self_cond=None,predict_x_start=False,learned_variance=False,clip_denoised=True,lowres_noise_level=None):
2798+
defp_sample(self,unet,x,t,image_embed,noise_scheduler,text_encodings=None,cond_scale=1.,lowres_cond_img=None,self_cond=None,predict_x_start=False,predict_v=False,learned_variance=False,clip_denoised=True,lowres_noise_level=None):
27692799
b,*_,device=*x.shape,x.device
2770-
model_mean,_,model_log_variance,x_start=self.p_mean_variance(unet,x=x,t=t,image_embed=image_embed,text_encodings=text_encodings,cond_scale=cond_scale,lowres_cond_img=lowres_cond_img,self_cond=self_cond,clip_denoised=clip_denoised,predict_x_start=predict_x_start,noise_scheduler=noise_scheduler,learned_variance=learned_variance,lowres_noise_level=lowres_noise_level)
2800+
model_mean,_,model_log_variance,x_start=self.p_mean_variance(unet,x=x,t=t,image_embed=image_embed,text_encodings=text_encodings,cond_scale=cond_scale,lowres_cond_img=lowres_cond_img,self_cond=self_cond,clip_denoised=clip_denoised,predict_x_start=predict_x_start,predict_v=predict_v,noise_scheduler=noise_scheduler,learned_variance=learned_variance,lowres_noise_level=lowres_noise_level)
27712801
noise=torch.randn_like(x)
27722802
# no noise when t == 0
27732803
nonzero_mask= (1- (t==0).float()).reshape(b,*((1,)* (len(x.shape)-1)))
@@ -2782,6 +2812,7 @@ def p_sample_loop_ddpm(
27822812
image_embed,
27832813
noise_scheduler,
27842814
predict_x_start=False,
2815+
predict_v=False,
27852816
learned_variance=False,
27862817
clip_denoised=True,
27872818
lowres_cond_img=None,
@@ -2840,6 +2871,7 @@ def p_sample_loop_ddpm(
28402871
lowres_cond_img=lowres_cond_img,
28412872
lowres_noise_level=lowres_noise_level,
28422873
predict_x_start=predict_x_start,
2874+
predict_v=predict_v,
28432875
noise_scheduler=noise_scheduler,
28442876
learned_variance=learned_variance,
28452877
clip_denoised=clip_denoised
@@ -2865,6 +2897,7 @@ def p_sample_loop_ddim(
28652897
timesteps,
28662898
eta=1.,
28672899
predict_x_start=False,
2900+
predict_v=False,
28682901
learned_variance=False,
28692902
clip_denoised=True,
28702903
lowres_cond_img=None,
@@ -2926,7 +2959,9 @@ def p_sample_loop_ddim(
29262959

29272960
# predict x0
29282961

2929-
ifpredict_x_start:
2962+
ifpredict_v:
2963+
x_start=noise_scheduler.predict_start_from_v(img,t=time_cond,v=pred)
2964+
elifpredict_x_start:
29302965
x_start=pred
29312966
else:
29322967
x_start=noise_scheduler.predict_start_from_noise(img,t=time_cond,noise=pred)
@@ -2938,8 +2973,8 @@ def p_sample_loop_ddim(
29382973

29392974
# predict noise
29402975

2941-
ifpredict_x_start:
2942-
pred_noise=noise_scheduler.predict_noise_from_start(img,t=time_cond,x0=pred)
2976+
ifpredict_x_startorpredict_v:
2977+
pred_noise=noise_scheduler.predict_noise_from_start(img,t=time_cond,x0=x_start)
29432978
else:
29442979
pred_noise=pred
29452980

@@ -2975,7 +3010,7 @@ def p_sample_loop(self, *args, noise_scheduler, timesteps = None, **kwargs):
29753010

29763011
returnself.p_sample_loop_ddim(*args,noise_scheduler=noise_scheduler,timesteps=timesteps,**kwargs)
29773012

2978-
defp_losses(self,unet,x_start,times,*,image_embed,noise_scheduler,lowres_cond_img=None,text_encodings=None,predict_x_start=False,noise=None,learned_variance=False,clip_denoised=False,is_latent_diffusion=False,lowres_noise_level=None):
3013+
defp_losses(self,unet,x_start,times,*,image_embed,noise_scheduler,lowres_cond_img=None,text_encodings=None,predict_x_start=False,predict_v=False,noise=None,learned_variance=False,clip_denoised=False,is_latent_diffusion=False,lowres_noise_level=None):
29793014
noise=default(noise,lambda:torch.randn_like(x_start))
29803015

29813016
# normalize to [-1, 1]
@@ -3020,7 +3055,12 @@ def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres
30203055

30213056
pred,_=self.parse_unet_output(learned_variance,unet_output)
30223057

3023-
target=noiseifnotpredict_x_startelsex_start
3058+
ifpredict_v:
3059+
target=noise_scheduler.calculate_v(x_start,times,noise)
3060+
elifpredict_x_start:
3061+
target=x_start
3062+
else:
3063+
target=noise
30243064

30253065
loss=noise_scheduler.loss_fn(pred,target,reduction='none')
30263066
loss=reduce(loss,'b ... -> b (...)','mean')
@@ -3106,7 +3146,7 @@ def sample(
31063146
num_unets=self.num_unets
31073147
cond_scale=cast_tuple(cond_scale,num_unets)
31083148

3109-
forunet_number,unet,vae,channel,image_size,predict_x_start,learned_variance,noise_scheduler,lowres_cond,sample_timesteps,unet_cond_scaleintqdm(zip(range(1,num_unets+1),self.unets,self.vaes,self.sample_channels,self.image_sizes,self.predict_x_start,self.learned_variance,self.noise_schedulers,self.lowres_conds,self.sample_timesteps,cond_scale)):
3149+
forunet_number,unet,vae,channel,image_size,predict_x_start,predict_v,learned_variance,noise_scheduler,lowres_cond,sample_timesteps,unet_cond_scaleintqdm(zip(range(1,num_unets+1),self.unets,self.vaes,self.sample_channels,self.image_sizes,self.predict_x_start,self.predict_v,self.learned_variance,self.noise_schedulers,self.lowres_conds,self.sample_timesteps,cond_scale)):
31103150
ifunet_number<start_at_unet_number:
31113151
continue# It's the easiest way to do it
31123152

@@ -3142,6 +3182,7 @@ def sample(
31423182
text_encodings=text_encodings,
31433183
cond_scale=unet_cond_scale,
31443184
predict_x_start=predict_x_start,
3185+
predict_v=predict_v,
31453186
learned_variance=learned_variance,
31463187
clip_denoised=notis_latent_diffusion,
31473188
lowres_cond_img=lowres_cond_img,
@@ -3181,6 +3222,7 @@ def forward(
31813222
lowres_conditioner=self.lowres_conds[unet_index]
31823223
target_image_size=self.image_sizes[unet_index]
31833224
predict_x_start=self.predict_x_start[unet_index]
3225+
predict_v=self.predict_v[unet_index]
31843226
random_crop_size=self.random_crop_sizes[unet_index]
31853227
learned_variance=self.learned_variance[unet_index]
31863228
b,c,h,w,device,=*image.shape,image.device
@@ -3219,7 +3261,7 @@ def forward(
32193261
image=vae.encode(image)
32203262
lowres_cond_img=maybe(vae.encode)(lowres_cond_img)
32213263

3222-
losses=self.p_losses(unet,image,times,image_embed=image_embed,text_encodings=text_encodings,lowres_cond_img=lowres_cond_img,predict_x_start=predict_x_start,learned_variance=learned_variance,is_latent_diffusion=is_latent_diffusion,noise_scheduler=noise_scheduler,lowres_noise_level=lowres_noise_level)
3264+
losses=self.p_losses(unet,image,times,image_embed=image_embed,text_encodings=text_encodings,lowres_cond_img=lowres_cond_img,predict_x_start=predict_x_start,predict_v=predict_v,learned_variance=learned_variance,is_latent_diffusion=is_latent_diffusion,noise_scheduler=noise_scheduler,lowres_noise_level=lowres_noise_level)
32233265

32243266
ifnotreturn_lowres_cond_image:
32253267
returnlosses

‎dalle2_pytorch/version.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__='1.10.9'
1+
__version__='1.11.1'

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp