- Notifications
You must be signed in to change notification settings - Fork26.3k
[export] Fix deserialization issue#150515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Conversation
pytorch-botbot commentedApr 2, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
🔗 Helpful Links🧪 See artifacts and rendered test results athud.pytorch.org/pr/150515
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commitff50489 with merge base85079e4 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
facebook-github-bot commentedApr 2, 2025
@angelayi has imported this pull request. If you are a Meta employee, you can view this diffon Phabricator. |
| "as_string", | ||
| ): | ||
| node_name=self.signature.input_specs[i].arg.name | ||
| node_name=self.signature.input_specs[i].arg.nameorf"arg{i}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
nit: maybe add a comment that this f"arg{i}" is for handling old export schema that doesn't have a name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I agree with your comment but I need to land this asap to fix the aps trunk issue so I will do it in a followup 😅
angelayi commentedApr 3, 2025
@pytorchbot merge |
pytorchmergebot commentedApr 3, 2025
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in thewiki. Questions? Feedback? Please reach out to thePyTorch DevX Team |
An internal model was serialized in 2023, and is now breaking while loading with the following error:``` File "<eval_with_key>.1675", line 4 def forward(self, arg1163_1, arg1164_1, , arg1166_1, , arg1168_1, arg1169_1, arg1170_1, , arg1172_1, arg1173_1, arg1174_1, arg1175_1, arg1176_1, arg1177_1, arg1178_1, arg1179_1, arg1180_1, arg1181_1, arg1182_1, arg1183_1, arg1184_1, arg1185_1, arg1186_1, arg1187_1, arg1188_1, arg1189_1, arg1190_1, arg1191_1, arg1192_1, arg1193_1, arg1194_1, arg1195_1, arg1196_1, arg1197_1, arg1198_1, arg1199_1, arg1200_1, arg1201_1, arg1202_1, arg1203_1, arg1204_1, arg1205_1, arg1206_1, arg1207_1, arg1208_1, arg1209_1, arg1210_1, arg1211_1, arg1212_1, arg1213_1, arg1214_1, arg1215_1, arg1216_1, , arg1218_1, arg1219_1, arg1220_1, arg1221_1, arg1222_1, arg1223_1, arg1224_1, , arg1226_1, arg1227_1, arg1228_1, , arg1230_1, , , , , , , , , , , , , , , ): ^SyntaxError: invalid syntax```The syntax errors are due to inputs that are `None` when exporting. Prior to changes inpytorch#123590 (landed 4/2024), input specs for none inputs look like `InputSpec(userInput=UserInputSpec(arg=Argument(asNone=True)))`, and during deserialization when creating a node, we would just use a dummy name `arg`. After to those changes, the input specs for none inputs look like `InputSpec(constantInput=InputToConstantInputSpec(name='y', value=ConstantValue(asNone=True)))`, and when creating a node we would use the name `y` as the name. However the PR didn't handle the case if it's loading an old package which doesn't have this name, so ended up putting empty names in the placeholder nodes.This error was uncovered afterpytorch#149717, where we now use the GraphModule's python codegen to run the UnflattenedModule instead of going through the interpreter path. The placeholder nodes having empty names caused the python codegen to fail.Pull Requestresolved:pytorch#150515Approved by:https://github.com/yushangdi
Uh oh!
There was an error while loading.Please reload this page.
An internal model was serialized in 2023, and is now breaking while loading with the following error:
The syntax errors are due to inputs that are
Nonewhen exporting. Prior to changes in#123590 (landed 4/2024), input specs for none inputs look likeInputSpec(userInput=UserInputSpec(arg=Argument(asNone=True))), and during deserialization when creating a node, we would just use a dummy namearg. After to those changes, the input specs for none inputs look likeInputSpec(constantInput=InputToConstantInputSpec(name='y', value=ConstantValue(asNone=True))), and when creating a node we would use the nameyas the name. However the PR didn't handle the case if it's loading an old package which doesn't have this name, so ended up putting empty names in the placeholder nodes.This error was uncovered after#149717, where we now use the GraphModule's python codegen to run the UnflattenedModule instead of going through the interpreter path. The placeholder nodes having empty names caused the python codegen to fail.