Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitc6c3882

Browse files
committed
fix all optional types in train config
1 parent512b52b commitc6c3882

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

‎dalle2_pytorch/train_configs.py‎

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool =
115115
classAdapterConfig(BaseModel):
116116
make:str="openai"
117117
model:str="ViT-L/14"
118-
base_model_kwargs:Dict[str,Any]=None
118+
base_model_kwargs:Optional[Dict[str,Any]]=None
119119

120120
defcreate(self):
121121
ifself.make=="openai":
@@ -134,8 +134,8 @@ def create(self):
134134
classDiffusionPriorNetworkConfig(BaseModel):
135135
dim:int
136136
depth:int
137-
max_text_len:int=None
138-
num_timesteps:int=None
137+
max_text_len:Optional[int]=None
138+
num_timesteps:Optional[int]=None
139139
num_time_embeds:int=1
140140
num_image_embeds:int=1
141141
num_text_embeds:int=1
@@ -158,7 +158,7 @@ def create(self):
158158
returnDiffusionPriorNetwork(**kwargs)
159159

160160
classDiffusionPriorConfig(BaseModel):
161-
clip:AdapterConfig=None
161+
clip:Optional[AdapterConfig]=None
162162
net:DiffusionPriorNetworkConfig
163163
image_embed_dim:int
164164
image_size:int
@@ -195,7 +195,7 @@ class DiffusionPriorTrainConfig(BaseModel):
195195
use_ema:bool=True
196196
ema_beta:float=0.99
197197
amp:bool=False
198-
warmup_steps:int=None# number of warmup steps
198+
warmup_steps:Optional[int]=None# number of warmup steps
199199
save_every_seconds:int=3600# how often to save
200200
eval_timesteps:List[int]= [64]# which sampling timesteps to evaluate with
201201
best_validation_loss:float=1e9# the current best valudation loss observed
@@ -228,10 +228,10 @@ def from_json_path(cls, json_path):
228228
classUnetConfig(BaseModel):
229229
dim:int
230230
dim_mults:ListOrTuple[int]
231-
image_embed_dim:int=None
232-
text_embed_dim:int=None
233-
cond_on_text_encodings:bool=None
234-
cond_dim:int=None
231+
image_embed_dim:Optional[int]=None
232+
text_embed_dim:Optional[int]=None
233+
cond_on_text_encodings:Optional[bool]=None
234+
cond_dim:Optional[int]=None
235235
channels:int=3
236236
self_attn:ListOrTuple[int]
237237
attn_dim_head:int=32
@@ -243,14 +243,14 @@ class Config:
243243

244244
classDecoderConfig(BaseModel):
245245
unets:ListOrTuple[UnetConfig]
246-
image_size:int=None
246+
image_size:Optional[int]=None
247247
image_sizes:ListOrTuple[int]=None
248248
clip:Optional[AdapterConfig]# The clip model to use if embeddings are not provided
249249
channels:int=3
250250
timesteps:int=1000
251251
sample_timesteps:Optional[SingularOrIterable[Optional[int]]]=None
252252
loss_type:str='l2'
253-
beta_schedule:ListOrTuple[str]=None# None means all cosine
253+
beta_schedule:Optional[ListOrTuple[str]]=None# None means all cosine
254254
learned_variance:SingularOrIterable[bool]=True
255255
image_cond_drop_prob:float=0.1
256256
text_cond_drop_prob:float=0.5
@@ -320,20 +320,20 @@ class DecoderTrainConfig(BaseModel):
320320
n_sample_images:int=6# The number of example images to produce when sampling the train and test dataset
321321
cond_scale:Union[float,List[float]]=1.0
322322
device:str='cuda:0'
323-
epoch_samples:int=None# Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
324-
validation_samples:int=None# Same as above but for validation.
323+
epoch_samples:Optional[int]=None# Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
324+
validation_samples:Optional[int]=None# Same as above but for validation.
325325
save_immediately:bool=False
326326
use_ema:bool=True
327327
ema_beta:float=0.999
328328
amp:bool=False
329-
unet_training_mask:ListOrTuple[bool]=None# If None, use all unets
329+
unet_training_mask:Optional[ListOrTuple[bool]]=None# If None, use all unets
330330

331331
classDecoderEvaluateConfig(BaseModel):
332332
n_evaluation_samples:int=1000
333-
FID:Dict[str,Any]=None
334-
IS:Dict[str,Any]=None
335-
KID:Dict[str,Any]=None
336-
LPIPS:Dict[str,Any]=None
333+
FID:Optional[Dict[str,Any]]=None
334+
IS:Optional[Dict[str,Any]]=None
335+
KID:Optional[Dict[str,Any]]=None
336+
LPIPS:Optional[Dict[str,Any]]=None
337337

338338
classTrainDecoderConfig(BaseModel):
339339
decoder:DecoderConfig

‎dalle2_pytorch/version.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__='1.15.2'
1+
__version__='1.15.3'

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp