Rate this Page

Tensor Attributes#

Created On: Apr 21, 2018 | Last Updated On: Jun 27, 2025

Eachtorch.Tensor has atorch.dtype,torch.device, andtorch.layout.

torch.dtype#

classtorch.dtype#

Atorch.dtype is an object that represents the data type of atorch.Tensor. PyTorch has several different data types:

Floating point dtypes

dtype

description

torch.float32 ortorch.float

32-bit floating point, as defined inhttps://en.wikipedia.org/wiki/IEEE_754

torch.float64 ortorch.double

64-bit floating point, as defined inhttps://en.wikipedia.org/wiki/IEEE_754

torch.float16 ortorch.half

16-bit floating point, as defined inhttps://en.wikipedia.org/wiki/IEEE_754, S-E-M 1-5-10

torch.bfloat16

16-bit floating point, sometimes referred to as Brain floating point, S-E-M 1-8-7

torch.complex32 ortorch.chalf

32-bit complex with twofloat16 components

torch.complex64 ortorch.cfloat

64-bit complex with twofloat32 components

torch.complex128 ortorch.cdouble

128-bit complex with twofloat64 components

torch.float8_e4m3fn[shell],[1]

8-bit floating point, S-E-M 1-4-3, fromhttps://arxiv.org/abs/2209.05433

torch.float8_e5m2[shell]

8-bit floating point, S-E-M 1-5-2, fromhttps://arxiv.org/abs/2209.05433

torch.float8_e4m3fnuz[shell],[1]

8-bit floating point, S-E-M 1-4-3, fromhttps://arxiv.org/pdf/2206.02915

torch.float8_e5m2fnuz[shell],[1]

8-bit floating point, S-E-M 1-5-2, fromhttps://arxiv.org/pdf/2206.02915

torch.float8_e8m0fnu[shell],[1]

8-bit floating point, S-E-M 0-8-0, fromhttps://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

torch.float4_e2m1fn_x2[shell],[1]

packed 4-bit floating point, S-E-M 1-2-1, fromhttps://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

Integer dtypes

dtype

description

torch.uint8

8-bit integer (unsigned)

torch.int8

8-bit integer (signed)

torch.uint16[shell],[2]

16-bit integer (unsigned)

torch.int16 ortorch.short

16-bit integer (signed)

torch.uint32[shell],[2]

32-bit integer (unsigned)

torch.int32 ortorch.int

32-bit integer (signed)

torch.uint64[shell],[2]

64-bit integer (unsigned)

torch.int64 ortorch.long

64-bit integer (signed)

torch.bool

Boolean

[shell](1,2,3,4,5,6,7,8,9)

a shell dtype a specialized dtype with limited op and backend support.Specifically, ops that support tensor creation (torch.empty,torch.fill,torch.zeros)and operations which do not peek inside the data elements (torch.cat,torch.view,torch.reshape)are supported. Ops that peek inside the data elements such as casting,matrix multiplication, nan/inf checks are supported only on a case bycase basis, depending on maturity and presence of hardware accelerated kernelsand established use cases.

[1](1,2,3,4,5)

The “fn”, “fnu” and “fnuz” dtype suffixes mean:“f” - finite value encodings only, no infinity;“n” - nan value encodings differ from the IEEE spec;“uz” - “unsigned zero” only, i.e. no negative zero encoding

[2](1,2,3)

Unsigned types asides fromuint8 are currently planned to only havelimited support in eager mode (they primarily exist to assist usage withtorch.compile); if you need eager support and the extra range is not needed,we recommend using their signed variants instead. Seepytorch/pytorch#58734 for more details.

Note: legacy constructors such astorch.*.FloatTensor,torch.*.DoubleTensor,torch.*.HalfTensor,torch.*.BFloat16Tensor,torch.*.ByteTensor,torch.*.CharTensor,torch.*.ShortTensor,torch.*.IntTensor,torch.*.LongTensor,torch.*.BoolTensor only remain for backwards compatibility and should no longer be used.

To find out if atorch.dtype is a floating point data type, the propertyis_floating_pointcan be used, which returnsTrue if the data type is a floating point data type.

To find out if atorch.dtype is a complex data type, the propertyis_complexcan be used, which returnsTrue if the data type is a complex data type.

When the dtypes of inputs to an arithmetic operation (add,sub,div,mul) differ, we promoteby finding the minimum dtype that satisfies the following rules:

  • If the type of a scalar operand is of a higher category than tensor operands(where complex > floating > integral > boolean), we promote to a type with sufficient size to holdall scalar operands of that category.

  • If a zero-dimension tensor operand has a higher category than dimensioned operands,we promote to a type with sufficient size and category to hold all zero-dim tensor operands ofthat category.

  • If there are no higher-category zero-dim operands, we promote to a type with sufficient sizeand category to hold all dimensioned operands.

A floating point scalar operand has dtypetorch.get_default_dtype() and an integralnon-boolean scalar operand has dtypetorch.int64. Unlike numpy, we do not inspectvalues when determining the minimumdtypes of an operand. Complex typesare not yet supported. Promotion for shell dtypes is not defined.

Promotion Examples:

>>>float_tensor=torch.ones(1,dtype=torch.float)>>>double_tensor=torch.ones(1,dtype=torch.double)>>>complex_float_tensor=torch.ones(1,dtype=torch.complex64)>>>complex_double_tensor=torch.ones(1,dtype=torch.complex128)>>>int_tensor=torch.ones(1,dtype=torch.int)>>>long_tensor=torch.ones(1,dtype=torch.long)>>>uint_tensor=torch.ones(1,dtype=torch.uint8)>>>bool_tensor=torch.ones(1,dtype=torch.bool)# zero-dim tensors>>>long_zerodim=torch.tensor(1,dtype=torch.long)>>>int_zerodim=torch.tensor(1,dtype=torch.int)>>>torch.add(5,5).dtypetorch.int64# 5 is an int64, but does not have higher category than int_tensor so is not considered.>>>(int_tensor+5).dtypetorch.int32>>>(int_tensor+long_zerodim).dtypetorch.int32>>>(long_tensor+int_tensor).dtypetorch.int64>>>(bool_tensor+long_tensor).dtypetorch.int64>>>(bool_tensor+uint_tensor).dtypetorch.uint8>>>(float_tensor+double_tensor).dtypetorch.float64>>>(complex_float_tensor+complex_double_tensor).dtypetorch.complex128>>>(bool_tensor+int_tensor).dtypetorch.int32# Since long is a different kind than float, result dtype only needs to be large enough# to hold the float.>>>torch.add(long_tensor,float_tensor).dtypetorch.float32
When the output tensor of an arithmetic operation is specified, we allow casting to itsdtype except that:
  • An integral output tensor cannot accept a floating point tensor.

  • A boolean output tensor cannot accept a non-boolean tensor.

  • A non-complex output tensor cannot accept a complex tensor

Casting Examples:

# allowed:>>>float_tensor*=float_tensor>>>float_tensor*=int_tensor>>>float_tensor*=uint_tensor>>>float_tensor*=bool_tensor>>>float_tensor*=double_tensor>>>int_tensor*=long_tensor>>>int_tensor*=uint_tensor>>>uint_tensor*=int_tensor# disallowed (RuntimeError: result type can't be cast to the desired output type):>>>int_tensor*=float_tensor>>>bool_tensor*=int_tensor>>>bool_tensor*=uint_tensor>>>float_tensor*=complex_float_tensor

torch.device#

classtorch.device#

Atorch.device is an object representing the device on which atorch.Tensor isor will be allocated.

Thetorch.device contains a device type (most commonly “cpu” or“cuda”, but also potentially“mps”,“xpu”,“xla” or“meta”) and optionaldevice ordinal for the device type. If the device ordinal is not present, this object will always representthe current device for the device type, even aftertorch.cuda.set_device() is called; e.g.,atorch.Tensor constructed with device'cuda' is equivalent to'cuda:X' where X isthe result oftorch.cuda.current_device().

Atorch.Tensor’s device can be accessed via theTensor.device property.

Atorch.device can be constructed using:

  • A device string, which is a string representation of the device type and optionally the device ordinal.

  • A device type and a device ordinal.

  • A device ordinal, where the currentaccelerator type will be used.

Via a device string:

>>>torch.device('cuda:0')device(type='cuda', index=0)>>>torch.device('cpu')device(type='cpu')>>>torch.device('mps')device(type='mps')>>>torch.device('cuda')# implicit index is the "current device index"device(type='cuda')

Via a device type and a device ordinal:

>>>torch.device('cuda',0)device(type='cuda', index=0)>>>torch.device('mps',0)device(type='mps', index=0)>>>torch.device('cpu',0)device(type='cpu', index=0)

Via a device ordinal:

Note

This method will raise a RuntimeError if no accelerator is currently detected.

>>>torch.device(0)# the current accelerator is cudadevice(type='cuda', index=0)>>>torch.device(1)# the current accelerator is xpudevice(type='xpu', index=1)>>>torch.device(0)# no current accelerator detectedTraceback (most recent call last):  File"<stdin>", line1, in<module>RuntimeError:Cannot access accelerator device when none is available.

The device object can also be used as a context manager to change the defaultdevice tensors are allocated on:

>>>withtorch.device('cuda:1'):...r=torch.randn(2,3)>>>r.devicedevice(type='cuda', index=1)

This context manager has no effect if a factory function is passed an explicit,non-None device argument. To globally change the default device, see alsotorch.set_default_device().

Warning

This function imposes a slight performance cost on every Pythoncall to the torch API (not just factory functions). If thisis causing problems for you, please comment onpytorch/pytorch#92701

Note

Thetorch.device argument in functions can generally be substituted with a string.This allows for fast prototyping of code.

>>># Example of a function that takes in a torch.device>>>cuda1=torch.device('cuda:1')>>>torch.randn((2,3),device=cuda1)
>>># You can substitute the torch.device with a string>>>torch.randn((2,3),device='cuda:1')

Note

Methods which take a device will generally accept a (properly formatted) stringor an integer device ordinal, i.e. the following are all equivalent:

>>>torch.randn((2,3),device=torch.device('cuda:1'))>>>torch.randn((2,3),device='cuda:1')>>>torch.randn((2,3),device=1)# equivalent to 'cuda:1' if the current accelerator is cuda

Note

Tensors are never moved automatically between devices and require an explicit call from the user. Scalar Tensors (with tensor.dim()==0) are the only exception to this rule and they are automatically transferred from CPU to GPU when needed as this operation can be done “for free”.Example:

>>># two scalars>>>torch.ones(())+torch.ones(()).cuda()# OK, scalar auto-transferred from CPU to GPU>>>torch.ones(()).cuda()+torch.ones(())# OK, scalar auto-transferred from CPU to GPU
>>># one scalar (CPU), one vector (GPU)>>>torch.ones(())+torch.ones(1).cuda()# OK, scalar auto-transferred from CPU to GPU>>>torch.ones(1).cuda()+torch.ones(())# OK, scalar auto-transferred from CPU to GPU
>>># one scalar (GPU), one vector (CPU)>>>torch.ones(()).cuda()+torch.ones(1)# Fail, scalar not auto-transferred from GPU to CPU and non-scalar not auto-transferred from CPU to GPU>>>torch.ones(1)+torch.ones(()).cuda()# Fail, scalar not auto-transferred from GPU to CPU and non-scalar not auto-transferred from CPU to GPU

torch.layout#

classtorch.layout#

Warning

Thetorch.layout class is in beta and subject to change.

Atorch.layout is an object that represents the memory layout of atorch.Tensor. Currently, we supporttorch.strided (dense Tensors)and have beta support fortorch.sparse_coo (sparse COO Tensors).

torch.strided represents dense Tensors and is the memory layout thatis most commonly used. Each strided tensor has an associatedtorch.Storage, which holds its data. These tensors providemulti-dimensional,stridedview of a storage. Strides are a list of integers: the k-th striderepresents the jump in the memory necessary to go from one element to thenext one in the k-th dimension of the Tensor. This concept makes it possibleto perform many tensor operations efficiently.

Example:

>>>x=torch.tensor([[1,2,3,4,5],[6,7,8,9,10]])>>>x.stride()(5, 1)>>>x.t().stride()(1, 5)

For more information ontorch.sparse_coo tensors, seetorch.sparse.

torch.memory_format#

classtorch.memory_format#

Atorch.memory_format is an object representing the memory format on which atorch.Tensor isor will be allocated.

Possible values are:

  • torch.contiguous_format:Tensor is or will be allocated in dense non-overlapping memory. Strides represented by values in decreasing order.

  • torch.channels_last:Tensor is or will be allocated in dense non-overlapping memory. Strides represented by values instrides[0]>strides[2]>strides[3]>strides[1]==1 aka NHWC order.

  • torch.channels_last_3d:Tensor is or will be allocated in dense non-overlapping memory. Strides represented by values instrides[0]>strides[2]>strides[3]>strides[4]>strides[1]==1 aka NDHWC order.

  • torch.preserve_format:Used in functions likeclone to preserve the memory format of the input tensor. If input tensor isallocated in dense non-overlapping memory, the output tensor strides will be copied from the input.Otherwise output strides will followtorch.contiguous_format