Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit6ddf5cf

Browse files
desertfirepytorchmergebot
authored andcommitted
[AOTI] Update cpp wrapper codegen to use v2 C shim (#120714)
Summary: To use the torchgen-ed v2 C shim interface, cpp wrapper codegen needs to update its rule for generating the right parameter and function call. Because changing the emitted code will cause a FC breakage, we add a flag to control the behavior.Differential Revision: [D54258086](https://our.internmc.facebook.com/intern/diff/D54258086)Pull Requestresolved:#120714Approved by:https://github.com/chenyang78ghstack dependencies:#120513
1 parentbd19d6d commit6ddf5cf

File tree

4 files changed

+45
-8
lines changed

4 files changed

+45
-8
lines changed

‎torch/_inductor/codegen/cpp_wrapper_cpu.py‎

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ class CppWrapperCpu(WrapperCodeGen):
2424
"""
2525

2626
def__init__(self):
27+
ifnothasattr(self,"device"):
28+
self.device="cpu"
2729
super().__init__()
28-
2930
self.declare="auto "
3031
self.declare_maybe_reference="decltype(auto) "
3132
self.ending=";"
@@ -149,7 +150,12 @@ def write_header(self):
149150
)
150151

151152
ifconfig.abi_compatible:
152-
self.header.splice("#include <torch/csrc/inductor/aoti_torch/c/shim.h>")
153+
ifconfig.c_shim_version=="1":
154+
self.header.splice("#include <torch/csrc/inductor/aoti_torch/c/shim.h>")
155+
else:
156+
self.header.splice(
157+
f"#include <torch/csrc/inductor/aoti_torch/generated/c_shim_{self.device}.h>"
158+
)
153159
else:
154160
ifnotV.graph.aot_mode:
155161
self.header.splice("#include <pybind11/pybind11.h>")
@@ -924,7 +930,11 @@ def generate_c_shim_extern_kernel_call(self, kernel, args):
924930
kernel_suffix=kernel_tokens[-1]
925931
ifkernel_suffix=="call":
926932
kernel_suffix=kernel_tokens[-2]
927-
shim_fn=f"aoti_torch_{kernel_suffix}"
933+
ifconfig.c_shim_version=="1":
934+
shim_fn=f"aoti_torch_{kernel_suffix}"
935+
else:
936+
shim_fn=f"aoti_torch_{self.device}_{kernel_suffix}"
937+
928938
# HACK: val_to_arg_str jams multiple arguments together using a comma. If that
929939
# ever breaks, it needs to be reworked to be able to return multiple arguments,
930940
# and the split-on-comma code here needs to be removed.
@@ -1676,12 +1686,24 @@ def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str:
16761686
):
16771687
ifvalisNone:
16781688
return"0"# nullptr is not available in C
1679-
ifisinstance(val, (bool,int,str,float)):
1689+
ifnotisinstance(type_.getElementType(),torch.TensorType):
16801690
var_name=f"var_{next(self.arg_var_id)}"
16811691
self.writeline(f"auto{var_name} ={self.val_to_arg_str(val)};")
16821692
returnf"&{var_name}"
1683-
ifnotisinstance(type_.getElementType(),torch.TensorType):
1684-
returnf"&{self.val_to_arg_str(val)}"
1693+
elifconfig.c_shim_version=="2":
1694+
# Similar to other data type, use pointer to denote optional tensor arg in v2 C shim
1695+
base_handle=self.val_to_arg_str(val)
1696+
if"wrap_with_raii_handle_if_needed"inbase_handle:
1697+
# wrap_with_raii_handle_if_needed creates a temp RAIIAtenTensorHandle, so we need to
1698+
# explicitly store it. Otherwise, it will be destroyed before the fallback kernel call.
1699+
tmp_var_name=f"var_{next(self.arg_var_id)}"
1700+
self.writeline(
1701+
f"RAIIAtenTensorHandle{tmp_var_name} ={base_handle};"
1702+
)
1703+
base_handle=tmp_var_name
1704+
var_name=f"var_{next(self.arg_var_id)}"
1705+
self.writeline(f"AtenTensorHandle{var_name} ={base_handle}.get();")
1706+
returnf"&{var_name}"
16851707

16861708
returnself.val_to_arg_str(val)
16871709

‎torch/_inductor/codegen/cpp_wrapper_cuda.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class CppWrapperCuda(CppWrapperCpu):
4343
"""
4444

4545
def__init__(self):
46+
self.device="cuda"
4647
super().__init__()
4748
self.grid_id=count()
4849
self.cuda=True

‎torch/_inductor/config.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ def enable_autotune_remote_cache():
4141
os.environ.get("TORCHINDUCTOR_ABI_COMPATIBLE","1"ifis_fbcode()else"0")=="1"
4242
)
4343

44+
c_shim_version=os.environ.get(
45+
"TORCHINDUCTOR_C_SHIM_VERSION","1"ifis_fbcode()else"2"
46+
)
47+
4448
# dead code elimination
4549
dce=False
4650

‎torch/_inductor/ir.py‎

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4855,7 +4855,10 @@ def is_not_write(arg):
48554855
self.init_args_default_value(kernel._schema)
48564856

48574857
defis_legacy_abi_kernel(self):
4858-
return"_scaled_dot_product_flash_attention"instr(self.python_kernel_name)
4858+
return (
4859+
config.c_shim_version=="1"
4860+
and"_scaled_dot_product_flash_attention"instr(self.python_kernel_name)
4861+
)
48594862

48604863
definit_args_default_value(self,schema):
48614864
self.args_default_value= [
@@ -4908,6 +4911,7 @@ def __repr__(self):
49084911
self.abi_compatible_kernel= (
49094912
f"{self.cpp_kernel_name}_v2"
49104913
ifself.cpp_kernel_namein {"at::_scaled_dot_product_flash_attention"}
4914+
andconfig.c_shim_version=="1"
49114915
elseself.cpp_kernel_name
49124916
)
49134917

@@ -5065,7 +5069,13 @@ def codegen(self, wrapper):
50655069
# Aten Fallback Ops
50665070
assertisinstance(kernel,torch._ops.OpOverload)
50675071
ifV.graph.cpp_wrapper:
5068-
ifconfig.is_fbcode()andkernelnotinhas_c_shim:
5072+
if (
5073+
config.is_fbcode()
5074+
andkernelnotinhas_c_shim
5075+
# C shim v2 is torchgen-ed, which should cover all aten ops.
5076+
# If you do hit a missed op, please update gen_aoti_c_shim.py.
5077+
andconfig.c_shim_version=="1"
5078+
):
50695079
log.warning(
50705080
"%s is missing a c-shim implementation, using proxy executor as fallback",
50715081
kernel,

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp