Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit40843bc

Browse files
committed
pydantic 2
1 parent00e07b7 commit40843bc

File tree

4 files changed

+19
-14
lines changed

4 files changed

+19
-14
lines changed

‎dalle2_pytorch/train_configs.py‎

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
importjson
22
fromtorchvisionimporttransformsasT
3-
frompydanticimportBaseModel,validator,root_validator
3+
frompydanticimportBaseModel,validator,model_validator
44
fromtypingimportList,Optional,Union,Tuple,Dict,Any,TypeVar
55

66
fromx_clipimportCLIPasXCLIP
@@ -38,12 +38,12 @@ class TrainSplitConfig(BaseModel):
3838
val:float=0.15
3939
test:float=0.1
4040

41-
@root_validator
42-
defvalidate_all(cls,fields):
43-
actual_sum=sum([*fields.values()])
41+
@model_validator(mode='after')
42+
defvalidate_all(self,m):
43+
actual_sum=sum([*dict(self).values()])
4444
ifactual_sum!=1.:
45-
raiseValueError(f'{fields.keys()} must sum to 1.0. Found:{actual_sum}')
46-
returnfields
45+
raiseValueError(f'{dict(self).keys()} must sum to 1.0. Found:{actual_sum}')
46+
returnself
4747

4848
classTrackerLogConfig(BaseModel):
4949
log_type:str='console'
@@ -59,6 +59,7 @@ def create(self, data_path: str):
5959
kwargs=self.dict()
6060
returncreate_logger(self.log_type,data_path,**kwargs)
6161

62+
6263
classTrackerLoadConfig(BaseModel):
6364
load_from:Optional[str]=None
6465
only_auto_resume:bool=False# Only attempt to load if the logger is auto-resuming
@@ -277,9 +278,9 @@ class Config:
277278
extra="allow"
278279

279280
classDecoderDataConfig(BaseModel):
280-
webdataset_base_url:str# path to a webdataset with jpg images
281-
img_embeddings_url:Optional[str]# path to .npy files with embeddings
282-
text_embeddings_url:Optional[str]# path to .npy files with embeddings
281+
webdataset_base_url:str# path to a webdataset with jpg images
282+
img_embeddings_url:Optional[str]=None# path to .npy files with embeddings
283+
text_embeddings_url:Optional[str]=None# path to .npy files with embeddings
283284
num_workers:int=4
284285
batch_size:int=64
285286
start_shard:int=0
@@ -346,11 +347,14 @@ class TrainDecoderConfig(BaseModel):
346347
deffrom_json_path(cls,json_path):
347348
withopen(json_path)asf:
348349
config=json.load(f)
350+
print(config)
349351
returncls(**config)
350352

351-
@root_validator
352-
defcheck_has_embeddings(cls,values):
353+
@model_validator(mode='after')
354+
defcheck_has_embeddings(self,m):
353355
# Makes sure that enough information is provided to get the embeddings specified for training
356+
values=dict(self)
357+
354358
data_config,decoder_config=values.get('data'),values.get('decoder')
355359

356360
ifnotexists(data_config)ornotexists(decoder_config):
@@ -375,4 +379,4 @@ def check_has_embeddings(cls, values):
375379
iftext_emb_url:
376380
assertusing_text_embeddings,"Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
377381

378-
returnvalues
382+
returnm

‎dalle2_pytorch/version.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__='1.14.2'
1+
__version__='1.15.1'

‎setup.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
'numpy',
3737
'packaging',
3838
'pillow',
39-
'pydantic',
39+
'pydantic>=2',
4040
'pytorch-warmup',
4141
'resize-right>=0.0.2',
4242
'rotary-embedding-torch',

‎train_decoder.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,7 @@ def initialize_training(config: TrainDecoderConfig, config_path):
577577
shards_per_process=len(all_shards)//world_size
578578
assertshards_per_process>0,"Not enough shards to split evenly"
579579
my_shards=all_shards[rank*shards_per_process: (rank+1)*shards_per_process]
580+
580581
dataloaders=create_dataloaders (
581582
available_shards=my_shards,
582583
img_preproc=config.data.img_preproc,

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp