torch.library#
Created On: Jun 13, 2022 | Last Updated On: Aug 13, 2025
torch.library is a collection of APIs for extending PyTorch’s core libraryof operators. It contains utilities for testing custom operators, creating newcustom operators, and extending operators defined with PyTorch’s C++ operatorregistration APIs (e.g. aten operators).
For a detailed guide on effectively using these APIs, please seePyTorch Custom Operators Landing Pagefor more details on how to effectively use these APIs.
Testing custom ops#
Usetorch.library.opcheck() to test custom ops for incorrect usage of thePython torch.library and/or C++ TORCH_LIBRARY APIs. Also, if your operator supportstraining, usetorch.autograd.gradcheck() to test that the gradients aremathematically correct.
- torch.library.opcheck(op,args,kwargs=None,*,test_utils=('test_schema','test_autograd_registration','test_faketensor','test_aot_dispatch_dynamic'),raise_exception=True,atol=None,rtol=None)[source]#
Given an operator and some sample arguments, tests if the operator isregistered correctly.
That is, when you use the torch.library/TORCH_LIBRARY APIs to create acustom op, you specified metadata (e.g. mutability info) about the custom opand these APIs require that the functions you pass them satisfy certainproperties (e.g. no data pointer access in the fake/meta/abstract kernel)
opchecktests these metadata and properties.Concretely, we test the following:
test_schema: If the schema matches the implementation ofthe operator. For example: if the schema specifies a Tensor is mutated,then we check the implementation mutates the Tensor. If the schemaspecifies that we return a new Tensor, then we check that theimplementation returns a new Tensor (instead of an existing one ora view of an existing one).
test_autograd_registration: If the operator supports training(autograd): we check that its autograd formula is registered viatorch.library.register_autograd or a manual registration to oneor more DispatchKey::Autograd keys. Any other DispatchKey-basedregistrations may lead to undefined behavior.
test_faketensor: If the operator has a FakeTensor kernel(and if it is correct). The FakeTensor kernel is necessary (but not sufficient) for the operator to work with PyTorch compilationAPIs (torch.compile/export/FX). We check that a FakeTensor kernel(also sometimes known as a meta kernel) was registered for theoperator and that it is correct. This test takes the result ofrunning the operator on real tensors and the result of runningthe operator on FakeTensors and checks that they have the sameTensor metadata (sizes/strides/dtype/device/etc).
test_aot_dispatch_dynamic: If the operator has correct behaviorwith PyTorch compilation APIs (torch.compile/export/FX).This checks that the outputs (and gradients, if applicable) are thesame under eager-mode PyTorch and torch.compile.This test is a superset of
test_faketensorand is an e2e test;other things it tests are that the operator supportsfunctionalization and that the backward pass (if it exists) alsosupports FakeTensor and functionalization.
For best results, please call
opcheckmultiple times with arepresentative set of inputs. If your operator supportsautograd, please useopcheckwith inputs withrequires_grad=True;if your operator supports multiple devices (e.g. CPU and CUDA), pleaseuseopcheckwith inputs on all supported devices.- Parameters
op (Union[OpOverload,OpOverloadPacket,CustomOpDef]) – The operator. Must either be a function decorated with
torch.library.custom_op()or an OpOverload/OpOverloadPacketfound in torch.ops.* (e.g. torch.ops.aten.sin, torch.ops.mylib.foo)kwargs (Optional[dict[str,Any]]) – The kwargs to the operator
test_utils (Union[str,Sequence[str]]) – Tests that we should run. Default: all of them.Example: (“test_schema”, “test_faketensor”)
raise_exception (bool) – If we should raise an exception on the firsterror. If False, we will return a dict with informationon if each test passed or not.
rtol (Optional[float]) – Relative tolerance for floating point comparisons.If specified
atolmust also be specified.If omitted, default values based on thedtypeare selected(see the table intorch.testing.assert_close()).atol (Optional[float]) – Absolute tolerance for floating point comparisons.If specified
rtolmust also be specified.If omitted, default values based on thedtypeare selected(see the table intorch.testing.assert_close()).
- Return type
Warning
opcheck and
torch.autograd.gradcheck()test different things;opcheck tests if your usage of torch.library APIs is correct whiletorch.autograd.gradcheck()tests if your autograd formula ismathematically correct. Use both to test custom ops that supportgradient computation.Example
>>>@torch.library.custom_op("mylib::numpy_mul",mutates_args=())>>>defnumpy_mul(x:Tensor,y:float)->Tensor:>>>x_np=x.numpy(force=True)>>>z_np=x_np*y>>>returntorch.from_numpy(z_np).to(x.device)>>>>>>@numpy_mul.register_fake>>>def_(x,y):>>>returntorch.empty_like(x)>>>>>>defsetup_context(ctx,inputs,output):>>>y,=inputs>>>ctx.y=y>>>>>>defbackward(ctx,grad):>>>returngrad*ctx.y,None>>>>>>numpy_mul.register_autograd(backward,setup_context=setup_context)>>>>>>sample_inputs=[>>>(torch.randn(3),3.14),>>>(torch.randn(2,3,device='cuda'),2.718),>>>(torch.randn(1,10,requires_grad=True),1.234),>>>(torch.randn(64,64,device='cuda',requires_grad=True),90.18),>>>]>>>>>>forargsinsample_inputs:>>>torch.library.opcheck(numpy_mul,args)
Creating new custom ops in Python#
Usetorch.library.custom_op() to create new custom ops.
- torch.library.custom_op(name,fn=None,/,*,mutates_args,device_types=None,schema=None,tags=None)[source]#
Wraps a function into custom operator.
Reasons why you may want to create a custom op include:- Wrapping a third-party library or custom kernel to work with PyTorchsubsystems like Autograd.- Preventing torch.compile/export/FX tracing from peeking inside your function.
This API is used as a decorator around a function (please see examples).The provided function must have type hints; these are needed to interfacewith PyTorch’s various subsystems.
- Parameters
name (str) – A name for the custom op that looks like “{namespace}::{name}”,e.g. “mylib::my_linear”. The name is used as the op’s stable identifierin PyTorch subsystems (e.g. torch.export, FX graphs).To avoid name collisions, please use your project name as the namespace;e.g. all custom ops in pytorch/fbgemm use “fbgemm” as the namespace.
mutates_args (Iterable[str] or"unknown") – The names of args that the function mutates.This MUST be accurate, otherwise, the behavior is undefined. If “unknown”,it pessimistically assumes that all inputs to the operator are being mutated.
device_types (None |str |Sequence[str]) – The device type(s) the functionis valid for. If no device type is provided, then the functionis used as the default implementation for all device types.Examples: “cpu”, “cuda”.When registering a device-specific implementation for an operator that accepts no Tensors,we require the operator to have a “device: torch.device argument”.
schema (None |str) – A schema string for the operator. If None(recommended) we’ll infer a schema for the operator from its typeannotations. We recommend letting us infer a schema unless youhave a specific reason not to.Example: “(Tensor x, int y) -> (Tensor, Tensor)”.
- Return type
Union[Callable[[Callable[[…],object]],CustomOpDef],CustomOpDef]
Note
We recommend not passing in a
schemaarg and instead letting us inferit from the type annotations. It is error-prone to write your own schema.You may wish to provide your own schema if our interpretation ofthe type annotation is not what you want.For more info on how to write a schema string, seehere- Examples::
>>>importtorch>>>fromtorchimportTensor>>>fromtorch.libraryimportcustom_op>>>importnumpyasnp>>>>>>@custom_op("mylib::numpy_sin",mutates_args=())>>>defnumpy_sin(x:Tensor)->Tensor:>>>x_np=x.cpu().numpy()>>>y_np=np.sin(x_np)>>>returntorch.from_numpy(y_np).to(device=x.device)>>>>>>x=torch.randn(3)>>>y=numpy_sin(x)>>>asserttorch.allclose(y,x.sin())>>>>>># Example of a custom op that only works for one device type.>>>@custom_op("mylib::numpy_sin_cpu",mutates_args=(),device_types="cpu")>>>defnumpy_sin_cpu(x:Tensor)->Tensor:>>>x_np=x.numpy()>>>y_np=np.sin(x_np)>>>returntorch.from_numpy(y_np)>>>>>>x=torch.randn(3)>>>y=numpy_sin_cpu(x)>>>asserttorch.allclose(y,x.sin())>>>>>># Example of a custom op that mutates an input>>>@custom_op("mylib::numpy_sin_inplace",mutates_args={"x"},device_types="cpu")>>>defnumpy_sin_inplace(x:Tensor)->None:>>>x_np=x.numpy()>>>np.sin(x_np,out=x_np)>>>>>>x=torch.randn(3)>>>expected=x.sin()>>>numpy_sin_inplace(x)>>>asserttorch.allclose(x,expected)>>>>>># Example of a factory function>>>@torch.library.custom_op("mylib::bar",mutates_args={},device_types="cpu")>>>defbar(device:torch.device)->Tensor:>>>returntorch.ones(3)>>>>>>bar("cpu")
- torch.library.triton_op(name,fn=None,/,*,mutates_args,schema=None)[source]#
Create a custom operator whose implementation is backed by 1+ triton kernels.
This is a more structured way of using triton kernels with PyTorch.Prefer using triton kernels with no
torch.librarycustom operator wrappers(liketorch.library.custom_op(),torch.library.triton_op()) becausethat is simpler;only usetorch.library.custom_op()/torch.library.triton_op()if youwant to create an operator that behaves like PyTorch built-in operators.For example, you may use atorch.librarywrapper API to define thebehavior of the triton kernel when passed a tensor subclass or undera TorchDispatchMode.Use
torch.library.triton_op()instead oftorch.library.custom_op()when the implementationconsists of 1+ triton kernels.torch.library.custom_op()treatscustom operators as opaque (torch.compile()andtorch.export.export()will never trace into them), buttriton_opmakes the implementation visible to these subsystems, allowing themto optimize the triton kernel(s).Note that
fnmust only consist of calls to PyTorch-understoodoperators and triton kernels. Any triton kernels called insidefnmust be wrapped in a call totorch.library.wrap_triton().- Parameters
name (str) – A name for the custom op that looks like “{namespace}::{name}”,e.g. “mylib::my_linear”. The name is used as the op’s stable identifierin PyTorch subsystems (e.g. torch.export, FX graphs).To avoid name collisions, please use your project name as the namespace;e.g. all custom ops in pytorch/fbgemm use “fbgemm” as the namespace.
mutates_args (Iterable[str] or"unknown") – The names of args that the function mutates.This MUST be accurate, otherwise, the behavior is undefined. If “unknown”,it pessimistically assumes that all inputs to the operator are being mutated.
schema (None |str) – A schema string for the operator. If None(recommended) we’ll infer a schema for the operator from its typeannotations. We recommend letting us infer a schema unless youhave a specific reason not to.Example: “(Tensor x, int y) -> (Tensor, Tensor)”.
- Return type
Example:
>>>importtorch>>>fromtorch.libraryimporttriton_op,wrap_triton>>>>>>importtriton>>>fromtritonimportlanguageastl>>>>>>@triton.jit>>>defadd_kernel(>>>in_ptr0,>>>in_ptr1,>>>out_ptr,>>>n_elements,>>>BLOCK_SIZE:"tl.constexpr",>>>):>>>pid=tl.program_id(axis=0)>>>block_start=pid*BLOCK_SIZE>>>offsets=block_start+tl.arange(0,BLOCK_SIZE)>>>mask=offsets<n_elements>>>x=tl.load(in_ptr0+offsets,mask=mask)>>>y=tl.load(in_ptr1+offsets,mask=mask)>>>output=x+y>>>tl.store(out_ptr+offsets,output,mask=mask)>>>>>>@triton_op("mylib::add",mutates_args={})>>>defadd(x:torch.Tensor,y:torch.Tensor)->torch.Tensor:>>>output=torch.empty_like(x)>>>n_elements=output.numel()>>>>>>defgrid(meta):>>>return(triton.cdiv(n_elements,meta["BLOCK_SIZE"]),)>>>>>># NB: we need to wrap the triton kernel in a call to wrap_triton>>>wrap_triton(add_kernel)[grid](x,y,output,n_elements,16)>>>returnoutput>>>>>>@torch.compile>>>deff(x,y):>>>returnadd(x,y)>>>>>>x=torch.randn(3,device="cuda")>>>y=torch.randn(3,device="cuda")>>>>>>z=f(x,y)>>>asserttorch.allclose(z,x+y)
- torch.library.wrap_triton(triton_kernel,/)[source]#
Allows capture of a triton kernel into a graph via make_fx ornon-strict
torch.export.These technologies perform Dispatcher-based tracing (via
__torch_dispatch__) and cannot see calls to raw triton kernels.Thewrap_tritonAPI wraps a triton kernel into a callable thatcan actually be traced into a graph.Please use this API together with
torch.library.triton_op().Examples
>>>importtorch>>>importtriton>>>fromtritonimportlanguageastl>>>fromtorch.fx.experimental.proxy_tensorimportmake_fx>>>fromtorch.libraryimportwrap_triton>>>>>>@triton.jit>>>defadd_kernel(>>>in_ptr0,>>>in_ptr1,>>>out_ptr,>>>n_elements,>>>BLOCK_SIZE:"tl.constexpr",>>>):>>>pid=tl.program_id(axis=0)>>>block_start=pid*BLOCK_SIZE>>>offsets=block_start+tl.arange(0,BLOCK_SIZE)>>>mask=offsets<n_elements>>>x=tl.load(in_ptr0+offsets,mask=mask)>>>y=tl.load(in_ptr1+offsets,mask=mask)>>>output=x+y>>>tl.store(out_ptr+offsets,output,mask=mask)>>>>>>defadd(x,y):>>>output=torch.empty_like(x)>>>n_elements=output.numel()>>>>>>defgrid_fn(meta):>>>return(triton.cdiv(n_elements,meta["BLOCK_SIZE"]),)>>>>>>wrap_triton(add_kernel)[grid_fn](x,y,output,n_elements,16)>>>returnoutput>>>>>>x=torch.randn(3,device="cuda")>>>y=torch.randn(3,device="cuda")>>>gm=make_fx(add)(x,y)>>>print(gm.code)>>># def forward(self, x_1, y_1):>>># empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False)>>># triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation(>>># kernel_idx = 0, constant_args_idx = 0,>>># grid = [(1, 1, 1)], kwargs = {>>># 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like,>>># 'n_elements': 3, 'BLOCK_SIZE': 16>>># })>>># return empty_like
- Return type
Extending custom ops (created from Python or C++)#
Use theregister.* methods, such astorch.library.register_kernel() andtorch.library.register_fake(), to add implementationsfor any operators (they may have been created usingtorch.library.custom_op() orvia PyTorch’s C++ operator registration APIs).
- torch.library.register_kernel(op,device_types,func=None,/,*,lib=None)[source]#
Register an implementation for a device type for this operator.
Some valid device_types are: “cpu”, “cuda”, “xla”, “mps”, “ipu”, “xpu”.This API may be used as a decorator.
- Parameters
op (str |OpOverload) – The operator to register an impl to.
device_types (None |str |Sequence[str]) – The device_types to register an impl to.If None, we will register to all device types – please only usethis option if your implementation is truly device-type-agnostic.
func (Callable) – The function to register as the implementation forthe given device types.
lib (Optional[Library]) – If provided, the lifetime of this registration
- Examples::
>>>importtorch>>>fromtorchimportTensor>>>fromtorch.libraryimportcustom_op>>>importnumpyasnp>>>>>># Create a custom op that works on cpu>>>@custom_op("mylib::numpy_sin",mutates_args=(),device_types="cpu")>>>defnumpy_sin(x:Tensor)->Tensor:>>>x_np=x.numpy()>>>y_np=np.sin(x_np)>>>returntorch.from_numpy(y_np)>>>>>># Add implementations for the cuda device>>>@torch.library.register_kernel("mylib::numpy_sin","cuda")>>>def_(x):>>>x_np=x.cpu().numpy()>>>y_np=np.sin(x_np)>>>returntorch.from_numpy(y_np).to(device=x.device)>>>>>>x_cpu=torch.randn(3)>>>x_cuda=x_cpu.cuda()>>>asserttorch.allclose(numpy_sin(x_cpu),x_cpu.sin())>>>asserttorch.allclose(numpy_sin(x_cuda),x_cuda.sin())
- torch.library.register_autocast(op,device_type,cast_inputs,/,*,lib=None)[source]#
Register an autocast dispatch rule for this custom op.
Validdevice_type include: “cpu” and “cuda”.
- Parameters
op (str |OpOverload) – The operator to register an autocast dispatch rule to.
device_type (str) – Device type to use. ‘cuda’ or ‘cpu’.The type is the same as thetype attribute of a
torch.device.Thus, you may obtain the device type of a tensor usingTensor.device.type.cast_inputs (
torch.dtype) – When custom op runs in an autocast-enabled region,casts incoming floating-point Tensors to the target dtype (non-floating-point Tensorsare not affected), then executes custom op with autocast disabled.lib (Optional[Library]) – If provided, the lifetime of this registration
- Examples::
>>>importtorch>>>fromtorchimportTensor>>>fromtorch.libraryimportcustom_op>>>>>># Create a custom op that works on cuda>>>@torch.library.custom_op("mylib::my_sin",mutates_args=())>>>defmy_sin(x:Tensor)->Tensor:>>>returntorch.sin(x)>>>>>># Register autocast dispatch rule for the cuda device>>>torch.library.register_autocast("mylib::my_sin","cuda",torch.float16)>>>>>>x=torch.randn(3,dtype=torch.float32,device="cuda")>>>withtorch.autocast("cuda",dtype=torch.float16):>>>y=torch.ops.mylib.my_sin(x)>>>asserty.dtype==torch.float16
- torch.library.register_autograd(op,backward,/,*,setup_context=None,lib=None)[source]#
Register a backward formula for this custom op.
In order for an operator to work with autograd, you need to registera backward formula:1. You must tell us how to compute gradients during the backward passby providing us a “backward” function.2. If you need any values from the forward to compute gradients, you canusesetup_context to save values for backward.
backwardruns during the backward pass. It accepts(ctx,*grads):-gradsis one or more gradients. The number of gradients matchesthe number of outputs of the operator.Thectxobject isthe same ctx object used bytorch.autograd.Function. The semantics ofbackward_fnare thesame astorch.autograd.Function.backward().setup_context(ctx,inputs,output)runs during the forward pass.Please save quantities needed for backward onto thectxobject viaeithertorch.autograd.function.FunctionCtx.save_for_backward()or assigning them as attributes ofctx. If your custom op haskwarg-only arguments, we expect the signature ofsetup_contextto besetup_context(ctx,inputs,keyword_only_inputs,output).Both
setup_context_fnandbackward_fnmust be traceable. That is,they may not directly accesstorch.Tensor.data_ptr()and they mustnot depend on or mutate global state. If you need a non-traceable backward,you can make it a separate custom_op that you call insidebackward_fn.If you need different autograd behavior on different devices, then werecommend creating two different custom operators, one for each devicethat needs different behavior, and switching between them at runtime.
Examples
>>>importtorch>>>importnumpyasnp>>>fromtorchimportTensor>>>>>>@torch.library.custom_op("mylib::numpy_sin",mutates_args=())>>>defnumpy_sin(x:Tensor)->Tensor:>>>x_np=x.cpu().numpy()>>>y_np=np.sin(x_np)>>>returntorch.from_numpy(y_np).to(device=x.device)>>>>>>defsetup_context(ctx,inputs,output)->Tensor:>>>x,=inputs>>>ctx.save_for_backward(x)>>>>>>defbackward(ctx,grad):>>>x,=ctx.saved_tensors>>>returngrad*x.cos()>>>>>>torch.library.register_autograd(..."mylib::numpy_sin",backward,setup_context=setup_context...)>>>>>>x=torch.randn(3,requires_grad=True)>>>y=numpy_sin(x)>>>(grad_x,)=torch.autograd.grad(y,x,torch.ones_like(y))>>>asserttorch.allclose(grad_x,x.cos())>>>>>># Example with a keyword-only arg>>>@torch.library.custom_op("mylib::numpy_mul",mutates_args=())>>>defnumpy_mul(x:Tensor,*,val:float)->Tensor:>>>x_np=x.cpu().numpy()>>>y_np=x_np*val>>>returntorch.from_numpy(y_np).to(device=x.device)>>>>>>defsetup_context(ctx,inputs,keyword_only_inputs,output)->Tensor:>>>ctx.val=keyword_only_inputs["val"]>>>>>>defbackward(ctx,grad):>>>returngrad*ctx.val>>>>>>torch.library.register_autograd(..."mylib::numpy_mul",backward,setup_context=setup_context...)>>>>>>x=torch.randn(3,requires_grad=True)>>>y=numpy_mul(x,val=3.14)>>>(grad_x,)=torch.autograd.grad(y,x,torch.ones_like(y))>>>asserttorch.allclose(grad_x,torch.full_like(x,3.14))
- torch.library.register_fake(op,func=None,/,*,lib=None,_stacklevel=1,allow_override=False)[source]#
Register a FakeTensor implementation (“fake impl”) for this operator.
Also sometimes known as a “meta kernel”, “abstract impl”.
An “FakeTensor implementation” specifies the behavior of this operator onTensors that carry no data (“FakeTensor”). Given some input Tensors withcertain properties (sizes/strides/storage_offset/device), it specifieswhat the properties of the output Tensors are.
The FakeTensor implementation has the same signature as the operator.It is run for both FakeTensors and meta tensors. To write a FakeTensorimplementation, assume that all Tensor inputs to the operator areregular CPU/CUDA/Meta tensors, but they do not have storage, andyou are trying to return regular CPU/CUDA/Meta tensor(s) as output.The FakeTensor implementation must consist of only PyTorch operations(and may not directly access the storage or data of any input orintermediate Tensors).
This API may be used as a decorator (see examples).
For a detailed guide on custom ops, please seehttps://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
- Parameters
op_name – Operator name (along with the overload) or OpOverload object.
lib (Optional[Library]) – Library to register the fake tensor to.
allow_override (bool) – Flag controlling if we want to override anexisting registered fake impl. This is by default off,and will error you’re trying to register a fake impl toan operator that already has a fake impl. This also onlyapplies if the custom operator was not created viatorch.library.custom_op, as overriding and existing fakeimpl is already allowed.
Examples
>>>importtorch>>>importnumpyasnp>>>fromtorchimportTensor>>>>>># Example 1: an operator without data-dependent output shape>>>@torch.library.custom_op("mylib::custom_linear",mutates_args=())>>>defcustom_linear(x:Tensor,weight:Tensor,bias:Tensor)->Tensor:>>>raiseNotImplementedError("Implementation goes here")>>>>>>@torch.library.register_fake("mylib::custom_linear")>>>def_(x,weight,bias):>>>assertx.dim()==2>>>assertweight.dim()==2>>>assertbias.dim()==1>>>assertx.shape[1]==weight.shape[1]>>>assertweight.shape[0]==bias.shape[0]>>>assertx.device==weight.device>>>>>>return(x@weight.t())+bias>>>>>>withtorch._subclasses.fake_tensor.FakeTensorMode():>>>x=torch.randn(2,3)>>>w=torch.randn(3,3)>>>b=torch.randn(3)>>>y=torch.ops.mylib.custom_linear(x,w,b)>>>>>>asserty.shape==(2,3)>>>>>># Example 2: an operator with data-dependent output shape>>>@torch.library.custom_op("mylib::custom_nonzero",mutates_args=())>>>defcustom_nonzero(x:Tensor)->Tensor:>>>x_np=x.numpy(force=True)>>>res=np.stack(np.nonzero(x_np),axis=1)>>>returntorch.tensor(res,device=x.device)>>>>>>@torch.library.register_fake("mylib::custom_nonzero")>>>def_(x):>>># Number of nonzero-elements is data-dependent.>>># Since we cannot peek at the data in an fake impl,>>># we use the ctx object to construct a new symint that>>># represents the data-dependent size.>>>ctx=torch.library.get_ctx()>>>nnz=ctx.new_dynamic_size()>>>shape=[nnz,x.dim()]>>>result=x.new_empty(shape,dtype=torch.int64)>>>returnresult>>>>>>fromtorch.fx.experimental.proxy_tensorimportmake_fx>>>>>>x=torch.tensor([0,1,2,3,4,0])>>>trace=make_fx(torch.ops.mylib.custom_nonzero,tracing_mode="symbolic")(x)>>>trace.print_readable()>>>>>>asserttorch.allclose(trace(x),torch.ops.mylib.custom_nonzero(x))
- torch.library.register_vmap(op,func=None,/,*,lib=None)[source]#
Register a vmap implementation to support
torch.vmap()for this custom op.This API may be used as a decorator (see examples).
In order for an operator to work with
torch.vmap(), you may need to register avmap implementation in the following signature:vmap_func(info,in_dims:Tuple[Optional[int]],*args,**kwargs),where
*argsand**kwargsare the arguments and kwargs forop.We do not support kwarg-only Tensor args.It specifies how do we compute the batched version of
opgiven inputs with an additionaldimension (specified byin_dims).For each arg in
args,in_dimshas a correspondingOptional[int]. It isNoneif the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integerspecifying what dimension of the Tensor is being vmapped over.infois a collection of additional metadata that may be helpful:info.batch_sizespecifies the size of the dimension being vmapped over, whileinfo.randomnessis therandomnessoption that was passed totorch.vmap().The return of the function
funcis a tuple of(output,out_dims). Similar toin_dims,out_dimsshould be of the same structure asoutputand contain oneout_dimper output that specifies if the output has the vmapped dimension and what index it is in.Examples
>>>importtorch>>>importnumpyasnp>>>fromtorchimportTensor>>>fromtypingimportTuple>>>>>>defto_numpy(tensor):>>>returntensor.cpu().numpy()>>>>>>lib=torch.library.Library("mylib","FRAGMENT")>>>@torch.library.custom_op("mylib::numpy_cube",mutates_args=())>>>defnumpy_cube(x:Tensor)->Tuple[Tensor,Tensor]:>>>x_np=to_numpy(x)>>>dx=torch.tensor(3*x_np**2,device=x.device)>>>returntorch.tensor(x_np**3,device=x.device),dx>>>>>>defnumpy_cube_vmap(info,in_dims,x):>>>result=numpy_cube(x)>>>returnresult,(in_dims[0],in_dims[0])>>>>>>torch.library.register_vmap(numpy_cube,numpy_cube_vmap)>>>>>>x=torch.randn(3)>>>torch.vmap(numpy_cube)(x)>>>>>>@torch.library.custom_op("mylib::numpy_mul",mutates_args=())>>>defnumpy_mul(x:Tensor,y:Tensor)->Tensor:>>>returntorch.tensor(to_numpy(x)*to_numpy(y),device=x.device)>>>>>>@torch.library.register_vmap("mylib::numpy_mul")>>>defnumpy_mul_vmap(info,in_dims,x,y):>>>x_bdim,y_bdim=in_dims>>>x=x.movedim(x_bdim,-1)ifx_bdimisnotNoneelsex.unsqueeze(-1)>>>y=y.movedim(y_bdim,-1)ify_bdimisnotNoneelsey.unsqueeze(-1)>>>result=x*y>>>result=result.movedim(-1,0)>>>returnresult,0>>>>>>>>>x=torch.randn(3)>>>y=torch.randn(3)>>>torch.vmap(numpy_mul)(x,y)
Note
The vmap function should aim to preserve the semantics of the entire custom operator.That is,
grad(vmap(op))should be replaceable with agrad(map(op)).If your custom operator has any custom behavior in the backward pass, pleasekeep this in mind.
- torch.library.impl_abstract(qualname,func=None,*,lib=None,_stacklevel=1)[source]#
This API was renamed to
torch.library.register_fake()in PyTorch 2.4.Please use that instead.
- torch.library.get_ctx()[source]#
get_ctx() returns the current AbstractImplCtx object.
Calling
get_ctx()is only valid inside of an fake impl(seetorch.library.register_fake()for more usage details.- Return type
FakeImplCtx
- torch.library.register_torch_dispatch(op,torch_dispatch_class,func=None,/,*,lib=None)[source]#
Registers a torch_dispatch rule for the given operator and
torch_dispatch_class.This allows for open registration to specify the behavior between the operatorand the
torch_dispatch_classwithout needing to modify thetorch_dispatch_classor the operator directly.The
torch_dispatch_classis either a Tensor subclass with__torch_dispatch__or aTorchDispatchMode.If it is a Tensor subclass, we expect
functo have the following signature:(cls,func:OpOverload,types:Tuple[type,...],args,kwargs)->AnyIf it is a TorchDispatchMode, we expect
functo have the following signature:(mode,func:OpOverload,types:Tuple[type,...],args,kwargs)->Anyargsandkwargswill have been normalized the same way they arein__torch_dispatch__(see__torch_dispatch__ calling convention).Examples
>>>importtorch>>>>>>@torch.library.custom_op("mylib::foo",mutates_args={})>>>deffoo(x:torch.Tensor)->torch.Tensor:>>>returnx.clone()>>>>>>classMyMode(torch.utils._python_dispatch.TorchDispatchMode):>>>def__torch_dispatch__(self,func,types,args=(),kwargs=None):>>>returnfunc(*args,**kwargs)>>>>>>@torch.library.register_torch_dispatch("mylib::foo",MyMode)>>>def_(mode,func,types,args,kwargs):>>>x,=args>>>returnx+1>>>>>>x=torch.randn(3)>>>y=foo(x)>>>asserttorch.allclose(y,x)>>>>>>withMyMode():>>>y=foo(x)>>>asserttorch.allclose(y,x+1)
- torch.library.infer_schema(prototype_function,/,*,mutates_args,op_name=None)[source]#
Parses the schema of a given function with type hints. The schema is inferred from thefunction’s type hints, and can be used to define a new operator.
We make the following assumptions:
None of the outputs alias any of the inputs or each other.
- String type annotations “device, dtype, Tensor, types” without library specification areassumed to be torch.*. Similarly, string type annotations “Optional, List, Sequence, Union”without library specification are assumed to be typing.*.
- Only the args listed in
mutates_argsare being mutated. Ifmutates_argsis “unknown”,it assumes that all inputs to the operator are being mutates.
Callers (e.g. the custom ops API) are responsible for checking these assumptions.
- Parameters
prototype_function (Callable) – The function from which to infer a schema for from its type annotations.
op_name (Optional[str]) – The name of the operator in the schema. If
nameis None, then thename is not included in the inferred schema. Note that the input schema totorch.library.Library.definerequires a operator name.mutates_args ("unknown" |Iterable[str]) – The arguments that are mutated in the function.
- Returns
The inferred schema.
- Return type
Example
>>>deffoo_impl(x:torch.Tensor)->torch.Tensor:>>>returnx.sin()>>>>>>infer_schema(foo_impl,op_name="foo",mutates_args={})foo(Tensor x) -> Tensor>>>>>>infer_schema(foo_impl,mutates_args={})(Tensor x) -> Tensor
- classtorch._library.custom_ops.CustomOpDef(namespace,name,schema,fn,tags=None)[source]#
CustomOpDef is a wrapper around a function that turns it into a custom op.
It has various methods for registering additional behavior for thiscustom op.
You should not instantiate CustomOpDef directly; instead, use the
torch.library.custom_op()API.- set_kernel_enabled(device_type,enabled=True)[source]#
Disable or re-enable an already registered kernel for this custom operator.
If the kernel is already disabled/enabled, this is a no-op.
Note
If a kernel is first disabled and then registered, it is disabled until enabled again.
- Parameters
Example
>>>inp=torch.randn(1)>>>>>># define custom op `f`.>>>@custom_op("mylib::f",mutates_args=())>>>deff(x:Tensor)->Tensor:>>>returntorch.zeros(1)>>>>>>print(f(inp))# tensor([0.]), default kernel>>>>>>@f.register_kernel("cpu")>>>def_(x):>>>returntorch.ones(1)>>>>>>print(f(inp))# tensor([1.]), CPU kernel>>>>>># temporarily disable the CPU kernel>>>withf.set_kernel_enabled("cpu",enabled=False):>>>print(f(inp))# tensor([0.]) with CPU kernel disabled
- torch.library.get_kernel(op,dispatch_key)[source]#
Returns the computed kernel for a given operator and dispatch key.
This function retrieves the kernel that would be executed for a givenoperator and dispatch key combination. The returned SafeKernelFunctioncan be used to call the kernel in a boxed fashion. The intended usecase for this function is to retrieve the original kernel for a givendispatch key and then register another kernel to the same dispatch keythat calls into the original kernel for certain cases.
- Parameters
op (Union[str,OpOverload,CustomOpDef]) – Operator name (along with the overload) or OpOverload objectCan be a string (e.g., “aten::add.Tensor”), an OpOverload, or a CustomOpDef.
dispatch_key (str |torch.DispatchKey) – The dispatch key to get the kernel for.Can be a string (e.g., “CPU”, “CUDA”) or a DispatchKey enum value.
- Returns
- A safe kernel function that can be used to
call the kernel.
- Return type
torch._C._SafeKernelFunction
- Raises
RuntimeError – If the operator does not exist.
Example
>>> # Get the CPU kernel for torch.add>>> kernel = torch.library.get_kernel("aten::add.Tensor", "CPU")>>>>>> # You can also use DispatchKey enum>>> kernel = torch.library.get_kernel("aten::add.Tensor", torch.DispatchKey.CPU)>>>>>> # Or use an OpOverload directly>>> kernel = torch.library.get_kernel(torch.ops.aten.add.Tensor, "CPU")>>>>>> # Example: Using get_kernel in a custom op with conditional dispatch>>> # Get the original kernel for torch.sin>>> original_sin_kernel = torch.library.get_kernel("aten::sin", "CPU")>>>>>> # If input has negative values, use original sin, otherwise return zeros>>> def conditional_sin_impl(dispatch_keys, x):>>> if (x < 0).any():>>> return original_sin_kernel.call_boxed(dispatch_keys, x)>>> else:>>> return torch.zeros_like(x)>>>>>> lib = torch.library.Library("aten", "IMPL")>>> # with_keyset=True so the first argument to the impl is the current DispatchKeySet>>> which needs to be the first argument to ``kernel.call_boxed``>>> lib.impl("sin", conditional_sin_impl, "CPU", with_keyset=True)>>>>>> # Test the conditional behavior>>> x_positive = torch.tensor([1.0, 2.0])>>> x_mixed = torch.tensor([-1.0, 2.0])>>> torch.sin(x_positive)tensor([0., 0.])>>> torch.sin(x_mixed)tensor([-0.8415, 0.9093])
Low-level APIs#
The following APIs are direct bindings to PyTorch’s C++ low-leveloperator registration APIs.
Warning
The low-level operator registration APIs and the PyTorch Dispatcher are a complicated PyTorch concept. We recommend you use the higher level APIs above (that do not require a torch.library.Library object) when possible.This blog post is a good starting point to learn about the PyTorch Dispatcher.
A tutorial that walks you through some examples on how to use this API is available onGoogle Colab.
- classtorch.library.Library(ns,kind,dispatch_key='')[source]#
A class to create libraries that can be used to register new operators oroverride operators in existing libraries from Python.A user can optionally pass in a dispatch keyname if they only want to registerkernels corresponding to only one specific dispatch key.
To create a library to override operators in an existing library (with name ns), set the kind to “IMPL”.To create a new library (with name ns) to register new operators, set the kind to “DEF”.To create a fragment of a possibly existing library to register operators (and bypassthe limitation that there is only one library for a given namespace), set the kind to“FRAGMENT”.
- Parameters
ns – library name
kind – “DEF”, “IMPL”, “FRAGMENT”
dispatch_key – PyTorch dispatch key (default: “”)
- define(schema,alias_analysis='',*,tags=())[source]#
Defines a new operator and its semantics in the ns namespace.
- Parameters
schema – function schema to define a new operator.
alias_analysis (optional) – Indicates if the aliasing properties of the operator arguments can beinferred from the schema (default behavior) or not (“CONSERVATIVE”).
tags (Tag |Sequence[Tag]) – one or more torch.Tag to apply to thisoperator. Tagging an operator changes the operator’s behaviorunder various PyTorch subsystems; please read the docs for thetorch.Tag carefully before applying it.
- Returns
name of the operator as inferred from the schema.
Example:
>>>my_lib=Library("mylib","DEF")>>>my_lib.define("sum(Tensor self) -> Tensor")
- fallback(fn,dispatch_key='',*,with_keyset=False)[source]#
Registers the function implementation as the fallback for the given key.
This function only works for a library with global namespace (“_”).
- Parameters
fn – function used as fallback for the given dispatch key or
fallthrough_kernel()to register a fallthrough.dispatch_key – dispatch key that the input function should be registered for. By default, it usesthe dispatch key that the library was created with.
with_keyset – flag controlling if the current dispatcher call keyset should be passed as the first argumentto
fnwhen calling. This should be used to create the appropriate keyset for redispatch calls.
Example:
>>>my_lib=Library("_","IMPL")>>>deffallback_kernel(op,*args,**kwargs):>>># Handle all autocast ops generically>>># ...>>>my_lib.fallback(fallback_kernel,"Autocast")
- impl(op_name,fn,dispatch_key='',*,with_keyset=False,allow_override=False)[source]#
Registers the function implementation for an operator defined in the library.
- Parameters
op_name – operator name (along with the overload) or OpOverload object.
fn – function that’s the operator implementation for the input dispatch key or
fallthrough_kernel()to register a fallthrough.dispatch_key – dispatch key that the input function should be registered for. By default, it usesthe dispatch key that the library was created with.
with_keyset – flag controlling if the current dispatcher call keyset should be passed as the first argumentto
fnwhen calling. This should be used to create the appropriate keyset for redispatch calls.allow_override – Flag controlling if we want to override anexisting registered kernel implementation. This is bydefault off, and will error you’re trying to register akernel to a dispatch key with a kernel alreadyregistered.
Example:
>>>my_lib=Library("aten","IMPL")>>>defdiv_cpu(self,other):>>>returnself*(1/other)>>>my_lib.impl("div.Tensor",div_cpu,"CPU")
- torch.library.fallthrough_kernel()[source]#
A dummy function to pass to
Library.implin order to register a fallthrough.
- torch.library.define(qualname,schema,*,lib=None,tags=())[source]#
- torch.library.define(lib,schema,alias_analysis='')
Defines a new operator.
In PyTorch, defining an op (short for “operator”) is a two step-process:- we need to define the op (by providing an operator name and schema)- we need to implement behavior for how the operator interacts withvarious PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
This entrypoint defines the custom operator (the first step)you must then perform the second step by calling various
impl_*APIs, liketorch.library.impl()ortorch.library.register_fake().- Parameters
qualname (str) – The qualified name for the operator. Should bea string that looks like “namespace::name”, e.g. “aten::sin”.Operators in PyTorch need a namespace toavoid name collisions; a given operator may only be created once.If you are writing a Python library, we recommend the namespace tobe the name of your top-level module.
schema (str) – The schema of the operator. E.g. “(Tensor x) -> Tensor”for an op that accepts one Tensor and returns one Tensor. It doesnot contain the operator name (that is passed in
qualname).lib (Optional[Library]) – If provided, the lifetime of this operatorwill be tied to the lifetime of the Library object.
tags (Tag |Sequence[Tag]) – one or more torch.Tag to apply to thisoperator. Tagging an operator changes the operator’s behaviorunder various PyTorch subsystems; please read the docs for thetorch.Tag carefully before applying it.
- Example::
>>>importtorch>>>importnumpyasnp>>>>>># Define the operator>>>torch.library.define("mylib::sin","(Tensor x) -> Tensor")>>>>>># Add implementations for the operator>>>@torch.library.impl("mylib::sin","cpu")>>>deff(x):>>>returntorch.from_numpy(np.sin(x.numpy()))>>>>>># Call the new operator from torch.ops.>>>x=torch.randn(3)>>>y=torch.ops.mylib.sin(x)>>>asserttorch.allclose(y,x.sin())
- torch.library.impl(lib,name,dispatch_key='')[source]#
- torch.library.impl(qualname:str,types:Union[str,Sequence[str]],func:Literal[None]=None,*,lib:Optional[Library]=None)→Callable[[Callable[...,object]],None]
- torch.library.impl(qualname:str,types:Union[str,Sequence[str]],func:Callable[...,object],*,lib:Optional[Library]=None)→None
- torch.library.impl(lib:Library,name:str,dispatch_key:str='')→Callable[[Callable[_P,_T]],Callable[_P,_T]]
Register an implementation for a device type for this operator.
You may pass “default” for
typesto register this implementation as thedefault implementation for ALL device types.Please only use this if the implementation truly supports all device types;for example, this is true if it is a composition of built-in PyTorch operators.This API may be used as a decorator. You can use nested decoratorswith this API provided they return a function and are placed insidethis API (see Example 2).
Some valid types are: “cpu”, “cuda”, “xla”, “mps”, “ipu”, “xpu”.
- Parameters
Examples
>>>importtorch>>>importnumpyasnp>>># Example 1: Register function.>>># Define the operator>>>torch.library.define("mylib::mysin","(Tensor x) -> Tensor")>>>>>># Add implementations for the cpu device>>>@torch.library.impl("mylib::mysin","cpu")>>>deff(x):>>>returntorch.from_numpy(np.sin(x.numpy()))>>>>>>x=torch.randn(3)>>>y=torch.ops.mylib.mysin(x)>>>asserttorch.allclose(y,x.sin())>>>>>># Example 2: Register function with decorator.>>>defcustom_decorator(func):>>>defwrapper(*args,**kwargs):>>>returnfunc(*args,**kwargs)+1>>>returnwrapper>>>>>># Define the operator>>>torch.library.define("mylib::sin_plus_one","(Tensor x) -> Tensor")>>>>>># Add implementations for the operator>>>@torch.library.impl("mylib::sin_plus_one","cpu")>>>@custom_decorator>>>deff(x):>>>returntorch.from_numpy(np.sin(x.numpy()))>>>>>># Call the new operator from torch.ops.>>>x=torch.randn(3)>>>>>>y1=torch.ops.mylib.sin_plus_one(x)>>>y2=torch.sin(x)+1>>>asserttorch.allclose(y1,y2)