Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Handle slices inmlx_funcify_IncSubtensor#1692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Open
jessegrabowski wants to merge10 commits intopymc-devs:main
base:main
Choose a base branch
Loading
fromjessegrabowski:mlx-incsubtensor

Conversation

@jessegrabowski
Copy link
Member

@jessegrabowskijessegrabowski commentedOct 24, 2025
edited by github-actionsbot
Loading

Description

The MLX dispatch forIncSubtensor was assuming that the indexes would always be integers, but they can actually be either integers or slices. This PR adds logic to handle the slice case.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚:https://pytensor--1692.org.readthedocs.build/en/1692/

Copy link

CopilotAI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Pull Request Overview

This PR fixes a bug in the MLX backend'sIncSubtensor dispatch where it incorrectly assumed all indices would be integers. The fix adds logic to properly handle slice objects by converting their start/stop/step components from potentially symbolic values to actual integers while preservingNone values.

Key Changes

  • Addedget_slice_int helper function to safely convert slice components to integers
  • Modified index processing to reconstruct slice objects with integer components
  • Added comprehensive test coverage for various slice patterns (positive, negative, step-based, and full slices)

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

FileDescription
pytensor/link/mlx/dispatch/subtensor.pyAddedget_slice_int helper and slice reconstruction logic to handle both integer and slice indices
tests/link/mlx/test_subtensor.pyAdded four test cases covering different slice patterns to verify the fix

returnNone
try:
returnint(element)
exceptException:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Using a bareexcept Exception is too broad. This should catch specific exceptions likeTypeError orValueError that would occur when trying to convert a non-integer value. The current implementation could mask unexpected errors.

Suggested change
exceptException:
except(TypeError,ValueError):

Copilot uses AI. Check for mistakes.
@ricardoV94
Copy link
Member

So MLX is okay withx[np.int64(1)], but notx[np.int64(1):]?

The helper definitely doesn't ignore slices, I think you misread the error message.

And I guess this would also apply to the regular Subtensor not just IncSubtensor? And AdvancedSubtensor as well, unless we typify the constant slices with integers? Because all those use numpy integers internally.

This may be worth opening an issue with them, even if we patch here.

@jessegrabowski
Copy link
MemberAuthor

So MLX is okay with x[np.int64(1)], but not x[np.int64(1):]?

Neither of these cases work:

importmlx.coreasmximportnumpyasnpx=mx.array([1,2,3])x[np.int64(1)]---------------------------------------------------------------------------ValueErrorTraceback (mostrecentcalllast)CellIn[42],line42importnumpyasnp3x=mx.array([1,2,3])---->4x[np.int64(1)]ValueError:Cannotindexmlxarrayusingthegiven type.
x[np.int64(1):]ValueErrorTraceback (mostrecentcalllast)CellIn[43],line42importnumpyasnp3x=mx.array([1,2,3])---->4x[np.int64(1):]ValueError:SliceindicesmustbeintegersorNone.

@ricardoV94
Copy link
Member

So we should open an issue with them and I'm surprised any indexing worked before. We're casting constant numpy integers to python int? If so we should do something for the tipefy of slices

cetagostini reacted with eyes emoji

@ricardoV94
Copy link
Member

ricardoV94 commentedOct 27, 2025
edited
Loading

This is how the Subtensor is working around it:

indices=indices_from_subtensor([int(element)forelementinilists],idx_list)

Should do the same thing in both Ops for consistency. The Subtensor approach seems simpler.

Perhaps put in a helper just so we can document why this is needed for future developers

cetagostini reacted with eyes emoji

@cetagostini
Copy link
Contributor

More don't wanted help here#1702

@jessegrabowski@ricardoV94

Introduces normalize_indices_for_mlx to convert NumPy integer and floating types, MLX scalar arrays, and slice components to Python int/float for MLX compatibility. Updates all MLX subtensor dispatch functions to use this normalization, resolving issues with MLX's strict indexing requirements. Adds comprehensive tests for np.int64 indices and slices in subtensor and inc_subtensor operations, including advanced indexing scenarios.
Appended a newline to the end of subtensor.py and test_subtensor.py to conform with POSIX standards and improve code consistency.
@cetagostini
Copy link
Contributor

The math test is kinda flaky, works sometimes others fail. Strange...

@ricardoV94
Copy link
Member

Yeah we know it. I'll take a look to stop it but it need not block us

mx=pytest.importorskip("mlx.core")


deftest_mlx_python_int_indexing():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

why would we be testing mlx directly?

cetagostini reacted with thumbs up emoji
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Ah not at all this was internal for me, let me delete this file. I didn't wanna to push!

ricardoV94 reacted with thumbs up emoji
@cetagostini
Copy link
Contributor

Issue in MLX:ml-explore/mlx#2710

else:
returnelement

indices=indices_from_subtensor(ilist,idx_list)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

move this outside the helper, as it is only relevant for basic Subtensor/IncSubtensor. The advanced methods don't have anidx_list and don't needindices_from_subtensor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Done!

@mlx_funcify.register(AdvancedSubtensor1)
defmlx_funcify_AdvancedSubtensor(op,node,**kwargs):
"""MLX implementation of AdvancedSubtensor."""
idx_list=getattr(op,"idx_list",None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Not a thing for Advanced indexing

Suggested change
idx_list = getattr(op, "idx_list", None)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Done!

@mlx_funcify.register(Subtensor)
defmlx_funcify_Subtensor(op,node,**kwargs):
"""MLX implementation of Subtensor."""
idx_list=getattr(op,"idx_list",None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Not optional. Better be explicit to reduce confusion. Apply this and the other suggestion in all dispatches

Suggested change
idx_list=getattr(op,"idx_list",None)
idx_list=op.idx_list

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Done!

)

# Advanced indexing set with array indices
indices= [np.int64(0),np.int64(2)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Is this actually advanced indexing? To be sure make one of the indices a vector array [0, 1, 2, 3]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

You can probably reuse the same sort of indices from the test just above

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Good catch corrected.

Simplifies index normalization logic in MLX subtensor dispatch functions by separating basic and advanced indexing cases. Updates the advanced incsubtensor test to use vector array indices and a matching value shape for improved coverage.
@codecov
Copy link

codecovbot commentedOct 30, 2025

Codecov Report

❌ Patch coverage is84.37500% with5 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.69%. Comparing base (17c675a) to head (449c2df).
⚠️ Report is 11 commits behind head on main.

Files with missing linesPatch %Lines
pytensor/link/mlx/dispatch/subtensor.py84.37%3 Missing and 2 partials⚠️

❌ Your patch check has failed because the patch coverage (84.37%) is below the target coverage (100.00%). You can increase the patch coverage or adjust thetarget coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@##             main    #1692      +/-   ##==========================================+ Coverage   81.61%   81.69%   +0.08%==========================================  Files         242      246       +4       Lines       53537    53655     +118       Branches     9433     9443      +10     ==========================================+ Hits        43695    43836     +141+ Misses       7366     7334      -32- Partials     2476     2485       +9
Files with missing linesCoverage Δ
pytensor/link/mlx/dispatch/subtensor.py87.50% <84.37%> (-6.35%)⬇️

... and16 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

)

# Advanced indexing set with vector array indices
indices=np.array([0,1,2,3],dtype=np.int64)
Copy link
Member

@ricardoV94ricardoV94Oct 31, 2025
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

this is not testing the "issue of having a scalarnp.int64 or slice with annp.int64 entry". My earlier suggestion was to have an array + one of those. The array is what forces it to be "Advanced"

cetagostini reacted with thumbs up emoji
Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

Reviewers

@ricardoV94ricardoV94ricardoV94 requested changes

Copilot code reviewCopilotCopilot left review comments

@williambdeanwilliambdeanAwaiting requested review from williambdean

+1 more reviewer

@cetagostinicetagostinicetagostini left review comments

Reviewers whose approvals may not affect merge requirements

Requested changes must be addressed to merge this pull request.

Assignees

@cetagostinicetagostini

Labels

bugSomething isn't workingmlx

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

MLXIncsubtensor fails on slices

3 participants

@jessegrabowski@ricardoV94@cetagostini

[8]ページ先頭

©2009-2025 Movatter.jp