@@ -134,7 +134,7 @@ def get_example_data(dataloader, device, n=5):
134134break
135135return list (zip (images [:n ],img_embeddings [:n ],text_embeddings [:n ],captions [:n ]))
136136
137- def generate_samples (trainer ,example_data ,start_unet = 1 ,end_unet = None ,condition_on_text_encodings = False ,cond_scale = 1.0 ,device = None ,text_prepend = "" ,match_image_size = True ):
137+ def generate_samples (trainer ,example_data ,clip = None , start_unet = 1 ,end_unet = None ,condition_on_text_encodings = False ,cond_scale = 1.0 ,device = None ,text_prepend = "" ,match_image_size = True ):
138138"""
139139 Takes example data and generates images from the embeddings
140140 Returns three lists: real images, generated images, and captions
@@ -144,7 +144,9 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
144144if img_embeddings [0 ]is None :
145145# Generate image embeddings from clip
146146imgs_tensor = torch .stack (real_images )
147- img_embeddings ,* _ = trainer .embed_image (imgs_tensor )
147+ assert clip is not None ,"clip is None, but img_embeddings is None"
148+ imgs_tensor .to (device = device )
149+ img_embeddings ,img_encoding = clip .embed_image (imgs_tensor )
148150sample_params ["image_embed" ]= img_embeddings
149151else :
150152# Then we are using precomputed image embeddings
@@ -153,8 +155,10 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
153155if condition_on_text_encodings :
154156if text_embeddings [0 ]is None :
155157# Generate text embeddings from text
158+ assert clip is not None ,"clip is None, but text_embeddings is None"
156159tokenized_texts = tokenize (txts ,truncate = True )
157- sample_params ["text" ]= tokenized_texts
160+ text_embed ,text_encodings = clip .embed_text (tokenized_texts )
161+ sample_params ["text_encodings" ]= text_encodings
158162else :
159163# Then we are using precomputed text embeddings
160164text_embeddings = torch .stack (text_embeddings )
@@ -166,23 +170,23 @@ def generate_samples(trainer, example_data, start_unet=1, end_unet=None, conditi
166170sample_params ["image" ]= torch .stack (real_images )
167171if device is not None :
168172sample_params ["_device" ]= device
169- samples = trainer .sample (** sample_params )
173+ samples = trainer .sample (** sample_params , _cast_deepspeed_precision = False ) # At sampling time we don't want to cast to FP16
170174generated_images = list (samples )
171175captions = [text_prepend + txt for txt in txts ]
172176if match_image_size :
173177generated_image_size = generated_images [0 ].shape [- 1 ]
174178real_images = [resize_image_to (image ,generated_image_size ,clamp_range = (0 ,1 ))for image in real_images ]
175179return real_images ,generated_images ,captions
176180
177- def generate_grid_samples (trainer ,examples ,start_unet = 1 ,end_unet = None ,condition_on_text_encodings = False ,cond_scale = 1.0 ,device = None ,text_prepend = "" ):
181+ def generate_grid_samples (trainer ,examples ,clip = None , start_unet = 1 ,end_unet = None ,condition_on_text_encodings = False ,cond_scale = 1.0 ,device = None ,text_prepend = "" ):
178182"""
179183 Generates samples and uses torchvision to put them in a side by side grid for easy viewing
180184 """
181- real_images ,generated_images ,captions = generate_samples (trainer ,examples ,start_unet ,end_unet ,condition_on_text_encodings ,cond_scale ,device ,text_prepend )
185+ real_images ,generated_images ,captions = generate_samples (trainer ,examples ,clip , start_unet ,end_unet ,condition_on_text_encodings ,cond_scale ,device ,text_prepend )
182186grid_images = [torchvision .utils .make_grid ([original_image ,generated_image ])for original_image ,generated_image in zip (real_images ,generated_images )]
183187return grid_images ,captions
184188
185- def evaluate_trainer (trainer ,dataloader ,device ,start_unet ,end_unet ,condition_on_text_encodings = False ,cond_scale = 1.0 ,inference_device = None ,n_evaluation_samples = 1000 ,FID = None ,IS = None ,KID = None ,LPIPS = None ):
189+ def evaluate_trainer (trainer ,dataloader ,device ,start_unet ,end_unet , clip = None ,condition_on_text_encodings = False ,cond_scale = 1.0 ,inference_device = None ,n_evaluation_samples = 1000 ,FID = None ,IS = None ,KID = None ,LPIPS = None ):
186190"""
187191 Computes evaluation metrics for the decoder
188192 """
@@ -192,7 +196,7 @@ def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, conditi
192196if len (examples )== 0 :
193197print ("No data to evaluate. Check that your dataloader has shards." )
194198return metrics
195- real_images ,generated_images ,captions = generate_samples (trainer ,examples ,start_unet ,end_unet ,condition_on_text_encodings ,cond_scale ,inference_device )
199+ real_images ,generated_images ,captions = generate_samples (trainer ,examples ,clip , start_unet ,end_unet ,condition_on_text_encodings ,cond_scale ,inference_device )
196200real_images = torch .stack (real_images ).to (device = device ,dtype = torch .float )
197201generated_images = torch .stack (generated_images ).to (device = device ,dtype = torch .float )
198202# Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
@@ -265,6 +269,7 @@ def train(
265269accelerator :Accelerator ,
266270tracker :Tracker ,
267271inference_device ,
272+ clip = None ,
268273evaluate_config = None ,
269274epoch_samples = None ,# If the training dataset is resampling, we have to manually stop an epoch
270275validation_samples = None ,
@@ -371,15 +376,19 @@ def move_unets(unet_training_mask):
371376forward_params ['image_embed' ]= img_emb
372377else :
373378# Forward pass automatically generates embedding
374- pass
379+ assert clip is not None
380+ img_embed ,img_encoding = clip .embed_image (img )
381+ forward_params ['image_embed' ]= img_embed
375382if condition_on_text_encodings :
376383if has_text_embedding :
377384forward_params ['text_encodings' ]= text_emb
378385else :
379386# Then we need to pass the text instead
380- tokenized_texts = tokenize (txt ,truncate = True )
387+ assert clip is not None
388+ tokenized_texts = tokenize (txt ,truncate = True ).to (inference_device )
381389assert tokenized_texts .shape [0 ]== len (img ),f"The number of texts ({ tokenized_texts .shape [0 ]} ) should be the same as the number of images ({ len (img )} )"
382- forward_params ['text' ]= tokenized_texts
390+ text_embed ,text_encodings = clip .embed_text (tokenized_texts )
391+ forward_params ['text_encodings' ]= text_encodings
383392loss = trainer .forward (img ,** forward_params ,unet_number = unet ,_device = inference_device )
384393trainer .update (unet_number = unet )
385394unet_losses_tensor [i % TRAIN_CALC_LOSS_EVERY_ITERS ,unet - 1 ]= loss
@@ -419,7 +428,7 @@ def move_unets(unet_training_mask):
419428save_trainer (tracker ,trainer ,epoch ,sample ,next_task ,validation_losses ,samples_seen )
420429if exists (n_sample_images )and n_sample_images > 0 :
421430trainer .eval ()
422- train_images ,train_captions = generate_grid_samples (trainer ,train_example_data ,first_trainable_unet ,last_trainable_unet ,condition_on_text_encodings ,cond_scale ,inference_device ,"Train: " )
431+ train_images ,train_captions = generate_grid_samples (trainer ,train_example_data ,clip , first_trainable_unet ,last_trainable_unet ,condition_on_text_encodings ,cond_scale ,inference_device ,"Train: " )
423432tracker .log_images (train_images ,captions = train_captions ,image_section = "Train Samples" ,step = step ())
424433
425434if epoch_samples is not None and sample >= epoch_samples :
@@ -462,15 +471,19 @@ def move_unets(unet_training_mask):
462471forward_params ['image_embed' ]= img_emb .float ()
463472else :
464473# Forward pass automatically generates embedding
465- pass
474+ assert clip is not None
475+ img_embed ,img_encoding = clip .embed_image (img )
476+ forward_params ['image_embed' ]= img_embed
466477if condition_on_text_encodings :
467478if has_text_embedding :
468479forward_params ['text_encodings' ]= text_emb .float ()
469480else :
470481# Then we need to pass the text instead
482+ assert clip is not None
471483tokenized_texts = tokenize (txt ,truncate = True )
472484assert tokenized_texts .shape [0 ]== len (img ),f"The number of texts ({ tokenized_texts .shape [0 ]} ) should be the same as the number of images ({ len (img )} )"
473- forward_params ['text' ]= tokenized_texts
485+ text_embed ,text_encodings = clip .embed_text (tokenized_texts )
486+ forward_params ['text_encodings' ]= text_encodings
474487loss = trainer .forward (img .float (),** forward_params ,unet_number = unet ,_device = inference_device )
475488average_val_loss_tensor [0 ,unet - 1 ]+= loss
476489
@@ -498,7 +511,7 @@ def move_unets(unet_training_mask):
498511if next_task == 'eval' :
499512if exists (evaluate_config ):
500513accelerator .print (print_ribbon (f"Starting Evaluation{ epoch } " ,repeat = 40 ))
501- evaluation = evaluate_trainer (trainer ,dataloaders ["val" ],inference_device ,first_trainable_unet ,last_trainable_unet ,inference_device = inference_device ,** evaluate_config .dict (),condition_on_text_encodings = condition_on_text_encodings ,cond_scale = cond_scale )
514+ evaluation = evaluate_trainer (trainer ,dataloaders ["val" ],inference_device ,first_trainable_unet ,last_trainable_unet ,clip = clip , inference_device = inference_device ,** evaluate_config .dict (),condition_on_text_encodings = condition_on_text_encodings ,cond_scale = cond_scale )
502515if is_master :
503516tracker .log (evaluation ,step = step ())
504517next_task = 'sample'
@@ -509,8 +522,8 @@ def move_unets(unet_training_mask):
509522# Generate examples and save the model if we are the master
510523# Generate sample images
511524print (print_ribbon (f"Sampling Set{ epoch } " ,repeat = 40 ))
512- test_images ,test_captions = generate_grid_samples (trainer ,test_example_data ,first_trainable_unet ,last_trainable_unet ,condition_on_text_encodings ,cond_scale ,inference_device ,"Test: " )
513- train_images ,train_captions = generate_grid_samples (trainer ,train_example_data ,first_trainable_unet ,last_trainable_unet ,condition_on_text_encodings ,cond_scale ,inference_device ,"Train: " )
525+ test_images ,test_captions = generate_grid_samples (trainer ,test_example_data ,clip , first_trainable_unet ,last_trainable_unet ,condition_on_text_encodings ,cond_scale ,inference_device ,"Test: " )
526+ train_images ,train_captions = generate_grid_samples (trainer ,train_example_data ,clip , first_trainable_unet ,last_trainable_unet ,condition_on_text_encodings ,cond_scale ,inference_device ,"Train: " )
514527tracker .log_images (test_images ,captions = test_captions ,image_section = "Test Samples" ,step = step ())
515528tracker .log_images (train_images ,captions = train_captions ,image_section = "Train Samples" ,step = step ())
516529
@@ -532,6 +545,7 @@ def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_
532545"NumProcesses" :accelerator .num_processes ,
533546"MixedPrecision" :accelerator .mixed_precision
534547 }
548+ accelerator .wait_for_everyone ()# If nodes arrive at this point at different times they might try to autoresume the current run which makes no sense and will cause errors
535549tracker :Tracker = tracker_config .create (config ,accelerator_config ,dummy_mode = dummy )
536550tracker .save_config (config_path ,config_name = 'decoder_config.json' )
537551tracker .add_save_metadata (state_dict_key = 'config' ,metadata = config .dict ())
@@ -555,10 +569,6 @@ def initialize_training(config: TrainDecoderConfig, config_path):
555569# If we are in deepspeed fp16 mode, we must ensure learned variance is off
556570if accelerator .mixed_precision == "fp16" and accelerator .distributed_type == accelerate_dataclasses .DistributedType .DEEPSPEED and config .decoder .learned_variance :
557571raise ValueError ("DeepSpeed fp16 mode does not support learned variance" )
558-
559- if accelerator .process_index != accelerator .local_process_index and accelerator .distributed_type == accelerate_dataclasses .DistributedType .DEEPSPEED :
560- # This is an invalid configuration until we figure out how to handle this
561- raise ValueError ("DeepSpeed does not support multi-node distributed training" )
562572
563573# Set up data
564574all_shards = list (range (config .data .start_shard ,config .data .end_shard + 1 ))
@@ -579,6 +589,11 @@ def initialize_training(config: TrainDecoderConfig, config_path):
579589seed = config .seed ,
580590 )
581591
592+ # If clip is in the model, we need to remove it for compatibility with deepspeed
593+ clip = None
594+ if config .decoder .clip is not None :
595+ clip = config .decoder .clip .create ()# Of course we keep it to use it during training, just not in the decoder as that causes issues
596+ config .decoder .clip = None
582597# Create the decoder model and print basic info
583598decoder = config .decoder .create ()
584599get_num_parameters = lambda model ,only_training = False :sum (p .numel ()for p in model .parameters ()if (p .requires_grad or not only_training ))
@@ -590,7 +605,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
590605has_text_embeddings = config .data .text_embeddings_url is not None
591606conditioning_on_text = any ([unet .cond_on_text_encodings for unet in config .decoder .unets ])
592607
593- has_clip_model = config . decoder . clip is not None
608+ has_clip_model = clip is not None
594609data_source_string = ""
595610
596611if has_img_embeddings :
@@ -615,6 +630,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
615630accelerator .print (f"Unet{ i } has{ get_num_parameters (unet )} total;{ get_num_parameters (unet ,only_training = True )} training" )
616631
617632train (dataloaders ,decoder ,accelerator ,
633+ clip = clip ,
618634tracker = tracker ,
619635inference_device = accelerator .device ,
620636evaluate_config = config .evaluate ,