Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
/sparsezooPublic archive

Commitc1a096f

Browse files
dsikkaDipika Sikkadbogunowicz
authored
[sparsezoo.analyze] Fix pathway such that it works for larger models (#437)
* fix analyze to work with larger models* update for failing tests; add comments* Update src/sparsezoo/utils/onnx/external_data.pyCo-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>---------Co-authored-by: Dipika Sikka <dipikasikka1@gmail.coom>Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
1 parente6b12f6 commitc1a096f

File tree

3 files changed

+41
-15
lines changed

3 files changed

+41
-15
lines changed

‎src/sparsezoo/analyze_v2/model_analysis.py‎

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,20 +146,25 @@ def analyze(path: str, download_path: Optional[str] = None) -> "ModelAnalysis":
146146
:param path: .onnx path or stub
147147
"""
148148
ifpath.endswith(".onnx"):
149-
onnx_model=load_model(path)
149+
onnx_model=load_model(path,load_external_data=False)
150+
onnx_model_path=path
150151
elifis_stub(path):
151152
model=Model(path,download_path)
152153
onnx_model_path=model.onnx_model.path
153-
onnx_model=onnx.load(onnx_model_path)
154+
onnx_model=onnx.load(onnx_model_path,load_external_data=False)
154155
else:
155156
raiseValueError(f"{path} is not a valid argument")
156157

157-
model_graph=ONNXGraph(onnx_model)
158-
node_shapes,_=extract_node_shapes_and_dtypes(model_graph.model)
158+
# just need graph to get shape information; dont load external data
159+
node_shapes,_=extract_node_shapes_and_dtypes(onnx_model,onnx_model_path)
159160

160161
summary_analysis=SummaryAnalysis()
161162
node_analyses= {}
162163

164+
# load external data for node analysis
165+
onnx_model=onnx.load(onnx_model_path)
166+
model_graph=ONNXGraph(onnx_model)
167+
163168
forgraph_order,nodeinenumerate(model_graph.nodes):
164169
node_id=extract_node_id(node)
165170
node_shape=node_shapes.get(node_id)

‎src/sparsezoo/utils/node_inference.py‎

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
importlogging
2020
fromcopyimportdeepcopy
21-
fromtypingimportAny,Dict,List,NamedTuple,Tuple,Union
21+
frompathlibimportPath
22+
fromtypingimportAny,Dict,List,NamedTuple,Optional,Tuple,Union
2223

2324
importnumpy
2425
importonnx
@@ -60,13 +61,14 @@
6061

6162

6263
defextract_nodes_shapes_and_dtypes_ort(
63-
model:ModelProto,
64+
model:ModelProto,path:Optional[str]=None
6465
)->Tuple[Dict[str,List[List[int]]],Dict[str,numpy.dtype]]:
6566
"""
6667
Creates a modified model to expose intermediate outputs and runs an ONNX Runtime
6768
InferenceSession to obtain the output shape of each node.
6869
6970
:param model: an ONNX model
71+
:param path: absolute path to the original onnx model
7072
:return: a list of NodeArg with their shape exposed
7173
"""
7274
importonnxruntime
@@ -79,11 +81,24 @@ def extract_nodes_shapes_and_dtypes_ort(
7981
)
8082
model_copy.graph.output.append(intermediate_layer_value_info)
8183

84+
# using the ModelProto does not work for large models when running the session
85+
# have to save again and pass the new path to the inference session
8286
sess_options=onnxruntime.SessionOptions()
8387
sess_options.log_severity_level=3
84-
sess=onnxruntime.InferenceSession(
85-
model_copy.SerializeToString(),sess_options,providers=["CPUExecutionProvider"]
86-
)
88+
89+
ifpath:
90+
parent_dir=Path(path).parent.absolute()
91+
new_path=parent_dir/"model_new.onnx"
92+
onnx.save(model_copy,new_path,save_as_external_data=True)
93+
sess=onnxruntime.InferenceSession(
94+
new_path,sess_options,providers=onnxruntime.get_available_providers()
95+
)
96+
else:
97+
sess=onnxruntime.InferenceSession(
98+
model_copy.SerializeToString(),
99+
sess_options,
100+
providers=onnxruntime.get_available_providers(),
101+
)
87102

88103
input_value_dict= {}
89104
forinputinmodel_copy.graph.input:
@@ -166,19 +181,20 @@ def extract_nodes_shapes_and_dtypes_shape_inference(
166181

167182

168183
defextract_nodes_shapes_and_dtypes(
169-
model:ModelProto,
184+
model:ModelProto,path:Optional[str]=None
170185
)->Tuple[Dict[str,List[List[int]]],Dict[str,numpy.dtype]]:
171186
"""
172187
Uses ONNX Runtime or shape inference to infer output shapes and dtypes from model
173188
174189
:param model: model to extract output values from
190+
:param path: absolute path to the original onnx model
175191
:return: output shapes and output data types
176192
"""
177193
output_shapes=None
178194
output_dtypes=None
179195

180196
try:
181-
output_shapes,output_dtypes=extract_nodes_shapes_and_dtypes_ort(model)
197+
output_shapes,output_dtypes=extract_nodes_shapes_and_dtypes_ort(model,path)
182198
exceptExceptionaserr:
183199
_LOGGER.warning(f"Extracting shapes using ONNX Runtime session failed:{err}")
184200

@@ -306,18 +322,19 @@ def collate_output_dtypes(
306322

307323

308324
defextract_node_shapes_and_dtypes(
309-
model:ModelProto,
325+
model:ModelProto,path:Optional[str]=None
310326
)->Tuple[Dict[str,NodeShape],Dict[str,NodeDataType]]:
311327
"""
312328
Extracts the shape and dtype information for each node as NodeShape objects
313329
and numpy dtypes.
314330
315331
:param model: the loaded onnx.ModelProto to extract node shape information from
332+
:param path: absolute path to the original onnx model
316333
:return: a mapping of node id to a NodeShape object
317334
"""
318335

319336
# Obtains output shapes for each model's node
320-
output_shapes,output_dtypes=extract_nodes_shapes_and_dtypes(model)
337+
output_shapes,output_dtypes=extract_nodes_shapes_and_dtypes(model,path)
321338

322339
# Package output shapes into each node's inputs and outputs
323340
node_shapes=collate_output_shapes(model,output_shapes)

‎src/sparsezoo/utils/onnx/external_data.py‎

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,19 +174,23 @@ def validate_onnx(model: Union[str, ModelProto]):
174174
raiseValueError(f"Invalid onnx model:{err}")
175175

176176

177-
defload_model(model:Union[str,ModelProto,Path])->ModelProto:
177+
defload_model(
178+
model:Union[str,ModelProto,Path],load_external_data:bool=True
179+
)->ModelProto:
178180
"""
179181
Load an ONNX model from an onnx model file path. If a ModelProto
180182
is given, then it is returned.
181183
182184
:param model: the model proto or path to the model ONNX file to check for loading
185+
:param load_external_data: if a path is given, whether or not to also load the
186+
external model data
183187
:return: the loaded ONNX ModelProto
184188
"""
185189
ifisinstance(model,ModelProto):
186190
returnmodel
187191

188192
ifisinstance(model, (Path,str)):
189-
returnonnx.load(clean_path(model))
193+
returnonnx.load(clean_path(model),load_external_data=load_external_data)
190194

191195
raiseTypeError(f"unknown type given for model:{type(model)}")
192196

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp