Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit45f3e20

Browse files
mikaylagawareckipytorchmergebot
authored andcommitted
Improve error message for weights_only load (#129705)
As@vmoens pointed out, the current error message does not make the "either/or" between setting `weights_only=False` and using `add_safe_globals` clear enough, and should print the code for the user to call `add_safe_globals`New formatting looks like suchIn the case that `add_safe_globals` can be used```python>>> import torch>>> from torch.testing._internal.two_tensor import TwoTensor>>> torch.save(TwoTensor(torch.randn(2), torch.randn(2)), "two_tensor.pt")>>> torch.load("two_tensor.pt", weights_only=True)Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/data/users/mg1998/pytorch/torch/serialization.py", line 1225, in load raise pickle.UnpicklingError(_get_wo_message(str(e))) from None_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options (1) Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source. (2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message. WeightsUnpickler error: Unsupported global: GLOBAL torch.testing._internal.two_tensor.TwoTensor was not an allowed global by default. Please use `torch.serialization.add_safe_globals([TwoTensor])` to allowlist this global if you trust this class/function.Check the documentation of torch.load to learn more about types accepted by default with weights_onlyhttps://pytorch.org/docs/stable/generated/torch.load.html.```For other issues (unsupported bytecode)```python>>> import torch>>> t = torch.randn(2, 3)>>> torch.save(t, "protocol_5.pt", pickle_protocol=5)>>> torch.load("protocol_5.pt", weights_only=True)/data/users/mg1998/pytorch/torch/_weights_only_unpickler.py:359: UserWarning: Detected pickle protocol 5 in the checkpoint, which was not the default pickle protocol used by `torch.load` (2). The weights_only Unpickler might not support all instructions implemented by this protocol, please file an issue for adding support if you encounter this. warnings.warn(Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/data/users/mg1998/pytorch/torch/serialization.py", line 1225, in load raise pickle.UnpicklingError(_get_wo_message(str(e))) from None_pickle.UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source. Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Unsupported operand 149Check the documentation of torch.load to learn more about types accepted by default with weights_onlyhttps://pytorch.org/docs/stable/generated/torch.load.html.```Old formatting would have been like:```pythonTraceback (most recent call last): File "<stdin>", line 1, in <module> File "/data/users/mg1998/pytorch/torch/serialization.py", line 1203, in load raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None_pickle.UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you get the file from a trusted source. Alternatively, to load with `weights_only` please check the recommended steps in the following error message. WeightsUnpickler error: Unsupported global: GLOBAL torch.testing._internal.two_tensor.TwoTensor was not an allowed global by default. Please use `torch.serialization.add_safe_globals` to allowlist this global if you trust this class/function.```Pull Requestresolved:#129705Approved by:https://github.com/albanD,https://github.com/vmoensghstack dependencies:#129239,#129396,#129509
1 parent99456a6 commit45f3e20

File tree

3 files changed

+47
-9
lines changed

3 files changed

+47
-9
lines changed

‎test/test_serialization.py‎

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,22 @@ def fake_set_state(obj, *args):
11121112
torch.serialization.clear_safe_globals()
11131113
ClassThatUsesBuildInstruction.__setstate__=None
11141114

1115+
@parametrize("unsafe_global", [True,False])
1116+
deftest_weights_only_error(self,unsafe_global):
1117+
sd= {'t':TwoTensor(torch.randn(2),torch.randn(2))}
1118+
pickle_protocol=torch.serialization.DEFAULT_PROTOCOLifunsafe_globalelse5
1119+
withBytesIOContext()asf:
1120+
torch.save(sd,f,pickle_protocol=pickle_protocol)
1121+
f.seek(0)
1122+
ifunsafe_global:
1123+
withself.assertRaisesRegex(pickle.UnpicklingError,
1124+
r"use `torch.serialization.add_safe_globals\(\[TwoTensor\]\)` to allowlist"):
1125+
torch.load(f,weights_only=True)
1126+
else:
1127+
withself.assertRaisesRegex(pickle.UnpicklingError,
1128+
"file an issue with the following so that we can make `weights_only=True`"):
1129+
torch.load(f,weights_only=True)
1130+
11151131
@parametrize('weights_only', (False,True))
11161132
deftest_serialization_math_bits(self,weights_only):
11171133
t=torch.randn(1,dtype=torch.cfloat)

‎torch/_weights_only_unpickler.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ def load(self):
210210
else:
211211
raiseRuntimeError(
212212
f"Unsupported global: GLOBAL{full_path} was not an allowed global by default. "
213-
"Please use `torch.serialization.add_safe_globals` to allowlist this global "
214-
"if you trust this class/function."
213+
f"Please use `torch.serialization.add_safe_globals([{name}])` to allowlist "
214+
"this globalif you trust this class/function."
215215
)
216216
elifkey[0]==NEWOBJ[0]:
217217
args=self.stack.pop()

‎torch/serialization.py‎

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
importio
66
importos
77
importpickle
8+
importre
89
importshutil
910
importstruct
1011
importsys
@@ -1107,12 +1108,33 @@ def load(
11071108
"""
11081109
torch._C._log_api_usage_once("torch.load")
11091110
UNSAFE_MESSAGE= (
1110-
"Weights only load failed. Re-running `torch.load` with `weights_only` set to `False`"
1111-
" will likely succeed, but it can result in arbitrary code execution."
1112-
" Do it only if you get the file from a trusted source. Alternatively, to load"
1113-
" with `weights_only` please check the recommended steps in the following error message."
1114-
" WeightsUnpickler error: "
1111+
"Re-running `torch.load` with `weights_only` set to `False` will likely succeed, "
1112+
"but it can result in arbitrary code execution. Do it only if you got the file from a "
1113+
"trusted source."
11151114
)
1115+
DOCS_MESSAGE= (
1116+
"\n\nCheck the documentation of torch.load to learn more about types accepted by default with "
1117+
"weights_only https://pytorch.org/docs/stable/generated/torch.load.html."
1118+
)
1119+
1120+
def_get_wo_message(message:str)->str:
1121+
pattern=r"GLOBAL (\S+) was not an allowed global by default."
1122+
has_unsafe_global=re.search(pattern,message)isnotNone
1123+
ifhas_unsafe_global:
1124+
updated_message= (
1125+
"Weights only load failed. This file can still be loaded, to do so you have two options "
1126+
f"\n\t(1){UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check "
1127+
"the recommended steps in the following error message.\n\tWeightsUnpickler error: "
1128+
+message
1129+
)
1130+
else:
1131+
updated_message= (
1132+
f"Weights only load failed.{UNSAFE_MESSAGE}\n Please file an issue with the following "
1133+
"so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler "
1134+
"error: "+message
1135+
)
1136+
returnupdated_message+DOCS_MESSAGE
1137+
11161138
ifweights_onlyisNone:
11171139
weights_only,warn_weights_only=False,True
11181140
else:
@@ -1200,7 +1222,7 @@ def load(
12001222
**pickle_load_args,
12011223
)
12021224
exceptRuntimeErrorase:
1203-
raisepickle.UnpicklingError(UNSAFE_MESSAGE+str(e))fromNone
1225+
raisepickle.UnpicklingError(_get_wo_message(str(e)))fromNone
12041226
return_load(
12051227
opened_zipfile,
12061228
map_location,
@@ -1224,7 +1246,7 @@ def load(
12241246
**pickle_load_args,
12251247
)
12261248
exceptRuntimeErrorase:
1227-
raisepickle.UnpicklingError(UNSAFE_MESSAGE+str(e))fromNone
1249+
raisepickle.UnpicklingError(_get_wo_message(str(e)))fromNone
12281250
return_legacy_load(
12291251
opened_file,map_location,pickle_module,**pickle_load_args
12301252
)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp