11import json
22from torchvision import transforms as T
3- from pydantic import BaseModel ,validator ,root_validator
3+ from pydantic import BaseModel ,validator ,model_validator
44from typing import List ,Optional ,Union ,Tuple ,Dict ,Any ,TypeVar
55
66from x_clip import CLIP as XCLIP
@@ -38,12 +38,12 @@ class TrainSplitConfig(BaseModel):
3838val :float = 0.15
3939test :float = 0.1
4040
41- @root_validator
42- def validate_all (cls , fields ):
43- actual_sum = sum ([* fields .values ()])
41+ @model_validator ( mode = 'after' )
42+ def validate_all (self , m ):
43+ actual_sum = sum ([* dict ( self ) .values ()])
4444if actual_sum != 1. :
45- raise ValueError (f'{ fields .keys ()} must sum to 1.0. Found:{ actual_sum } ' )
46- return fields
45+ raise ValueError (f'{ dict ( self ) .keys ()} must sum to 1.0. Found:{ actual_sum } ' )
46+ return self
4747
4848class TrackerLogConfig (BaseModel ):
4949log_type :str = 'console'
@@ -59,6 +59,7 @@ def create(self, data_path: str):
5959kwargs = self .dict ()
6060return create_logger (self .log_type ,data_path ,** kwargs )
6161
62+
6263class TrackerLoadConfig (BaseModel ):
6364load_from :Optional [str ]= None
6465only_auto_resume :bool = False # Only attempt to load if the logger is auto-resuming
@@ -277,9 +278,9 @@ class Config:
277278extra = "allow"
278279
279280class DecoderDataConfig (BaseModel ):
280- webdataset_base_url :str # path to a webdataset with jpg images
281- img_embeddings_url :Optional [str ]# path to .npy files with embeddings
282- text_embeddings_url :Optional [str ]# path to .npy files with embeddings
281+ webdataset_base_url :str # path to a webdataset with jpg images
282+ img_embeddings_url :Optional [str ]= None # path to .npy files with embeddings
283+ text_embeddings_url :Optional [str ]= None # path to .npy files with embeddings
283284num_workers :int = 4
284285batch_size :int = 64
285286start_shard :int = 0
@@ -346,11 +347,14 @@ class TrainDecoderConfig(BaseModel):
346347def from_json_path (cls ,json_path ):
347348with open (json_path )as f :
348349config = json .load (f )
350+ print (config )
349351return cls (** config )
350352
351- @root_validator
352- def check_has_embeddings (cls , values ):
353+ @model_validator ( mode = 'after' )
354+ def check_has_embeddings (self , m ):
353355# Makes sure that enough information is provided to get the embeddings specified for training
356+ values = dict (self )
357+
354358data_config ,decoder_config = values .get ('data' ),values .get ('decoder' )
355359
356360if not exists (data_config )or not exists (decoder_config ):
@@ -375,4 +379,4 @@ def check_has_embeddings(cls, values):
375379if text_emb_url :
376380assert using_text_embeddings ,"Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
377381
378- return values
382+ return m