Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

slice_scatter decomposition#2519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
apbose merged 1 commit intomainfromslice_scatter_decomposition
May 30, 2024
Merged

Conversation

@apbose
Copy link
Collaborator

@apboseapbose commentedDec 6, 2023
edited
Loading

Fixes#2434
This PR would be dependant on#2664 and#2669. Major changes

  1. 2664- Implementation makes use ofaten::scatter.src
  2. 2669- Constants getting converted to fake tensors inget_attr call due to which different device locationmeta and cpu in torch

@apboseapbose marked this pull request as draftDecember 6, 2023 09:08
@github-actionsgithub-actionsbot added component: api [Python]Issues re: Python API component: dynamoIssues relating to the `torch.compile` or `torch._dynamo.export` paths component: loweringIssues re: The lowering / preprocessing passes component: testsIssues re: Tests labelsDec 6, 2023
Copy link

@github-actionsgithub-actionsbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actionsgithub-actionsbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decompositions.py2023-12-06 09:08:13.895012+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decompositions.py2023-12-06 09:11:58.776404+00:00@@ -186,21 +186,22 @@    src_dim = list(src_tensor.shape())    src_dim[dim] = torch.floor_divide(end - start, step)    src = torch.expand(src, src_dim)-    if (start == 0 and end == dim_size and step == 0):+    if start == 0 and end == dim_size and step == 0:        return input_tensor    mask = []    if start != 0:        mask.append(torch.ge(input_tensor_shape, start))    if end != dim_size:        mask.append(torch.ge(input_tensor_shape, end))    if step != 1:        mask.append(torch.eq(src_dim, 0))    src_val = torch.masked(mask, src_dim, 0)-    return torch.where(mask, src_val,input_tensor)+    return torch.where(mask, src_val, input_tensor)+def get_decompositions(    enable_experimental_decompositions: bool = False,) -> Dict[OpOverload, Callable[[Any], Any]]:    if enable_experimental_decompositions:--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py2023-12-06 09:08:13.915012+00:00+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py2023-12-06 09:12:02.062349+00:00@@ -418,11 +418,10 @@            0,            DECIMALS_OF_AGREEMENT,            f"MaxPool3d TRT outputs don't match with the original model.",        )-    def test_lowering_select_scatter_module(self):        class selectScatter(torch.nn.Module):            def __init__(self, *args, **kwargs) -> None:                super().__init__(*args, **kwargs)@@ -435,11 +434,10 @@            torch.ops.aten.lt.default,            torch.ops.aten.lt.default,            torch.ops.aten.expand.default,            torch.ops.aten.eq.default,            torch.ops.aten.where.default,-        }        unexpected_ops = {torch.ops.aten.select_scatter}        inputs = [torch.randn(2, 2), torch.ones(2)]@@ -485,7 +483,8 @@            0,            DECIMALS_OF_AGREEMENT,            f"Select_scatter TRT outputs don't match with the original model.",        )+if __name__ == "__main__":    run_tests()

Copy link

@github-actionsgithub-actionsbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actionsgithub-actionsbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py2023-12-19 18:39:51.699972+00:00+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_decompositions.py2023-12-19 18:41:49.917712+00:00@@ -418,11 +418,10 @@            0,            DECIMALS_OF_AGREEMENT,            f"MaxPool3d TRT outputs don't match with the original model.",        )-    def test_lowering_select_scatter_module(self):        class selectScatter(torch.nn.Module):            def __init__(self, *args, **kwargs) -> None:                super().__init__(*args, **kwargs)@@ -435,11 +434,10 @@            torch.ops.aten.lt.default,            torch.ops.aten.lt.default,            torch.ops.aten.expand.default,            torch.ops.aten.eq.default,            torch.ops.aten.where.default,-        }        unexpected_ops = {torch.ops.aten.select_scatter}        inputs = [torch.randn(2, 2), torch.ones(2)]@@ -485,7 +483,8 @@            0,            DECIMALS_OF_AGREEMENT,            f"Select_scatter TRT outputs don't match with the original model.",        )+if __name__ == "__main__":    run_tests()

Copy link

@github-actionsgithub-actionsbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actionsgithub-actionsbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Code conforms to Python style guidelines

@apboseapbose marked this pull request as ready for reviewJanuary 2, 2024 22:00
Copy link

@github-actionsgithub-actionsbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actionsgithub-actionsbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Code conforms to Python style guidelines

Comment on lines 178 to 192
ifstartisnotNoneandstart<0:
start=start+dim_size
ifendisnotNoneandend<0:
end=end+dim_size
ifstartisNone:
start=0
ifendisNone:
end=dim_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Consider switching to useget_positive_dim utility.


ifstart==0andend==dim_sizeandstep==0:
returninput_tensor
index_tensor=np.arange(start,end_dim,step)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Does this work withtorch.arange?

end=dim_size

src_dim=src_tensor.shape
step_dim=torch.floor_divide(end-start,step)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

(end - start) // step

ifstep_dim>src_dim[dim]:
end_dim=src_dim[dim]
else:
indices=torch.Tensor(np.arange(0,step_dim))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

torch.arange

unbind_source_tensors=torch.unbind(src,dim)
unbind_source_tensors_list=list(unbind_source_tensors)

fori,indexinenumerate(index_tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

range(start, end_dim, step) instead of index tensor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

See other comment

@apboseapboseforce-pushed theslice_scatter_decomposition branch fromafeba1e toa0b031fCompareJanuary 12, 2024 18:25

ifstart==0andend==dim_sizeandstep==0:
returninput_tensor
index_tensor=torch.arange(start,end_dim,step)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Is this tensor needed; could it be replaced withrange, as below?

unbind_source_tensors=torch.unbind(src,dim)
unbind_source_tensors_list=list(unbind_source_tensors)

fori,indexinenumerate(index_tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

See other comment

ifstep_dim>src_dim[dim]:
end_dim=src_dim[dim]
else:
indices=torch.Tensor(torch.arange(0,step_dim))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

torch.arange should already return a Tensor, so the cast should not be needed

indices=indices.to(torch.int32)
src=torch.index_select(src,dim,indices)

ifstart==0andend==dim_sizeandstep==0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Should this bestep == 1?

Copy link
CollaboratorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I think this should bestep == 0 sincestep == 1 would result in tensors being inserted in source tensor at step 1 interval.

end_dim=src_dim[dim]
else:
indices=torch.arange(0,step_dim)
indices=indices.to(torch.int32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

If the indices areint64, it is fine to leave them as-is and not change the data type, since later operators may expect or requireint64

Copy link
CollaboratorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Wouldn't this be required for the subsequenttorch.index_select ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Doestorch expectint64 for the indices inindex_select, or TensorRT? If it is TensorRT, then there is no need to perform the cast, because the outputs of the above operation will already have been handled in theTRTInterpreter

Copy link
CollaboratorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Yes torch would expect int64 input for the indices. Since this is a constant, I think TRTInterpretor should be able to handle it. Yes I will remove this.

@apboseapboseforce-pushed theslice_scatter_decomposition branch fromcec6a4e to8fb696eCompareFebruary 20, 2024 20:01
Copy link

@github-actionsgithub-actionsbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py2024-02-20 19:59:59.374321+00:00+++ /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py2024-02-20 20:01:49.660284+00:00@@ -1,10 +1,11 @@"""# Reference- [Very Deep Convolutional Networks for Large-Scale Image Recognition](    https://arxiv.org/abs/1409.1556) (ICLR 2015)"""+import torchimport torch.nn as nnimport torch.nn.functional as Ffrom functools import reduce--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py2024-02-20 19:59:59.382321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py2024-02-20 20:01:49.759276+00:00@@ -30,16 +30,18 @@        gpu_id (int): Device ID for target GPU        dla_core (int): Core ID for target DLA core        allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed    """-    device_type: Optional[-        trt.DeviceType-    ] = None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.+    device_type: Optional[trt.DeviceType] = (+        None  #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.+    )    gpu_id: int = -1  #: Device ID for target GPU    dla_core: int = -1  #: Core ID for target DLA core-    allow_gpu_fallback: bool = False  #: Whether falling back to GPU if DLA cannot support an op should be allowed+    allow_gpu_fallback: bool = (+        False  #: Whether falling back to GPU if DLA cannot support an op should be allowed+    )    def __init__(self, *args: Any, **kwargs: Any):        """__init__ Method for torch_tensorrt.Device        Device accepts one of a few construction patterns--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py2024-02-20 19:59:59.382321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py2024-02-20 20:01:49.959821+00:00@@ -26,16 +26,16 @@    class _ShapeMode(Enum):        STATIC = 0        DYNAMIC = 1-    shape_mode: Optional[-        _ShapeMode-    ] = None  #: Is input statically or dynamically shaped-    shape: Optional[-        Tuple[int, ...] | Dict[str, Tuple[int, ...]]-    ] = None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``+    shape_mode: Optional[_ShapeMode] = (+        None  #: Is input statically or dynamically shaped+    )+    shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (+        None  #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``+    )    dtype: _enums.dtype = (        _enums.dtype.unknown    )  #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)    _explicit_set_dtype: bool = False    format: _enums.TensorFormat = (--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py2024-02-20 19:59:59.382321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py2024-02-20 20:01:50.013227+00:00@@ -212,13 +212,13 @@        "precision": precision,        "debug": debug,        "device": device,        "workspace_size": workspace_size,        "min_block_size": min_block_size,-        "torch_executed_ops": torch_executed_ops-        if torch_executed_ops is not None-        else set(),+        "torch_executed_ops": (+            torch_executed_ops if torch_executed_ops is not None else set()+        ),        "pass_through_build_failures": pass_through_build_failures,        "max_aux_streams": max_aux_streams,        "version_compatible": version_compatible,        "optimization_level": optimization_level,        "use_python_runtime": use_python_runtime,--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py2024-02-20 19:59:59.382321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py2024-02-20 20:01:50.235895+00:00@@ -26,13 +26,13 @@from packaging import version_LOGGER: logging.Logger = logging.getLogger(__name__)-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[-    Callable[[torch.fx.GraphModule], None]-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")+)class UnsupportedOperatorException(RuntimeError):    pass@@ -90,13 +90,13 @@        self.input_specs_iter = 0        self._cur_node_name: Optional[str] = None        self._cur_node: Optional[torch.fx.Node] = None        self._input_names: List[str] = []        self._output_names: List[str] = []-        self._itensor_to_tensor_meta: Dict[-            trt.tensorrt.ITensor, TensorMetadata-        ] = dict()+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (+            dict()+        )        self.compilation_settings = compilation_settings        # Data types for TRT Module output Tensors        self.output_dtypes = output_dtypes--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py2024-02-20 19:59:59.382321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py2024-02-20 20:01:50.278485+00:00@@ -322,17 +322,15 @@    else:        raise AssertionError(f"Cannot convert {input_val} to TRT constant")@overload-def get_positive_dim(dim: int, dim_size: int) -> int:-    ...+def get_positive_dim(dim: int, dim_size: int) -> int: ...@overload-def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:-    ...+def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...def get_positive_dim(    dim: Union[int, Sequence[int]], dim_size: int) -> Union[int, Tuple[int, ...]]:--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py2024-02-20 19:59:59.386321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py2024-02-20 20:01:50.623768+00:00@@ -5,13 +5,13 @@from torch._decomp import get_decompositions as get_torch_decompositionsfrom torch._ops import OpOverload, OpOverloadPacketaten = torch.ops.aten-_core_aten_decompositions: Dict[-    OpOverload, Callable[[Any], Any]-] = core_aten_decompositions()+_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (+    core_aten_decompositions()+)torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {    aten._adaptive_avg_pool2d_backward,    aten.addcdiv,    aten.addcdiv_,    aten.addcmul,@@ -179,13 +179,13 @@torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {    aten._softmax.default,}-ENABLED_TORCH_DECOMPOSITIONS: Dict[-    OpOverload, Callable[[Any], Any]-] = get_torch_decompositions(torch_enabled_decompositions)+ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (+    get_torch_decompositions(torch_enabled_decompositions)+)TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}def check_decomp_set_invariants() -> None:    """Validates no overlap between enabled and disabled decomposition sets"""--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py2024-02-20 19:59:59.386321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py2024-02-20 20:01:50.628829+00:00@@ -20,16 +20,14 @@        logger.debug(f"Graph after lowering linear:\n{gm.graph}")    return gm-def linear_replacement() -> (-    Tuple[-        torch.fx.GraphModule,-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],-    ]-):+def linear_replacement() -> Tuple[+    torch.fx.GraphModule,+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],+]:    """Constructs the original and replacement functions for linear"""    # Original graph    def orig(        input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py2024-02-20 19:59:59.386321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py2024-02-20 20:01:50.665412+00:00@@ -20,16 +20,14 @@        logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")    return gm-def view_replacement() -> (-    Tuple[-        torch.fx.GraphModule,-        Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],-    ]-):+def view_replacement() -> Tuple[+    torch.fx.GraphModule,+    Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],+]:    """Constructs the original and replacement functions for view"""    # Original graph    def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:        return torch.ops.aten.view.default(input, shape)--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py2024-02-20 19:59:59.386321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py2024-02-20 20:01:50.681914+00:00@@ -58,16 +58,14 @@        logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}")    return gm-def scaled_dot_product_attention_replacement() -> (-    Tuple[-        Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],-        Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],-    ]-):+def scaled_dot_product_attention_replacement() -> Tuple[+    Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],+    Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],+]:    """Constructs the original and replacement functions for efficient attention"""    # Efficient Attention original graph    def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:        outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py2024-02-20 19:59:59.386321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py2024-02-20 20:01:50.959439+00:00@@ -99,25 +99,29 @@                self.engine.get_binding_dtype(idx), Frameworks.TORCH            )            for idx in self.output_binding_indices_in_order        ]        self.output_shapes = [-            tuple(self.engine.get_binding_shape(idx))-            if self.engine.has_implicit_batch_dimension-            else tuple()+            (+                tuple(self.engine.get_binding_shape(idx))+                if self.engine.has_implicit_batch_dimension+                else tuple()+            )            for idx in self.output_binding_indices_in_order        ]        self.hidden_output_dtypes = [            unified_dtype_converter(                self.engine.get_binding_dtype(idx), Frameworks.TORCH            )            for idx in self.hidden_output_binding_indices_in_order        ]        self.hidden_output_shapes = [-            tuple(self.engine.get_binding_shape(idx))-            if self.engine.has_implicit_batch_dimension-            else tuple()+            (+                tuple(self.engine.get_binding_shape(idx))+                if self.engine.has_implicit_batch_dimension+                else tuple()+            )            for idx in self.hidden_output_binding_indices_in_order        ]    def _check_initialized(self) -> None:        if not self.initialized:@@ -165,13 +169,15 @@        self.__dict__.update(state)        if self.engine:            self.context = self.engine.create_execution_context()    def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:-        with torch.autograd.profiler.record_function(-            "PythonTorchTensorRTModule:Forward"-        ) if self.profiling_enabled else nullcontext():+        with (+            torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")+            if self.profiling_enabled+            else nullcontext()+        ):            self._check_initialized()            # If in safe mode, check at each iteration for for whether a switch is required            if (                torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE@@ -198,13 +204,17 @@                    torch.cuda.set_device(device_id)                    inputs = tuple([tensor.to(device) for tensor in inputs])                    logger.warning(f"Moved all input Tensors to cuda:{device_id}")-            with torch.autograd.profiler.record_function(-                "PythonTorchTensorRTModule:ProcessInputs"-            ) if self.profiling_enabled else nullcontext():+            with (+                torch.autograd.profiler.record_function(+                    "PythonTorchTensorRTModule:ProcessInputs"+                )+                if self.profiling_enabled+                else nullcontext()+            ):                assert len(inputs) == len(                    self.input_names                ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."                contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]@@ -237,13 +247,17 @@                    self.context.set_binding_shape(                        idx, tuple(contiguous_inputs[i].shape)                    )-            with torch.autograd.profiler.record_function(-                "PythonTorchTensorRTModule:ProcessOutputs"-            ) if self.profiling_enabled else nullcontext():+            with (+                torch.autograd.profiler.record_function(+                    "PythonTorchTensorRTModule:ProcessOutputs"+                )+                if self.profiling_enabled+                else nullcontext()+            ):                # create output tensors                outputs: List[torch.Tensor] = []                for i, idx in enumerate(self.output_binding_indices_in_order):                    shape = tuple(self.context.get_binding_shape(idx))@@ -264,13 +278,17 @@                        dtype=self.hidden_output_dtypes[i],                        device=torch.cuda.current_device(),                    )                    bindings[idx] = output.data_ptr()-            with torch.autograd.profiler.record_function(-                "PythonTorchTensorRTModule:TensorRTRuntime"-            ) if self.profiling_enabled else nullcontext():+            with (+                torch.autograd.profiler.record_function(+                    "PythonTorchTensorRTModule:TensorRTRuntime"+                )+                if self.profiling_enabled+                else nullcontext()+            ):                self.context.execute_async_v2(                    bindings, torch.cuda.current_stream().cuda_stream                )            if len(outputs) == 1:--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py2024-02-20 19:59:59.390321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py2024-02-20 20:01:51.233651+00:00@@ -315,25 +315,21 @@    name: str,) -> Union[TRTTensor, Sequence[TRTTensor]]:    kwargs_new = {        "input": args[0],        "kernel_size": args[1],-        "stride": args[2]-        if len(args) > 2-        else (None, None)-        if len(args[1]) == 2-        else (None, None, None),-        "padding": args[3]-        if len(args) > 3-        else (0, 0)-        if len(args[1]) == 2-        else (0, 0, 0),-        "dilation": args[4]-        if len(args) > 4-        else (1, 1)-        if len(args[1]) == 2-        else (1, 1, 1),+        "stride": (+            args[2]+            if len(args) > 2+            else (None, None) if len(args[1]) == 2 else (None, None, None)+        ),+        "padding": (+            args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)+        ),+        "dilation": (+            args[4] if len(args) > 4 else (1, 1) if len(args[1]) == 2 else (1, 1, 1)+        ),        "ceil_mode": args[5] if len(args) > 5 else False,    }    return acc_ops_converters.acc_ops_max_poolnd(        network, target, None, kwargs_new, name    )--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py2024-02-20 19:59:59.390321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py2024-02-20 20:01:51.283354+00:00@@ -124,25 +124,29 @@        interpreter = TRTInterpreter(            mod,            input_specs=self.lower_setting.input_specs,            explicit_batch_dimension=self.lower_setting.explicit_batch_dimension,            explicit_precision=self.lower_setting.explicit_precision,-            logger_level=trt.Logger.VERBOSE-            if self.lower_setting.verbose_log-            else trt.Logger.WARNING,+            logger_level=(+                trt.Logger.VERBOSE+                if self.lower_setting.verbose_log+                else trt.Logger.WARNING+            ),        )        interp_result: TRTInterpreterResult = interpreter.run(            max_batch_size=self.lower_setting.max_batch_size,            max_workspace_size=self.lower_setting.max_workspace_size,            lower_precision=self.lower_setting.lower_precision,            strict_type_constraints=self.lower_setting.strict_type_constraints,            algorithm_selector=algo_selector,            timing_cache=cache_data,-            profiling_verbosity=trt.ProfilingVerbosity.DETAILED-            if self.lower_setting.verbose_profile-            else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,+            profiling_verbosity=(+                trt.ProfilingVerbosity.DETAILED+                if self.lower_setting.verbose_profile+                else trt.ProfilingVerbosity.LAYER_NAMES_ONLY+            ),            tactic_sources=self.lower_setting.tactic_sources,        )        # Update timing cache file if needed        timing_cache = interp_result.serialized_cache@@ -295,14 +299,12 @@                module.half()                # A custom conversion function can be passed to the lowerer to                # handle inputs with custom types. By default, just handle                # tensors and NoneType.                if fp16_conversion_fn is None:-                    conversion_fn = (-                        lambda x: x.half()-                        if x is not None and x.dtype == torch.float32-                        else x+                    conversion_fn = lambda x: (+                        x.half() if x is not None and x.dtype == torch.float32 else x                    )                else:                    conversion_fn = fp16_conversion_fn                inputs = tuple(conversion_fn(x) for x in inputs)--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py2024-02-20 19:59:59.390321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py2024-02-20 20:01:51.328023+00:00@@ -19,13 +19,13 @@from .observer import Observerfrom .utils import get_dynamic_dims, LowerPrecision, unified_dtype_converter, Frameworks_LOGGER: logging.Logger = logging.getLogger(__name__)-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[-    Callable[[torch.fx.GraphModule], None]-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")+TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (+    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")+)class TRTInterpreterResult(NamedTuple):    engine: Any    input_names: Sequence[str]@@ -73,13 +73,13 @@        self.input_specs_iter = 0        self.validate_input_specs()        self._cur_node_name: Optional[str] = None        self._input_names: List[str] = []        self._output_names: List[str] = []-        self._itensor_to_tensor_meta: Dict[-            trt.tensorrt.ITensor, TensorMetadata-        ] = dict()+        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (+            dict()+        )    def validate_input_specs(self):        for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:            if not self.network.has_implicit_batch_dimension:                assert (--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py2024-02-20 19:59:59.390321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py2024-02-20 20:01:51.545029+00:00@@ -194,13 +194,15 @@                    lowering_start_time = datetime.datetime.now()                    self.lower_setting.input_specs = generate_input_specs(                        submod_inputs,                        self.lower_setting,-                        additional_submodule_inputs[submod_name]-                        if additional_submodule_inputs-                        else None,+                        (+                            additional_submodule_inputs[submod_name]+                            if additional_submodule_inputs+                            else None+                        ),                    )                    lowered_module = self._lower_func(                        submod, submod_inputs, self.lower_setting, submod_name                    )                    setattr(split_result.split_module, submod_name, lowered_module)@@ -234,13 +236,15 @@                if not submod_name.startswith(split_result.non_acc_submodule_prefix):                    _LOGGER.info(f"ACC submodule graph: {submod.graph}")                    lowering_start_time = datetime.datetime.now()                    self.lower_setting.additional_inputs = (-                        additional_submodule_inputs[submod_name]-                        if additional_submodule_inputs-                        else None,+                        (+                            additional_submodule_inputs[submod_name]+                            if additional_submodule_inputs+                            else None+                        ),                    )                    lowered_module = self._lower_func(                        submod, submod_inputs, self.lower_setting, submod_name                    )--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py2024-02-20 19:59:59.390321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py2024-02-20 20:01:51.722875+00:00@@ -193,13 +193,11 @@                kwargs2 = {"equal_nan": True}                if rtol:                    kwargs2["rtol"] = rtol                if atol:                    kwargs2["atol"] = atol-                kwargs2[-                    "msg"-                ] = (+                kwargs2["msg"] = (                    lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"                )                # If tensors are on different devices, make sure to compare                # their copies that are on the same device.                if x.get_device() != y.get_device():--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py2024-02-20 19:59:59.390321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py2024-02-20 20:01:51.782883+00:00@@ -536,13 +536,13 @@        reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node(            maybe_reshape        )        if not reshape_batch_size:            continue-        reshape_batch_size_inferred_source: Optional[-            fx.Node-        ] = get_reshape_batch_size_inferred_source(reshape_batch_size)+        reshape_batch_size_inferred_source: Optional[fx.Node] = (+            get_reshape_batch_size_inferred_source(reshape_batch_size)+        )        if not reshape_batch_size_inferred_source:            continue        reshape_input: fx.Node = maybe_reshape.kwargs["input"]        if reshape_input == reshape_batch_size_inferred_source:--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py2024-02-20 19:59:59.394321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py2024-02-20 20:01:52.206806+00:00@@ -21,13 +21,15 @@        inputs = [torch.randn(1, 10)]        self.run_test(            Split(),            inputs,            expected_ops={-                acc_ops.split-                if isinstance(split_size_or_sections, int)-                else acc_ops.slice_tensor+                (+                    acc_ops.split+                    if isinstance(split_size_or_sections, int)+                    else acc_ops.slice_tensor+                )            },            test_explicit_batch_dim=False,        )    @parameterized.expand(@@ -68,13 +70,15 @@        ]        self.run_test_with_dynamic_shape(            Split(),            input_specs,            expected_ops={-                acc_ops.split-                if isinstance(split_size_or_sections, int)-                else acc_ops.slice_tensor+                (+                    acc_ops.split+                    if isinstance(split_size_or_sections, int)+                    else acc_ops.slice_tensor+                )            },        )    # Testing with (-1, -1, -1) results into following error:    # AssertionError: Can't chunk on dynamic shape dimension!--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py2024-02-20 19:59:59.394321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py2024-02-20 20:01:52.903172+00:00@@ -152,13 +152,13 @@            mod.eval()            if len(expected_ops):                self.assert_has_op(mod, expected_ops)            interpreter_result = interpreter.run(-                lower_precision=LowerPrecision.FP16-                if fp16_mode-                else LowerPrecision.FP32+                lower_precision=(+                    LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32+                )            )            trt_mod = TRTModule(                interpreter_result.engine,                interpreter_result.input_names,                interpreter_result.output_names,--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py2024-02-20 19:59:59.398321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py2024-02-20 20:01:53.269384+00:00@@ -67,25 +67,29 @@                self.engine.get_binding_dtype(idx), Frameworks.TORCH            )            for idx in self.output_binding_indices_in_order        ]        self.output_shapes = [-            tuple(self.engine.get_binding_shape(idx))-            if self.engine.has_implicit_batch_dimension-            else tuple()+            (+                tuple(self.engine.get_binding_shape(idx))+                if self.engine.has_implicit_batch_dimension+                else tuple()+            )            for idx in self.output_binding_indices_in_order        ]        self.hidden_output_dtypes: Sequence[torch.dtype] = [            unified_dtype_converter(                self.engine.get_binding_dtype(idx), Frameworks.TORCH            )            for idx in self.hidden_output_binding_indices_in_order        ]        self.hidden_output_shapes = [-            tuple(self.engine.get_binding_shape(idx))-            if self.engine.has_implicit_batch_dimension-            else tuple()+            (+                tuple(self.engine.get_binding_shape(idx))+                if self.engine.has_implicit_batch_dimension+                else tuple()+            )            for idx in self.hidden_output_binding_indices_in_order        ]    def _check_initialized(self):        if not self.initialized:--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py2024-02-20 19:59:59.398321+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py2024-02-20 20:01:53.546949+00:00@@ -404,13 +404,13 @@        "inputs": inputs if inputs is not None else [],        # "input_signature": input_signature,        "device": device,        "disable_tf32": disable_tf32,  # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas        "sparse_weights": sparse_weights,  # Enable sparsity for convolution and fully connected layers.-        "enabled_precisions": enabled_precisions-        if enabled_precisions is not None-        else set(),  # Enabling FP16 kernels+        "enabled_precisions": (+            enabled_precisions if enabled_precisions is not None else set()+        ),  # Enabling FP16 kernels        "refit": refit,  # enable refit        "debug": debug,  # enable debuggable engine        "capability": capability,  # Restrict kernel selection to safe gpu kernels or safe dla kernels        "num_avg_timing_iters": num_avg_timing_iters,  # Number of averaging timing iterations used to select kernels        "workspace_size": workspace_size,  # Maximum size of workspace given to TensorRT

@apbose
Copy link
CollaboratorAuthor

apbose commentedFeb 20, 2024
edited
Loading

Monitoring the CI to see if this error comes in the test-

torch._dynamo.exc.BackendCompilerFailed: backend='functools.partial(<function fx_dynamo_testing_backend at 0x7f19514f7af0>, store_intermediate_graphs=[], min_block_size=1, torch_executed_ops=set(), use_fast_partitioner=True)' raised:RuntimeError: Attempted to set the storage of a tensor on device "meta" to a storage on different device "cpu".  This is no longer allowed; the devices must match.

Copy link

@github-actionsgithub-actionsbot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decompositions.py2024-02-27 08:54:58.869787+00:00+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decompositions.py2024-02-27 08:56:47.352375+00:00@@ -187,11 +187,11 @@    step_dim = (end - start) // step    end_dim = end    if step_dim > src_dim[dim]:        end_dim = src_dim[dim]    else:-        #In this case src first step_dim need to be selected+        # In this case src first step_dim need to be selected        indices = torch.Tensor(torch.arange(0, step_dim))        indices = indices.to(torch.int32)        src = torch.index_select(src_tensor, dim, indices)    if start == 0 and end == dim_size and step == 0:

@apboseapbose mentioned this pull requestFeb 27, 2024
@apboseapboseforce-pushed theslice_scatter_decomposition branch from13bbdab tof7e0642CompareMarch 8, 2024 01:00
@apboseapbose requested a review fromgs-oliveMarch 18, 2024 23:47
@apboseapboseforce-pushed theslice_scatter_decomposition branch 5 times, most recently fromdf7d401 to1bd061bCompareMarch 19, 2024 00:21
@apboseapbose mentioned this pull requestMar 19, 2024
Comment on lines 177 to 213
dim_size=input_tensor.shape[dim]
start=get_positive_dim(start,input_tensor.shape[dim])
ifendisNone:
end=dim_size
end=get_positive_dim(end,input_tensor.shape[dim])
ifstepisNone:
step=1

src_dim=src_tensor.shape
# step == 0 is not a valid torch case
# also src_dim should be equal to slice dimension

ifstart==0andend==dim_sizeandstep==1:
returnsrc_tensor

cat_tensors= []
index_tensor_shape= []
fori,src_each_diminenumerate(list(src_dim)):
ifi!=dim:
index_tensor_shape.append(src_each_dim)
forindexinrange(start,end,step):
cat_tensors.append(index*torch.ones(index_tensor_shape))
index_tensor=torch.stack(cat_tensors,dim)
index_tensor=index_tensor.to(torch.int64).cuda()
output_tensor=torch.scatter(input_tensor,dim,index_tensor,src)
returnoutput_tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Could this potentially be simplified to avoidfor-loops usingtorch.arange? For instance, see thisimplementation

Copy link
CollaboratorAuthor

@apboseapboseApr 3, 2024
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Hi@gs-olive I tried the above implementation.
I am not sure howget_expanded_index works, but I think it will be difficult to achieve the above behavior without for loops.
I tried two alternate

indices = torch.arange(start,stop, step)cat_tensors = torch.unsqueeze(indices,1) * torch.ones(index_tensor_shape)).split(1, dim = 0)#orcat_tensors = indices(:, None) * torch.ones(index_tensor_shape)).split(1, dim = 0)

The thing is we need to unsqueeze indices n no of times, where n is the dimension of index_tensor_shape. While the above would work for cases

input = torch.ones(8,8)src = torch.ones(8,2)out  = torch.slice_scatter(input, src, 1, 6, 8, 1)

or

input = torch.ones(8,8)src = torch.ones(8,1)out  = torch.slice_scatter(input, src, 1, 6, 7, 1)

it would start failing for input and src with sizestorch.ones(8,8,8) andtorch.zeros(8,2,8) ortorch.zeros(8,1,8)respectively. We would have to unsqueeze n no of times, eg: torch.unsqueeze(indices,1,1) or indices[:,None,None] would work, but then that would again be a for loop.

I cannot think of another way on top of my mind, if you have any suggestion you could let me know.
For now the test cases pass with for loop so I have reverted back to that,

gs-olive reacted with thumbs up emoji
@apboseapboseforce-pushed theslice_scatter_decomposition branch 2 times, most recently from8c37797 to498ff5eCompareMarch 26, 2024 20:40
@apboseapbose mentioned this pull requestApr 2, 2024
Copy link
Contributor

@gs-olivegs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Overall looks good - added a few comments/questions

forindexinrange(start,end,step):
cat_tensors.append(index*torch.ones(index_tensor_shape))
index_tensor=torch.stack(cat_tensors,dim)
index_tensor=index_tensor.to(torch.int64).cuda()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

This will cause a graph break if it inserts a cast in the graph representation, since TRT cannot support Int64 casts. What is the resultant output graph in this case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

This operation might be avoidable by specifyingdtype=torch.long in thetorch.ones(...) call, though if the index tensor is a constant and not anITensor, it may not be necessary.

Copy link
CollaboratorAuthor

@apboseapboseApr 4, 2024
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

@gs-olive thetorch.long was present since otherwise torch would have complained that torch requiresint64 input for torch.scatter in thisline.

The casetorch.slice_scatter(torch.zeros(8,8), torch.ones(8,2), 1, 6, None, 1) leads to this with the cast toindex_tensor = index_tensor.to(torch.int64).cuda() -
Pre-AOT Autograd graph:=============

graph():    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]    %l_src_ : torch.Tensor [num_users=1] = placeholder[target=L_src_]    %clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_src_,), kwargs = {})    %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})    %slice_scatter : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter](args = (%clone_default, %clone_default_1, 1, 6, None, 1), kwargs = {})    return (slice_scatter,)

