|
1 | 1 | fromtypingimportAny,Callable,List,Optional,Tuple |
2 | 2 | importnumpyasnp |
3 | | -fromonnximportModelProto |
| 3 | +fromonnximportModelProto,TensorProto |
4 | 4 | fromonnx.referenceimportReferenceEvaluator |
5 | 5 | from .._helpersimportnp_dtype_to_tensor_dtype |
6 | 6 | from .npx_numpy_tensors_opsimportConstantOfShape |
@@ -183,6 +183,60 @@ def __array_namespace__(self, api_version: Optional[str] = None): |
183 | 183 | f"Unable to return an implementation for api_version={api_version!r}." |
184 | 184 | ) |
185 | 185 |
|
| 186 | +def__bool__(self): |
| 187 | +"Implicit conversion to bool." |
| 188 | +ifself.dtype!=DType(TensorProto.BOOL): |
| 189 | +raiseTypeError( |
| 190 | +f"Conversion to bool only works for bool scalar, not for{self!r}." |
| 191 | + ) |
| 192 | +ifself.shape== (0,): |
| 193 | +returnFalse |
| 194 | +iflen(self.shape)!=0: |
| 195 | +raiseValueError( |
| 196 | +f"Conversion to bool only works for scalar, not for{self!r}." |
| 197 | + ) |
| 198 | +returnbool(self._tensor) |
| 199 | + |
| 200 | +def__int__(self): |
| 201 | +"Implicit conversion to bool." |
| 202 | +iflen(self.shape)!=0: |
| 203 | +raiseValueError( |
| 204 | +f"Conversion to bool only works for scalar, not for{self!r}." |
| 205 | + ) |
| 206 | +ifself.dtypenotin { |
| 207 | +DType(TensorProto.INT64), |
| 208 | +DType(TensorProto.INT32), |
| 209 | +DType(TensorProto.INT16), |
| 210 | +DType(TensorProto.INT8), |
| 211 | +DType(TensorProto.UINT64), |
| 212 | +DType(TensorProto.UINT32), |
| 213 | +DType(TensorProto.UINT16), |
| 214 | +DType(TensorProto.UINT8), |
| 215 | + }: |
| 216 | +raiseTypeError( |
| 217 | +f"Conversion to int only works for int scalar, " |
| 218 | +f"not for dtype={self.dtype}." |
| 219 | + ) |
| 220 | +returnint(self._tensor) |
| 221 | + |
| 222 | +def__float__(self): |
| 223 | +"Implicit conversion to bool." |
| 224 | +iflen(self.shape)!=0: |
| 225 | +raiseValueError( |
| 226 | +f"Conversion to bool only works for scalar, not for{self!r}." |
| 227 | + ) |
| 228 | +ifself.dtypenotin { |
| 229 | +DType(TensorProto.FLOAT), |
| 230 | +DType(TensorProto.DOUBLE), |
| 231 | +DType(TensorProto.FLOAT16), |
| 232 | +DType(TensorProto.BFLOAT16), |
| 233 | + }: |
| 234 | +raiseTypeError( |
| 235 | +f"Conversion to int only works for float scalar, " |
| 236 | +f"not for dtype={self.dtype}." |
| 237 | + ) |
| 238 | +returnfloat(self._tensor) |
| 239 | + |
186 | 240 |
|
187 | 241 | classJitNumpyTensor(NumpyTensor,JitTensor): |
188 | 242 | """ |
|