- Notifications
You must be signed in to change notification settings - Fork26.8k
[ONNX] Supporting different opset versions for torchlib registry#149901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Changes fromall commits
e6c6b1e0978311872b5be7fec74a4ec0a695fde15a9b220c42d2d569581da79File filter
Filter by extension
Conversations
Uh oh!
There was an error while loading.Please reload this page.
Jump to
Uh oh!
There was an error while loading.Please reload this page.
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -46,7 +46,7 @@ | ||
| import ops_test_common | ||
| import torch | ||
| from torch.onnx._internal.exporter._torchlib.ops import core as core_ops, nn as nn_ops | ||
| from torch.testing._internal import common_methods_invocations | ||
| from torch.testing._internal.opinfo import definitions as opinfo_definitions | ||
| @@ -78,6 +78,12 @@ class TorchLibOpInfo: | ||
| compare_shape_only_for_output: tuple[int, ...] = () | ||
| # Whether the function is designed for complex inputs | ||
| complex: bool = False | ||
| # The ONNX opset version in which the function was introduced. | ||
| # Its specifies the minimum ONNX opset version required to use the function. | ||
| # It ensures that the function is only used when the target ONNX opset version | ||
| # is compatible. For example, if `opset_introduced=20`, the function will only | ||
| # be used when exporting to ONNX models targeting opset version 20 or higher. | ||
| opset_introduced: int = 18 | ||
shubhambhokare1 marked this conversation as resolved. OutdatedShow resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
| # The acceptable tolerance of the inference result difference between PyTorch and ORT. | ||
| # Format: {dtype: (rtol, atol)}. | ||
| # For example: {torch.float16: (1e-3, 1e-3)} | ||
| @@ -447,8 +453,10 @@ def _where_input_wrangler( | ||
| TorchLibOpInfo("abs", core_ops.aten_abs_complex, complex=True), | ||
| TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}), | ||
| TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True), | ||
| TorchLibOpInfo("gelu_op20", nn_ops.aten_gelu_opset20, opset_introduced=20), | ||
| ) | ||
| ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims")) | ||
| ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims")) | ||
| ops_test_common.duplicate_opinfo( | ||
| @@ -500,6 +508,7 @@ def _where_input_wrangler( | ||
| "nn.functional.replication_pad3d", | ||
| ), | ||
| ) | ||
| ops_test_common.duplicate_opinfo(OPS_DB, "nn.functional.gelu", ("gelu_op20",)) | ||
| ops_test_common.duplicate_opinfo( | ||
| OPS_DB, | ||
| "nn.functional.scaled_dot_product_attention", | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -50,7 +50,7 @@ def export_compat( | ||
| verbose: bool | None = None, | ||
| input_names: Sequence[str] | None = None, | ||
| output_names: Sequence[str] | None = None, | ||
| opset_version: int | None =_constants.TORCHLIB_OPSET, | ||
Collaborator
| ||
| custom_translation_table: dict[Callable, Callable | Sequence[Callable]] | ||
| | None = None, | ||
| dynamic_axes: Mapping[str, Mapping[int, str]] | ||
| @@ -105,8 +105,7 @@ def export_compat( | ||
| dynamic_shapes_with_export_dim, need_axis_mapping = ( | ||
| _dynamic_shapes.convert_str_to_export_dim(dynamic_shapes) | ||
| ) | ||
| registry = _registration.ONNXRegistry().from_torchlib(opset_version=opset_version) | ||
| if custom_translation_table is not None: | ||
| for torch_op, onnx_ops in custom_translation_table.items(): | ||
| # TODO(justinchuby): Support complex inputs with annotations | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -90,7 +90,9 @@ def rename_axis(model: ir.Model, rename_mapping: dict[str, str]) -> None: | ||
| value.shape = ir.Shape(new_shape) | ||
| def add_torchlib_common_imports( | ||
shubhambhokare1 marked this conversation as resolved. OutdatedShow resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
| model: ir.Model, opset_version: int = _constants.TORCHLIB_OPSET | ||
| ) -> None: | ||
| """Hack to add torchlib common imports to the model.""" | ||
| try: | ||
| @@ -99,9 +101,11 @@ def add_torchlib_common_imports(model: ir.Model) -> None: | ||
| model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1 | ||
| rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto()) | ||
| rank_func.opset_imports[""] = opset_version | ||
| is_scalar_func = ir.serde.deserialize_function( | ||
| common_ops.IsScalar.to_function_proto() | ||
| ) | ||
| is_scalar_func.opset_imports[""] = opset_version | ||
| model.functions[rank_func.identifier()] = rank_func | ||
| model.functions[is_scalar_func.identifier()] = is_scalar_func | ||
| except Exception: | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -42,6 +42,9 @@ class OnnxDecompMeta: | ||
| signature: The ONNX signature of the function. When None, the signature is inferred. | ||
| is_custom: Whether the function is a custom function. | ||
| is_complex: Whether the function is a function that handles complex valued inputs. | ||
| opset_introduced: | ||
| The ONNX opset version in which the function was introduced. | ||
| Its specifies the minimum ONNX opset version required to use the function. | ||
| device: The device the function is registered to. If None, it is registered to all devices. | ||
| skip_signature_inference: Whether to skip signature inference for the function. | ||
| """ | ||
| @@ -51,6 +54,7 @@ class OnnxDecompMeta: | ||
| signature: _schemas.OpSignature | None | ||
| is_custom: bool = False | ||
| is_complex: bool = False | ||
| opset_introduced: int = 18 | ||
shubhambhokare1 marked this conversation as resolved. OutdatedShow resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
| device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051 | ||
| skip_signature_inference: bool = False | ||
| @@ -150,13 +154,14 @@ def opset_version(self) -> int: | ||
| return self._opset_version | ||
| @classmethod | ||
| def from_torchlib(cls, opset_version=_constants.TORCHLIB_OPSET) -> ONNXRegistry: | ||
| """Populates the registry with ATen functions from torchlib. | ||
| Args: | ||
| torchlib_registry: The torchlib registry to use for populating the registry. | ||
| """ | ||
| registry = cls() | ||
| registry._opset_version = opset_version | ||
| for meta in _torchlib_registry.get_torchlib_ops(): | ||
| registry._register(meta.fx_target, meta) | ||
| @@ -185,6 +190,7 @@ def from_torchlib(cls) -> ONNXRegistry: | ||
| logger.exception("Failed to register '%s'. Skipped", qualified_name) | ||
| continue | ||
| registry._cleanup_registry_based_on_opset_version() | ||
| return registry | ||
| def _register( | ||
| @@ -274,5 +280,24 @@ def is_registered(self, target: TorchOp) -> bool: | ||
| """ | ||
| return bool(self.get_decomps(target)) | ||
| def _cleanup_registry_based_on_opset_version(self) -> None: | ||
| """Pick the implementation with the highest opset version valid until the current opset version.""" | ||
| cleaned_functions = {} | ||
| for target_or_name, decomps in self.functions.items(): | ||
| # Filter decompositions to only include those with opset_introduced <= opset_version | ||
| decomps = [d for d in decomps if d.opset_introduced <= self.opset_version] | ||
| # Keep only the decomposition with the highest opset_introduced | ||
| if decomps: | ||
| # Find the maximum opset_introduced | ||
| max_opset = max(d.opset_introduced for d in decomps) | ||
| # Keep all decompositions with the maximum opset_introduced | ||
| cleaned_functions[target_or_name] = [ | ||
| d for d in decomps if d.opset_introduced == max_opset | ||
| ] | ||
| self.functions = cleaned_functions | ||
| def __repr__(self) -> str: | ||
| return f"{self.__class__.__name__}(functions={self.functions})" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -30,6 +30,7 @@ def onnx_impl( | ||
| *, | ||
| trace_only: bool = False, | ||
| complex: bool = False, | ||
| opset_introduced: int = 18, | ||
Collaborator There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others.Learn more. Maybe set this to _constants.TORCHLIB_OPSET too? | ||
| no_compile: bool = False, | ||
| private: bool = False, | ||
| ) -> Callable[[_T], _T]: | ||
| @@ -74,6 +75,7 @@ def wrapper( | ||
| fx_target=t, | ||
| signature=None, | ||
| is_complex=complex, | ||
| opset_introduced=opset_introduced, | ||
| skip_signature_inference=no_compile, | ||
| ) | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| from __future__ import annotations | ||
| __all__ = ["core", "hop", "nn", "symbolic"] | ||
| from torch.onnx._internal.exporter._torchlib.ops import core, hop,nn,symbolic |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| """torch.ops.aten operators under the `core` module.""" | ||
| # mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index" | ||
| # ruff: noqa: TCH001,TCH002 | ||
| # flake8: noqa | ||
| from __future__ import annotations | ||
| import math | ||
| from onnxscript.onnx_opset import opset20 as op20 | ||
| import torch | ||
| from torch.onnx._internal.exporter._torchlib._tensor_typing import TReal | ||
| from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl | ||
| aten = torch.ops.aten | ||
| @onnx_impl(aten.gelu.default, trace_only=True, opset_introduced=20) | ||
| def aten_gelu_opset20( | ||
| self: TReal, | ||
| approximate: str = "none", | ||
| ) -> TReal: | ||
| """gelu(Tensor self, *, bool approximate=False) -> Tensor""" | ||
| return op20.Gelu(self, approximate=approximate) | ||
shubhambhokare1 marked this conversation as resolved. OutdatedShow resolvedHide resolvedUh oh!There was an error while loading.Please reload this page. | ||
Uh oh!
There was an error while loading.Please reload this page.