55from onnx import ModelProto ,TensorProto ,load
66from onnx .defs import onnx_opset_version
77from onnx .reference import ReferenceEvaluator
8+ from onnx .reference .op_run import OpRun
89from onnx .helper import (
910make_tensor_value_info ,
1011make_node ,
@@ -68,7 +69,7 @@ def test_exp(self):
6869 functions = []
6970 inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
7071 nodes.append(
71- make_node (
72+ make_node_extended (
7273 'Exp',
7374 ['X'],
7475 ['Y']
@@ -144,14 +145,14 @@ def test_transpose(self):
144145 )
145146 inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
146147 nodes.append(
147- make_node (
148+ make_node_extended (
148149 'Reshape',
149150 ['X', 'r'],
150151 ['r0_0']
151152 )
152153 )
153154 nodes.append(
154- make_node (
155+ make_node_extended (
155156 'Transpose',
156157 ['r0_0'],
157158 ['Y'],
@@ -210,7 +211,7 @@ def test_topk_reverse(self):
210211 inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
211212 inputs.append(make_tensor_value_info('K', TensorProto.INT64, shape=[]))
212213 nodes.append(
213- make_node (
214+ make_node_extended (
214215 'TopK',
215216 ['X', 'K'],
216217 ['Values', 'Indices'],
@@ -284,14 +285,14 @@ def test_aionnxml(self):
284285 )
285286 inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
286287 nodes.append(
287- make_node (
288+ make_node_extended (
288289 'Reshape',
289290 ['X', 'r'],
290291 ['USE']
291292 )
292293 )
293294 nodes.append(
294- make_node (
295+ make_node_extended (
295296 'Normalizer',
296297 ['USE'],
297298 ['Y'],
@@ -317,16 +318,115 @@ def test_aionnxml(self):
317318self .maxDiff = None
318319self .assertEqual (expected ,code )
319320
321+ @classmethod
322+ def _code_line (cls ,code ):
323+ lines = code .split ("\n " )
324+ return "\n " .join (f"{ i + 1 :03d} { line } " for i ,line in enumerate (lines ))
325+
326+ @classmethod
327+ def _run (cls ,code ):
328+ try :
329+ code_compiled = compile (code ,"<string>" ,mode = "exec" )
330+ except Exception as e :
331+ raise AssertionError (
332+ f"Compilation failed due to{ e } \n ---\n { cls ._code_line (code )} \n ---\n { e } "
333+ )from e
334+
335+ import onnx
336+ import onnx .helper
337+ import onnx .numpy_helper
338+ import onnx_array_api .light_api .make_helper
339+ import onnx .reference .custom_element_types
340+
341+ def from_array_extended (tensor ,name = None ):
342+ dt = tensor .dtype
343+ if (
344+ dt == onnx .reference .custom_element_types .float8e4m3fn
345+ and dt .descr [0 ][0 ]== "e4m3fn"
346+ ):
347+ to = TensorProto .FLOAT8E4M3FN
348+ dt_to = np .uint8
349+ elif (
350+ dt == onnx .reference .custom_element_types .bfloat16
351+ and dt .descr [0 ][0 ]== "bfloat16"
352+ ):
353+ to = TensorProto .BFLOAT16
354+ dt_to = np .uint16
355+ else :
356+ return onnx .numpy_helper .from_array (tensor ,name )
357+
358+ t = onnx .numpy_helper .from_array (tensor .astype (dt_to ),name )
359+ t .data_type = to
360+ return t
361+
362+ globs = onnx .__dict__ .copy ()
363+ globs .update (onnx .helper .__dict__ )
364+ globs .update (onnx .numpy_helper .__dict__ )
365+ globs .update (onnx_array_api .light_api .make_helper .__dict__ )
366+ globs .update (onnx .reference .custom_element_types .__dict__ )
367+ globs ["from_array_extended" ]= from_array_extended
368+ locs = {}
369+ try :
370+ exec (code_compiled ,globs ,locs )
371+ except Exception as e :
372+ raise AssertionError (
373+ f"Execution failed due to{ e } \n ---\n { cls ._code_line (code )} \n ---\n { e } "
374+ )from e
375+ return globs ,locs
376+
320377def test_remove_nodes (self ):
321378path = os .path .join (
322379os .path .dirname (__file__ ),"_data" ,"custom_ops_type_inference_fails_0.onnx"
323380 )
324381onx = load (path )
325- text = translate (onx ,api = "onnx" )
326- with open ("debug_test_remove_nodes.py" ,"w" )as f :
327- f .write (text )
382+ code = translate (onx ,api = "onnx" )
383+ _ ,locs = self ._run (code )
384+ self .assertIn ("model" ,locs )
385+ model = locs ["model" ]
386+ x = np .arange (4 ).reshape ((- 1 ,2 )).astype (np .float32 )
387+ feeds = {"X" :x }
388+
389+ class CustomGemmFloat8E4M3FN (OpRun ):
390+ op_domain = "onnx_extented.ortops.tutorial.cpu"
391+
392+ def _run (
393+ self ,
394+ x ,
395+ y ,
396+ bias = None ,
397+ scale_x = None ,
398+ scale_y = None ,
399+ scale_z = None ,
400+ transA = False ,
401+ transB = False ,
402+ dtype = None ,
403+ rowMajor = None ,
404+ computeType = None ,
405+ ):
406+ if scale_x is not None :
407+ x = x * scale_x
408+ if transA :
409+ x = x .T
410+ if scale_y is not None :
411+ y = y * scale_y
412+ if transB :
413+ y = y .T
414+ z = x @y
415+ if bias is not None :
416+ z += bias
417+ if scale_z is not None :
418+ z = z / scale_z
419+ return (z ,)
420+
421+ ref = ReferenceEvaluator (onx ,new_ops = [CustomGemmFloat8E4M3FN ])
422+ expected = ref .run (None ,feeds )[0 ]
423+ ref2 = ReferenceEvaluator (model ,new_ops = [CustomGemmFloat8E4M3FN ])
424+ got = ref2 .run (None ,feeds )[0 ]
425+ self .assertEqualArray (expected ,got )
426+
427+ # with open("debug_test_remove_nodes.py", "w") as f:
428+ # f.write(code)
328429
329430
330431if __name__ == "__main__" :
331- # TestLightApi().test_topk()
332432unittest .main (verbosity = 2 )