Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit59fa101

Browse files
committed
fix classifier free guidance for diffusion prior, thanks to@jaykim9870 for spotting the issue
1 parent916ece1 commit59fa101

File tree

2 files changed

+44
-20
lines changed

2 files changed

+44
-20
lines changed

‎dalle2_pytorch/dalle2_pytorch.py‎

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -978,7 +978,10 @@ def __init__(
978978
# dalle1 learned padding strategy
979979

980980
self.max_text_len=max_text_len
981-
self.null_text_embed=nn.Parameter(torch.randn(1,max_text_len,dim))
981+
982+
self.null_text_encodings=nn.Parameter(torch.randn(1,max_text_len,dim))
983+
self.null_text_embeds=nn.Parameter(torch.randn(1,num_text_embeds,dim))
984+
self.null_image_embed=nn.Parameter(torch.randn(1,dim))
982985

983986
# whether to use self conditioning, Hinton's group's new ddpm technique
984987

@@ -995,7 +998,7 @@ def forward_with_cond_scale(
995998
ifcond_scale==1:
996999
returnlogits
9971000

998-
null_logits=self.forward(*args,cond_drop_prob=1.,**kwargs)
1001+
null_logits=self.forward(*args,text_cond_drop_prob=1.,image_cond_drop_prob=1,**kwargs)
9991002
returnnull_logits+ (logits-null_logits)*cond_scale
10001003

10011004
defforward(
@@ -1006,7 +1009,8 @@ def forward(
10061009
text_embed,
10071010
text_encodings=None,
10081011
self_cond=None,
1009-
cond_drop_prob=0.
1012+
text_cond_drop_prob=0.,
1013+
image_cond_drop_prob=0.
10101014
):
10111015
batch,dim,device,dtype=*image_embed.shape,image_embed.device,image_embed.dtype
10121016

@@ -1024,6 +1028,14 @@ def forward(
10241028
text_embed=self.to_text_embeds(text_embed)
10251029
image_embed=self.to_image_embeds(image_embed)
10261030

1031+
# classifier free guidance masks
1032+
1033+
text_keep_mask=prob_mask_like((batch,),1-text_cond_drop_prob,device=device)
1034+
text_keep_mask=rearrange(text_keep_mask,'b -> b 1 1')
1035+
1036+
image_keep_mask=prob_mask_like((batch,),1-image_cond_drop_prob,device=device)
1037+
image_keep_mask=rearrange(image_keep_mask,'b -> b 1 1')
1038+
10271039
# make text encodings optional
10281040
# although the paper seems to suggest it is present <--
10291041

@@ -1044,32 +1056,39 @@ def forward(
10441056
text_encodings=F.pad(text_encodings, (0,0,0,remainder),value=0.)
10451057
mask=F.pad(mask, (0,remainder),value=False)
10461058

1047-
null_text_embeds=self.null_text_embed.to(text_encodings.dtype)
1059+
# mask out text encodings with null encodings
1060+
1061+
null_text_encodings=self.null_text_encodings.to(text_encodings.dtype)
10481062

10491063
text_encodings=torch.where(
1050-
rearrange(mask,'b n -> b n 1').clone(),
1064+
rearrange(mask,'b n -> b n 1').clone()&text_keep_mask,
10511065
text_encodings,
1052-
null_text_embeds
1066+
null_text_encodings
10531067
)
10541068

1055-
#classifier free guidance
1069+
#mask out text embeddings with null text embeddings
10561070

1057-
keep_mask=prob_mask_like((batch,),1-cond_drop_prob,device=device)
1058-
keep_mask=rearrange(keep_mask,'b -> b 1')
1071+
null_text_embeds=self.null_text_embeds.to(text_embed.dtype)
10591072

1060-
mask&=keep_mask
1073+
text_embeds=torch.where(
1074+
text_keep_mask,
1075+
text_embed,
1076+
null_text_embeds
1077+
)
1078+
1079+
# mask out image embeddings with null image embeddings
10611080

1062-
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
1081+
null_image_embed=self.null_image_embed.to(image_embed.dtype)
10631082

1064-
keep_mask=repeat(keep_mask,'b 1 -> b n',n=num_text_embeds)
1065-
mask=torch.cat((mask,keep_mask),dim=1)
1083+
image_embed=torch.where(
1084+
image_keep_mask,
1085+
image_embed,
1086+
null_image_embed
1087+
)
10661088

10671089
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
10681090
# but let's just do it right
10691091

1070-
attend_padding=1+num_time_embeds+num_image_embeds+int(self.self_cond)# 1 for learned queries + number of image embeds + time embeds
1071-
mask=F.pad(mask, (0,attend_padding),value=True)# extend mask for text embedding, noised image embedding, time step embedding, and learned query
1072-
10731092
time_embed=self.to_time_embeds(diffusion_timesteps)
10741093

10751094
learned_queries=repeat(self.learned_query,'d -> b 1 d',b=batch)
@@ -1107,6 +1126,8 @@ def __init__(
11071126
timesteps=1000,
11081127
sample_timesteps=None,
11091128
cond_drop_prob=0.,
1129+
text_cond_drop_prob=None,
1130+
image_cond_drop_prob=None,
11101131
loss_type="l2",
11111132
predict_x_start=True,
11121133
beta_schedule="cosine",
@@ -1147,8 +1168,10 @@ def __init__(
11471168
self.image_embed_dim=default(image_embed_dim,lambda:clip.dim_latent)
11481169
self.channels=default(image_channels,lambda:clip.image_channels)
11491170

1150-
self.cond_drop_prob=cond_drop_prob
1151-
self.can_classifier_guidance=cond_drop_prob>0.
1171+
self.text_cond_drop_prob=default(text_cond_drop_prob,cond_drop_prob)
1172+
self.image_cond_drop_prob=default(image_cond_drop_prob,cond_drop_prob)
1173+
1174+
self.can_classifier_guidance=self.text_cond_drop_prob>0.andself.image_cond_drop_prob>0.
11521175
self.condition_on_text_encodings=condition_on_text_encodings
11531176

11541177
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
@@ -1308,7 +1331,8 @@ def p_losses(self, image_embed, times, text_cond, noise = None):
13081331
image_embed_noisy,
13091332
times,
13101333
self_cond=self_cond,
1311-
cond_drop_prob=self.cond_drop_prob,
1334+
text_cond_drop_prob=self.text_cond_drop_prob,
1335+
image_cond_drop_prob=self.image_cond_drop_prob,
13121336
**text_cond
13131337
)
13141338

‎dalle2_pytorch/version.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__='1.8.4'
1+
__version__='1.9.0'

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp