Rate this Page

torch.utils.cpp_extension#

Created On: Mar 07, 2018 | Last Updated On: Feb 16, 2025

torch.utils.cpp_extension.CppExtension(name,sources,*args,**kwargs)[source]#

Create asetuptools.Extension for C++.

Convenience method that creates asetuptools.Extension with thebare minimum (but often sufficient) arguments to build a C++ extension.

All arguments are forwarded to thesetuptools.Extensionconstructor. Full list arguments can be found athttps://setuptools.pypa.io/en/latest/userguide/ext_modules.html#extension-api-reference

Warning

The PyTorch python API (as provided in libtorch_python) cannot be builtwith the flagpy_limited_api=True. When this flag is passed, it isthe user’s responsibility in their library to not use APIs fromlibtorch_python (in particular pytorch/python bindings) and to only useAPIs from libtorch (aten objects, operators and the dispatcher). Forexample, to give access to custom ops from python, the library shouldregister the ops through the dispatcher.

Contrary to CPython setuptools, who does not define -DPy_LIMITED_APIas a compile flag when py_limited_api is specified as an option forthe “bdist_wheel” command insetup, PyTorch does! We will specify-DPy_LIMITED_API=min_supported_cpython to best enforce consistency,safety, and sanity in order to encourage best practices. To target adifferent version, set min_supported_cpython to the hexcode of theCPython version of choice.

Example

>>>fromsetuptoolsimportsetup>>>fromtorch.utils.cpp_extensionimportBuildExtension,CppExtension>>>setup(...name='extension',...ext_modules=[...CppExtension(...name='extension',...sources=['extension.cpp'],...extra_compile_args=['-g'],...extra_link_args=['-Wl,--no-as-needed','-lm'])...],...cmdclass={...'build_ext':BuildExtension...})
torch.utils.cpp_extension.CUDAExtension(name,sources,*args,**kwargs)[source]#

Create asetuptools.Extension for CUDA/C++.

Convenience method that creates asetuptools.Extension with thebare minimum (but often sufficient) arguments to build a CUDA/C++extension. This includes the CUDA include path, library path and runtimelibrary.

All arguments are forwarded to thesetuptools.Extensionconstructor. Full list arguments can be found athttps://setuptools.pypa.io/en/latest/userguide/ext_modules.html#extension-api-reference

Warning

The PyTorch python API (as provided in libtorch_python) cannot be builtwith the flagpy_limited_api=True. When this flag is passed, it isthe user’s responsibility in their library to not use APIs fromlibtorch_python (in particular pytorch/python bindings) and to only useAPIs from libtorch (aten objects, operators and the dispatcher). Forexample, to give access to custom ops from python, the library shouldregister the ops through the dispatcher.

Contrary to CPython setuptools, who does not define -DPy_LIMITED_APIas a compile flag when py_limited_api is specified as an option forthe “bdist_wheel” command insetup, PyTorch does! We will specify-DPy_LIMITED_API=min_supported_cpython to best enforce consistency,safety, and sanity in order to encourage best practices. To target adifferent version, set min_supported_cpython to the hexcode of theCPython version of choice.

Example

>>>fromsetuptoolsimportsetup>>>fromtorch.utils.cpp_extensionimportBuildExtension,CUDAExtension>>>setup(...name='cuda_extension',...ext_modules=[...CUDAExtension(...name='cuda_extension',...sources=['extension.cpp','extension_kernel.cu'],...extra_compile_args={'cxx':['-g'],...'nvcc':['-O2']},...extra_link_args=['-Wl,--no-as-needed','-lcuda'])...],...cmdclass={...'build_ext':BuildExtension...})

Compute capabilities:

By default the extension will be compiled to run on all archs of the cards visible during thebuilding process of the extension, plus PTX. If down the road a new card is installed theextension may need to be recompiled. If a visible card has a compute capability (CC) that’snewer than the newest version for which your nvcc can build fully-compiled binaries, PyTorchwill make nvcc fall back to building kernels with the newest version of PTX your nvcc doessupport (see below for details on PTX).

You can override the default behavior usingTORCH_CUDA_ARCH_LIST to explicitly specify whichCCs you want the extension to support:

TORCH_CUDA_ARCH_LIST="6.18.6"pythonbuild_my_extension.pyTORCH_CUDA_ARCH_LIST="5.26.06.17.07.58.08.6+PTX"pythonbuild_my_extension.py

The +PTX option causes extension kernel binaries to include PTX instructions for the specifiedCC. PTX is an intermediate representation that allows kernels to runtime-compile for any CC >=the specified CC (for example, 8.6+PTX generates PTX that can runtime-compile for any GPU withCC >= 8.6). This improves your binary’s forward compatibility. However, relying on older PTX toprovide forward compat by runtime-compiling for newer CCs can modestly reduce performance onthose newer CCs. If you know exact CC(s) of the GPUs you want to target, you’re always betteroff specifying them individually. For example, if you want your extension to run on 8.0 and 8.6,“8.0+PTX” would work functionally because it includes PTX that can runtime-compile for 8.6, but“8.0 8.6” would be better.

Note that while it’s possible to include all supported archs, the more archs get included theslower the building process will be, as it will build a separate kernel image for each arch.

Note that CUDA-11.5 nvcc will hit internal compiler error while parsing torch/extension.h on Windows.To workaround the issue, move python binding logic to pure C++ file.

Example use:

#include <ATen/ATen.h>at::Tensor SigmoidAlphaBlendForwardCuda(….)

Instead of:

#include <torch/extension.h>torch::Tensor SigmoidAlphaBlendForwardCuda(…)

Currently open issue for nvcc bug:pytorch/pytorch#69460Complete workaround code example:facebookresearch/pytorch3d

Relocatable device code linking:

If you want to reference device symbols across compilation units (across object files),the object files need to be built withrelocatable device code (-rdc=true or -dc).An exception to this rule is “dynamic parallelism” (nested kernel launches) which is not used a lot anymore.Relocatable device code is less optimized so it needs to be used only on object files that need it.Using-dlto (Device Link Time Optimization) at the device code compilation step anddlink stephelps reduce the protentional perf degradation of-rdc.Note that it needs to be used at both steps to be useful.

If you haverdc objects you need to have an extra-dlink (device linking) step before the CPU symbol linking step.There is also a case where-dlink is used without-rdc:when an extension is linked against a static lib containing rdc-compiled objectslike the [NVSHMEM library](https://developer.nvidia.com/nvshmem).

Note: Ninja is required to build a CUDA Extension with RDC linking.

Example

>>>CUDAExtension(...name='cuda_extension',...sources=['extension.cpp','extension_kernel.cu'],...dlink=True,...dlink_libraries=["dlink_lib"],...extra_compile_args={'cxx':['-g'],...'nvcc':['-O2','-rdc=true']})
torch.utils.cpp_extension.SyclExtension(name,sources,*args,**kwargs)[source]#

Creates asetuptools.Extension for SYCL/C++.

Convenience method that creates asetuptools.Extension with thebare minimum (but often sufficient) arguments to build a SYCL/C++extension.

All arguments are forwarded to thesetuptools.Extensionconstructor.

Warning

The PyTorch python API (as provided in libtorch_python) cannot be builtwith the flagpy_limited_api=True. When this flag is passed, it isthe user’s responsibility in their library to not use APIs fromlibtorch_python (in particular pytorch/python bindings) and to only useAPIs from libtorch (aten objects, operators and the dispatcher). Forexample, to give access to custom ops from python, the library shouldregister the ops through the dispatcher.

Contrary to CPython setuptools, who does not define -DPy_LIMITED_APIas a compile flag when py_limited_api is specified as an option forthe “bdist_wheel” command insetup, PyTorch does! We will specify-DPy_LIMITED_API=min_supported_cpython to best enforce consistency,safety, and sanity in order to encourage best practices. To target adifferent version, set min_supported_cpython to the hexcode of theCPython version of choice.

Example

>>>fromtorch.utils.cpp_extensionimportBuildExtension,SyclExtension>>>setup(...name='xpu_extension',...ext_modules=[...SyclExtension(...name='xpu_extension',...sources=['extension.cpp','extension_kernel.cpp'],...extra_compile_args={'cxx':['-g','-std=c++20','-fPIC']})...],...cmdclass={...'build_ext':BuildExtension...})

By default the extension will be compiled to run on all archs of the cards visible during thebuilding process of the extension. If down the road a new card is installed theextension may need to be recompiled. You can override the default behavior usingTORCH_XPU_ARCH_LIST to explicitly specify which device architectures you want the extensionto support:

TORCH_XPU_ARCH_LIST="pvc,xe-lpg"pythonbuild_my_extension.py

Note that while it’s possible to include all supported archs, the more archs get included theslower the building process will be, as it will build a separate kernel image for each arch.

Note: Ninja is required to build SyclExtension.

torch.utils.cpp_extension.BuildExtension(*args,**kwargs)[source]#

A customsetuptools build extension .

Thissetuptools.build_ext subclass takes care of passing theminimum required compiler flags (e.g.-std=c++17) as well as mixedC++/CUDA/SYCL compilation (and support for CUDA/SYCL files in general).

When usingBuildExtension, it is allowed to supply a dictionaryforextra_compile_args (rather than the usual list) that maps fromlanguages/compilers (the only expected values arecxx,nvcc orsycl) to a list of additional compiler flags to supply to the compiler.This makes it possible to supply different flags to the C++, CUDA and SYCLcompiler during mixed compilation.

use_ninja (bool): Ifuse_ninja isTrue (default), then weattempt to build using the Ninja backend. Ninja greatly speeds upcompilation compared to the standardsetuptools.build_ext.Fallbacks to the standard distutils backend if Ninja is not available.

Note

By default, the Ninja backend uses #CPUS + 2 workers to build theextension. This may use up too many resources on some systems. Onecan control the number of workers by setting theMAX_JOBS environmentvariable to a non-negative number.

torch.utils.cpp_extension.load(name,sources,extra_cflags=None,extra_cuda_cflags=None,extra_sycl_cflags=None,extra_ldflags=None,extra_include_paths=None,build_directory=None,verbose=False,with_cuda=None,with_sycl=None,is_python_module=True,is_standalone=False,keep_intermediates=True)[source]#

Load a PyTorch C++ extension just-in-time (JIT).

To load an extension, a Ninja build file is emitted, which is used tocompile the given sources into a dynamic library. This library issubsequently loaded into the current Python process as a module andreturned from this function, ready for use.

By default, the directory to which the build file is emitted and theresulting library compiled to is<tmp>/torch_extensions/<name>, where<tmp> is the temporary folder on the current platform and<name>the name of the extension. This location can be overridden in two ways.First, if theTORCH_EXTENSIONS_DIR environment variable is set, itreplaces<tmp>/torch_extensions and all extensions will be compiledinto subfolders of this directory. Second, if thebuild_directoryargument to this function is supplied, it overrides the entire path, i.e.the library will be compiled into that folder directly.

To compile the sources, the default system compiler (c++) is used,which can be overridden by setting theCXX environment variable. To passadditional arguments to the compilation process,extra_cflags orextra_ldflags can be provided. For example, to compile your extensionwith optimizations, passextra_cflags=['-O3']. You can also useextra_cflags to pass further include directories.

CUDA support with mixed compilation is provided. Simply pass CUDA sourcefiles (.cu or.cuh) along with other sources. Such files will bedetected and compiled with nvcc rather than the C++ compiler. This includespassing the CUDA lib64 directory as a library directory, and linkingcudart. You can pass additional flags to nvcc viaextra_cuda_cflags, just like withextra_cflags for C++. Variousheuristics for finding the CUDA install directory are used, which usuallywork fine. If not, setting theCUDA_HOME environment variable is thesafest option.

SYCL support with mixed compilation is provided. Simply pass SYCL sourcefiles (.sycl) along with other sources. Such files will be detectedand compiled with SYCL compiler (such as Intel DPC++ Compiler) ratherthan the C++ compiler. You can pass additional flags to SYCL compilerviaextra_sycl_cflags, just like withextra_cflags for C++.SYCL compiler is expected to be found via system PATH environmentvariable.

Parameters
  • name – The name of the extension to build. This MUST be the same as thename of the pybind11 module!

  • sources (Union[str,list[str]]) – A list of relative or absolute paths to C++ source files.

  • extra_cflags – optional list of compiler flags to forward to the build.

  • extra_cuda_cflags – optional list of compiler flags to forward to nvccwhen building CUDA sources.

  • extra_sycl_cflags – optional list of compiler flags to forward to SYCLcompiler when building SYCL sources.

  • extra_ldflags – optional list of linker flags to forward to the build.

  • extra_include_paths – optional list of include directories to forwardto the build.

  • build_directory – optional path to use as build workspace.

  • verbose – IfTrue, turns on verbose logging of load steps.

  • with_cuda (Optional[bool]) – Determines whether CUDA headers and libraries are added tothe build. If set toNone (default), this value isautomatically determined based on the existence of.cu or.cuh insources. Set it toTrue` to force CUDA headersand libraries to be included.

  • with_sycl (Optional[bool]) – Determines whether SYCL headers and libraries are added tothe build. If set toNone (default), this value isautomatically determined based on the existence of.sycl insources. Set it toTrue` to force SYCL headers andlibraries to be included.

  • is_python_module – IfTrue (default), imports the produced sharedlibrary as a Python module. IfFalse, behavior depends onis_standalone.

  • is_standalone – IfFalse (default) loads the constructed extensioninto the process as a plain dynamic library. IfTrue, build astandalone executable.

Returns

Returns the loaded PyTorch extension as a Python module.

Ifis_python_module isFalse andis_standalone isFalse:

Returns nothing. (The shared library is loaded into the process asa side effect.)

Ifis_standalone isTrue.

Return the path to the executable. (On Windows, TORCH_LIB_PATH isadded to the PATH environment variable as a side effect.)

Return type

Ifis_python_module isTrue

Example

>>>fromtorch.utils.cpp_extensionimportload>>>module=load(...name='extension',...sources=['extension.cpp','extension_kernel.cu'],...extra_cflags=['-O2'],...verbose=True)
torch.utils.cpp_extension.load_inline(name,cpp_sources,cuda_sources=None,sycl_sources=None,functions=None,extra_cflags=None,extra_cuda_cflags=None,extra_sycl_cflags=None,extra_ldflags=None,extra_include_paths=None,build_directory=None,verbose=False,with_cuda=None,with_sycl=None,is_python_module=True,with_pytorch_error_handling=True,keep_intermediates=True,use_pch=False,no_implicit_headers=False)[source]#

Load a PyTorch C++ extension just-in-time (JIT) from string sources.

This function behaves exactly likeload(), but takes its sources asstrings rather than filenames. These strings are stored to files in thebuild directory, after which the behavior ofload_inline() isidentical toload().

Seethetestsfor good examples of using this function.

Sources may omit two required parts of a typical non-inline C++ extension:the necessary header includes, as well as the (pybind11) binding code. Moreprecisely, strings passed tocpp_sources are first concatenated into asingle.cpp file. This file is then prepended with#include<torch/extension.h>

Furthermore, if thefunctions argument is supplied, bindings will beautomatically generated for each function specified.functions caneither be a list of function names, or a dictionary mapping from functionnames to docstrings. If a list is given, the name of each function is usedas its docstring.

The sources incuda_sources are concatenated into a separate.cufile and prepended withtorch/types.h,cuda.h andcuda_runtime.h includes. The.cpp and.cu files are compiledseparately, but ultimately linked into a single library. Note that nobindings are generated for functions incuda_sources per se. To bindto a CUDA kernel, you must create a C++ function that calls it, and eitherdeclare or define this C++ function in one of thecpp_sources (andinclude its name infunctions).

The sources insycl_sources are concatenated into a separate.syclfile and prepended withtorch/types.h,sycl/sycl.hpp includes.The.cpp and.sycl files are compiled separately, but ultimatelylinked into a single library. Note that no bindings are generated forfunctions insycl_sources per se. To bind to a SYCL kernel, you mustcreate a C++ function that calls it, and either declare or define thisC++ function in one of thecpp_sources (and include its nameinfunctions).

Seeload() for a description of arguments omitted below.

Parameters
  • cpp_sources – A string, or list of strings, containing C++ source code.

  • cuda_sources – A string, or list of strings, containing CUDA source code.

  • sycl_sources – A string, or list of strings, containing SYCL source code.

  • functions – A list of function names for which to generate functionbindings. If a dictionary is given, it should map function names todocstrings (which are otherwise just the function names).

  • with_cuda – Determines whether CUDA headers and libraries are added tothe build. If set toNone (default), this value isautomatically determined based on whethercuda_sources isprovided. Set it toTrue to force CUDA headersand libraries to be included.

  • with_sycl – Determines whether SYCL headers and libraries are added tothe build. If set toNone (default), this value isautomatically determined based on whethersycl_sources isprovided. Set it toTrue to force SYCL headersand libraries to be included.

  • with_pytorch_error_handling – Determines whether pytorch error andwarning macros are handled by pytorch instead of pybind. To dothis, each functionfoo is called via an intermediary_safe_foofunction. This redirection might cause issues in obscure casesof cpp. This flag should be set toFalse when this redirectcauses issues.

  • no_implicit_headers – IfTrue, skips automatically adding headers, most notably#include<torch/extension.h> and#include<torch/types.h> lines.Use this option to improve cold start times when youalready include the necessary headers in your source code. Default:False.

Example

>>>fromtorch.utils.cpp_extensionimportload_inline>>>source="""at::Tensor sin_add(at::Tensor x, at::Tensor y) {  return x.sin() + y.sin();}""">>>module=load_inline(name='inline_extension',...cpp_sources=[source],...functions=['sin_add'])

Note

Since load_inline will just-in-time compile the source code, please ensurethat you have the right toolchains installed in the runtime. For example,when loading C++, make sure a C++ compiler is available. If you’re loadinga CUDA extension, you will need to additionally install the corresponding CUDAtoolkit (nvcc and any other dependencies your code has). Compiling toolchainsare not included when you install torch and must be additionally installed.

During compiling, by default, the Ninja backend uses #CPUS + 2 workers to buildthe extension. This may use up too many resources on some systems. Onecan control the number of workers by setting theMAX_JOBS environmentvariable to a non-negative number.

torch.utils.cpp_extension.include_paths(device_type='cpu')[source]#

Get the include paths required to build a C++ or CUDA or SYCL extension.

Parameters

device_type (str) – Defaults to “cpu”.

Returns

A list of include path strings.

Return type

list[str]

torch.utils.cpp_extension.get_compiler_abi_compatibility_and_version(compiler)[source]#

Determine if the given compiler is ABI-compatible with PyTorch alongside its version.

Parameters

compiler (str) – The compiler executable name to check (e.g.g++).Must be executable in a shell process.

Returns

A tuple that contains a boolean that defines if the compiler is (likely) ABI-incompatible with PyTorch,followed by aTorchVersion string that contains the compiler version separated by dots.

Return type

tuple[bool, torch.torch_version.TorchVersion]

torch.utils.cpp_extension.verify_ninja_availability()[source]#

RaiseRuntimeError ifninja build system is not available on the system, does nothing otherwise.

torch.utils.cpp_extension.is_ninja_available()[source]#

ReturnTrue if theninja build system is available on the system,False otherwise.