Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit9f37705

Browse files
authored
Add static graph param (lucidrains#226)
* Add static graph param* use static graph param
1 parentc3df46e commit9f37705

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

‎dalle2_pytorch/train_configs.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ class DecoderTrainConfig(BaseModel):
307307
wd:SingularOrIterable[float]=0.01
308308
warmup_steps:Optional[SingularOrIterable[int]]=None
309309
find_unused_parameters:bool=True
310+
static_graph:bool=True
310311
max_grad_norm:SingularOrIterable[float]=0.5
311312
save_every_n_samples:int=100000
312313
n_sample_images:int=6# The number of example images to produce when sampling the train and test dataset

‎train_decoder.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
556556
torch.manual_seed(config.seed)
557557

558558
# Set up accelerator for configurable distributed training
559-
ddp_kwargs=DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters)
559+
ddp_kwargs=DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters,static_graph=config.train.static_graph)
560560
init_kwargs=InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
561561
accelerator=Accelerator(kwargs_handlers=[ddp_kwargs,init_kwargs])
562562

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp