@@ -978,7 +978,10 @@ def __init__(
978978# dalle1 learned padding strategy
979979
980980self .max_text_len = max_text_len
981- self .null_text_embed = nn .Parameter (torch .randn (1 ,max_text_len ,dim ))
981+
982+ self .null_text_encodings = nn .Parameter (torch .randn (1 ,max_text_len ,dim ))
983+ self .null_text_embeds = nn .Parameter (torch .randn (1 ,num_text_embeds ,dim ))
984+ self .null_image_embed = nn .Parameter (torch .randn (1 ,dim ))
982985
983986# whether to use self conditioning, Hinton's group's new ddpm technique
984987
@@ -995,7 +998,7 @@ def forward_with_cond_scale(
995998if cond_scale == 1 :
996999return logits
9971000
998- null_logits = self .forward (* args ,cond_drop_prob = 1. ,** kwargs )
1001+ null_logits = self .forward (* args ,text_cond_drop_prob = 1. , image_cond_drop_prob = 1 ,** kwargs )
9991002return null_logits + (logits - null_logits )* cond_scale
10001003
10011004def forward (
@@ -1006,7 +1009,8 @@ def forward(
10061009text_embed ,
10071010text_encodings = None ,
10081011self_cond = None ,
1009- cond_drop_prob = 0.
1012+ text_cond_drop_prob = 0. ,
1013+ image_cond_drop_prob = 0.
10101014 ):
10111015batch ,dim ,device ,dtype = * image_embed .shape ,image_embed .device ,image_embed .dtype
10121016
@@ -1024,6 +1028,14 @@ def forward(
10241028text_embed = self .to_text_embeds (text_embed )
10251029image_embed = self .to_image_embeds (image_embed )
10261030
1031+ # classifier free guidance masks
1032+
1033+ text_keep_mask = prob_mask_like ((batch ,),1 - text_cond_drop_prob ,device = device )
1034+ text_keep_mask = rearrange (text_keep_mask ,'b -> b 1 1' )
1035+
1036+ image_keep_mask = prob_mask_like ((batch ,),1 - image_cond_drop_prob ,device = device )
1037+ image_keep_mask = rearrange (image_keep_mask ,'b -> b 1 1' )
1038+
10271039# make text encodings optional
10281040# although the paper seems to suggest it is present <--
10291041
@@ -1044,32 +1056,39 @@ def forward(
10441056text_encodings = F .pad (text_encodings , (0 ,0 ,0 ,remainder ),value = 0. )
10451057mask = F .pad (mask , (0 ,remainder ),value = False )
10461058
1047- null_text_embeds = self .null_text_embed .to (text_encodings .dtype )
1059+ # mask out text encodings with null encodings
1060+
1061+ null_text_encodings = self .null_text_encodings .to (text_encodings .dtype )
10481062
10491063text_encodings = torch .where (
1050- rearrange (mask ,'b n -> b n 1' ).clone (),
1064+ rearrange (mask ,'b n -> b n 1' ).clone ()& text_keep_mask ,
10511065text_encodings ,
1052- null_text_embeds
1066+ null_text_encodings
10531067 )
10541068
1055- #classifier free guidance
1069+ #mask out text embeddings with null text embeddings
10561070
1057- keep_mask = prob_mask_like ((batch ,),1 - cond_drop_prob ,device = device )
1058- keep_mask = rearrange (keep_mask ,'b -> b 1' )
1071+ null_text_embeds = self .null_text_embeds .to (text_embed .dtype )
10591072
1060- mask &= keep_mask
1073+ text_embeds = torch .where (
1074+ text_keep_mask ,
1075+ text_embed ,
1076+ null_text_embeds
1077+ )
1078+
1079+ # mask out image embeddings with null image embeddings
10611080
1062- # whether text embedding is masked or not depends on the classifier free guidance conditional masking
1081+ null_image_embed = self . null_image_embed . to ( image_embed . dtype )
10631082
1064- keep_mask = repeat (keep_mask ,'b 1 -> b n' ,n = num_text_embeds )
1065- mask = torch .cat ((mask ,keep_mask ),dim = 1 )
1083+ image_embed = torch .where (
1084+ image_keep_mask ,
1085+ image_embed ,
1086+ null_image_embed
1087+ )
10661088
10671089# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
10681090# but let's just do it right
10691091
1070- attend_padding = 1 + num_time_embeds + num_image_embeds + int (self .self_cond )# 1 for learned queries + number of image embeds + time embeds
1071- mask = F .pad (mask , (0 ,attend_padding ),value = True )# extend mask for text embedding, noised image embedding, time step embedding, and learned query
1072-
10731092time_embed = self .to_time_embeds (diffusion_timesteps )
10741093
10751094learned_queries = repeat (self .learned_query ,'d -> b 1 d' ,b = batch )
@@ -1107,6 +1126,8 @@ def __init__(
11071126timesteps = 1000 ,
11081127sample_timesteps = None ,
11091128cond_drop_prob = 0. ,
1129+ text_cond_drop_prob = None ,
1130+ image_cond_drop_prob = None ,
11101131loss_type = "l2" ,
11111132predict_x_start = True ,
11121133beta_schedule = "cosine" ,
@@ -1147,8 +1168,10 @@ def __init__(
11471168self .image_embed_dim = default (image_embed_dim ,lambda :clip .dim_latent )
11481169self .channels = default (image_channels ,lambda :clip .image_channels )
11491170
1150- self .cond_drop_prob = cond_drop_prob
1151- self .can_classifier_guidance = cond_drop_prob > 0.
1171+ self .text_cond_drop_prob = default (text_cond_drop_prob ,cond_drop_prob )
1172+ self .image_cond_drop_prob = default (image_cond_drop_prob ,cond_drop_prob )
1173+
1174+ self .can_classifier_guidance = self .text_cond_drop_prob > 0. and self .image_cond_drop_prob > 0.
11521175self .condition_on_text_encodings = condition_on_text_encodings
11531176
11541177# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
@@ -1308,7 +1331,8 @@ def p_losses(self, image_embed, times, text_cond, noise = None):
13081331image_embed_noisy ,
13091332times ,
13101333self_cond = self_cond ,
1311- cond_drop_prob = self .cond_drop_prob ,
1334+ text_cond_drop_prob = self .text_cond_drop_prob ,
1335+ image_cond_drop_prob = self .image_cond_drop_prob ,
13121336** text_cond
13131337 )
13141338