Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit69e2217

Browse files
committed
Clean up LoRA training output directory structure.
1 parent5f6282b commit69e2217

File tree

3 files changed

+21
-12
lines changed

3 files changed

+21
-12
lines changed

‎invokeai/backend/training/lora/lora_training.py‎

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
importos
55
importrandom
66
importshutil
7+
importtime
78

89
importdatasets
910
importdiffusers
@@ -32,18 +33,21 @@
3233
frominvokeai.backend.training.lora.networks.loraimportLoRANetwork
3334

3435

35-
def_initialize_accelerator(train_config:LoraTrainingConfig)->Accelerator:
36+
def_initialize_accelerator(
37+
out_dir:str,train_config:LoraTrainingConfig
38+
)->Accelerator:
3639
"""Configure Hugging Face accelerate and return an Accelerator.
3740
3841
Args:
42+
out_dir (str): The output directory where results will be written.
3943
train_config (LoraTrainingConfig): LoRA training configuration.
4044
4145
Returns:
4246
Accelerator
4347
"""
4448
accelerator_project_config=ProjectConfiguration(
45-
project_dir=train_config.output_dir,
46-
logging_dir=os.path.join(train_config.output_dir,"logs"),
49+
project_dir=out_dir,
50+
logging_dir=os.path.join(out_dir,"logs"),
4751
)
4852
returnAccelerator(
4953
project_config=accelerator_project_config,
@@ -379,6 +383,7 @@ def collate_fn(examples):
379383
def_save_checkpoint(
380384
idx:int,
381385
prefix:str,
386+
out_dir:str,
382387
network:LoRANetwork,
383388
save_dtype:torch.dtype,
384389
train_config:LoraTrainingConfig,
@@ -399,7 +404,7 @@ def _save_checkpoint(
399404
# Before saving a checkpoint, check if this save would put us over the
400405
# max_checkpoints limit.
401406
iftrain_config.max_checkpointsisnotNone:
402-
checkpoints=os.listdir(train_config.output_dir)
407+
checkpoints=os.listdir(out_dir)
403408
checkpoints= [dfordincheckpointsifd.startswith(full_prefix)]
404409
checkpoints=sorted(
405410
checkpoints,
@@ -419,7 +424,7 @@ def _save_checkpoint(
419424

420425
forcheckpoint_to_removeincheckpoints_to_remove:
421426
checkpoint_to_remove=os.path.join(
422-
train_config.output_dir,checkpoint_to_remove
427+
out_dir,checkpoint_to_remove
423428
)
424429
ifos.path.isfile(checkpoint_to_remove):
425430
# Delete checkpoint file.
@@ -428,7 +433,7 @@ def _save_checkpoint(
428433
# Delete checkpoint directory.
429434
shutil.rmtree(checkpoint_to_remove)
430435

431-
save_path=os.path.join(train_config.output_dir,f"{full_prefix}{idx:0>8}")
436+
save_path=os.path.join(out_dir,f"{full_prefix}{idx:0>8}")
432437
network.save_weights(save_path,save_dtype,None)
433438
# accelerator.save_state(save_path)
434439
logger.info(f"Saved state to{save_path}")
@@ -437,7 +442,9 @@ def _save_checkpoint(
437442
defrun_lora_training(
438443
app_config:InvokeAIAppConfig,train_config:LoraTrainingConfig
439444
):
440-
accelerator=_initialize_accelerator(train_config)
445+
out_dir=os.path.join(train_config.base_output_dir,f"{time.time()}")
446+
447+
accelerator=_initialize_accelerator(out_dir,train_config)
441448
logger=_initialize_logging(accelerator)
442449

443450
# Set the accelerate seed.
@@ -552,7 +559,7 @@ def run_lora_training(
552559

553560
# Initialize the trackers we use, and store the training configuration.
554561
ifaccelerator.is_main_process:
555-
accelerator.init_trackers(__name__,config=train_config.dict())
562+
accelerator.init_trackers("lora_training",config=train_config.dict())
556563

557564
# Train!
558565
total_batch_size= (
@@ -699,6 +706,7 @@ def run_lora_training(
699706
_save_checkpoint(
700707
idx=global_step,
701708
prefix="step",
709+
out_dir=out_dir,
702710
network=accelerator.unwrap_model(lora_network),
703711
save_dtype=weight_dtype,
704712
train_config=train_config,

‎invokeai/backend/training/lora/lora_training_config.py‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ class LoraTrainingConfig(BaseModel):
1212
##################
1313

1414
# The output directory where the training outputs (model checkpoints, logs,
15-
# intermediate predictions) will be written.
16-
output_dir:str
15+
# intermediate predictions) will be written. A subdirectory will be created
16+
# with a timestamp for each new training run.
17+
base_output_dir:str
1718

1819
# The integration to report results and logs to ('all', 'tensorboard',
1920
# 'wandb', or 'comet_ml'). This value is passed to Hugging Face Accelerate.

‎invokeai/frontend/training/lora.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def parse_args():
2525
),
2626
)
2727
parser.add_argument(
28-
"--output_dir",
28+
"--base_output_dir",
2929
type=str,
3030
# TODO(ryand): Decide on a training directory structure and update for
3131
# consistency with TI training.
@@ -49,7 +49,7 @@ def main():
4949
cfg=yaml.safe_load(f)
5050

5151
# Override 'output_dir' config.
52-
cfg["output_dir"]=args.output_dir
52+
cfg["base_output_dir"]=args.base_output_dir
5353

5454
train_config=LoraTrainingConfig(**cfg)
5555
run_lora_training(app_config,train_config)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp