Rate this Page

torch.load#

torch.load(f,map_location=None,pickle_module=pickle,*,weights_only=True,mmap=None,**pickle_load_args)[source]#

Loads an object saved withtorch.save() from a file.

torch.load() uses Python’s unpickling facilities but treats storages,which underlie tensors, specially. They are first deserialized on theCPU and are then moved to the device they were saved from. If this fails(e.g. because the run time system doesn’t have certain devices), an exceptionis raised. However, storages can be dynamically remapped to an alternativeset of devices using themap_location argument.

Ifmap_location is a callable, it will be called once for each serializedstorage with two arguments: storage and location. The storage argumentwill be the initial deserialization of the storage, residing on the CPU.Each serialized storage has a location tag associated with it whichidentifies the device it was saved from, and this tag is the secondargument passed tomap_location. The builtin location tags are'cpu'for CPU tensors and'cuda:device_id' (e.g.'cuda:2') for CUDA tensors.map_location should return eitherNone or a storage. Ifmap_location returns a storage, it will be used as the final deserializedobject, already moved to the right device. Otherwise,torch.load() willfall back to the default behavior, as ifmap_location wasn’t specified.

Ifmap_location is atorch.device object or a string containinga device tag, it indicates the location where all tensors should be loaded.

Otherwise, ifmap_location is a dict, it will be used to remap location tagsappearing in the file (keys), to ones that specify where to put thestorages (values).

User extensions can register their own location tags and tagging anddeserialization methods usingtorch.serialization.register_package().

SeeLayout Control for more advanced tools to manipulate a checkpoint.

Parameters
  • f (Union[str,PathLike[str],IO[bytes]]) – a file-like object (has to implementread(),readline(),tell(), andseek()),or a string or os.PathLike object containing a file name

  • map_location (Optional[Union[Callable[[Storage,str],Storage],device,str,dict[str,str]]]) – a function,torch.device, string or a dict specifying how to remap storagelocations

  • pickle_module (Optional[Any]) – module used for unpickling metadata and objects (has tomatch thepickle_module used to serialize file)

  • weights_only (Optional[bool]) – Indicates whether unpickler should be restricted toloading only tensors, primitive types, dictionariesand any types added viatorch.serialization.add_safe_globals().Seetorch.load with weights_only=True for more details.

  • mmap (Optional[bool]) – Indicates whether the file should be mapped rather than loading all the storages into memory.Typically, tensor storages in the file will first be moved from disk to CPU memory, after which theyare moved to the location that they were tagged with when saving, or specified bymap_location. Thissecond step is a no-op if the final location is CPU. When themmap flag is set, instead of copying thetensor storages from disk to CPU memory in the first step,f is mapped, which means tensor storageswill be lazily loaded when their data is accessed.

  • pickle_load_args (Any) – (Python 3 only) optional keyword arguments passed over topickle_module.load() andpickle_module.Unpickler(), e.g.,errors=....

Return type

Any

Warning

torch.load() unlessweights_only parameter is set toTrue,usespickle module implicitly, which is known to be insecure.It is possible to construct malicious pickle data which will execute arbitrary codeduring unpickling. Never load data that could have come from an untrustedsource in an unsafe mode, or that could have been tampered with.Only load data you trust.

Note

When you calltorch.load() on a file which contains GPU tensors, those tensorswill be loaded to GPU by default. You can calltorch.load(..,map_location='cpu')and thenload_state_dict() to avoid GPU RAM surge when loading a model checkpoint.

Note

By default, we decode byte strings asutf-8. This is to avoid a common errorcaseUnicodeDecodeError:'ascii'codeccan'tdecodebyte0x...when loading files saved by Python 2 in Python 3. If this defaultis incorrect, you may use an extraencoding keyword argument to specify howthese objects should be loaded, e.g.,encoding='latin1' decodes themto strings usinglatin1 encoding, andencoding='bytes' keeps themas byte arrays which can be decoded later withbyte_array.decode(...).

Example

>>>torch.load("tensors.pt",weights_only=True)# Load all tensors onto the CPU>>>torch.load(..."tensors.pt",...map_location=torch.device("cpu"),...weights_only=True,...)# Load all tensors onto the CPU, using a function>>>torch.load(..."tensors.pt",...map_location=lambdastorage,loc:storage,...weights_only=True,...)# Load all tensors onto GPU 1>>>torch.load(..."tensors.pt",...map_location=lambdastorage,loc:storage.cuda(1),...weights_only=True,...)# type: ignore[attr-defined]# Map tensors from GPU 1 to GPU 0>>>torch.load(..."tensors.pt",...map_location={"cuda:1":"cuda:0"},...weights_only=True,...)# Load tensor from io.BytesIO object# Loading from a buffer setting weights_only=False, warning this can be unsafe>>>withopen("tensor.pt","rb")asf:...buffer=io.BytesIO(f.read())>>>torch.load(buffer,weights_only=False)# Load a module with 'ascii' encoding for unpickling# Loading from a module setting weights_only=False, warning this can be unsafe>>>torch.load("module.pt",encoding="ascii",weights_only=False)