Post AOT Autograd graph:=============

graph():    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg1_1,), kwargs = {})    %clone_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {})    %empty_strided : [num_users=1] = call_function[target=torch.ops.aten.empty_strided.default](args = ([8], [1]), kwargs = {dtype: to$ch.int64, layout: torch.strided, device: cpu, pin_memory: False})    %full_like : [num_users=1] = call_function[target=torch.ops.aten.full_like.default](args = (%empty_strided, 1), kwargs = {pin_memo$y: False})    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%full_like, 6), kwargs = {})    %empty_strided_1 : [num_users=1] = call_function[target=torch.ops.aten.empty_strided.default](args = ([8], [1]), kwargs = {dtype: torch.int64, layout: torch.strided, device: cpu, pin_memory: False})    %full_like_1 : [num_users=1] = call_function[target=torch.ops.aten.full_like.default](args = (%empty_strided_1, 1), kwargs = {pin_memory: False})    %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%full_like_1, 7), kwargs = {})    %unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%mul, 1), kwargs = {})    %unsqueeze_1 : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%mul_1, 1), kwargs = {})    %cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%unsqueeze, %unsqueeze_1], 1), kwargs = {})    %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%cat,), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0})    %scatter : [num_users=1] = call_function[target=torch.ops.aten.scatter.src](args = (%clone_1, 1, %_to_copy, %clone), kwargs = {})    return (scatter,)

