Rate this Page

torch.use_deterministic_algorithms#

torch.use_deterministic_algorithms(mode,*,warn_only=False)[source]#

Sets whether PyTorch operations must use “deterministic”algorithms. That is, algorithms which, given the same input, and whenrun on the same software and hardware, always produce the same output.When enabled, operations will use deterministic algorithms when available,and if only nondeterministic algorithms are available they will throw aRuntimeError when called.

Note

This setting alone is not always enough to make an applicationreproducible. Refer toReproducibility for more information.

Note

torch.set_deterministic_debug_mode() offers an alternativeinterface for this feature.

The following normally-nondeterministic operations will actdeterministically whenmode=True:

The following normally-nondeterministic operations will throw aRuntimeError whenmode=True:

In addition, several operations fill uninitialized memory when this settingis turned on and whentorch.utils.deterministic.fill_uninitialized_memory is turned on.See the documentation for that attribute for more information.

A handful of CUDA operations are nondeterministic if the CUDA version is10.2 or greater, unless the environment variableCUBLAS_WORKSPACE_CONFIG=:4096:8orCUBLAS_WORKSPACE_CONFIG=:16:8 is set. See the CUDA documentation for moredetails:https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibilityIf one of these environment variable configurations is not set, aRuntimeErrorwill be raised from these operations when called with CUDA tensors:

Note that deterministic operations tend to have worse performance thannondeterministic operations.

Note

This flag does not detect or prevent nondeterministic behavior causedby calling an inplace operation on a tensor with an internal memoryoverlap or by giving such a tensor as theout argument for anoperation. In these cases, multiple writes of different data may targeta single memory location, and the order of writes is not guaranteed.

Parameters

mode (bool) – If True, makes potentially nondeterministicoperations switch to a deterministic algorithm or throw a runtimeerror. If False, allows nondeterministic operations.

Keyword Arguments

warn_only (bool, optional) – If True, operations that do nothave a deterministic implementation will throw a warning instead ofan error. Default:False

Example:

>>>torch.use_deterministic_algorithms(True)# Forward mode nondeterministic error>>>torch.randn(10,device='cuda').kthvalue(1)...RuntimeError: kthvalue CUDA does not have a deterministic implementation...# Backward mode nondeterministic error>>>torch.nn.AvgPool3d(1)(torch.randn(3,4,5,6,requires_grad=True).cuda()).sum().backward()...RuntimeError: avg_pool3d_backward_cuda does not have a deterministic implementation...