Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
/sparsezooPublic archive

Commit6ae1b94

Browse files
authored
[BugFix] Fix Serialization of Computed Properties in BaseModel (#485)
* Add failing test* Fix computed field serialization
1 parentf349af4 commit6ae1b94

File tree

3 files changed

+44
-59
lines changed

3 files changed

+44
-59
lines changed

‎src/sparsezoo/analyze_v1/analysis.py‎

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ class YAMLSerializableBaseModel(BaseModel):
103103
A BaseModel that adds a .yaml(...) function to all child classes
104104
"""
105105

106+
model_config=ConfigDict(protected_namespaces=())
107+
108+
defdict(self,*args,**kwargs)->Dict[str,Any]:
109+
# alias for model_dump for pydantic v2 upgrade
110+
# to allow for easier migration
111+
returnself.model_dump(*args,**kwargs)
112+
106113
defyaml(self,file_path:Optional[str]=None)->Union[str,None]:
107114
"""
108115
:param file_path: optional file path to save yaml to
@@ -111,7 +118,7 @@ def yaml(self, file_path: Optional[str] = None) -> Union[str, None]:
111118
"""
112119
file_stream=Noneiffile_pathisNoneelseopen(file_path,"w")
113120
ret=yaml.dump(
114-
self.dict(),stream=file_stream,allow_unicode=True,sort_keys=False
121+
self.model_dump(),stream=file_stream,allow_unicode=True,sort_keys=False
115122
)
116123

117124
iffile_streamisnotNone:
@@ -127,7 +134,7 @@ def parse_yaml_file(cls, file_path: str):
127134
"""
128135
withopen(file_path,"r")asfile:
129136
dict_obj=yaml.safe_load(file)
130-
returncls.parse_obj(dict_obj)
137+
returncls.model_validate(dict_obj)
131138

132139
@classmethod
133140
defparse_yaml_raw(cls,yaml_raw:str):
@@ -136,7 +143,7 @@ def parse_yaml_raw(cls, yaml_raw: str):
136143
:return: instance of ModelAnalysis class
137144
"""
138145
dict_obj=yaml.safe_load(yaml_raw)# unsafe: needs to load numpy
139-
returncls.parse_obj(dict_obj)
146+
returncls.model_validate(dict_obj)
140147

141148

142149
@dataclass

‎src/sparsezoo/analyze_v1/utils/models.py‎

Lines changed: 7 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
importtextwrap
1616
fromtypingimportClassVar,Dict,List,Optional,Tuple,Union
1717

18-
frompydanticimportBaseModel,Field
18+
frompydanticimportBaseModel,Field,computed_field
1919

2020

2121
__all__= [
@@ -33,58 +33,6 @@
3333
PrintOrderType=ClassVar[List[str]]
3434

3535

36-
classPropertyBaseModel(BaseModel):
37-
"""
38-
https://github.com/samuelcolvin/pydantic/issues/935#issuecomment-1152457432
39-
40-
Workaround for serializing properties with pydantic until
41-
https://github.com/samuelcolvin/pydantic/issues/935
42-
is solved
43-
"""
44-
45-
@classmethod
46-
defget_properties(cls):
47-
return [
48-
prop
49-
forpropindir(cls)
50-
ifisinstance(getattr(cls,prop),property)
51-
andpropnotin ("__values__","fields")
52-
]
53-
54-
defdict(
55-
self,
56-
*,
57-
include:Union["AbstractSetIntStr","MappingIntStrAny"]=None,# noqa: F821
58-
exclude:Union["AbstractSetIntStr","MappingIntStrAny"]=None,# noqa: F821
59-
by_alias:bool=False,
60-
skip_defaults:bool=None,
61-
exclude_unset:bool=False,
62-
exclude_defaults:bool=False,
63-
exclude_none:bool=False,
64-
)->"DictStrAny":# noqa: F821
65-
attribs=super().dict(
66-
include=include,
67-
exclude=exclude,
68-
by_alias=by_alias,
69-
skip_defaults=skip_defaults,
70-
exclude_unset=exclude_unset,
71-
exclude_defaults=exclude_defaults,
72-
exclude_none=exclude_none,
73-
)
74-
props=self.get_properties()
75-
# Include and exclude properties
76-
ifinclude:
77-
props= [propforpropinpropsifpropininclude]
78-
ifexclude:
79-
props= [propforpropinpropsifpropnotinexclude]
80-
81-
# Update the attribute dict with the properties
82-
ifprops:
83-
attribs.update({prop:getattr(self,prop)forpropinprops})
84-
85-
returnattribs
86-
87-
8836
classNodeCounts(BaseModel):
8937
"""
9038
Pydantic model for specifying the number zero and non-zero operations and the
@@ -114,7 +62,7 @@ class NodeIO(BaseModel):
11462
)
11563

11664

117-
classZeroNonZeroParams(PropertyBaseModel):
65+
classZeroNonZeroParams(BaseModel):
11866
"""
11967
Pydantic model for specifying the number zero and non-zero operations and the
12068
associated sparsity
@@ -127,20 +75,22 @@ class ZeroNonZeroParams(PropertyBaseModel):
12775
description="The number of parameters whose value is zero",default=0
12876
)
12977

78+
@computed_field(repr=True,return_type=Union[int,float])
13079
@property
13180
defsparsity(self):
13281
total_values=self.total
13382
iftotal_values>0:
13483
returnself.zero/total_values
13584
else:
136-
return0
85+
return0.0
13786

87+
@computed_field(repr=True,return_type=int)
13888
@property
13989
deftotal(self):
14090
returnself.non_zero+self.zero
14191

14292

143-
classDenseSparseOps(PropertyBaseModel):
93+
classDenseSparseOps(BaseModel):
14494
"""
14595
Pydantic model for specifying the number dense and sparse operations and the
14696
associated operation sparsity
@@ -155,6 +105,7 @@ class DenseSparseOps(PropertyBaseModel):
155105
default=0,
156106
)
157107

108+
@computed_field(repr=True,return_type=Union[int,float])
158109
@property
159110
defsparsity(self):
160111
total_ops=self.sparse+self.dense
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
importpytest
16+
17+
fromsparsezoo.analyze_v1.utils.modelsimportDenseSparseOps,ZeroNonZeroParams
18+
19+
20+
@pytest.mark.parametrize("model", [DenseSparseOps,ZeroNonZeroParams])
21+
@pytest.mark.parametrize("computed_fields", [["sparsity"]])
22+
deftest_model_dump_has_computed_fields(model,computed_fields):
23+
model=model()
24+
model_dict=model.model_dump()
25+
forcomputed_fieldincomputed_fields:
26+
assertcomputed_fieldinmodel_dict
27+
assertmodel_dict[computed_field]==getattr(model,computed_field)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp