Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit0d82dff

Browse files
committed
in ddim, noise should be predicted after x0 is maybe clipped, thanks to@lukovnikov for pointing this out in another repository
1 parent8bbc956 commit0d82dff

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

‎dalle2_pytorch/dalle2_pytorch.py‎

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,19 +1281,28 @@ def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scal
12811281

12821282
pred=self.net.forward_with_cond_scale(image_embed,time_cond,self_cond=self_cond,cond_scale=cond_scale,**text_cond)
12831283

1284+
# derive x0
1285+
12841286
ifself.predict_x_start:
12851287
x_start=pred
1286-
pred_noise=self.noise_scheduler.predict_noise_from_start(image_embed,t=time_cond,x0=pred)
12871288
else:
1288-
x_start=self.noise_scheduler.predict_start_from_noise(image_embed,t=time_cond,noise=pred)
1289-
pred_noise=pred
1289+
x_start=self.noise_scheduler.predict_start_from_noise(image_embed,t=time_cond,noise=pred_noise)
1290+
1291+
# clip x0 before maybe predicting noise
12901292

12911293
ifnotself.predict_x_start:
12921294
x_start.clamp_(-1.,1.)
12931295

12941296
ifself.predict_x_startandself.sampling_clamp_l2norm:
12951297
x_start=self.l2norm_clamp_embed(x_start)
12961298

1299+
# predict noise
1300+
1301+
ifself.predict_x_start:
1302+
pred_noise=self.noise_scheduler.predict_noise_from_start(image_embed,t=time_cond,x0=x_start)
1303+
else:
1304+
pred_noise=pred
1305+
12971306
iftime_next<0:
12981307
image_embed=x_start
12991308
continue
@@ -2897,16 +2906,25 @@ def p_sample_loop_ddim(
28972906

28982907
pred,_=self.parse_unet_output(learned_variance,unet_output)
28992908

2909+
# predict x0
2910+
29002911
ifpredict_x_start:
29012912
x_start=pred
2902-
pred_noise=noise_scheduler.predict_noise_from_start(img,t=time_cond,x0=pred)
29032913
else:
29042914
x_start=noise_scheduler.predict_start_from_noise(img,t=time_cond,noise=pred)
2905-
pred_noise=pred
2915+
2916+
# maybe clip x0
29062917

29072918
ifclip_denoised:
29082919
x_start=self.dynamic_threshold(x_start)
29092920

2921+
# predict noise
2922+
2923+
ifpredict_x_start:
2924+
pred_noise=noise_scheduler.predict_noise_from_start(img,t=time_cond,x0=pred)
2925+
else:
2926+
pred_noise=pred
2927+
29102928
c1=eta* ((1-alpha/alpha_next)* (1-alpha_next)/ (1-alpha)).sqrt()
29112929
c2= ((1-alpha_next)-torch.square(c1)).sqrt()
29122930
noise=torch.randn_like(img)ifnotis_last_timestepelse0.

‎dalle2_pytorch/version.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__='1.10.4'
1+
__version__='1.10.5'

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp