Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[ONNX] Fix bfloat16 support in onnx_program callable#151121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Closed

Conversation

@justinchuby
Copy link
Collaborator

@justinchubyjustinchuby commentedApr 11, 2025
edited
Loading

  • Added a test to guard bfloat16. The optimizer incorrectly turns bfloat16 initializers into uint16, but this is not relevant to export logic.
  • Fix bfloat16 support in onnx_program callable

Tested with the following with cuda

importtorchclassBfloatModel(torch.nn.Module):def__init__(self):super().__init__()self.param=torch.nn.Parameter(torch.tensor(2.0,dtype=torch.bfloat16))defforward(self,x):returnx*torch.tensor(1.0,dtype=torch.bfloat16)*self.paraminput=torch.randn(1,10,dtype=torch.bfloat16)model=BfloatModel()onnx_program=torch.onnx.export(model, (input,),dynamo=True,optimize=False,verify=True)

Previously bfloat16 constants will be created as uint16. Added a test to guard
@pytorch-bot
Copy link

pytorch-botbot commentedApr 11, 2025
edited
Loading

🔗 Helpful Links

🧪 See artifacts and rendered test results athud.pytorch.org/pr/151121

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit4bf3801 with merge based385179 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@justinchubyjustinchuby added module: onnxRelated to torch.onnx release notes: onnxtorch.onnx related changes that should show up in the release notes and removed topic: not user facingtopic category release notes: onnxtorch.onnx related changes that should show up in the release notes labelsApr 11, 2025
@justinchubyjustinchuby requested review fromCopilot andtitaiwangms and removed request forshubhambhokare1,titaiwangms andwschinApril 11, 2025 16:58
Copy link
Contributor

CopilotAI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Copilot reviewed 1 out of 1 changed files in this pull request and generated 1 comment.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@justinchubyjustinchuby changed the title[ONNX] Create test to guard bfloat16 export[ONNX] Create test to guard bfloat16 export and support in onnx_programApr 11, 2025
@justinchubyjustinchuby changed the title[ONNX] Create test to guard bfloat16 export and support in onnx_program[ONNX] Create test to guard bfloat16 exportApr 11, 2025
@justinchubyjustinchuby marked this pull request as draftApril 11, 2025 17:19
@justinchubyjustinchuby marked this pull request as ready for reviewApril 12, 2025 00:33
@justinchubyjustinchuby added topic: bug fixestopic category release notes: onnxtorch.onnx related changes that should show up in the release notes and removed topic: not user facingtopic category labelsApr 12, 2025
Copy link
Contributor

CopilotAI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

Comments suppressed due to low confidence (1)

torch/onnx/_internal/exporter/_onnx_program.py:131

  • Using tensor.view assumes that the tensor is contiguous. Consider calling tensor.contiguous() before view conversion to avoid potential issues with non-contiguous tensors.
tensor.view(torch.uint16).numpy(force=True), onnx_element_type=onnx_type

@justinchubyjustinchuby changed the title[ONNX] Create test to guard bfloat16 export[ONNX] Fix bfloat16 support in onnx_program callableApr 12, 2025
@justinchubyjustinchuby added the ciflow/trunkTrigger trunk jobs on your pull request labelApr 12, 2025
)
raiseRuntimeError(
f"Failed to convert tensor of type '{tensor.dtype}' to OrtValue. "
"Please ensure that ONNX Runtime is built with DLPack support or is the latest version"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

We should be specific to the version.

Copy link
CollaboratorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I can create a follow up. I wanted to mention a newer version that doesn't exist yet that can handle things better

@justinchuby
Copy link
CollaboratorAuthor

@pytorchbot merge

pytorch-bot[bot] reacted with thumbs up emoji

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in thewiki.

Questions? Feedback? Please reach out to thePyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

timocafe pushed a commit to timocafe/pytorch that referenced this pull requestApr 16, 2025
- Added a test to guard bfloat16. The optimizer incorrectly turns bfloat16 initializers into uint16, but this is not relevant to export logic.- Fix bfloat16 support in onnx_program callableTested with the following with cuda```pyimport torchclass BfloatModel(torch.nn.Module):    def __init__(self):        super().__init__()        self.param = torch.nn.Parameter(torch.tensor(2.0, dtype=torch.bfloat16))    def forward(self, x):        return x * torch.tensor(1.0, dtype=torch.bfloat16) * self.paraminput = torch.randn(1, 10, dtype=torch.bfloat16)model = BfloatModel()onnx_program = torch.onnx.export(model, (input,), dynamo=True, optimize=False, verify=True)```Pull Requestresolved:pytorch#151121Approved by:https://github.com/titaiwangmsCo-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
amathewc pushed a commit to amathewc/pytorch that referenced this pull requestApr 17, 2025
- Added a test to guard bfloat16. The optimizer incorrectly turns bfloat16 initializers into uint16, but this is not relevant to export logic.- Fix bfloat16 support in onnx_program callableTested with the following with cuda```pyimport torchclass BfloatModel(torch.nn.Module):    def __init__(self):        super().__init__()        self.param = torch.nn.Parameter(torch.tensor(2.0, dtype=torch.bfloat16))    def forward(self, x):        return x * torch.tensor(1.0, dtype=torch.bfloat16) * self.paraminput = torch.randn(1, 10, dtype=torch.bfloat16)model = BfloatModel()onnx_program = torch.onnx.export(model, (input,), dynamo=True, optimize=False, verify=True)```Pull Requestresolved:pytorch#151121Approved by:https://github.com/titaiwangmsCo-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

Reviewers

Copilot code reviewCopilotCopilot left review comments

@titaiwangmstitaiwangmstitaiwangms approved these changes

@xaduprexadupreAwaiting requested review from xadupre

Assignees

No one assigned

Labels

ciflow/trunkTrigger trunk jobs on your pull requestMergedmodule: onnxRelated to torch.onnxopen sourcerelease notes: onnxtorch.onnx related changes that should show up in the release notestopic: bug fixestopic category

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

4 participants

@justinchuby@pytorchmergebot@titaiwangms@pytorchbot

[8]ページ先頭

©2009-2025 Movatter.jp