Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[ONNX] Supporting different opset versions for torchlib registry#149901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Closed
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletionstest/onnx/exporter/test_api.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -246,6 +246,31 @@ def test_dynamic_shapes_supports_nested_input_model_with_input_names_assigned(se
)
)

def test_upgraded_torchlib_impl(self):
class GeluModel(torch.nn.Module):
def forward(self, input):
# Use GELU activation function
return torch.nn.functional.gelu(input, approximate="tanh")

input = torch.randn(1, 3, 4, 4)
onnx_program_op18 = torch.onnx.export(
GeluModel(),
input,
dynamo=True,
)
all_nodes_op18 = [n.op_type for n in onnx_program_op18.model.graph]
self.assertIn("Tanh", all_nodes_op18)
self.assertNotIn("Gelu", all_nodes_op18)

onnx_program_op20 = torch.onnx.export(
GeluModel(),
input,
opset_version=20,
dynamo=True,
)
all_nodes_op20 = [n.op_type for n in onnx_program_op20.model.graph]
self.assertIn("Gelu", all_nodes_op20)

def test_refine_dynamic_shapes_with_onnx_export(self):
# NOTE: From test/export/test_export.py

Expand Down
8 changes: 5 additions & 3 deletionstest/onnx/torchlib/ops_test_common.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -52,6 +52,7 @@
torch.float64,
)


TEST_OPSET_VERSION = 18
IS_MACOS = sys.platform.startswith("darwin")
IS_WINDOWS = os.name == "nt"
Expand DownExpand Up@@ -487,6 +488,7 @@ def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -
def graph_executor(
test_name: str,
outputs: Sequence[Any],
opset_version: int = TEST_OPSET_VERSION,
) -> Callable[[Callable[..., Any], tuple[Any], dict[str, Any]], None]:
"""Eagerly executes a function."""

