- Notifications
You must be signed in to change notification settings - Fork26.3k
[ONNX] Fix bfloat16 support in onnx_program callable#151121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Conversation
Previously bfloat16 constants will be created as uint16. Added a test to guard
pytorch-botbot commentedApr 11, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
🔗 Helpful Links🧪 See artifacts and rendered test results athud.pytorch.org/pr/151121
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit4bf3801 with merge based385179 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Copilot reviewed 1 out of 1 changed files in this pull request and generated 1 comment.
Uh oh!
There was an error while loading.Please reload this page.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
Comments suppressed due to low confidence (1)
torch/onnx/_internal/exporter/_onnx_program.py:131
- Using tensor.view assumes that the tensor is contiguous. Consider calling tensor.contiguous() before view conversion to avoid potential issues with non-contiguous tensors.
tensor.view(torch.uint16).numpy(force=True), onnx_element_type=onnx_typeUh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
| ) | ||
| raiseRuntimeError( | ||
| f"Failed to convert tensor of type '{tensor.dtype}' to OrtValue. " | ||
| "Please ensure that ONNX Runtime is built with DLPack support or is the latest version" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
We should be specific to the version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I can create a follow up. I wanted to mention a newer version that doesn't exist yet that can handle things better
justinchuby commentedApr 14, 2025
@pytorchbot merge |
pytorchmergebot commentedApr 14, 2025
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in thewiki. Questions? Feedback? Please reach out to thePyTorch DevX Team |
- Added a test to guard bfloat16. The optimizer incorrectly turns bfloat16 initializers into uint16, but this is not relevant to export logic.- Fix bfloat16 support in onnx_program callableTested with the following with cuda```pyimport torchclass BfloatModel(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.tensor(2.0, dtype=torch.bfloat16)) def forward(self, x): return x * torch.tensor(1.0, dtype=torch.bfloat16) * self.paraminput = torch.randn(1, 10, dtype=torch.bfloat16)model = BfloatModel()onnx_program = torch.onnx.export(model, (input,), dynamo=True, optimize=False, verify=True)```Pull Requestresolved:pytorch#151121Approved by:https://github.com/titaiwangmsCo-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
- Added a test to guard bfloat16. The optimizer incorrectly turns bfloat16 initializers into uint16, but this is not relevant to export logic.- Fix bfloat16 support in onnx_program callableTested with the following with cuda```pyimport torchclass BfloatModel(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.tensor(2.0, dtype=torch.bfloat16)) def forward(self, x): return x * torch.tensor(1.0, dtype=torch.bfloat16) * self.paraminput = torch.randn(1, 10, dtype=torch.bfloat16)model = BfloatModel()onnx_program = torch.onnx.export(model, (input,), dynamo=True, optimize=False, verify=True)```Pull Requestresolved:pytorch#151121Approved by:https://github.com/titaiwangmsCo-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Uh oh!
There was an error while loading.Please reload this page.
Tested with the following with cuda