Tensor Basics#
The ATen tensor library backing PyTorch is a simple tensor library that exposesthe Tensor operations in Torch directly in C++17. ATen’s API is auto-generatedfrom the same declarations PyTorch uses so the two APIs will track each otherover time.
Tensor types are resolved dynamically, such that the API is generic and does notinclude templates. That is, there is oneTensor type. It can hold a CPU orCUDA Tensor, and the tensor may have Doubles, Float, Ints, etc. This designmakes it easy to write generic code without templating everything.
Seehttps://pytorch.org/cppdocs/api/namespace_at.html#functions for the providedAPI. Excerpt:
Tensoratan2(constTensor&other)const;Tensor&atan2_(constTensor&other);Tensorpow(Scalarexponent)const;Tensorpow(constTensor&exponent)const;Tensor&pow_(Scalarexponent);Tensor&pow_(constTensor&exponent);Tensorlerp(constTensor&end,Scalarweight)const;Tensor&lerp_(constTensor&end,Scalarweight);Tensorhistc()const;Tensorhistc(int64_tbins)const;Tensorhistc(int64_tbins,Scalarmin)const;Tensorhistc(int64_tbins,Scalarmin,Scalarmax)const;
In place operations are also provided, and always suffixed by_ to indicatethey will modify the Tensor.
Efficient Access to Tensor Elements#
When using Tensor-wide operations, the relative cost of dynamic dispatch is verysmall. However, there are cases, especially in your own kernels, where efficientelement-wise access is needed, and the cost of dynamic dispatch inside theelement-wise loop is very high. ATen providesaccessors that are created witha single dynamic check that a Tensor is the type and number of dimensions.Accessors then expose an API for accessing the Tensor elements efficiently.
Accessors are temporary views of a Tensor. They are only valid for the lifetimeof the tensor that they view and hence should only be used locally in afunction, like iterators.
Note that accessors are not compatible with CUDA tensors inside kernel functions.Instead, you will have to use apacked accessor which behaves the same way butcopies tensor metadata instead of pointing to it.
It is thus recommended to useaccessors for CPU tensors andpacked accessorsfor CUDA tensors.
CPU accessors#
torch::Tensorfoo=torch::rand({12,12});// assert foo is 2-dimensional and holds floats.autofoo_a=foo.accessor<float,2>();floattrace=0;for(inti=0;i<foo_a.size(0);i++){// use the accessor foo_a to get tensor data.trace+=foo_a[i][i];}
CUDA accessors#
__global__voidpacked_accessor_kernel(torch::PackedTensorAccessor64<float,2>foo,float*trace){inti=threadIdx.x;gpuAtomicAdd(trace,foo[i][i]);}torch::Tensorfoo=torch::rand({12,12});// assert foo is 2-dimensional and holds floats.autofoo_a=foo.packed_accessor64<float,2>();floattrace=0;packed_accessor_kernel<<<1,12>>>(foo_a,&trace);
In addition toPackedTensorAccessor64 andpacked_accessor64 there arealso the correspondingPackedTensorAccessor32 andpacked_accessor32which use 32-bit integers for indexing. This can be quite a bit faster on CUDAbut may lead to overflows in the indexing calculations.
Note that the template can hold other parameters such as the pointer restrictionand the integer type for indexing. See documentation for a thorough templatedescription ofaccessors andpacked accessors.
Using Externally Created Data#
If you already have your tensor data allocated in memory (CPU or CUDA),you can view that memory as aTensor in ATen:
floatdata[]={1,2,3,4,5,6};torch::Tensorf=torch::from_blob(data,{2,3});
These tensors cannot be resized because ATen does not own the memory, butotherwise behave as normal tensors.
Scalars and zero-dimensional tensors#
In addition to theTensor objects, ATen also includesScalars thatrepresent a single number. Like a Tensor, Scalars are dynamically typed and canhold any one of ATen’s number types. Scalars can be implicitly constructed fromC++ number types. Scalars are needed because some functions likeaddmm takenumbers along with Tensors and expect these numbers to be the same dynamic typeas the tensor. They are also used in the API to indicate places where a functionwillalways return a Scalar value, likesum.
namespacetorch{Tensoraddmm(Scalarbeta,constTensor&self,Scalaralpha,constTensor&mat1,constTensor&mat2);Scalarsum(constTensor&self);}// namespace torch// Usage.torch::Tensora=...torch::Tensorb=...torch::Tensorc=...torch::Tensorr=torch::addmm(1.0,a,.5,b,c);
In addition toScalars, ATen also allowsTensor objects to bezero-dimensional. These Tensors hold a single value and they can be referencesto a single element in a largerTensor. They can be used anywhere aTensor is expected. They are normally created by operators likeselectwhich reduce the dimensions of aTensor.
torch::Tensortwo=torch::rand({10,20});two[1][2]=4;// ^^^^^^ <- zero-dimensional Tensor