Expand All@@ -500,10 +502,10 @@ def _capture_graph_and_evaluate_torch_script_evaluator(
(),
(),
nodes=(),
opset_imports={"":18, "pkg.torch.onnx": 1},
opset_imports={"":opset_version, "pkg.torch.onnx": 1},
name="main_graph",
)
opset = onnxscript.opset18
opset = onnxscript.values.Opset("", opset_version)
tracer = _building.OpRecorder(opset, {})
ort_inputs = {}
onnxscript_args: list[Any] = []
Expand DownExpand Up@@ -590,7 +592,7 @@ def _capture_graph_and_evaluate_torch_script_evaluator(
proto = onnxscript_function.to_function_proto()
ir_function = ir.serde.deserialize_function(proto)
onnx_model.functions[identifier] = ir_function
_ir_passes.add_torchlib_common_imports(onnx_model)
_ir_passes.add_torchlib_common_imports(onnx_model, opset_version=opset_version)
_ir_passes.add_opset_imports(onnx_model)
# Make sure the model is valid
model_proto = ir.to_proto(onnx_model)
Expand Down
11 changes: 10 additions & 1 deletiontest/onnx/torchlib/ops_test_data.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -46,7 +46,7 @@
import ops_test_common

import torch
from torch.onnx._internal.exporter._torchlib.ops import core as core_ops
from torch.onnx._internal.exporter._torchlib.ops import core as core_ops, nn as nn_ops
from torch.testing._internal import common_methods_invocations
from torch.testing._internal.opinfo import definitions as opinfo_definitions

Expand DownExpand Up@@ -78,6 +78,12 @@ class TorchLibOpInfo:
compare_shape_only_for_output: tuple[int, ...] = ()
# Whether the function is designed for complex inputs
complex: bool = False
# The ONNX opset version in which the function was introduced.
# Its specifies the minimum ONNX opset version required to use the function.
# It ensures that the function is only used when the target ONNX opset version
# is compatible. For example, if `opset_introduced=20`, the function will only
# be used when exporting to ONNX models targeting opset version 20 or higher.
opset_introduced: int = 18
# The acceptable tolerance of the inference result difference between PyTorch and ORT.
# Format: {dtype: (rtol, atol)}.
# For example: {torch.float16: (1e-3, 1e-3)}
Expand DownExpand Up@@ -447,8 +453,10 @@ def _where_input_wrangler(
TorchLibOpInfo("abs", core_ops.aten_abs_complex, complex=True),
TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}),
TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True),
TorchLibOpInfo("gelu_op20", nn_ops.aten_gelu_opset20, opset_introduced=20),
)


ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))
ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims"))
ops_test_common.duplicate_opinfo(
Expand DownExpand Up@@ -500,6 +508,7 @@ def _where_input_wrangler(
"nn.functional.replication_pad3d",
),
)
ops_test_common.duplicate_opinfo(OPS_DB, "nn.functional.gelu", ("gelu_op20",))
ops_test_common.duplicate_opinfo(
OPS_DB,
"nn.functional.scaled_dot_product_attention",
Expand Down
4 changes: 3 additions & 1 deletiontest/onnx/torchlib/test_ops.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -220,7 +220,9 @@ def run_test_output_match(

test_name = test_suite.id()
function_output, model_proto = function_executor(
test_name, reference_torch_outputs
test_name,
reference_torch_outputs,
opset_version=torchlib_op_info.opset_introduced,
)(onnx_function, input_onnx, kwargs_onnx)
# Finally we re-flatten everything
# TODO: add pytree structure comparison.
Expand Down
5 changes: 2 additions & 3 deletionstorch/onnx/_internal/exporter/_compat.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -50,7 +50,7 @@ def export_compat(
verbose: bool | None = None,
input_names: Sequence[str] | None = None,
output_names: Sequence[str] | None = None,
opset_version: int | None =None,
opset_version: int | None =_constants.TORCHLIB_OPSET,
Copy link
Collaborator

@titaiwangmstitaiwangmsJul 25, 2025
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Should we delete the None and line 70 and 71@justinchuby
Are there other contexts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

None is an allowed input in the main export api. So we accept it here

custom_translation_table: dict[Callable, Callable | Sequence[Callable]]
| None = None,
dynamic_axes: Mapping[str, Mapping[int, str]]
Expand DownExpand Up@@ -105,8 +105,7 @@ def export_compat(
dynamic_shapes_with_export_dim, need_axis_mapping = (
_dynamic_shapes.convert_str_to_export_dim(dynamic_shapes)
)

registry = _registration.ONNXRegistry.from_torchlib()
registry = _registration.ONNXRegistry().from_torchlib(opset_version=opset_version)
if custom_translation_table is not None:
for torch_op, onnx_ops in custom_translation_table.items():
# TODO(justinchuby): Support complex inputs with annotations
Expand Down
6 changes: 5 additions & 1 deletiontorch/onnx/_internal/exporter/_ir_passes.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -90,7 +90,9 @@ def rename_axis(model: ir.Model, rename_mapping: dict[str, str]) -> None:
value.shape = ir.Shape(new_shape)


def add_torchlib_common_imports(model: ir.Model) -> None:
def add_torchlib_common_imports(
model: ir.Model, opset_version: int = _constants.TORCHLIB_OPSET
) -> None:
"""Hack to add torchlib common imports to the model."""

try:
Expand All@@ -99,9 +101,11 @@ def add_torchlib_common_imports(model: ir.Model) -> None:

model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1
rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto())
rank_func.opset_imports[""] = opset_version
is_scalar_func = ir.serde.deserialize_function(
common_ops.IsScalar.to_function_proto()
)
is_scalar_func.opset_imports[""] = opset_version
model.functions[rank_func.identifier()] = rank_func
model.functions[is_scalar_func.identifier()] = is_scalar_func
except Exception:
Expand Down
27 changes: 26 additions & 1 deletiontorch/onnx/_internal/exporter/_registration.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -42,6 +42,9 @@ class OnnxDecompMeta:
signature: The ONNX signature of the function. When None, the signature is inferred.
is_custom: Whether the function is a custom function.
is_complex: Whether the function is a function that handles complex valued inputs.
opset_introduced:
The ONNX opset version in which the function was introduced.
Its specifies the minimum ONNX opset version required to use the function.
device: The device the function is registered to. If None, it is registered to all devices.
skip_signature_inference: Whether to skip signature inference for the function.
"""
Expand All@@ -51,6 +54,7 @@ class OnnxDecompMeta:
signature: _schemas.OpSignature | None
is_custom: bool = False
is_complex: bool = False
opset_introduced: int = 18
device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051
skip_signature_inference: bool = False

Expand DownExpand Up@@ -150,13 +154,14 @@ def opset_version(self) -> int:
return self._opset_version

@classmethod
def from_torchlib(cls) -> ONNXRegistry:
def from_torchlib(cls, opset_version=_constants.TORCHLIB_OPSET) -> ONNXRegistry:
"""Populates the registry with ATen functions from torchlib.

Args:
torchlib_registry: The torchlib registry to use for populating the registry.
"""
registry = cls()
registry._opset_version = opset_version
for meta in _torchlib_registry.get_torchlib_ops():
registry._register(meta.fx_target, meta)

Expand DownExpand Up@@ -185,6 +190,7 @@ def from_torchlib(cls) -> ONNXRegistry:
logger.exception("Failed to register '%s'. Skipped", qualified_name)
continue

registry._cleanup_registry_based_on_opset_version()
return registry

def _register(
Expand DownExpand Up@@ -274,5 +280,24 @@ def is_registered(self, target: TorchOp) -> bool:
"""
return bool(self.get_decomps(target))

def _cleanup_registry_based_on_opset_version(self) -> None:
"""Pick the implementation with the highest opset version valid until the current opset version."""
cleaned_functions = {}
for target_or_name, decomps in self.functions.items():
# Filter decompositions to only include those with opset_introduced <= opset_version
decomps = [d for d in decomps if d.opset_introduced <= self.opset_version]

# Keep only the decomposition with the highest opset_introduced
if decomps:
# Find the maximum opset_introduced
max_opset = max(d.opset_introduced for d in decomps)

# Keep all decompositions with the maximum opset_introduced
cleaned_functions[target_or_name] = [
d for d in decomps if d.opset_introduced == max_opset
]

self.functions = cleaned_functions

def __repr__(self) -> str:
return f"{self.__class__.__name__}(functions={self.functions})"
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -30,6 +30,7 @@ def onnx_impl(
*,
trace_only: bool = False,
complex: bool = False,
opset_introduced: int = 18,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Maybe set this to _constants.TORCHLIB_OPSET too?

no_compile: bool = False,
private: bool = False,
) -> Callable[[_T], _T]:
Expand DownExpand Up@@ -74,6 +75,7 @@ def wrapper(
fx_target=t,
signature=None,
is_complex=complex,
opset_introduced=opset_introduced,
skip_signature_inference=no_compile,
)
)
Expand Down
4 changes: 2 additions & 2 deletionstorch/onnx/_internal/exporter/_torchlib/ops/__init__.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
from __future__ import annotations


__all__ = ["core", "hop", "symbolic"]
__all__ = ["core", "hop", "nn", "symbolic"]

from torch.onnx._internal.exporter._torchlib.ops import core, hop, symbolic
from torch.onnx._internal.exporter._torchlib.ops import core, hop,nn,symbolic
26 changes: 26 additions & 0 deletionstorch/onnx/_internal/exporter/_torchlib/ops/nn.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
"""torch.ops.aten operators under the `core` module."""
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index"
# ruff: noqa: TCH001,TCH002
# flake8: noqa

from __future__ import annotations

import math

from onnxscript.onnx_opset import opset20 as op20

import torch
from torch.onnx._internal.exporter._torchlib._tensor_typing import TReal
from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl


aten = torch.ops.aten


@onnx_impl(aten.gelu.default, trace_only=True, opset_introduced=20)
def aten_gelu_opset20(
self: TReal,
approximate: str = "none",
) -> TReal:
"""gelu(Tensor self, *, bool approximate=False) -> Tensor"""
return op20.Gelu(self, approximate=approximate)
Loading

[8]ページ先頭

©2009-2026 Movatter.jp