torch.load#
- torch.load(f,map_location=None,pickle_module=pickle,*,weights_only=True,mmap=None,**pickle_load_args)[source]#
Loads an object saved with
torch.save()from a file.torch.load()uses Python’s unpickling facilities but treats storages,which underlie tensors, specially. They are first deserialized on theCPU and are then moved to the device they were saved from. If this fails(e.g. because the run time system doesn’t have certain devices), an exceptionis raised. However, storages can be dynamically remapped to an alternativeset of devices using themap_locationargument.If
map_locationis a callable, it will be called once for each serializedstorage with two arguments: storage and location. The storage argumentwill be the initial deserialization of the storage, residing on the CPU.Each serialized storage has a location tag associated with it whichidentifies the device it was saved from, and this tag is the secondargument passed tomap_location. The builtin location tags are'cpu'for CPU tensors and'cuda:device_id'(e.g.'cuda:2') for CUDA tensors.map_locationshould return eitherNoneor a storage. Ifmap_locationreturns a storage, it will be used as the final deserializedobject, already moved to the right device. Otherwise,torch.load()willfall back to the default behavior, as ifmap_locationwasn’t specified.If
map_locationis atorch.deviceobject or a string containinga device tag, it indicates the location where all tensors should be loaded.Otherwise, if
map_locationis a dict, it will be used to remap location tagsappearing in the file (keys), to ones that specify where to put thestorages (values).User extensions can register their own location tags and tagging anddeserialization methods using
torch.serialization.register_package().SeeLayout Control for more advanced tools to manipulate a checkpoint.
- Parameters
f (Union[str,PathLike[str],IO[bytes]]) – a file-like object (has to implement
read(),readline(),tell(), andseek()),or a string or os.PathLike object containing a file namemap_location (Optional[Union[Callable[[Storage,str],Storage],device,str,dict[str,str]]]) – a function,
torch.device, string or a dict specifying how to remap storagelocationspickle_module (Optional[Any]) – module used for unpickling metadata and objects (has tomatch the
pickle_moduleused to serialize file)weights_only (Optional[bool]) – Indicates whether unpickler should be restricted toloading only tensors, primitive types, dictionariesand any types added via
torch.serialization.add_safe_globals().Seetorch.load with weights_only=True for more details.mmap (Optional[bool]) – Indicates whether the file should be mapped rather than loading all the storages into memory.Typically, tensor storages in the file will first be moved from disk to CPU memory, after which theyare moved to the location that they were tagged with when saving, or specified by
map_location. Thissecond step is a no-op if the final location is CPU. When themmapflag is set, instead of copying thetensor storages from disk to CPU memory in the first step,fis mapped, which means tensor storageswill be lazily loaded when their data is accessed.pickle_load_args (Any) – (Python 3 only) optional keyword arguments passed over to
pickle_module.load()andpickle_module.Unpickler(), e.g.,errors=....
- Return type
Warning
torch.load()unlessweights_only parameter is set toTrue,usespicklemodule implicitly, which is known to be insecure.It is possible to construct malicious pickle data which will execute arbitrary codeduring unpickling. Never load data that could have come from an untrustedsource in an unsafe mode, or that could have been tampered with.Only load data you trust.Note
When you call
torch.load()on a file which contains GPU tensors, those tensorswill be loaded to GPU by default. You can calltorch.load(..,map_location='cpu')and thenload_state_dict()to avoid GPU RAM surge when loading a model checkpoint.Note
By default, we decode byte strings as
utf-8. This is to avoid a common errorcaseUnicodeDecodeError:'ascii'codeccan'tdecodebyte0x...when loading files saved by Python 2 in Python 3. If this defaultis incorrect, you may use an extraencodingkeyword argument to specify howthese objects should be loaded, e.g.,encoding='latin1'decodes themto strings usinglatin1encoding, andencoding='bytes'keeps themas byte arrays which can be decoded later withbyte_array.decode(...).Example
>>>torch.load("tensors.pt",weights_only=True)# Load all tensors onto the CPU>>>torch.load(..."tensors.pt",...map_location=torch.device("cpu"),...weights_only=True,...)# Load all tensors onto the CPU, using a function>>>torch.load(..."tensors.pt",...map_location=lambdastorage,loc:storage,...weights_only=True,...)# Load all tensors onto GPU 1>>>torch.load(..."tensors.pt",...map_location=lambdastorage,loc:storage.cuda(1),...weights_only=True,...)# type: ignore[attr-defined]# Map tensors from GPU 1 to GPU 0>>>torch.load(..."tensors.pt",...map_location={"cuda:1":"cuda:0"},...weights_only=True,...)# Load tensor from io.BytesIO object# Loading from a buffer setting weights_only=False, warning this can be unsafe>>>withopen("tensor.pt","rb")asf:...buffer=io.BytesIO(f.read())>>>torch.load(buffer,weights_only=False)# Load a module with 'ascii' encoding for unpickling# Loading from a module setting weights_only=False, warning this can be unsafe>>>torch.load("module.pt",encoding="ascii",weights_only=False)