Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit916ece1

Browse files
authored
Merge pull requestlucidrains#234 from Veldrovive/deepspeed_fp16
Fixed issues with clip and deepspeed fp16
2 parents083508f +cbaadb6 commit916ece1

File tree

3 files changed

+40
-24
lines changed

3 files changed

+40
-24
lines changed

‎dalle2_pytorch/train_configs.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ class DecoderConfig(BaseModel):
241241
clip:Optional[AdapterConfig]# The clip model to use if embeddings are not provided
242242
channels:int=3
243243
timesteps:int=1000
244-
sample_timesteps:Optional[SingularOrIterable[int]]=None
244+
sample_timesteps:Optional[SingularOrIterable[Optional[int]]]=None
245245
loss_type:str='l2'
246246
beta_schedule:ListOrTuple[str]=None# None means all cosine
247247
learned_variance:SingularOrIterable[bool]=True

‎dalle2_pytorch/trainer.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def __init__(
519519
clip=decoder.clip
520520
clip.to(precision_type)
521521

522-
decoder,*optimizers=list(self.accelerator.prepare(decoder,*optimizers))
522+
decoder,train_dataloader,*optimizers=list(self.accelerator.prepare(decoder,dataloaders['train'],*optimizers))
523523

524524
self.decoder=decoder
525525

‎train_decoder.py‎

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def get_example_data(dataloader, device, n=5):
134134
break
135135
returnlist(zip(images[:n],img_embeddings[:n],text_embeddings[:n],captions[:n]))
136136

137-
defgenerate_samples(trainer,example_data,start_unet=1,end_unet=None,condition_on_text_encodings=False,cond_scale=1.0,device=None,text_prepend="",match_image_size=True):
137+
defgenerate_samples(trainer,example_data,clip=None,start_unet=1,end_unet=None,condition_on_text_encodings=False,cond_scale=1.0,device=None,text_prepend="",match_image_size=True):
138138
"""
139139
Takes example data and generates images from the embeddings
140140
Returns three lists: real images, generated images, and captions
@@ -144,7 +144,9 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
144144
ifimg_embeddings[0]isNone:
145145
# Generate image embeddings from clip
146146
imgs_tensor=torch.stack(real_images)
147-
img_embeddings,*_=trainer.embed_image(imgs_tensor)
147+
assertclipisnotNone,"clip is None, but img_embeddings is None"
148+
imgs_tensor.to(device=device)
149+
img_embeddings,img_encoding=clip.embed_image(imgs_tensor)
148150
sample_params["image_embed"]=img_embeddings
149151
else:
150152
# Then we are using precomputed image embeddings
@@ -153,8 +155,10 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
153155
ifcondition_on_text_encodings:
154156
iftext_embeddings[0]isNone:
155157
# Generate text embeddings from text
158+
assertclipisnotNone,"clip is None, but text_embeddings is None"
156159
tokenized_texts=tokenize(txts,truncate=True)
157-
sample_params["text"]=tokenized_texts
160+
text_embed,text_encodings=clip.embed_text(tokenized_texts)
161+
sample_params["text_encodings"]=text_encodings
158162
else:
159163
# Then we are using precomputed text embeddings
160164
text_embeddings=torch.stack(text_embeddings)
@@ -166,23 +170,23 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
166170
sample_params["image"]=torch.stack(real_images)
167171
ifdeviceisnotNone:
168172
sample_params["_device"]=device
169-
samples=trainer.sample(**sample_params)
173+
samples=trainer.sample(**sample_params,_cast_deepspeed_precision=False)# At sampling time we don't want to cast to FP16
170174
generated_images=list(samples)
171175
captions= [text_prepend+txtfortxtintxts]
172176
ifmatch_image_size:
173177
generated_image_size=generated_images[0].shape[-1]
174178
real_images= [resize_image_to(image,generated_image_size,clamp_range=(0,1))forimageinreal_images]
175179
returnreal_images,generated_images,captions
176180

177-
defgenerate_grid_samples(trainer,examples,start_unet=1,end_unet=None,condition_on_text_encodings=False,cond_scale=1.0,device=None,text_prepend=""):
181+
defgenerate_grid_samples(trainer,examples,clip=None,start_unet=1,end_unet=None,condition_on_text_encodings=False,cond_scale=1.0,device=None,text_prepend=""):
178182
"""
179183
Generates samples and uses torchvision to put them in a side by side grid for easy viewing
180184
"""
181-
real_images,generated_images,captions=generate_samples(trainer,examples,start_unet,end_unet,condition_on_text_encodings,cond_scale,device,text_prepend)
185+
real_images,generated_images,captions=generate_samples(trainer,examples,clip,start_unet,end_unet,condition_on_text_encodings,cond_scale,device,text_prepend)
182186
grid_images= [torchvision.utils.make_grid([original_image,generated_image])fororiginal_image,generated_imageinzip(real_images,generated_images)]
183187
returngrid_images,captions
184188

185-
defevaluate_trainer(trainer,dataloader,device,start_unet,end_unet,condition_on_text_encodings=False,cond_scale=1.0,inference_device=None,n_evaluation_samples=1000,FID=None,IS=None,KID=None,LPIPS=None):
189+
defevaluate_trainer(trainer,dataloader,device,start_unet,end_unet,clip=None,condition_on_text_encodings=False,cond_scale=1.0,inference_device=None,n_evaluation_samples=1000,FID=None,IS=None,KID=None,LPIPS=None):
186190
"""
187191
Computes evaluation metrics for the decoder
188192
"""
@@ -192,7 +196,7 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, conditi
192196
iflen(examples)==0:
193197
print("No data to evaluate. Check that your dataloader has shards.")
194198
returnmetrics
195-
real_images,generated_images,captions=generate_samples(trainer,examples,start_unet,end_unet,condition_on_text_encodings,cond_scale,inference_device)
199+
real_images,generated_images,captions=generate_samples(trainer,examples,clip,start_unet,end_unet,condition_on_text_encodings,cond_scale,inference_device)
196200
real_images=torch.stack(real_images).to(device=device,dtype=torch.float)
197201
generated_images=torch.stack(generated_images).to(device=device,dtype=torch.float)
198202
# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
@@ -265,6 +269,7 @@ def train(
265269
accelerator:Accelerator,
266270
tracker:Tracker,
267271
inference_device,
272+
clip=None,
268273
evaluate_config=None,
269274
epoch_samples=None,# If the training dataset is resampling, we have to manually stop an epoch
270275
validation_samples=None,
@@ -371,15 +376,19 @@ def move_unets(unet_training_mask):
371376
forward_params['image_embed']=img_emb
372377
else:
373378
# Forward pass automatically generates embedding
374-
pass
379+
assertclipisnotNone
380+
img_embed,img_encoding=clip.embed_image(img)
381+
forward_params['image_embed']=img_embed
375382
ifcondition_on_text_encodings:
376383
ifhas_text_embedding:
377384
forward_params['text_encodings']=text_emb
378385
else:
379386
# Then we need to pass the text instead
380-
tokenized_texts=tokenize(txt,truncate=True)
387+
assertclipisnotNone
388+
tokenized_texts=tokenize(txt,truncate=True).to(inference_device)
381389
asserttokenized_texts.shape[0]==len(img),f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
382-
forward_params['text']=tokenized_texts
390+
text_embed,text_encodings=clip.embed_text(tokenized_texts)
391+
forward_params['text_encodings']=text_encodings
383392
loss=trainer.forward(img,**forward_params,unet_number=unet,_device=inference_device)
384393
trainer.update(unet_number=unet)
385394
unet_losses_tensor[i%TRAIN_CALC_LOSS_EVERY_ITERS,unet-1]=loss
@@ -419,7 +428,7 @@ def move_unets(unet_training_mask):
419428
save_trainer(tracker,trainer,epoch,sample,next_task,validation_losses,samples_seen)
420429
ifexists(n_sample_images)andn_sample_images>0:
421430
trainer.eval()
422-
train_images,train_captions=generate_grid_samples(trainer,train_example_data,first_trainable_unet,last_trainable_unet,condition_on_text_encodings,cond_scale,inference_device,"Train: ")
431+
train_images,train_captions=generate_grid_samples(trainer,train_example_data,clip,first_trainable_unet,last_trainable_unet,condition_on_text_encodings,cond_scale,inference_device,"Train: ")
423432
tracker.log_images(train_images,captions=train_captions,image_section="Train Samples",step=step())
424433

425434
ifepoch_samplesisnotNoneandsample>=epoch_samples:
@@ -462,15 +471,19 @@ def move_unets(unet_training_mask):
462471
forward_params['image_embed']=img_emb.float()
463472
else:
464473
# Forward pass automatically generates embedding
465-
pass
474+
assertclipisnotNone
475+
img_embed,img_encoding=clip.embed_image(img)
476+
forward_params['image_embed']=img_embed
466477
ifcondition_on_text_encodings:
467478
ifhas_text_embedding:
468479
forward_params['text_encodings']=text_emb.float()
469480
else:
470481
# Then we need to pass the text instead
482+
assertclipisnotNone
471483
tokenized_texts=tokenize(txt,truncate=True)
472484
asserttokenized_texts.shape[0]==len(img),f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
473-
forward_params['text']=tokenized_texts
485+
text_embed,text_encodings=clip.embed_text(tokenized_texts)
486+
forward_params['text_encodings']=text_encodings
474487
loss=trainer.forward(img.float(),**forward_params,unet_number=unet,_device=inference_device)
475488
average_val_loss_tensor[0,unet-1]+=loss
476489

@@ -498,7 +511,7 @@ def move_unets(unet_training_mask):
498511
ifnext_task=='eval':
499512
ifexists(evaluate_config):
500513
accelerator.print(print_ribbon(f"Starting Evaluation{epoch}",repeat=40))
501-
evaluation=evaluate_trainer(trainer,dataloaders["val"],inference_device,first_trainable_unet,last_trainable_unet,inference_device=inference_device,**evaluate_config.dict(),condition_on_text_encodings=condition_on_text_encodings,cond_scale=cond_scale)
514+
evaluation=evaluate_trainer(trainer,dataloaders["val"],inference_device,first_trainable_unet,last_trainable_unet,clip=clip,inference_device=inference_device,**evaluate_config.dict(),condition_on_text_encodings=condition_on_text_encodings,cond_scale=cond_scale)
502515
ifis_master:
503516
tracker.log(evaluation,step=step())
504517
next_task='sample'
@@ -509,8 +522,8 @@ def move_unets(unet_training_mask):
509522
# Generate examples and save the model if we are the master
510523
# Generate sample images
511524
print(print_ribbon(f"Sampling Set{epoch}",repeat=40))
512-
test_images,test_captions=generate_grid_samples(trainer,test_example_data,first_trainable_unet,last_trainable_unet,condition_on_text_encodings,cond_scale,inference_device,"Test: ")
513-
train_images,train_captions=generate_grid_samples(trainer,train_example_data,first_trainable_unet,last_trainable_unet,condition_on_text_encodings,cond_scale,inference_device,"Train: ")
525+
test_images,test_captions=generate_grid_samples(trainer,test_example_data,clip,first_trainable_unet,last_trainable_unet,condition_on_text_encodings,cond_scale,inference_device,"Test: ")
526+
train_images,train_captions=generate_grid_samples(trainer,train_example_data,clip,first_trainable_unet,last_trainable_unet,condition_on_text_encodings,cond_scale,inference_device,"Train: ")
514527
tracker.log_images(test_images,captions=test_captions,image_section="Test Samples",step=step())
515528
tracker.log_images(train_images,captions=train_captions,image_section="Train Samples",step=step())
516529

@@ -532,6 +545,7 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_
532545
"NumProcesses":accelerator.num_processes,
533546
"MixedPrecision":accelerator.mixed_precision
534547
}
548+
accelerator.wait_for_everyone()# If nodes arrive at this point at different times they might try to autoresume the current run which makes no sense and will cause errors
535549
tracker:Tracker=tracker_config.create(config,accelerator_config,dummy_mode=dummy)
536550
tracker.save_config(config_path,config_name='decoder_config.json')
537551
tracker.add_save_metadata(state_dict_key='config',metadata=config.dict())
@@ -555,10 +569,6 @@ def initialize_training(config: TrainDecoderConfig, config_path):
555569
# If we are in deepspeed fp16 mode, we must ensure learned variance is off
556570
ifaccelerator.mixed_precision=="fp16"andaccelerator.distributed_type==accelerate_dataclasses.DistributedType.DEEPSPEEDandconfig.decoder.learned_variance:
557571
raiseValueError("DeepSpeed fp16 mode does not support learned variance")
558-
559-
ifaccelerator.process_index!=accelerator.local_process_indexandaccelerator.distributed_type==accelerate_dataclasses.DistributedType.DEEPSPEED:
560-
# This is an invalid configuration until we figure out how to handle this
561-
raiseValueError("DeepSpeed does not support multi-node distributed training")
562572

563573
# Set up data
564574
all_shards=list(range(config.data.start_shard,config.data.end_shard+1))
@@ -579,6 +589,11 @@ def initialize_training(config: TrainDecoderConfig, config_path):
579589
seed=config.seed,
580590
)
581591

592+
# If clip is in the model, we need to remove it for compatibility with deepspeed
593+
clip=None
594+
ifconfig.decoder.clipisnotNone:
595+
clip=config.decoder.clip.create()# Of course we keep it to use it during training, just not in the decoder as that causes issues
596+
config.decoder.clip=None
582597
# Create the decoder model and print basic info
583598
decoder=config.decoder.create()
584599
get_num_parameters=lambdamodel,only_training=False:sum(p.numel()forpinmodel.parameters()if (p.requires_gradornotonly_training))
@@ -590,7 +605,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
590605
has_text_embeddings=config.data.text_embeddings_urlisnotNone
591606
conditioning_on_text=any([unet.cond_on_text_encodingsforunetinconfig.decoder.unets])
592607

593-
has_clip_model=config.decoder.clipisnotNone
608+
has_clip_model=clipisnotNone
594609
data_source_string=""
595610

596611
ifhas_img_embeddings:
@@ -615,6 +630,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
615630
accelerator.print(f"Unet{i} has{get_num_parameters(unet)} total;{get_num_parameters(unet,only_training=True)} training")
616631

617632
train(dataloaders,decoder,accelerator,
633+
clip=clip,
618634
tracker=tracker,
619635
inference_device=accelerator.device,
620636
evaluate_config=config.evaluate,

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp