Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit370bade

Browse files
committed
Add dataloader to LoRA training script, and load configs from YAML rather than command-line params.
1 parent8941f2e commit370bade

File tree

4 files changed

+218
-102
lines changed

4 files changed

+218
-102
lines changed

‎invokeai/backend/training/lora/lora_training.py‎

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
importjson
22
importlogging
3+
importos
4+
importrandom
35

46
importdatasets
57
importdiffusers
8+
importnumpyasnp
69
importtorch
710
importtransformers
811
fromaccelerateimportAccelerator
912
fromaccelerate.loggingimportMultiProcessAdapter,get_logger
1013
fromaccelerate.utilsimportProjectConfiguration,set_seed
1114
fromdiffusersimportAutoencoderKL,DDPMScheduler,UNet2DConditionModel
12-
frompackagingimportversion
15+
fromtorchvisionimporttransforms
1316
fromtransformersimportCLIPTextModel,CLIPTokenizer
1417

1518
importinvokeai.backend.training.lora.networks.loraaskohya_lora_module
@@ -217,6 +220,134 @@ def _initialize_optimizer(
217220
)
218221

219222

223+
def_initialize_dataset(
224+
train_config:LoraTrainingConfig,
225+
accelerator:Accelerator,
226+
tokenizer:CLIPTokenizer,
227+
)->torch.utils.data.DataLoader:
228+
# In distributed training, the load_dataset function guarantees that only
229+
# one local process will download the dataset.
230+
iftrain_config.dataset_nameisnotNone:
231+
# Download the dataset from the Hugging Face hub.
232+
dataset=datasets.load_dataset(
233+
train_config.dataset_name,
234+
train_config.dataset_config_name,
235+
cache_dir=train_config.hf_cache_dir,
236+
)
237+
eliftrain_config.dataset_dirisnotNone:
238+
data_files= {}
239+
data_files["train"]=os.path.join(train_config.dataset_dir,"**")
240+
# See more about loading custom images at
241+
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
242+
dataset=datasets.load_dataset(
243+
"imagefolder",
244+
data_files=data_files,
245+
cache_dir=train_config.hf_cache_dir,
246+
)
247+
else:
248+
raiseValueError(
249+
"At least one of 'dataset_name' or 'dataset_dir' must be set."
250+
)
251+
252+
# Preprocessing the datasets.
253+
# We need to tokenize inputs and targets.
254+
column_names=dataset["train"].column_names
255+
256+
# Get the column names for input/target.
257+
iftrain_config.dataset_image_columnnotincolumn_names:
258+
raiseValueError(
259+
f"The dataset_image_column='{train_config.dataset_image_column}' is"
260+
f" not in the set of dataset column names: '{column_names}'."
261+
)
262+
iftrain_config.dataset_caption_columnnotincolumn_names:
263+
raiseValueError(
264+
f"The dataset_caption_column='{train_config.dataset_caption_column}'"
265+
f" is not in the set of dataset column names: '{column_names}'."
266+
)
267+
268+
# Preprocessing the datasets.
269+
# We need to tokenize input captions and transform the images.
270+
deftokenize_captions(examples,is_train=True):
271+
captions= []
272+
forcaptioninexamples[train_config.dataset_caption_column]:
273+
ifisinstance(caption,str):
274+
captions.append(caption)
275+
elifisinstance(caption, (list,np.ndarray)):
276+
# take a random caption if there are multiple
277+
captions.append(
278+
random.choice(caption)ifis_trainelsecaption[0]
279+
)
280+
else:
281+
raiseValueError(
282+
f"Caption column `{train_config.dataset_caption_column}`"
283+
" should contain either strings or lists of strings."
284+
)
285+
inputs=tokenizer(
286+
captions,
287+
max_length=tokenizer.model_max_length,
288+
padding="max_length",
289+
truncation=True,
290+
return_tensors="pt",
291+
)
292+
returninputs.input_ids
293+
294+
# Preprocessing the datasets.
295+
train_transforms=transforms.Compose(
296+
[
297+
transforms.Resize(
298+
train_config.resolution,
299+
interpolation=transforms.InterpolationMode.BILINEAR,
300+
),
301+
(
302+
transforms.CenterCrop(train_config.resolution)
303+
iftrain_config.center_crop
304+
elsetransforms.RandomCrop(train_config.resolution)
305+
),
306+
(
307+
transforms.RandomHorizontalFlip()
308+
iftrain_config.random_flip
309+
elsetransforms.Lambda(lambdax:x)
310+
),
311+
transforms.ToTensor(),
312+
transforms.Normalize([0.5], [0.5]),
313+
]
314+
)
315+
316+
defpreprocess_train(examples):
317+
images= [
318+
image.convert("RGB")
319+
forimageinexamples[train_config.dataset_image_column]
320+
]
321+
examples["pixel_values"]= [train_transforms(image)forimageinimages]
322+
examples["input_ids"]=tokenize_captions(examples)
323+
returnexamples
324+
325+
withaccelerator.main_process_first():
326+
# Set the training transforms
327+
train_dataset=dataset["train"].with_transform(preprocess_train)
328+
329+
defcollate_fn(examples):
330+
pixel_values=torch.stack(
331+
[example["pixel_values"]forexampleinexamples]
332+
)
333+
pixel_values=pixel_values.to(
334+
memory_format=torch.contiguous_format
335+
).float()
336+
input_ids=torch.stack([example["input_ids"]forexampleinexamples])
337+
return {"pixel_values":pixel_values,"input_ids":input_ids}
338+
339+
# DataLoaders creation:
340+
train_dataloader=torch.utils.data.DataLoader(
341+
train_dataset,
342+
shuffle=True,
343+
collate_fn=collate_fn,
344+
batch_size=train_config.train_batch_size,
345+
num_workers=train_config.dataloader_num_workers,
346+
)
347+
348+
returntrain_dataloader
349+
350+
220351
defrun_lora_training(
221352
app_config:InvokeAIAppConfig,train_config:LoraTrainingConfig
222353
):
@@ -268,3 +399,10 @@ def run_lora_training(
268399
)
269400

270401
optimizer=_initialize_optimizer(train_config,trainable_params)
402+
403+
data_loader=_initialize_dataset(train_config,accelerator,tokenizer)
404+
405+
x=train_features,train_labels=next(iter(data_loader))
406+
logger.info(x.keys())
407+
logger.info(x["pixel_values"].shape)
408+
logger.info(x["input_ids"].shape)

‎invokeai/backend/training/lora/lora_training_config.py‎

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ class LoraTrainingConfig(BaseModel):
4040
# of a slower backward pass.
4141
gradient_checkpointing:bool=False
4242

43+
#####################
44+
# Optimizer Configs
45+
#####################
46+
4347
# Initial learning rate (after the potential warmup period) to use.
4448
learning_rate:float=1e-4
4549

@@ -48,3 +52,54 @@ class LoraTrainingConfig(BaseModel):
4852
adam_beta2:float=0.999
4953
adam_weight_decay:float=1e-2
5054
adam_epsilon:float=1e-8
55+
56+
##################
57+
# Dataset Configs
58+
##################
59+
60+
# The name of a Hugging Face dataset.
61+
# One of dataset_name and dataset_dir should be set (dataset_name takes
62+
# precedence).
63+
# See also: dataset_config_name.
64+
dataset_name:typing.Optional[str]=None
65+
66+
# The directory to load a dataset from. The dataset is expected to be in
67+
# Hugging Face imagefolder format
68+
# (https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder).
69+
# One of dataset_name and dataset_dir should be set (dataset_name takes
70+
# precedence).
71+
dataset_dir:typing.Optional[str]=None
72+
73+
# The Hugging Face dataset config name. Leave as None if there's only one
74+
# config.
75+
# This parameter is only used if dataset_name is set.
76+
dataset_config_name:typing.Optional[str]=None
77+
78+
# The Hugging Face cache directory to use for dataset downloads.
79+
# If None, the default value will be used (usually
80+
# '~/.cache/huggingface/datasets').
81+
hf_cache_dir:typing.Optional[str]=None
82+
83+
# The name of the dataset column that contains image paths.
84+
dataset_image_column:str="image"
85+
86+
# The name of the dataset column that contains captions.
87+
dataset_caption_column:str="text"
88+
89+
# The resolution for input images. All of the images in the dataset will be
90+
# resized to this (square) resolution.
91+
resolution:int=512
92+
93+
# If True, input images will be center-cropped to resolution.
94+
# If False, input images will be randomly cropped to resolution.
95+
center_crop:bool=False
96+
97+
# Whether random flip augmentations should be applied to input images.
98+
random_flip:bool=False
99+
100+
# The training batch size.
101+
train_batch_size:int=4
102+
103+
# Number of subprocesses to use for data loading. 0 means that the data will
104+
# be loaded in the main process.
105+
dataloader_num_workers:int=0

‎invokeai/frontend/training/lora.py‎

Lines changed: 23 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
frompathlibimportPath
22

3+
importyaml
4+
35
frominvokeai.app.services.configimportInvokeAIAppConfig,PagingArgumentParser
46
frominvokeai.backend.training.lora.lora_trainingimportrun_lora_training
57
frominvokeai.backend.training.lora.lora_training_configimport (
@@ -12,114 +14,27 @@ def parse_args():
1214

1315
parser=PagingArgumentParser(description="LoRA model training.")
1416

15-
# General configs
16-
general_group=parser.add_argument_group("General")
17-
general_group.add_argument(
18-
"--output_dir",
17+
parser.add_argument(
18+
"--cfg_file",
1919
type=Path,
20-
# TODO(ryand): Decide on a training directory structure and update for
21-
# consistency with TI training.
22-
default=config.root/"training/lora/output",
23-
)
24-
25-
# Base model configs
26-
model_group=parser.add_argument_group("Model")
27-
model_group.add_argument(
28-
"--model",
29-
type=str,
3020
required=True,
3121
help=(
32-
"Name of the diffusers model to train against, as defined in "
33-
"'configs/models.yaml' (e.g. 'sd-1/main/stable-diffusion-v1-5')."
34-
),
35-
)
36-
37-
# Training Group
38-
training_group=parser.add_argument_group("Training")
39-
training_group.add_argument(
40-
"--gradient_accumulation_steps",
41-
type=int,
42-
default=1,
43-
help=(
44-
"The number of gradient steps to accumulate before each weight"
45-
" update. This value is passed to Hugging Face Accelerate. This is"
46-
" an alternative to increasing the batch size when training with"
47-
" limited VRAM."
48-
),
49-
)
50-
training_group.add_argument(
51-
"--mixed_precision",
52-
type=str,
53-
default=None,
54-
help=(
55-
"The mixed precision mode to use ('no','fp16','bf16 or 'fp8'). "
56-
"This value is passed to Hugging Face Accelerate. See "
57-
"accelerate.Accelerator for more details."
58-
),
59-
)
60-
training_group.add_argument(
61-
"--report_to",
62-
type=str,
63-
default="tensorboard",
64-
help=(
65-
"The integration to report results and logs to ('all',"
66-
" 'tensorboard', 'wandb', or 'comet_ml'). This value is passed to"
67-
" Hugging Face Accelerate. See accelerate.Accelerator.log_with for"
68-
" more details."
69-
),
70-
)
71-
training_group.add_argument(
72-
"--xformers",
73-
action="store_true",
74-
default=False,
75-
help=(
76-
"If set, xformers will be used for faster and more memory-efficient"
77-
" attention blocks."
22+
"Path to the YAML training config file. See `LoraTrainingConfig`"
23+
" for the supported fields."
7824
),
7925
)
80-
training_group.add_argument(
81-
"--gradient_checkpointing",
82-
action="store_true",
83-
help=(
84-
"Whether or not to use gradient checkpointing to save memory at the"
85-
" expense of slower backward pass."
86-
),
87-
)
88-
89-
# Optimizer Group
90-
optimizer_group=parser.add_argument_group("Optimizer")
91-
optimizer_group.add_argument(
92-
"--learning_rate",
93-
type=float,
94-
default=1e-4,
26+
parser.add_argument(
27+
"--output_dir",
28+
type=Path,
29+
# TODO(ryand): Decide on a training directory structure and update for
30+
# consistency with TI training.
31+
default=config.root/"training/lora/output",
9532
help=(
96-
"Initial learning rate (after the potential warmup period) to use."
33+
"The output directory where the training outputs (model"
34+
" checkpoints, logs, intermediate predictions) will be written."
35+
" Defaults to `$INVOKEAI_HOME/training/lora/output`."
9736
),
9837
)
99-
parser.add_argument(
100-
"--adam_beta1",
101-
type=float,
102-
default=0.9,
103-
help="The beta1 parameter for the Adam optimizer.",
104-
)
105-
parser.add_argument(
106-
"--adam_beta2",
107-
type=float,
108-
default=0.999,
109-
help="The beta2 parameter for the Adam optimizer.",
110-
)
111-
parser.add_argument(
112-
"--adam_weight_decay",
113-
type=float,
114-
default=1e-2,
115-
help="Weight decay parameter for the Adam optimizer.",
116-
)
117-
parser.add_argument(
118-
"--adam_epsilon",
119-
type=float,
120-
default=1e-08,
121-
help="Epsilon value for the Adam optimizer",
122-
)
12338

12439
returnparser.parse_args()
12540

@@ -128,7 +43,14 @@ def main():
12843
app_config=InvokeAIAppConfig.get_config()
12944
args=parse_args()
13045

131-
train_config=LoraTrainingConfig(**vars(args))
46+
# Load YAML config file.
47+
withopen(args.cfg_file,"r")asf:
48+
cfg=yaml.safe_load(f)
49+
50+
# Override 'output_dir' config.
51+
cfg["output_dir"]=args.output_dir
52+
53+
train_config=LoraTrainingConfig(**cfg)
13254
run_lora_training(app_config,train_config)
13355

13456

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
dataset_name:lambdalabs/pokemon-blip-captions

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp