- Notifications
You must be signed in to change notification settings - Fork151
Handle slices inmlx_funcify_IncSubtensor#1692
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
base:main
Are you sure you want to change the base?
Uh oh!
There was an error while loading.Please reload this page.
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Pull Request Overview
This PR fixes a bug in the MLX backend'sIncSubtensor dispatch where it incorrectly assumed all indices would be integers. The fix adds logic to properly handle slice objects by converting their start/stop/step components from potentially symbolic values to actual integers while preservingNone values.
Key Changes
- Added
get_slice_inthelper function to safely convert slice components to integers - Modified index processing to reconstruct slice objects with integer components
- Added comprehensive test coverage for various slice patterns (positive, negative, step-based, and full slices)
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| pytensor/link/mlx/dispatch/subtensor.py | Addedget_slice_int helper and slice reconstruction logic to handle both integer and slice indices |
| tests/link/mlx/test_subtensor.py | Added four test cases covering different slice patterns to verify the fix |
| returnNone | ||
| try: | ||
| returnint(element) | ||
| exceptException: |
CopilotAIOct 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Using a bareexcept Exception is too broad. This should catch specific exceptions likeTypeError orValueError that would occur when trying to convert a non-integer value. The current implementation could mask unexpected errors.
| exceptException: | |
| except(TypeError,ValueError): |
Uh oh!
There was an error while loading.Please reload this page.
ricardoV94 commentedOct 24, 2025
So MLX is okay with The helper definitely doesn't ignore slices, I think you misread the error message. And I guess this would also apply to the regular Subtensor not just IncSubtensor? And AdvancedSubtensor as well, unless we typify the constant slices with integers? Because all those use numpy integers internally. This may be worth opening an issue with them, even if we patch here. |
jessegrabowski commentedOct 26, 2025
Neither of these cases work: importmlx.coreasmximportnumpyasnpx=mx.array([1,2,3])x[np.int64(1)]---------------------------------------------------------------------------ValueErrorTraceback (mostrecentcalllast)CellIn[42],line42importnumpyasnp3x=mx.array([1,2,3])---->4x[np.int64(1)]ValueError:Cannotindexmlxarrayusingthegiven type. x[np.int64(1):]ValueErrorTraceback (mostrecentcalllast)CellIn[43],line42importnumpyasnp3x=mx.array([1,2,3])---->4x[np.int64(1):]ValueError:SliceindicesmustbeintegersorNone. |
ricardoV94 commentedOct 26, 2025
So we should open an issue with them and I'm surprised any indexing worked before. We're casting constant numpy integers to python int? If so we should do something for the tipefy of slices |
ricardoV94 commentedOct 27, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
This is how the Subtensor is working around it:
Should do the same thing in both Ops for consistency. The Subtensor approach seems simpler. Perhaps put in a helper just so we can document why this is needed for future developers |
cetagostini commentedOct 28, 2025
More don't wanted help here#1702 |
Introduces normalize_indices_for_mlx to convert NumPy integer and floating types, MLX scalar arrays, and slice components to Python int/float for MLX compatibility. Updates all MLX subtensor dispatch functions to use this normalization, resolving issues with MLX's strict indexing requirements. Adds comprehensive tests for np.int64 indices and slices in subtensor and inc_subtensor operations, including advanced indexing scenarios.
Appended a newline to the end of subtensor.py and test_subtensor.py to conform with POSIX standards and improve code consistency.
cetagostini commentedOct 28, 2025
The math test is kinda flaky, works sometimes others fail. Strange... |
ricardoV94 commentedOct 28, 2025
Yeah we know it. I'll take a look to stop it but it need not block us |
Uh oh!
There was an error while loading.Please reload this page.
| mx=pytest.importorskip("mlx.core") | ||
| deftest_mlx_python_int_indexing(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
why would we be testing mlx directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Ah not at all this was internal for me, let me delete this file. I didn't wanna to push!
cetagostini commentedOct 28, 2025
Issue in MLX:ml-explore/mlx#2710 |
| else: | ||
| returnelement | ||
| indices=indices_from_subtensor(ilist,idx_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
move this outside the helper, as it is only relevant for basic Subtensor/IncSubtensor. The advanced methods don't have anidx_list and don't needindices_from_subtensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Done!
| @mlx_funcify.register(AdvancedSubtensor1) | ||
| defmlx_funcify_AdvancedSubtensor(op,node,**kwargs): | ||
| """MLX implementation of AdvancedSubtensor.""" | ||
| idx_list=getattr(op,"idx_list",None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Not a thing for Advanced indexing
| idx_list = getattr(op, "idx_list", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Done!
| @mlx_funcify.register(Subtensor) | ||
| defmlx_funcify_Subtensor(op,node,**kwargs): | ||
| """MLX implementation of Subtensor.""" | ||
| idx_list=getattr(op,"idx_list",None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Not optional. Better be explicit to reduce confusion. Apply this and the other suggestion in all dispatches
| idx_list=getattr(op,"idx_list",None) | |
| idx_list=op.idx_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Done!
tests/link/mlx/test_subtensor.py Outdated
| ) | ||
| # Advanced indexing set with array indices | ||
| indices= [np.int64(0),np.int64(2)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Is this actually advanced indexing? To be sure make one of the indices a vector array [0, 1, 2, 3]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
You can probably reuse the same sort of indices from the test just above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Good catch corrected.
Simplifies index normalization logic in MLX subtensor dispatch functions by separating basic and advanced indexing cases. Updates the advanced incsubtensor test to use vector array indices and a matching value shape for improved coverage.
Codecov Report❌ Patch coverage is
❌ Your patch check has failed because the patch coverage (84.37%) is below the target coverage (100.00%). You can increase the patch coverage or adjust thetarget coverage. Additional details and impacted files@@ Coverage Diff @@## main #1692 +/- ##==========================================+ Coverage 81.61% 81.69% +0.08%========================================== Files 242 246 +4 Lines 53537 53655 +118 Branches 9433 9443 +10 ==========================================+ Hits 43695 43836 +141+ Misses 7366 7334 -32- Partials 2476 2485 +9
🚀 New features to boost your workflow:
|
| ) | ||
| # Advanced indexing set with vector array indices | ||
| indices=np.array([0,1,2,3],dtype=np.int64) |
ricardoV94Oct 31, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
this is not testing the "issue of having a scalarnp.int64 or slice with annp.int64 entry". My earlier suggestion was to have an array + one of those. The array is what forces it to be "Advanced"
Uh oh!
There was an error while loading.Please reload this page.
Description
The MLX dispatch for
IncSubtensorwas assuming that the indexes would always be integers, but they can actually be either integers or slices. This PR adds logic to handle the slice case.Related Issue
Incsubtensorfails on slices #1690mlx#1350Checklist
Type of change
📚 Documentation preview 📚:https://pytensor--1692.org.readthedocs.build/en/1692/