Rate this Page

Meta device#

Created On: Jun 17, 2025 | Last Updated On: Jun 17, 2025

The “meta” device is an abstract device which denotes a tensor which recordsonly metadata, but no actual data. Meta tensors have two primary use cases:

  • Models can be loaded on the meta device, allowing you to load arepresentation of the model without actually loading the actual parametersinto memory. This can be helpful if you need to make transformations onthe model before you load the actual data.

  • Most operations can be performed on meta tensors, producing new metatensors that describe what the result would have been if you performedthe operation on a real tensor. You can use this to perform abstractanalysis without needing to spend time on compute or space to representthe actual tensors. Because meta tensors do not have real data, you cannotperform data-dependent operations liketorch.nonzero() oritem(). In some cases, not all device types (e.g., CPUand CUDA) have exactly the same output metadata for an operation; wetypically prefer representing the CUDA behavior faithfully in thissituation.

Warning

Although in principle meta tensor computation should always be faster thanan equivalent CPU/CUDA computation, many meta tensor implementations areimplemented in Python and have not been ported to C++ for speed, so youmay find that you get lower absolute framework latency with small CPU tensors.

Idioms for working with meta tensors#

An object can be loaded withtorch.load() onto meta device by specifyingmap_location='meta':

>>>torch.save(torch.randn(2),'foo.pt')>>>torch.load('foo.pt',map_location='meta')tensor(..., device='meta', size=(2,))

If you have some arbitrary code which performs some tensor construction withoutexplicitly specifying a device, you can override it to instead construct on meta device by usingthetorch.device() context manager:

>>>withtorch.device('meta'):...print(torch.randn(30,30))...tensor(..., device='meta', size=(30, 30))

This is especially helpful NN module construction, where you often are notable to explicitly pass in a device for initialization:

>>>fromtorch.nn.modulesimportLinear>>>withtorch.device('meta'):...print(Linear(20,30))...Linear(in_features=20, out_features=30, bias=True)

You cannot convert a meta tensor directly to a CPU/CUDA tensor, because themeta tensor stores no data and we do not know what the correct data values foryour new tensor are:

>>>torch.ones(5,device='meta').to("cpu")Traceback (most recent call last):  File"<stdin>", line1, in<module>NotImplementedError:Cannot copy out of meta tensor; no data!

Use a factory function liketorch.empty_like() to explicitly specify howyou would like the missing data to be filled in.

NN modules have a convenience methodtorch.nn.Module.to_empty() thatallows you to move the module to another device, leaving all parametersuninitialized. You are expected to explicitly reinitialize the parametersmanually:

>>>fromtorch.nn.modulesimportLinear>>>withtorch.device('meta'):...m=Linear(20,30)>>>m.to_empty(device="cpu")Linear(in_features=20, out_features=30, bias=True)

torch._subclasses.meta_utils contains undocumented utilities for takingan arbitrary Tensor and constructing an equivalent meta Tensor with highfidelity. These APIs are experimental and may be changed in a BC breaking wayat any time.