Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitece9c26

Browse files
committed
Add flag for xformers use during LoRA training.
1 parentd111156 commitece9c26

File tree

3 files changed

+41
-14
lines changed

3 files changed

+41
-14
lines changed

‎invokeai/backend/training/lora/lora_training.py‎

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
importjson
22
importlogging
3-
importtorch
43

54
importdatasets
65
importdiffusers
6+
importtorch
77
importtransformers
88
fromaccelerateimportAccelerator
99
fromaccelerate.loggingimportMultiProcessAdapter,get_logger
1010
fromaccelerate.utilsimportProjectConfiguration,set_seed
1111
fromdiffusersimportAutoencoderKL,DDPMScheduler,UNet2DConditionModel
12+
frompackagingimportversion
1213
fromtransformersimportCLIPTextModel,CLIPTokenizer
1314

1415
frominvokeai.app.services.configimportInvokeAIAppConfig
@@ -199,3 +200,9 @@ def run_lora_training(
199200
tokenizer,noise_scheduler,text_encoder,vae,unet=_load_models(
200201
accelerator,app_config,train_config,logger
201202
)
203+
204+
iftrain_config.xformers:
205+
importxformers
206+
207+
unet.enable_xformers_memory_efficient_attention()
208+
vae.enable_xformers_memory_efficient_attention()

‎invokeai/backend/training/lora/lora_training_config.py‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,6 @@ class LoraTrainingConfig(BaseModel):
3232
report_to:typing.Optional[
3333
typing.Literal["all","tensorboard","wandb","comet_ml"]
3434
]="tensorboard"
35+
36+
# If true, use xformers for more efficient attention blocks.
37+
xformers:bool

‎invokeai/frontend/training/lora.py‎

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ def parse_args():
2828
"--model",
2929
type=str,
3030
required=True,
31-
help="Name of the diffusers model to train against, as defined in "
32-
"'configs/models.yaml' (e.g. 'sd-1/main/stable-diffusion-v1-5').",
31+
help=(
32+
"Name of the diffusers model to train against, as defined in "
33+
"'configs/models.yaml' (e.g. 'sd-1/main/stable-diffusion-v1-5')."
34+
),
3335
)
3436

3537
# Training Group
@@ -38,27 +40,42 @@ def parse_args():
3840
"--gradient_accumulation_steps",
3941
type=int,
4042
default=1,
41-
help="The number of gradient steps to accumulate before each weight "
42-
"update. This value is passed to Hugging Face Accelerate. This is an "
43-
"alternative to increasing the batch size when training with limited "
44-
"VRAM.",
43+
help=(
44+
"The number of gradient steps to accumulate before each weight"
45+
" update. This value is passed to Hugging Face Accelerate. This is"
46+
" an alternative to increasing the batch size when training with"
47+
" limited VRAM."
48+
),
4549
)
4650
training_group.add_argument(
4751
"--mixed_precision",
4852
type=str,
4953
default=None,
50-
help="The mixed precision mode to use ('no','fp16','bf16 or 'fp8'). "
51-
"This value is passed to Hugging Face Accelerate. See "
52-
"accelerate.Accelerator for more details.",
54+
help=(
55+
"The mixed precision mode to use ('no','fp16','bf16 or 'fp8'). "
56+
"This value is passed to Hugging Face Accelerate. See "
57+
"accelerate.Accelerator for more details."
58+
),
5359
)
5460
training_group.add_argument(
5561
"--report_to",
5662
type=str,
5763
default="tensorboard",
58-
help="The integration to report results and logs to ('all', "
59-
"'tensorboard', 'wandb', or 'comet_ml'). This value is passed to "
60-
"Hugging Face Accelerate. See accelerate.Accelerator.log_with for more "
61-
"details.",
64+
help=(
65+
"The integration to report results and logs to ('all',"
66+
" 'tensorboard', 'wandb', or 'comet_ml'). This value is passed to"
67+
" Hugging Face Accelerate. See accelerate.Accelerator.log_with for"
68+
" more details."
69+
),
70+
)
71+
training_group.add_argument(
72+
"--xformers",
73+
action="store_true",
74+
default=False,
75+
help=(
76+
"If set, xformers will be used for faster and more memory-efficient"
77+
" attention blocks."
78+
),
6279
)
6380

6481
returnparser.parse_args()

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp