Rate this Page

torch.testing#

Created On: May 07, 2021 | Last Updated On: Jun 10, 2025

torch.testing.assert_close(actual,expected,*,allow_subclasses=True,rtol=None,atol=None,equal_nan=False,check_device=True,check_dtype=True,check_layout=True,check_stride=False,msg=None)[source]#

Asserts thatactual andexpected are close.

Ifactual andexpected are strided, non-quantized, real-valued, and finite, they are considered close if

actualexpectedatol+rtolexpected\lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert

Non-finite values (-inf andinf) are only considered close if and only if they are equal.NaN’s areonly considered equal to each other ifequal_nan isTrue.

In addition, they are only considered close if they have the same

  • device (ifcheck_device isTrue),

  • dtype (ifcheck_dtype isTrue),

  • layout (ifcheck_layout isTrue), and

  • stride (ifcheck_stride isTrue).

If eitheractual orexpected is a meta tensor, only the attribute checks will be performed.

Ifactual andexpected are sparse (either having COO, CSR, CSC, BSR, or BSC layout), their strided members arechecked individually. Indices, namelyindices for COO,crow_indices andcol_indices for CSR and BSR,orccol_indices androw_indices for CSC and BSC layouts, respectively,are always checked for equality whereas the values are checked for closeness according to the definition above.

Ifactual andexpected are quantized, they are considered close if they have the sameqscheme() and the result ofdequantize() is close according to thedefinition above.

actual andexpected can beTensor’s or any tensor-or-scalar-likes from whichtorch.Tensor’s can be constructed withtorch.as_tensor(). Except for Python scalars the input typeshave to be directly related. In addition,actual andexpected can beSequence’sorMapping’s in which case they are considered close if their structure matches and alltheir elements are considered close according to the above definition.

Note

Python scalars are an exception to the type relation requirement, because theirtype(), i.e.int,float, andcomplex, is equivalent to thedtype of a tensor-like. Thus,Python scalars of different types can be checked, but requirecheck_dtype=False.

Parameters
  • actual (Any) – Actual input.

  • expected (Any) – Expected input.

  • allow_subclasses (bool) – IfTrue (default) and except for Python scalars, inputs of directly related typesare allowed. Otherwise type equality is required.

  • rtol (Optional[float]) – Relative tolerance. If specifiedatol must also be specified. If omitted, defaultvalues based on thedtype are selected with the below table.

  • atol (Optional[float]) – Absolute tolerance. If specifiedrtol must also be specified. If omitted, defaultvalues based on thedtype are selected with the below table.

  • equal_nan (Union[bool,str]) – IfTrue, twoNaN values will be considered equal.

  • check_device (bool) – IfTrue (default), asserts that corresponding tensors are on the samedevice. If this check is disabled, tensors on differentdevice’s are moved to the CPU before being compared.

  • check_dtype (bool) – IfTrue (default), asserts that corresponding tensors have the samedtype. If thischeck is disabled, tensors with differentdtype’s are promoted to a commondtype (according totorch.promote_types()) before being compared.

  • check_layout (bool) – IfTrue (default), asserts that corresponding tensors have the samelayout. If thischeck is disabled, tensors with differentlayout’s are converted to strided tensors before beingcompared.

  • check_stride (bool) – IfTrue and corresponding tensors are strided, asserts that they have the same stride.

  • msg (Optional[Union[str,Callable[[str],str]]]) – Optional error message to use in case a failure occurs duringthe comparison. Can also passed as callable in which case it will be called with the generated message andshould return the new message.

Raises
  • ValueError – If notorch.Tensor can be constructed from an input.

  • ValueError – If onlyrtol oratol is specified.

  • AssertionError – If corresponding inputs are not Python scalars and are not directly related.

  • AssertionError – Ifallow_subclasses isFalse, but corresponding inputs are not Python scalars and have different types.

  • AssertionError – If the inputs areSequence’s, but their length does not match.

  • AssertionError – If the inputs areMapping’s, but their set of keys do not match.

  • AssertionError – If corresponding tensors do not have the sameshape.

  • AssertionError – Ifcheck_layout isTrue, but corresponding tensors do not have the samelayout.

  • AssertionError – If only one of corresponding tensors is quantized.

  • AssertionError – If corresponding tensors are quantized, but have differentqscheme()’s.

  • AssertionError – Ifcheck_device isTrue, but corresponding tensors are not on the samedevice.

  • AssertionError – Ifcheck_dtype isTrue, but corresponding tensors do not have the samedtype.

  • AssertionError – Ifcheck_stride isTrue, but corresponding strided tensors do not have the same stride.

  • AssertionError – If the values of corresponding tensors are not close according to the definition above.

The following table displays the defaultrtol andatol for differentdtype’s. In case of mismatchingdtype’s, the maximum of both tolerances is used.

dtype

rtol

atol

float16

1e-3

1e-5

bfloat16

1.6e-2

1e-5

float32

1.3e-6

1e-5

float64

1e-7

1e-7

complex32

1e-3

1e-5

complex64

1.3e-6

1e-5

complex128

1e-7

1e-7

quint8

1.3e-6

1e-5

quint2x4

1.3e-6

1e-5

quint4x2

1.3e-6

1e-5

qint8

1.3e-6

1e-5

qint32

1.3e-6

1e-5

other

0.0

0.0

Note

assert_close() is highly configurable with strict default settings. Users are encouragedtopartial() it to fit their use case. For example, if an equality check is needed, one mightdefine anassert_equal that uses zero tolerances for everydtype by default:

>>>importfunctools>>>assert_equal=functools.partial(torch.testing.assert_close,rtol=0,atol=0)>>>assert_equal(1e-9,1e-10)Traceback (most recent call last):...AssertionError:Scalars are not equal!Expected 1e-10 but got 1e-09.Absolute difference: 9.000000000000001e-10Relative difference: 9.0

Examples

>>># tensor to tensor comparison>>>expected=torch.tensor([1e0,1e-1,1e-2])>>>actual=torch.acos(torch.cos(expected))>>>torch.testing.assert_close(actual,expected)
>>># scalar to scalar comparison>>>importmath>>>expected=math.sqrt(2.0)>>>actual=2.0/math.sqrt(2.0)>>>torch.testing.assert_close(actual,expected)
>>># numpy array to numpy array comparison>>>importnumpyasnp>>>expected=np.array([1e0,1e-1,1e-2])>>>actual=np.arccos(np.cos(expected))>>>torch.testing.assert_close(actual,expected)
>>># sequence to sequence comparison>>>importnumpyasnp>>># The types of the sequences do not have to match. They only have to have the same>>># length and their elements have to match.>>>expected=[torch.tensor([1.0]),2.0,np.array(3.0)]>>>actual=tuple(expected)>>>torch.testing.assert_close(actual,expected)
>>># mapping to mapping comparison>>>fromcollectionsimportOrderedDict>>>importnumpyasnp>>>foo=torch.tensor(1.0)>>>bar=2.0>>>baz=np.array(3.0)>>># The types and a possible ordering of mappings do not have to match. They only>>># have to have the same set of keys and their elements have to match.>>>expected=OrderedDict([("foo",foo),("bar",bar),("baz",baz)])>>>actual={"baz":baz,"bar":bar,"foo":foo}>>>torch.testing.assert_close(actual,expected)
>>>expected=torch.tensor([1.0,2.0,3.0])>>>actual=expected.clone()>>># By default, directly related instances can be compared>>>torch.testing.assert_close(torch.nn.Parameter(actual),expected)>>># This check can be made more strict with allow_subclasses=False>>>torch.testing.assert_close(...torch.nn.Parameter(actual),expected,allow_subclasses=False...)Traceback (most recent call last):...TypeError:No comparison pair was able to handle inputs of type<class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'>.>>># If the inputs are not directly related, they are never considered close>>>torch.testing.assert_close(actual.numpy(),expected)Traceback (most recent call last):...TypeError:No comparison pair was able to handle inputs of type <class 'numpy.ndarray'>and <class 'torch.Tensor'>.>>># Exceptions to these rules are Python scalars. They can be checked regardless of>>># their type if check_dtype=False.>>>torch.testing.assert_close(1.0,1,check_dtype=False)
>>># NaN != NaN by default.>>>expected=torch.tensor(float("Nan"))>>>actual=expected.clone()>>>torch.testing.assert_close(actual,expected)Traceback (most recent call last):...AssertionError:Scalars are not close!Expected nan but got nan.Absolute difference: nan (up to 1e-05 allowed)Relative difference: nan (up to 1.3e-06 allowed)>>>torch.testing.assert_close(actual,expected,equal_nan=True)
>>>expected=torch.tensor([1.0,2.0,3.0])>>>actual=torch.tensor([1.0,4.0,5.0])>>># The default error message can be overwritten.>>>torch.testing.assert_close(...actual,expected,msg="Argh, the tensors are not close!"...)Traceback (most recent call last):...AssertionError:Argh, the tensors are not close!>>># If msg is a callable, it can be used to augment the generated message with>>># extra information>>>torch.testing.assert_close(...actual,expected,msg=lambdamsg:f"Header\n\n{msg}\n\nFooter"...)Traceback (most recent call last):...AssertionError:HeaderTensor-likes are not close!Mismatched elements: 2 / 3 (66.7%)Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed)Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed)Footer
torch.testing.make_tensor(*shape,dtype,device,low=None,high=None,requires_grad=False,noncontiguous=False,exclude_zero=False,memory_format=None)[source]#

Creates a tensor with the givenshape,device, anddtype, and filled withvalues uniformly drawn from[low,high).

Iflow orhigh are specified and are outside the range of thedtype’s representablefinite values then they are clamped to the lowest or highest representable finite value, respectively.IfNone, then the following table describes the default values forlow andhigh,which depend ondtype.

dtype

low

high

boolean type

0

2

unsigned integral type

0

10

signed integral types

-9

10

floating types

-9

9

complex types

-9

9

Parameters
  • shape (Tuple[int,...]) – Single integer or a sequence of integers defining the shape of the output tensor.

  • dtype (torch.dtype) – The data type of the returned tensor.

  • device (Union[str,torch.device]) – The device of the returned tensor.

  • low (Optional[Number]) – Sets the lower limit (inclusive) of the given range. If a number is provided it isclamped to the least representable finite value of the given dtype. WhenNone (default),this value is determined based on thedtype (see the table above). Default:None.

  • high (Optional[Number]) –

    Sets the upper limit (exclusive) of the given range. If a number is provided it isclamped to the greatest representable finite value of the given dtype. WhenNone (default) this valueis determined based on thedtype (see the table above). Default:None.

    Deprecated since version 2.1:Passinglow==high tomake_tensor() for floating or complex types is deprecatedsince 2.1 and will be removed in 2.3. Usetorch.full() instead.

  • requires_grad (Optional[bool]) – If autograd should record operations on the returned tensor. Default:False.

  • noncontiguous (Optional[bool]) – IfTrue, the returned tensor will be noncontiguous. This argument isignored if the constructed tensor has fewer than two elements. Mutually exclusive withmemory_format.

  • exclude_zero (Optional[bool]) – IfTrue then zeros are replaced with the dtype’s small positive valuedepending on thedtype. For bool and integer types zero is replaced with one. For floatingpoint types it is replaced with the dtype’s smallest positive normal number (the “tiny” value of thedtype’sfinfo() object), and for complex types it is replaced with a complex numberwhose real and imaginary parts are both the smallest positive normal number representable by the complextype. DefaultFalse.

  • memory_format (Optional[torch.memory_format]) – The memory format of the returned tensor. Mutually exclusivewithnoncontiguous.

Raises
  • ValueError – Ifrequires_grad=True is passed for integraldtype

  • ValueError – Iflow>=high.

  • ValueError – If eitherlow orhigh isnan.

  • ValueError – If bothnoncontiguous andmemory_format are passed.

  • TypeError – Ifdtype isn’t supported by this function.

Return type

Tensor

Examples

>>>fromtorch.testingimportmake_tensor>>># Creates a float tensor with values in [-1, 1)>>>make_tensor((3,),device="cpu",dtype=torch.float32,low=-1,high=1)tensor([ 0.1205, 0.2282, -0.6380])>>># Creates a bool tensor on CUDA>>>make_tensor((2,2),device="cuda",dtype=torch.bool)tensor([[False, False],        [False, True]], device='cuda:0')
torch.testing.assert_allclose(actual,expected,rtol=None,atol=None,equal_nan=True,msg='')[source]#

Warning

torch.testing.assert_allclose() is deprecated since1.12 and will be removed in a future release.Please usetorch.testing.assert_close() instead. You can find detailed upgrade instructionshere.