44import os
55import random
66import shutil
7+ import time
78
89import datasets
910import diffusers
3233from invokeai .backend .training .lora .networks .lora import LoRANetwork
3334
3435
35- def _initialize_accelerator (train_config :LoraTrainingConfig )-> Accelerator :
36+ def _initialize_accelerator (
37+ out_dir :str ,train_config :LoraTrainingConfig
38+ )-> Accelerator :
3639"""Configure Hugging Face accelerate and return an Accelerator.
3740
3841 Args:
42+ out_dir (str): The output directory where results will be written.
3943 train_config (LoraTrainingConfig): LoRA training configuration.
4044
4145 Returns:
4246 Accelerator
4347 """
4448accelerator_project_config = ProjectConfiguration (
45- project_dir = train_config . output_dir ,
46- logging_dir = os .path .join (train_config . output_dir ,"logs" ),
49+ project_dir = out_dir ,
50+ logging_dir = os .path .join (out_dir ,"logs" ),
4751 )
4852return Accelerator (
4953project_config = accelerator_project_config ,
@@ -379,6 +383,7 @@ def collate_fn(examples):
379383def _save_checkpoint (
380384idx :int ,
381385prefix :str ,
386+ out_dir :str ,
382387network :LoRANetwork ,
383388save_dtype :torch .dtype ,
384389train_config :LoraTrainingConfig ,
@@ -399,7 +404,7 @@ def _save_checkpoint(
399404# Before saving a checkpoint, check if this save would put us over the
400405# max_checkpoints limit.
401406if train_config .max_checkpoints is not None :
402- checkpoints = os .listdir (train_config . output_dir )
407+ checkpoints = os .listdir (out_dir )
403408checkpoints = [d for d in checkpoints if d .startswith (full_prefix )]
404409checkpoints = sorted (
405410checkpoints ,
@@ -419,7 +424,7 @@ def _save_checkpoint(
419424
420425for checkpoint_to_remove in checkpoints_to_remove :
421426checkpoint_to_remove = os .path .join (
422- train_config . output_dir ,checkpoint_to_remove
427+ out_dir ,checkpoint_to_remove
423428 )
424429if os .path .isfile (checkpoint_to_remove ):
425430# Delete checkpoint file.
@@ -428,7 +433,7 @@ def _save_checkpoint(
428433# Delete checkpoint directory.
429434shutil .rmtree (checkpoint_to_remove )
430435
431- save_path = os .path .join (train_config . output_dir ,f"{ full_prefix } { idx :0>8} " )
436+ save_path = os .path .join (out_dir ,f"{ full_prefix } { idx :0>8} " )
432437network .save_weights (save_path ,save_dtype ,None )
433438# accelerator.save_state(save_path)
434439logger .info (f"Saved state to{ save_path } " )
@@ -437,7 +442,9 @@ def _save_checkpoint(
437442def run_lora_training (
438443app_config :InvokeAIAppConfig ,train_config :LoraTrainingConfig
439444):
440- accelerator = _initialize_accelerator (train_config )
445+ out_dir = os .path .join (train_config .base_output_dir ,f"{ time .time ()} " )
446+
447+ accelerator = _initialize_accelerator (out_dir ,train_config )
441448logger = _initialize_logging (accelerator )
442449
443450# Set the accelerate seed.
@@ -552,7 +559,7 @@ def run_lora_training(
552559
553560# Initialize the trackers we use, and store the training configuration.
554561if accelerator .is_main_process :
555- accelerator .init_trackers (__name__ ,config = train_config .dict ())
562+ accelerator .init_trackers ("lora_training" ,config = train_config .dict ())
556563
557564# Train!
558565total_batch_size = (
@@ -699,6 +706,7 @@ def run_lora_training(
699706_save_checkpoint (
700707idx = global_step ,
701708prefix = "step" ,
709+ out_dir = out_dir ,
702710network = accelerator .unwrap_model (lora_network ),
703711save_dtype = weight_dtype ,
704712train_config = train_config ,