- Notifications
You must be signed in to change notification settings - Fork25.8k
Commitd0d6b1f
[torchgen] Generate out variant for functional operator (#81437)
Summary:Previously we don't generate out variant (both schema and kernel) for an operator with functional variant only. This adds support for that and adds test.## Changes on `native_function_generation.py`We are generating out variant for all functional variants if possible. This PR introduces a lot of newly generated out variants and `native_functions.yaml` needs to incorporate the changes by adding `autogen` keywords.The logic for determining what operators we should generate an out variant for is the following:1. No existing out variant for this `NativeFunction`2. Contains an existing in place, mutable or functional variant3. Contains at least 1 tensor like return(s)For operators matching the first two conditions but failing the third, I listed them in `FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT`.## Special handlingThe following operators satisfy all 3 criteria above but we chose to not autogen them, with some reasons.* `mkldnn_adaptive_avg_pool2d`, the generated out variant `mkldnn_adaptive_avg_pool2d.out` is colliding with the `mkldnn_adaptive_avg_pool2d_out` kernel in `adaptive_avg_pool2d.out` operator. I manually created `mkldnn_adaptive_avg_pool2d.out` and renamed `mkldnn_adaptive_avg_pool2d_out` to `mkldnn_adaptive_avg_pool2d_out_stub`.* `min`, `max` and `mean`. There already exist `min.out`, `max.out` and `mean.out` but they are having different semantics with the functional ones. I manually created `min.unary_out`, `max.unary_out` and `mean.dtype_out` to disambiguate.## Autograd ChangesWe introduced a logic to not match derivatives info in `derivatives.yaml` to out variant, since we are generating `NOT_IMPLEMENTED` kernels for those out variants anyway. The issue we are seeing with the original logic is that it doesn't handle `TensorOption` arguments really well. For example we have these two operators:* `_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor`* `_to_copy.out(Tensor self, *, bool non_blocking=False, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)`If we uses `_to_copy` derivative info, there will be compilation error since `dtype` is missing from `_to_copy.out` signature.Test Plan: Rely on unit testDifferential Revision: D37832342Pull Requestresolved:#81437Approved by:https://github.com/iseeyuan,https://github.com/bdhirsh1 parentbb1e3d8 commitd0d6b1f
File tree
11 files changed
+647
-56
lines changed- aten/src/ATen/native
- mkldnn
- tools/test
- torchgen
- api
- torch/_subclasses
11 files changed
+647
-56
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
142 | 142 | | |
143 | 143 | | |
144 | 144 | | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
145 | 149 | | |
146 | 150 | | |
147 | 151 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
| 2 | + | |
2 | 3 | | |
3 | 4 | | |
4 | 5 | | |
| |||
17 | 18 | | |
18 | 19 | | |
19 | 20 | | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
20 | 28 | | |
21 | 29 | | |
22 | 30 | | |
| |||
25 | 33 | | |
26 | 34 | | |
27 | 35 | | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
28 | 43 | | |
29 | 44 | | |
30 | 45 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
101 | 101 | | |
102 | 102 | | |
103 | 103 | | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
104 | 108 | | |
105 | | - | |
| 109 | + | |
106 | 110 | | |
107 | 111 | | |
108 | 112 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2 | 2 | | |
3 | 3 | | |
4 | 4 | | |
| 5 | + | |
5 | 6 | | |
6 | 7 | | |
7 | 8 | | |
| |||
80 | 81 | | |
81 | 82 | | |
82 | 83 | | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
83 | 90 | | |
84 | 91 | | |
85 | 92 | | |
| |||
498 | 505 | | |
499 | 506 | | |
500 | 507 | | |
| 508 | + | |
| 509 | + | |
| 510 | + | |
| 511 | + | |
| 512 | + | |
| 513 | + | |
501 | 514 | | |
502 | 515 | | |
503 | 516 | | |
504 | | - | |
| 517 | + | |
| 518 | + | |
| 519 | + | |
| 520 | + | |
505 | 521 | | |
506 | 522 | | |
507 | 523 | | |
| |||
0 commit comments
Comments
(0)