Post lowering Autograd graph:=============

graph():    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]    %scatter : [num_users=1] = call_function[target=torch.ops.aten.scatter.src](args = (%arg0_1, 1, %_frozen_param0, %arg1_1), kwargs = {})    return (scatter,)

As mentioned by you since it is a frozen param and a constant, there are no graph breaks and not necessary. Not sure if this would be the case always though.
Hence I changed it totorch.ones() with dtypetorch.long as suggested.
A side general question- Would the graph break lead to significant performance impact? That is the reason we should avoid them as far as possible?

}
unexpected_ops= {torch.ops.aten.select_scatter}

inputs= [torch.zeros(8,8).cuda(),torch.ones(8,2).cuda(),1,6,None,1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Could this case be modified to be 3D, as in your comment above.

Copy link
CollaboratorAuthor

@apboseapboseApr 4, 2024
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I kept the old test case and added another with the 3D.

@apboseapboseforce-pushed theslice_scatter_decomposition branch 2 times, most recently from9bc7d6c to2b101ddCompareMay 30, 2024 16:52
changing decomposition patternslice scatter changesReview comments addressRemoving arange and replacing with rangeslice_scatter adding to decomposition groupusing aten::scatter in aten.slice_scatterCorrecting the slice_scatter case with aten::scatter useremoving unnecessary cases from slice_scatter impl and adding test casechanging for loop to torch.arangeReverting back the torch.arange to for loopAdding test case for 3d cases and removing the casting to torch.int64 and including it torch.onesRemoving aten.index in the decomposition ops
@apboseapbose merged commit6152607 intomainMay 30, 2024
Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

Reviewers

@github-actionsgithub-actions[bot]github-actions[bot] requested changes

@gs-olivegs-oliveAwaiting requested review from gs-olive

Assignees

No one assigned

Labels

cla signedcomponent: api [Python]Issues re: Python APIcomponent: dynamoIssues relating to the `torch.compile` or `torch._dynamo.export` pathscomponent: loweringIssues re: The lowering / preprocessing passescomponent: testsIssues re: Tests

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

Add support foraten.slice_scatter

5 participants

@apbose@gs-olive@narendasan@facebook-github-bot

[8]ページ先頭

©2009-2025 Movatter.jp