Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit863d0eb

Browse files
robietapytorchmergebot
authored andcommitted
Optimize Triton template heuristics (#170444)
Summary:This diff contains three small optimizations:1) Directly cache the triton Config object import. Not a huge win, but measurably faster than relying on importlib's cache.2) Only copy configs when the new value is different from the old one. Configs are fairly large objects, so unneccesary dict copies get expensive.3) Replace `gcd(k, BLOCK_K) == BLOCK_K` with `(k % BLOCK_K) == 0`. This is equivalent when `BLOCK_K > 0`, which must be true.Test Plan:```tlp buck run mode/opt //scripts/paulzhan:repro```and then looking at perfetto.Differential Revision: D88415189Pull Requestresolved:#170444Approved by:https://github.com/PaulZhang12,https://github.com/eellison,https://github.com/shunting314
1 parentc047f39 commit863d0eb

File tree

2 files changed

+48
-23
lines changed

2 files changed

+48
-23
lines changed

‎test/inductor/test_max_autotune.py‎

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2283,6 +2283,25 @@ def mm_plus_mm_func(a1, b1, a2, b2) -> torch.Tensor:
22832283
_,code_out=run_and_get_code(c_f,*args)
22842284
FileCheck().check(output_code_padding_check).run(code_out[0])
22852285

2286+
@parametrize("k", (15,16))
2287+
@parametrize("dynamic", (False,True))
2288+
deftest_even_k(self,k:int,dynamic:bool):
2289+
M,N=21,31
2290+
a=torch.randn((M,k),dtype=torch.float16,device=GPU_TYPE)
2291+
b=torch.randn((k,N),dtype=torch.float16,device=GPU_TYPE)
2292+
2293+
ifdynamic:
2294+
torch._dynamo.mark_dynamic(a,1)
2295+
torch._dynamo.mark_dynamic(b,0)
2296+
2297+
withconfig.patch({"max_autotune":True}),fresh_cache():
2298+
_=torch.compile(torch.mm)(a,b)
2299+
cache=TritonTemplate.all_templates["mm"]._generated_code_cache._cache
2300+
cache_key=next(iter(cache))
2301+
2302+
self.assertObjectIn(k, (15,16))
2303+
self.assertEqual("'EVEN_K': True"incache_key,k==16andnotdynamic)
2304+
22862305

22872306
classTestMaxAutotunePrecompile(TestCase):
22882307
deftest_precompilation_threads(self):

‎torch/_inductor/template_heuristics/triton.py‎

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545

4646
fromtritonimportConfigasTritonConfig
4747

48+
else:
49+
fromtorch._inductor.runtime.triton_compatimportConfigasTritonConfig
50+
4851

4952
# Gemm Configs
5053
@dataclasses.dataclass
@@ -625,18 +628,26 @@ def _scale_mm_configs(
625628
)
626629

627630
forcinconfigs:
628-
scaled_config=dataclasses.replace(
629-
c,
630-
block_m=max(min(int(c.block_m*scale),m_hint),min_block_size),
631-
block_n=max(min(int(c.block_n*scale),n_hint),min_block_size),
632-
block_k=max(min(int(c.block_k*scale),k_hint),min_block_size_k),
633-
hint_override=hint_override,
634-
)
635-
636-
ifnotexclude(
637-
scaled_config.block_m,scaled_config.block_n,scaled_config.block_k
638-
):
639-
scaled_configs.append(scaled_config)
631+
block_m=max(min(int(c.block_m*scale),m_hint),min_block_size)
632+
block_n=max(min(int(c.block_n*scale),n_hint),min_block_size)
633+
block_k=max(min(int(c.block_k*scale),k_hint),min_block_size_k)
634+
ifnotexclude(block_m,block_n,block_k):
635+
# This copy is expensive, so avoid it if we can.
636+
if (block_m,block_n,block_k,hint_override)!= (
637+
c.block_m,
638+
c.block_n,
639+
c.block_k,
640+
c.hint_override,
641+
):
642+
c=dataclasses.replace(
643+
c,
644+
block_m=block_m,
645+
block_n=block_n,
646+
block_k=block_k,
647+
hint_override=hint_override,
648+
)
649+
650+
scaled_configs.append(c)
640651

641652
returnscaled_configs
642653

@@ -753,8 +764,6 @@ def preprocess_mm_configs(
753764
deftriton_config(
754765
self,num_stages:int,num_warps:int,**kwargs:Any
755766
)->TritonConfig:
756-
fromtritonimportConfigasTritonConfig# type: ignore[attr-defined]
757-
758767
returnTritonConfig(kwargs,num_stages=num_stages,num_warps=num_warps)
759768

760769
defget_mm_configs(self)->partial[Generator[TritonConfig,None,None]]:
@@ -1667,21 +1676,18 @@ def _get_template_configs_impl(
16671676
def_convert_config_to_template_kwargs(
16681677
self,
16691678
triton_config:TritonConfig,
1670-
m:sympy.Integer,
1671-
n:sympy.Integer,
1672-
k:sympy.Integer,
1679+
m:sympy.Integer|sympy.Symbol,
1680+
n:sympy.Integer|sympy.Symbol,
1681+
k:sympy.Integer|sympy.Symbol,
16731682
out_dtype:torch.dtype,
16741683
)->dict[str,Any]:
16751684
"""
16761685
Convert triton config to template kwargs.
16771686
Moved from mm_common.mm_options.
16781687
"""
1679-
# Calculate EVEN_K symbolic
1680-
even_k_symbolic= (
1681-
# it isn't worth guarding on this
1682-
sympy.gcd(k,triton_config.kwargs["BLOCK_K"])
1683-
==triton_config.kwargs["BLOCK_K"]
1684-
)
1688+
# Calculate EVEN_K symbolic. (It isn't worth guarding on this)
1689+
even_k_symbolic= (k%triton_config.kwargs["BLOCK_K"])==0
1690+
even_k_symbolic=V.graph.sizevars.statically_known_true(even_k_symbolic)
16851691

16861692
# Build options dict
16871693

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp