Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[MPS] Add support for two more isin variants#154010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Closed
malfet wants to merge5 commits intogh/malfet/350/basefromgh/malfet/350/head

Conversation

@malfet
Copy link
Contributor

@malfetmalfet commentedMay 21, 2025
edited
Loading

Stack fromghstack (oldest at bottom):

isin_Tensor_Scalar_out is just a redispatch to eq/neq
isin_Scalar_Tensor_out redispatches back to genericisin op, but needs a small tweak to handle float scalars
Make sure thatout is resized to an expected value inisin_Tensor_Tensor_out_mps

Add unittests to validate that, but skip them on MacOS-13, where MPS op just returns garbage

Before this change both of those failed

>>>importtorch>>>t=torch.tensor([0,1,2],device='mps')>>>torch.isin(t,1)Traceback (mostrecentcalllast):File"<stdin>",line1,in<module>NotImplementedError:Theoperator'aten::isin.Tensor_Scalar_out'isnotcurrentlyimplementedfortheMPSdevice.Ifyouwantthisoptobeconsideredforadditionpleasecommentonhttps://github.com/pytorch/pytorch/issues/141287andmentionuse-case,thatresultedinmissingopaswellascommithash3b875c25ea6d8802a0c53af9eb961ddf2f058188.Asatemporaryfix,youcansettheenvironmentvariable`PYTORCH_ENABLE_MPS_FALLBACK=1`tousetheCPUasafallbackforthisop.WARNING:thiswillbeslowerthanrunningnativelyonMPS.>>>torch.isin(1,t)Traceback (mostrecentcalllast):File"<stdin>",line1,in<module>NotImplementedError:Theoperator'aten::isin.Scalar_Tensor_out'isnotcurrentlyimplementedfortheMPSdevice.Ifyouwantthisoptobeconsideredforadditionpleasecommentonhttps://github.com/pytorch/pytorch/issues/141287andmentionuse-case,thatresultedinmissingopaswellascommithash3b875c25ea6d8802a0c53af9eb961ddf2f058188.Asatemporaryfix,youcansettheenvironmentvariable`PYTORCH_ENABLE_MPS_FALLBACK=1`tousetheCPUasafallbackforthisop.WARNING:thiswillbeslowerthanrunningnativelyonMPS.

cc@voznesenskym@penguinwu@EikanWang@jgong5@Guobing-Chen@XiaobingSuper@zhuhaozhe@blzheng@wenzhe-nrv@jiayisunx@ipiszy@chenyang78@kadeng@muchulee8@amjames@chauhang@aakhundov

[ghstack-poisoned]
@malfetmalfet requested a review fromkulinseth as acode ownerMay 21, 2025 03:48
@pytorch-bot
Copy link

pytorch-botbot commentedMay 21, 2025
edited
Loading

🔗 Helpful Links

🧪 See artifacts and rendered test results athud.pytorch.org/pr/154010

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 72 Pending

As of commit64c8b36 with merge base2e56ce0 (image):
💚 Looks good so far! There are no failures yet. 💚

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-botpytorch-botbot added the ciflow/mpsRun MPS tests (subset of trunk) labelMay 21, 2025
malfet added a commit that referenced this pull requestMay 21, 2025
`isin_Tensor_Scalar_out` is just a redispatch to eq/neq`isin_Scalar_Tensor_out` redispatches back to generic `isin` opAdd unittests to validate thatBefore this change both of those failed```python>>> import torch>>> t = torch.tensor([0, 1, 2], device='mps')>>> torch.isin(t, 1)Traceback (most recent call last):  File "<stdin>", line 1, in <module>NotImplementedError: The operator 'aten::isin.Tensor_Scalar_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on#141287 and mention use-case, that resulted in missing op as well as commit hash3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.>>> torch.isin(1, t)Traceback (most recent call last):  File "<stdin>", line 1, in <module>NotImplementedError: The operator 'aten::isin.Scalar_Tensor_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on#141287 and mention use-case, that resulted in missing op as well as commit hash3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.```ghstack-source-id:9c8eb66Pull Requestresolved:#154010
@github-actions
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. Seehttps://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@malfetmalfet added topic: improvementstopic category release notes: mpsRelease notes category labelsMay 21, 2025
@albanDalbanD removed their request for reviewMay 21, 2025 17:28
[ghstack-poisoned]
malfet added a commit that referenced this pull requestMay 21, 2025
`isin_Tensor_Scalar_out` is just a redispatch to eq/neq`isin_Scalar_Tensor_out` redispatches back to generic `isin` opAdd unittests to validate thatBefore this change both of those failed```python>>> import torch>>> t = torch.tensor([0, 1, 2], device='mps')>>> torch.isin(t, 1)Traceback (most recent call last):  File "<stdin>", line 1, in <module>NotImplementedError: The operator 'aten::isin.Tensor_Scalar_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on#141287 and mention use-case, that resulted in missing op as well as commit hash3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.>>> torch.isin(1, t)Traceback (most recent call last):  File "<stdin>", line 1, in <module>NotImplementedError: The operator 'aten::isin.Scalar_Tensor_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on#141287 and mention use-case, that resulted in missing op as well as commit hash3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.```ghstack-source-id:123db4ePull Requestresolved:#154010
[ghstack-poisoned]
malfet added a commit that referenced this pull requestMay 22, 2025
`isin_Tensor_Scalar_out` is just a redispatch to eq/neq`isin_Scalar_Tensor_out` redispatches back to generic `isin` opAdd unittests to validate thatBefore this change both of those failed```python>>> import torch>>> t = torch.tensor([0, 1, 2], device='mps')>>> torch.isin(t, 1)Traceback (most recent call last):  File "<stdin>", line 1, in <module>NotImplementedError: The operator 'aten::isin.Tensor_Scalar_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on#141287 and mention use-case, that resulted in missing op as well as commit hash3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.>>> torch.isin(1, t)Traceback (most recent call last):  File "<stdin>", line 1, in <module>NotImplementedError: The operator 'aten::isin.Scalar_Tensor_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on#141287 and mention use-case, that resulted in missing op as well as commit hash3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.```ghstack-source-id:99c45aePull Requestresolved:#154010
[ghstack-poisoned]
malfet added a commit that referenced this pull requestMay 22, 2025
`isin_Tensor_Scalar_out` is just a redispatch to eq/neq`isin_Scalar_Tensor_out` redispatches back to generic `isin` opAdd unittests to validate thatBefore this change both of those failed```python>>> import torch>>> t = torch.tensor([0, 1, 2], device='mps')>>> torch.isin(t, 1)Traceback (most recent call last):  File "<stdin>", line 1, in <module>NotImplementedError: The operator 'aten::isin.Tensor_Scalar_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on#141287 and mention use-case, that resulted in missing op as well as commit hash3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.>>> torch.isin(1, t)Traceback (most recent call last):  File "<stdin>", line 1, in <module>NotImplementedError: The operator 'aten::isin.Scalar_Tensor_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on#141287 and mention use-case, that resulted in missing op as well as commit hash3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.```ghstack-source-id:96faf9aPull Requestresolved:#154010
[ghstack-poisoned]
malfet added a commit that referenced this pull requestMay 22, 2025
`isin_Tensor_Scalar_out` is just a redispatch to eq/neq`isin_Scalar_Tensor_out` redispatches back to generic `isin` opAdd unittests to validate thatBefore this change both of those failed```python>>> import torch>>> t = torch.tensor([0, 1, 2], device='mps')>>> torch.isin(t, 1)Traceback (most recent call last):  File "<stdin>", line 1, in <module>NotImplementedError: The operator 'aten::isin.Tensor_Scalar_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on#141287 and mention use-case, that resulted in missing op as well as commit hash3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.>>> torch.isin(1, t)Traceback (most recent call last):  File "<stdin>", line 1, in <module>NotImplementedError: The operator 'aten::isin.Scalar_Tensor_out' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on#141287 and mention use-case, that resulted in missing op as well as commit hash3b875c2. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.```ghstack-source-id:e60527dPull Requestresolved:#154010
@malfet
Copy link
ContributorAuthor

@pytorchbot merge -f "Roses are red, violets are blue, I want to land this PR now"

pytorch-bot[bot] reacted with thumbs up emoji

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag,bypassing any CI checks (ETA: 1-5 minutes). Please use-f as last resort and instead consider-i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in thewiki.

Questions? Feedback? Please reach out to thePyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

Reviewers

@Skylion007Skylion007Skylion007 approved these changes

@dccidccidcci approved these changes

@manuelcandalesmanuelcandalesmanuelcandales approved these changes

@kulinsethkulinsethAwaiting requested review from kulinseth

Assignees

No one assigned

Labels

autoformatciflow/inductorciflow/mpsRun MPS tests (subset of trunk)Mergedmodule: inductorrelease notes: mpsRelease notes categorytopic: improvementstopic category

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

6 participants

@malfet@pytorchmergebot@Skylion007@dcci@manuelcandales

[8]ページ先頭

©2009-2025 Movatter.jp