Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

A Python-embedded DSL that makes it easy to write fast, scalable ML kernels with minimal boilerplate.

License

NotificationsYou must be signed in to change notification settings

pytorch/helion

Helion Logo

Events

About

📚View Documentation 📚 |🎥Watch Talk 🎥 |🚀Try In Colab 🚀 |Try In AMD DevCloud

Helion is a Python-embedded domain-specific language (DSL) forauthoring machine learning kernels, designed to compile down toTriton,a performant backend for programming GPUs and other devices. Helion aimsto raise the level of abstraction compared to Triton, making it easierto write correct and efficient kernels while enabling more automationin the autotuning process.

The nameHelion refers to the nucleus of a helium-3 atom, whileTritonrefers to hydrogen-3.

Helion can be viewed either asPyTorch with tiles or asa higher-level Triton. Compared toTriton, Helion reduces manual coding effort through autotuning. Helion spends more time (approx10 min) autotuning as it evaluates hundreds of potential Triton implementations generatedfrom a single Helion kernel. This larger search space also makes kernels more performanceportable between different hardware. Helion automates and autotunes over:

  1. Tensor Indexing:

    • Automatically calculates strides and indices.
    • Autotunes choices among various indexing methods (pointers, block pointers, TensorDescriptors).
    • Supports per-operation indexing strategies for fine-grained memory access control of loads and stores.
  2. Masking:

    • Most masking is implicit in Helion, and is optimized away when not needed.
  3. Grid Sizes and PID Calculations:

    • Automatically determines grid sizes.
    • Autotunes multiple mappings from Program IDs (PIDs) to data tiles.
  4. Implicit Search Space Definition:

    • Eliminates the need to manually define search configurations.
    • Automatically generates configuration flags and exploration spaces.
  5. Kernel Arguments Management:

    • Automates the handling of kernel arguments, including tensor sizes and strides.
    • Lifts global variables and (nested) closures into kernel arguments, allowing better templating.
  6. Looping Reductions:

    • Can automatically convert large reductions into looped implementations.
  7. Automated Optimizations:

    • PID swizzling for improved L2 cache reuse.
    • Loop reordering.
    • Persistent kernel strategies.
    • Warp specialization choices, unrolling, and more.

Example

A minimal matrix multiplication kernel in Helion looks like this:

importtorch,helion,helion.languageashl@helion.kernel()defmatmul(x:torch.Tensor,y:torch.Tensor)->torch.Tensor:m,k=x.size()k,n=y.size()out=torch.empty([m,n],dtype=x.dtype,device=x.device)fortile_m,tile_ninhl.tile([m,n]):acc=hl.zeros([tile_m,tile_n],dtype=torch.float32)fortile_kinhl.tile(k):acc=torch.addmm(acc,x[tile_m,tile_k],y[tile_k,tile_n])out[tile_m,tile_n]=accreturnout

The code outside thefor loops is standard PyTorch code executed onthe CPU. It is typically used for tasks like allocating output tensorsand performing shape computations.

The code inside thefor loops is compiled into a Triton kernel,resulting in a single GPU kernel. A single Helion kernel is alwayscompiled to exactly one GPU kernel.

Thehl.tile function subdivides the iteration space (in this casem byn) into tiles. These tiles are executed in parallel on the GPU. Tilingdetails, such as dimensionality (1D vs 2D), tile sizes, and loop ordering,are automatically determined by Helion's autotuner. Alternatively, thesedetails can be explicitly specified using theconfig= argument inhelion.kernel.

  • The outerfor loop is mapped onto the grid of the generatedkernel. The grid size is determined automatically based on the chosentile size.

  • The innerfor loop translates into a loop within the generated kernel,and its tile size is also determined automatically.

Within a Helion kernel, standard PyTorch operators (liketorch.addmm) are automatically mapped to Triton operations usingTorchInductor.Thus, familiarity with PyTorch means you already know most ofHelion. Helion supports a wide range of operations including pointwise(add,sigmoid, etc.), reductions (sum,softmax, etc.), views,and matrix multiplication operations. Arbitrary function callswithin a Helion kernel are supported, but must be traceable withmake_fx.

Autotuning

The above example can be executed with:

out=matmul(torch.randn([2048,2048],device="cuda"),torch.randn([2048,2048],device="cuda"))

When a kernel runs for the first time, Helion initiates autotuning. Atypical autotuning session produces output similar to:

[0s] Starting DifferentialEvolutionSearch with population=40, generations=20, crossover_rate=0.8[20s] Initial population: failed=4 min=0.0266 mid=0.1577 max=1.2390 best=Config(block_sizes=[64, 32, 64], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[3, 1], range_warp_specializes=[True, False], range_num_stages=[1, 0], range_multi_buffers=[True, True], range_flattens=[None, False], num_warps=4, num_stages=7, indexing='block_ptr', pid_type='persistent_blocked')[51s] Generation 2: replaced=17 min=0.0266 mid=0.0573 max=0.1331 best=Config(block_sizes=[64, 32, 64], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[3, 1], range_warp_specializes=[True, False], range_num_stages=[1, 0], range_multi_buffers=[True, True], range_flattens=[None, False], num_warps=4, num_stages=7, indexing='block_ptr', pid_type='persistent_blocked')[88s] Generation 3: replaced=18 min=0.0225 mid=0.0389 max=0.1085 best=Config(block_sizes=[64, 64, 16], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[0, 1], range_warp_specializes=[None, None], range_num_stages=[0, 0], range_multi_buffers=[None, False], range_flattens=[None, None], num_warps=4, num_stages=6, indexing='pointer', pid_type='flat')...[586s] Generation 19: replaced=3 min=0.0184 mid=0.0225 max=0.0287 best=Config(block_sizes=[64, 64, 64], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[0, 1], range_warp_specializes=[None, False], range_num_stages=[0, 3], range_multi_buffers=[None, False], range_flattens=[None, None], num_warps=8, num_stages=6, indexing='block_ptr', pid_type='flat')[586s] Autotuning complete in 586.6s after searching 1520 configs.One can hardcode the best config and skip autotuning with:    @helion.kernel(config=helion.Config(block_sizes=[64, 64, 64], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[0, 1], range_warp_specializes=[None, False], range_num_stages=[0, 3], range_multi_buffers=[None, False], range_flattens=[None, None], num_warps=8, num_stages=6, indexing='block_ptr', pid_type='flat'))

Because autotuning can be time-consuming (around 10 minutes in the aboveexample), you may want to manually specify the best configuration found fromautotuning to avoid repeated tuning:

@helion.kernel(config=helion.Config(block_sizes=[64,64,64],loop_orders=[[0,1]],l2_groupings=[4],range_unroll_factors=[0,1],range_warp_specializes=[None,False],range_num_stages=[0,3],range_multi_buffers=[None,False],range_flattens=[None,None],num_warps=8,num_stages=6,indexing='block_ptr',pid_type='flat'))defmatmul(x:torch.Tensor,y:torch.Tensor)->torch.Tensor:    ...

This explicit configuration skips autotuning on subsequent runs.

You can also specify multiple configurations, prompting Helion to performa more lightweight autotuning process:

@helion.kernel(configs=[helion.Config(...),helion.Config(...),])defmatmul(x:torch.Tensor,y:torch.Tensor)->torch.Tensor:    ...

In this case, Helion evaluates the provided configurations and selects the fastest one.

Additionally, Helion provides programmatic APIs to manage autotuningand configurations directly from your code.

For production deployment, we recommend using ahead-of-time tuned configurations rather than relying on runtime autotuning. The autotuning process can be time-consuming and resource-intensive, making it unsuitable for production environments where predictable performance and startup times are critical.

Static shapes and autotuning keys

By default Helion uses static shapes (static_shapes=True). This means each unique input shape/stride signature is treated as its own specialization and will be autotuned separately. This typically yields the best performance, but may increase autotuning time when many shapes are encountered.

If you want to reduce autotuning time by sharing configurations between different shapes, setstatic_shapes=False. In this mode, the autotuning key ignores exact sizes, allowing a single tuned config to be reused across multiple shapes. This can come with a performance penalty compared to fully specialized static shapes.

@helion.kernel(static_shapes=False)defmy_kernel(x:torch.Tensor)->torch.Tensor:    ...

Configurations

Helion configurations include the following options:

  • block_sizes (list[int]):Controls tile sizes corresponding to each dimension passedhl.tile or calltohl.register_block_size in the kernel.

  • loop_orders (list[list[int]]):Contains one entry perhl.tile call with two or more dimensions,allowing you to permute the iteration order of the tiles.

  • flatten_loops (list[bool]):Contains one entry perhl.tile call with two or more dimensions,allowing you to flatten the iteration space into a single dimension.

  • range_unroll_factors (list[int]):Contains one entry per loop dimension, specifying the unroll factor fortl.range() calls. Values less than 1 omit theloop_unroll_factor parameter.

  • range_num_stages (list[int]):Contains one entry per loop dimension, specifying the number of stages fortl.range() calls. Values less than 1 omit thenum_stages parameter.

  • range_multi_buffers (list[bool | None]):Contains one entry per loop dimension, controlling thedisallow_acc_multi_bufferparameter fortl.range() calls.True allows multi-buffer (setsdisallow_acc_multi_buffer=False),False disallows multi-buffer (setsdisallow_acc_multi_buffer=True), andNone omits the parameter.

  • range_flattens (list[bool | None]):Contains one entry per loop dimension, controlling theflattenparameter fortl.range() calls.True setsflatten=True,False setsflatten=False, andNone omits the parameter.

  • range_warp_specializes (list[bool | None]):Contains one entry per loop dimension, controlling thewarp_specializeparameter fortl.range() calls.True setswarp_specialize=True,False setswarp_specialize=False, andNone omits the parameter.Only available on CUDA devices with Blackwell or newer architectureswhenallow_warp_specialize setting is enabled.

  • static_ranges (list[bool]):Contains one entry per loop dimension with static bounds, controlling whether to usetl.static_range() calls.True generatestl.static_range() and ignores range_* configs for that loop.False generatestl.range().

  • reduction_loops (list[int | None]):Contains one entry per reduction dimension (seeexamples/softmax.py). UsingNone triggers a persistent reduction,where the entire reduction is processed in a single tile. Specifying aninteger block size converts the reduction into a loop, beneficial forlarger reductions that exceed the registers available.

  • l2_groupings (list[int]):Reorders the program IDs (PIDs) of the generated kernel for improved L2cache behavior. A value of1 disables this optimization, while highervalues specify the grouping size.

  • indexing ("pointer","tensor_descriptor","block_ptr", or a list of these):Specifies the memory indexing strategy for load and store operations. Can be:

    • A single strategy (applies to all loads and stores):indexing="block_ptr"
    • A list of strategies (one per load/store in execution order):indexing=["pointer", "pointer", "block_ptr"]
    • Empty/omitted (defaults to"pointer" for all operations)
    • When using a list, provide strategies in order:[load1, load2, ..., store1, store2, ...]

    The"tensor_descriptor" option uses Tensor Memory Accelerators (TMAs) butrequires a Hopper or newer GPU and the latest development version of Triton.

  • pid_type ("flat","xyz","persistent_blocked", or"persistent_interleaved"):Specifies the program ID mapping strategy."flat" uses only the x-dimension,"xyz" utilizes multiple grid dimensions, and persistent strategies enablepersistent kernels for improved SM utilization.

  • num_warps (int):Sets the number of warps the kernel will use.

  • num_stages (int):Defines the number of pipeline stages to be passed to Triton.

  • load_eviction_policies (list[str]):Controls eviction policy used for loads discovered in device loops. Each entrycorresponds to a load site; allowed values are"" (no policy),"first"(maps to Tritonevict_first), and"last" (maps to Tritonevict_last).Expliciteviction_policy=... onhl.load overrides this config.

Changing these options results in often significantly differentoutput Triton code, allowing the autotuner to explore a wide range ofimplementations from a single Helion kernel.

Settings for Development and Debugging

When developing kernels with Helion, you might prefer skipping autotuning for faster iteration. Todo this, set the environment variableHELION_AUTOTUNE_EFFORT=none or use the decorator argument@helion.kernel(autotune_effort="none").Warning: The default configuration is slow and not intended forproduction or performance testing.

To view the generated Triton code, set the environment variableHELION_PRINT_OUTPUT_CODE=1 or includeprint_output_code=True in the@helion.kernel decorator. This prints the Triton code tostderr, which ishelpful for debugging and understanding Helion's compilation process. One can also usefoo_kernel.bind(args).to_triton_code(config) to get the Triton code as a string.

To emit a repro script that includes the Helion kernel definition, the config decorator, and ahelion_repro_caller() helper that recreates the runtime inputs before invoking the Helion kernel, setHELION_PRINT_REPRO=1 or includeprint_repro=True in the@helion.kernel decorator. This printsthe repro script tostderr, which is helpful for debugging and for sharing minimal repro on GitHub issue tracker.

Within anhl.tile/hl.grid device loop, if you want to print intermediate results usingprint("x", ...) syntax,or pause execution using Python's built-inbreakpoint(), set eitherTRITON_INTERPRET=1 (runs Triton's CPU interpreter)orHELION_INTERPRET=1 (runs the Helion kernel in eager mode).

To force autotuning, bypassing provided configurations, setHELION_FORCE_AUTOTUNE=1 or invokefoo_kernel.autotune(args, force=True).

Additional settings are available insettings.py. If both an environmentvariable and a kernel decorator argument are set, the kernel decorator argument takes precedence, and the environmentvariable will be ignored.

Enable logging by setting the environment variableHELION_LOGS=all for INFO-level logs, orHELION_LOGS=+allfor DEBUG-level logs. Alternatively, you can specify logging for specific modules using a comma-separated list(e.g.,HELION_LOGS=+helion.runtime.kernel).

Requirements

Helion currently targets Linux systems and requires a recent Python and PyTorch environment:

  • Linux-based OS
  • Python 3.10–3.14
  • PyTorch 2.9 or later
  • Triton 3.5 or later(Older versions may work, but will lack support for features likeTMA on Hopper/Blackwell GPUs and may exhibit lower performance.)

Installation

We recommend usinguv to manage an isolated virtual environment. First,install compatible versions ofPyTorch andTriton.

Once your environment is set up, you can install Helion:

pip install helion

Alternatively, you may install from source for development purposes. If usinguv, create and activate a virtual environment first:

git clone https://github.com/pytorch/helion.gitcd helion# Create and activate a virtual environment with uv (one-time)uv venv .venvsource .venv/bin/activate# To install in editable w/ required dev packagespip install -e .'[dev]'

This installs Helion in "editable" mode so that changes to the sourcecode take effect without needing to reinstall.

Linting

We usepre-commit to run ruff, pyrefly, and other checks automatically.

– One-time setup (installs the git hook):

pip install pre-commitpre-commit install

– Run all checks across the repository:

pre-commit run --all-files

Note: You can still run the underlying tools directly via./lint.sh [fix|check|unsafe].

Community

Questions or feedback? Join us on theGPU MODE Discord in the#helion channel.

License

Helion is BSD-style licensed, as found in the LICENSE file.

About

A Python-embedded DSL that makes it easy to write fast, scalable ML kernels with minimal boilerplate.

Resources

License

Code of conduct

Contributing

Stars

Watchers

Forks

Packages

No packages published

Languages


[8]ページ先頭

©2009-2025 Movatter.jp