TunableOp#
Created On: Jun 03, 2024 | Last Updated On: Oct 15, 2025
Overview#
This module exposes a TunableOp interface.
Some operations, such as GEMMs, could be implemented using more than one libraryor more than one technique. For example, a GEMM could be implemented for CUDA orROCm using either the blas or blasLt libraries. Further, ROCm’s rocblas andhipblaslt libraries allow the user to query for all possible algorithms and thenchoose one. How does one know which implementation is the fastest and should bechosen? That’s what TunableOp provides.
Enabling TunableOp and Tuning Separately#
The TunableOp feature is enabled separately from enabling the tuning phaseitself. Enabling TunableOp means that PyTorch will replace any standardoperators with their Tunable implementations. Any call to a TunableOp firstchecks whether it has already been tuned for the given operator inputs. If so,it will immediately call the tuned operation; no further tuning will take placeeven when the tuning setting is enabled. Instead if no tuning result is found,and tuning is enabled, the TunableOp will benchmark every registeredimplementation of that operator for the given set of inputs and select thefastest.
File Input and Output#
The first time any TunableOp is invoked, the internal database of tunedoperations will be prepared by attempting to read the results from the givenfile. The default filename is ‘tunableop_results.csv’. To support tuning whenmultiple GPUs are used across multiple processes, the GPU device ordinal isautomatically inserted into the filename to avoid multiple processes overwritingthe same file.
If tuning is enabled and new tunings are discovered during the course of yourworkload, it will also write out to this same filename with all tunings, boththe ones it read in at startup as well as the new ones found at runtime. Thiscan be used, for example, to build up a tunings file across many workloads byreusing the same file. The output file is automatically created when theapplication terminates. This behavior can be controlled by the C++ and PythonAPIs but not the environment variables.
Assuming you specified a filename, you’ll end up with a CSV file with contentslike so:
Validator,PT_VERSION,2.2.0Validator,ROCM_VERSION,6.0.0.0-12969-1544e39Validator,HIPBLASLT_VERSION,0.6.0-a9c5cc7Validator,ROCBLAS_VERSION,4.0.0-72e57364-dirtyGemmTunableOp_float_NT,nt_25088_4096_64,Gemm_Hipblaslt_1219,1.262GemmTunableOp_float_NT,nt_4096_4096_64,Gemm_Rocblas_1216,0.033
Note the “Validator” lines. If you change a library version, or ROCm version, orPyTorch version, TunableOp will detect this and reject the tunings file becausethe prior tunings are likely affected by other software changes.
The remaining lines are the tuned solutions for each TunableOp encounteredduring your execution. Each line consists of 4 comma-separated fields: operatorname, operator parameters, solution name, and average execution time. Theexecution time is an optional field. The CSV file can be edited, but withcaution. For example, the solution name (field 3) can be changed to “Default”and it will fall back to the original PyTorch untuned implementation. Or, in thecase of ROCm’s hipBLAS or hipBLASLt libraries, if you know the specific solutionindex you can override the solution that TunableOp selected by replacing thevalue. The operator name and parameters (fields 1 and 2) are internally namedand should not be modified. In the case of GemmTunableOp, field 1 indicates thedatatype and whether the inputs are transposed (T) or not (N) and field 2indicates the M, N, K input shapes.
There is an option to enable verbose output but it is only recommended fordebugging purposes. This will produce a lot of diagnostic messages but may beuseful to see if TunableOp is being used at all. Otherwise, TunableOp iscompletely silent, besides file output, unless there is a warning or errorduring its use. The verbose option is only available by setting the environmentvariable PYTORCH_TUNABLEOP_VEROBSE=1.
A Note on Tuning Behavior, Warmup, and Cache Effects#
Tuning an operator consists of iterating through the list or registeredimplementations and profiling each one. The profile is established by running asingle implementation in a loop multiple times and taking the average executiontime. There is also an optional warmup phase prior to tuning that can help withreaching stable power states by the hardware. During tuning of a workload thevarious hardware caches will more likely produce hits than when not tuning.There are options for flushing the instruction cache and rotate the input tensorswhich might help produce a more faithful profile of the tuned operator as if theoperator were run within a larger workload instead of in a tight, repetitive loop.
By default, each possible solution for a given operator will be run for either100 iterations or as many iterations that can be run within 30ms, whichever issmaller, and its average execution will be calculated. The fastest solutionamong all that were successfully profiled will be chosen. A profile might failif the given solution doesn’t achieve the same accuracy as the defaultimplementation or if the solution returns an error code.
Current Tunable Operators#
TunableGemm for ROCm#
Currently only a TunableGemm for ROCm is implemented. Note that CUDA builds ofPyTorch will function correctly when using TunableOp but the only solutionavailable to CUDA builds is the ‘Default’ implementation i.e. the originalcuBLAS default, now called through TunableOp. Any call to at::cuda::blas::gemm()or ::bgemm() will be routed through TunableOp when enabled. Calling gemm() for agiven set of input arguments (transa, transb, m, n, k) will attempt to use thefastest available implementation across both rocblas and hipblaslt.
Offline Tuning#
Motivation#
There are several use cases for offline tuning.
One use case involves a workload with a high-memory utilization, where regular tuning might lead to running out of memory.
Another use case is for compute-intensive workloads. In such cases, it is more resource-efficient to collectthe GEMMs for the workload once and then tune repeatedly with different tuning parameters or libraries.
Workflow#
There are basically two steps:1) Set the environment variables to collect the untuned GEMM and this will generatetunableop_untuned0.csv:
exportPYTORCH_TUNABLEOP_ENABLED=1exportPYTORCH_TUNABLEOP_TUNING=0exportPYTORCH_TUNABLEOP_RECORD_UNTUNED=1...
Run a Python script that reads the
tunableop_untuned0.csvand generates thetunableop_results0.csv, like this:
importtorch.cuda.tunableastunableimportosos.putenv("PYTORCH_TUNABLEOP_ENABLED","1")os.putenv("PYTORCH_TUNABLEOP_TUNING","1")os.putenv("PYTORCH_TUNABLEOP_RECORD_UNTUNED","0")tunable.tune_gemm_in_file("tunableop_untuned0.csv")
It is also possible to take multiple untuned files and distribute the GEMMs for tuning to multiple GPUswithin a single node. In the first step, the GEMMs are first gathered and duplicate GEMMs are eliminated.Next, the GEMMs are distributed to different GPUs for tuning. After all GEMMs are tuned, the results fromall the GPUs are then gathered into a single file whose base filename has_full0 appended to it(for exampletunableop_results_full0.csv). Finally, this new file, containing the gathered results, will beduplicated N times, once for each GPU as convenience to the user will run the workload with the tunedconfiguration on N GPUs.
if__name__=="__main__":num_gpus=8# number of GPUs that will be used during the tuning processtunable.mgpu_tune_gemm_in_file("tunableop_untuned?.csv",num_gpus)
Note that the usage of themgpu_tune_gemm_in_file API is different from its single GPU counterpart(tune_gemm_in_file). The body of the Python script that calls the API must be wrapped inmain() as showndue to the use of concurrent futures module. The argument tomgpu_tune_gemm_in_file must contain a wild cardexpression (? or*) to generate the list of untuned files containing the GEMMs to be processed. Thenum_gpusmust between 1 and the total number of GPUs available.
Tuning Context#
The behavior of TunableOp is currently manipulated through environmentvariables, the C++ interface of at::cuda::tunable::getTuningContext(), or thetorch.cuda.tunable python interfaces. The environment variables take precedenceover any setting you manipulate using the C++ or Python APIs.
Environment Variable Interface#
Environment variables are cached the first time they are read. You cannot use theenvironment variable interface programmatically since the settings become fixed.Use the C++ or Python APIs instead.
API Reference#
- torch.cuda.tunable.enable(val=True)[source]#
This is the big on/off switch for all TunableOp implementations.
- torch.cuda.tunable.is_enabled()[source]#
Returns whether the TunableOp feature is enabled.
- Return type:
- torch.cuda.tunable.tuning_enable(val=True)[source]#
Enable tuning of TunableOp implementations.
When enabled, if a tuned entry isn’t found, run the tuning step and recordthe entry.
- torch.cuda.tunable.tuning_is_enabled()[source]#
Returns whether TunableOp implementations can be tuned.
- Return type:
- torch.cuda.tunable.record_untuned_enable(val=True)[source]#
Enable recording untuned of TunableOp perations for offline tuning.
When enabled, if a tuned entry isn’t found, write it to the untuned file.
- torch.cuda.tunable.record_untuned_is_enabled()[source]#
Returns whether TunableOp operations are recorded for offline tuning.
- Return type:
- torch.cuda.tunable.set_max_tuning_duration(duration)[source]#
Set max time in milliseconds to spend tuning a given solution.
If both max tuning duration and iterations are set, the smaller of the twowill be honored. At minimum 1 tuning iteration will always be run.
- torch.cuda.tunable.get_max_tuning_duration()[source]#
Get max time to spend tuning a given solution.
- Return type:
- torch.cuda.tunable.set_max_tuning_iterations(iterations)[source]#
Set max number of iterations to spend tuning a given solution.
If both max tuning duration and iterations are set, the smaller of the twowill be honored. At minimum 1 tuning iteration will always be run.
- torch.cuda.tunable.get_max_tuning_iterations()[source]#
Get max iterations to spend tuning a given solution.
- Return type:
- torch.cuda.tunable.set_filename(filename,insert_device_ordinal=False)[source]#
Set the filename to use for input/output of tuning results.
If
insert_device_ordinalisTruethen the current device ordinalwill be added to the given filename automatically. This can be used in a1-process-per-gpu scenario to ensure all processes write to a separate file.
- torch.cuda.tunable.read_file(filename=None)[source]#
Read results from a TunableOp CSV file.
If
filenameis not given,get_filename()is called.- Return type:
- torch.cuda.tunable.mgpu_tune_gemm_in_file(filename_pattern,num_gpus)[source]#
Process one or more files and distribute work over one or more GPUs.
- torch.cuda.tunable.set_rotating_buffer_size(buffer_size)[source]#
Set rotating buffer size to this value in MB, if the buffer size is greater than zero.
If less than zero, query L2 cache size. If equal to zero, means deactivate rotating buffer.