|
| 1 | +importos |
| 2 | +importplatform |
| 3 | +importunittest |
| 4 | +fromtypingimportAny |
| 5 | +importnumpy |
| 6 | +importonnx.backend.base |
| 7 | +importonnx.backend.test |
| 8 | +importonnx.shape_inference |
| 9 | +importonnx.version_converter |
| 10 | +fromonnximportModelProto |
| 11 | +fromonnx.backend.baseimportDevice,DeviceType |
| 12 | +fromonnx.defsimportonnx_opset_version |
| 13 | +fromonnx_array_api.referenceimportExtendedReferenceEvaluator |
| 14 | + |
| 15 | + |
| 16 | +classExtendedReferenceEvaluatorBackendRep(onnx.backend.base.BackendRep): |
| 17 | +def__init__(self,session): |
| 18 | +self._session=session |
| 19 | + |
| 20 | +defrun(self,inputs,**kwargs): |
| 21 | +ifisinstance(inputs,numpy.ndarray): |
| 22 | +inputs= [inputs] |
| 23 | +ifisinstance(inputs,list): |
| 24 | +iflen(inputs)==len(self._session.input_names): |
| 25 | +feeds=dict(zip(self._session.input_names,inputs)) |
| 26 | +else: |
| 27 | +feeds= {} |
| 28 | +pos_inputs=0 |
| 29 | +forinp,tshapeinzip( |
| 30 | +self._session.input_names,self._session.input_types |
| 31 | + ): |
| 32 | +shape=tuple(d.dim_valuefordintshape.tensor_type.shape.dim) |
| 33 | +ifshape==inputs[pos_inputs].shape: |
| 34 | +feeds[inp]=inputs[pos_inputs] |
| 35 | +pos_inputs+=1 |
| 36 | +ifpos_inputs>=len(inputs): |
| 37 | +break |
| 38 | +elifisinstance(inputs,dict): |
| 39 | +feeds=inputs |
| 40 | +else: |
| 41 | +raiseTypeError(f"Unexpected input type{type(inputs)!r}.") |
| 42 | +outs=self._session.run(None,feeds) |
| 43 | +returnouts |
| 44 | + |
| 45 | + |
| 46 | +classExtendedReferenceEvaluatorBackend(onnx.backend.base.Backend): |
| 47 | +@classmethod |
| 48 | +defis_opset_supported(cls,model):# pylint: disable=unused-argument |
| 49 | +returnTrue,"" |
| 50 | + |
| 51 | +@classmethod |
| 52 | +defsupports_device(cls,device:str)->bool: |
| 53 | +d=Device(device) |
| 54 | +returnd.type==DeviceType.CPU# type: ignore[no-any-return] |
| 55 | + |
| 56 | +@classmethod |
| 57 | +defcreate_inference_session(cls,model): |
| 58 | +returnExtendedReferenceEvaluator(model) |
| 59 | + |
| 60 | +@classmethod |
| 61 | +defprepare( |
| 62 | +cls,model:Any,device:str="CPU",**kwargs:Any |
| 63 | + )->ExtendedReferenceEvaluatorBackendRep: |
| 64 | +# if isinstance(model, ExtendedReferenceEvaluatorBackendRep): |
| 65 | +# return model |
| 66 | +ifisinstance(model,ExtendedReferenceEvaluator): |
| 67 | +returnExtendedReferenceEvaluatorBackendRep(model) |
| 68 | +ifisinstance(model, (str,bytes,ModelProto)): |
| 69 | +inf=cls.create_inference_session(model) |
| 70 | +returncls.prepare(inf,device,**kwargs) |
| 71 | +raiseTypeError(f"Unexpected type{type(model)} for model.") |
| 72 | + |
| 73 | +@classmethod |
| 74 | +defrun_model(cls,model,inputs,device=None,**kwargs): |
| 75 | +rep=cls.prepare(model,device,**kwargs) |
| 76 | +returnrep.run(inputs,**kwargs) |
| 77 | + |
| 78 | +@classmethod |
| 79 | +defrun_node(cls,node,inputs,device=None,outputs_info=None,**kwargs): |
| 80 | +raiseNotImplementedError("Unable to run the model node by node.") |
| 81 | + |
| 82 | + |
| 83 | +backend_test=onnx.backend.test.BackendTest( |
| 84 | +ExtendedReferenceEvaluatorBackend,__name__ |
| 85 | +) |
| 86 | + |
| 87 | +ifos.getenv("APPVEYOR"): |
| 88 | +backend_test.exclude("(test_vgg19|test_zfnet)") |
| 89 | +ifplatform.architecture()[0]=="32bit": |
| 90 | +backend_test.exclude("(test_vgg19|test_zfnet|test_bvlc_alexnet)") |
| 91 | +ifplatform.system()=="Windows": |
| 92 | +backend_test.exclude("test_sequence_model") |
| 93 | + |
| 94 | +ifonnx_opset_version()<21: |
| 95 | +backend_test.exclude( |
| 96 | +"(test_averagepool_2d_dilations" |
| 97 | +"|test_if*" |
| 98 | +"|test_loop*" |
| 99 | +"|test_scan*" |
| 100 | +"|test_sequence_map*" |
| 101 | +")" |
| 102 | + ) |
| 103 | + |
| 104 | +ifonnx_opset_version()<19: |
| 105 | +backend_test.exclude( |
| 106 | +"(test_argm[ai][nx]_default_axis_example" |
| 107 | +"|test_argm[ai][nx]_default_axis_random" |
| 108 | +"|test_argm[ai][nx]_keepdims_example" |
| 109 | +"|test_argm[ai][nx]_keepdims_random" |
| 110 | +"|test_argm[ai][nx]_negative_axis_keepdims_example" |
| 111 | +"|test_argm[ai][nx]_negative_axis_keepdims_random" |
| 112 | +"|test_argm[ai][nx]_no_keepdims_example" |
| 113 | +"|test_argm[ai][nx]_no_keepdims_random" |
| 114 | +"|test_col2im_pads" |
| 115 | +"|test_gru_batchwise" |
| 116 | +"|test_gru_defaults" |
| 117 | +"|test_gru_seq_length" |
| 118 | +"|test_gru_with_initial_bias" |
| 119 | +"|test_layer_normalization_2d_axis1_expanded" |
| 120 | +"|test_layer_normalization_2d_axis_negative_1_expanded" |
| 121 | +"|test_layer_normalization_3d_axis1_epsilon_expanded" |
| 122 | +"|test_layer_normalization_3d_axis2_epsilon_expanded" |
| 123 | +"|test_layer_normalization_3d_axis_negative_1_epsilon_expanded" |
| 124 | +"|test_layer_normalization_3d_axis_negative_2_epsilon_expanded" |
| 125 | +"|test_layer_normalization_4d_axis1_expanded" |
| 126 | +"|test_layer_normalization_4d_axis2_expanded" |
| 127 | +"|test_layer_normalization_4d_axis3_expanded" |
| 128 | +"|test_layer_normalization_4d_axis_negative_1_expanded" |
| 129 | +"|test_layer_normalization_4d_axis_negative_2_expanded" |
| 130 | +"|test_layer_normalization_4d_axis_negative_3_expanded" |
| 131 | +"|test_layer_normalization_default_axis_expanded" |
| 132 | +"|test_logsoftmax_large_number_expanded" |
| 133 | +"|test_lstm_batchwise" |
| 134 | +"|test_lstm_defaults" |
| 135 | +"|test_lstm_with_initial_bias" |
| 136 | +"|test_lstm_with_peepholes" |
| 137 | +"|test_mvn" |
| 138 | +"|test_mvn_expanded" |
| 139 | +"|test_softmax_large_number_expanded" |
| 140 | +"|test_operator_reduced_mean" |
| 141 | +"|test_operator_reduced_mean_keepdim)" |
| 142 | + ) |
| 143 | + |
| 144 | +# The following tests are not supported. |
| 145 | +backend_test.exclude( |
| 146 | +"(test_gradient" |
| 147 | +"|test_if_opt" |
| 148 | +"|test_loop16_seq_none" |
| 149 | +"|test_range_float_type_positive_delta_expanded" |
| 150 | +"|test_range_int32_type_negative_delta_expanded" |
| 151 | +"|test_scan_sum)" |
| 152 | +) |
| 153 | + |
| 154 | +ifonnx_opset_version()<21: |
| 155 | +# The following tests are using types not supported by NumPy. |
| 156 | +# They could be if method to_array is extended to support custom |
| 157 | +# types the same as the reference implementation does |
| 158 | +# (see onnx.reference.op_run.to_array_extended). |
| 159 | +backend_test.exclude( |
| 160 | +"(test_cast_FLOAT_to_BFLOAT16" |
| 161 | +"|test_cast_BFLOAT16_to_FLOAT" |
| 162 | +"|test_cast_BFLOAT16_to_FLOAT" |
| 163 | +"|test_castlike_BFLOAT16_to_FLOAT" |
| 164 | +"|test_castlike_FLOAT_to_BFLOAT16" |
| 165 | +"|test_castlike_FLOAT_to_BFLOAT16_expanded" |
| 166 | +"|test_cast_no_saturate_" |
| 167 | +"|_to_FLOAT8" |
| 168 | +"|_FLOAT8" |
| 169 | +"|test_quantizelinear_e4m3fn" |
| 170 | +"|test_quantizelinear_e5m2" |
| 171 | +")" |
| 172 | + ) |
| 173 | + |
| 174 | +# Disable test about float 8 |
| 175 | +backend_test.exclude( |
| 176 | +"(test_castlike_BFLOAT16*" |
| 177 | +"|test_cast_BFLOAT16*" |
| 178 | +"|test_cast_no_saturate*" |
| 179 | +"|test_cast_FLOAT_to_FLOAT8*" |
| 180 | +"|test_cast_FLOAT16_to_FLOAT8*" |
| 181 | +"|test_cast_FLOAT8_to_*" |
| 182 | +"|test_castlike_BFLOAT16*" |
| 183 | +"|test_castlike_no_saturate*" |
| 184 | +"|test_castlike_FLOAT_to_FLOAT8*" |
| 185 | +"|test_castlike_FLOAT16_to_FLOAT8*" |
| 186 | +"|test_castlike_FLOAT8_to_*" |
| 187 | +"|test_quantizelinear_e*)" |
| 188 | + ) |
| 189 | + |
| 190 | +# The following tests are too slow with the reference implementation (Conv). |
| 191 | +backend_test.exclude( |
| 192 | +"(test_bvlc_alexnet" |
| 193 | +"|test_densenet121" |
| 194 | +"|test_inception_v1" |
| 195 | +"|test_inception_v2" |
| 196 | +"|test_resnet50" |
| 197 | +"|test_shufflenet" |
| 198 | +"|test_squeezenet" |
| 199 | +"|test_vgg19" |
| 200 | +"|test_zfnet512)" |
| 201 | +) |
| 202 | + |
| 203 | +# The following tests cannot pass because they consists in generating random number. |
| 204 | +backend_test.exclude("(test_bernoulli)") |
| 205 | + |
| 206 | +ifonnx_opset_version()<21: |
| 207 | +# The following tests fail due to a bug in the backend test comparison. |
| 208 | +backend_test.exclude( |
| 209 | +"(test_cast_FLOAT_to_STRING|test_castlike_FLOAT_to_STRING|test_strnorm)" |
| 210 | + ) |
| 211 | + |
| 212 | +# The following tests fail due to a shape mismatch. |
| 213 | +backend_test.exclude( |
| 214 | +"(test_center_crop_pad_crop_axes_hwc_expanded|test_lppool_2d_dilations)" |
| 215 | + ) |
| 216 | + |
| 217 | +# The following tests fail due to a type mismatch. |
| 218 | +backend_test.exclude("(test_eyelike_without_dtype)") |
| 219 | + |
| 220 | +# The following tests fail due to discrepancies (small but still higher than 1e-7). |
| 221 | +backend_test.exclude("test_adam_multiple")# 1e-2 |
| 222 | + |
| 223 | + |
| 224 | +# import all test cases at global scope to make them visible to python.unittest |
| 225 | +globals().update(backend_test.test_cases) |
| 226 | + |
| 227 | +if__name__=="__main__": |
| 228 | +res=unittest.main(verbosity=2,exit=False) |
| 229 | +tests_run=res.result.testsRun |
| 230 | +errors=len(res.result.errors) |
| 231 | +skipped=len(res.result.skipped) |
| 232 | +unexpected_successes=len(res.result.unexpectedSuccesses) |
| 233 | +expected_failures=len(res.result.expectedFailures) |
| 234 | +print("---------------------------------") |
| 235 | +print( |
| 236 | +f"tests_run={tests_run} errors={errors} skipped={skipped} " |
| 237 | +f"unexpected_successes={unexpected_successes} " |
| 238 | +f"expected_failures={expected_failures}" |
| 239 | + ) |