- Notifications
You must be signed in to change notification settings - Fork151
Fix non-inplace IfElse on numba mode#1765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
base:main
Are you sure you want to change the base?
Fix non-inplace IfElse on numba mode#1765
Uh oh!
There was an error while loading.Please reload this page.
Conversation
ricardoV94 left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Nice trick with the list, I was thinking we would need to use codegen due to numba limitations.
Left some comments.
We also need to work on the existing tests so they would have failed before the fix and pass now. We have to test both single and multi output and inplace or not
| # Return a tuple of copies | ||
| out= [None]*n_outs | ||
| foriinrange(n_outs): | ||
| out[i]=selected[i].copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
We should only make a copy if not op.inplace
| defifelse(cond,*args): | ||
| ifcond: | ||
| res=args[:n_outs] | ||
| arr=args[0] |
ricardoV94Dec 3, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
This case can be simplified to have signatureifelse(cond, if_true, if_false), without need for indexing internally.
It's unrelated to the copy change
| # Return a copy | ||
| returnarr.copy() | ||
| returnifelse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
| returnifelse | |
| cache_version=1 | |
| returnifelse,cache_version |
We need to tell PyTensor the implementation of ifelse changed to invalidate old caches
ricardoV94 commentedDec 3, 2025
The list trick didn't work. I suspect that would happen. We need to use codegen for the non inplace version of ifelse with multiple outputs. Other cases will work fine. You can see some cases of codegen for the numba funcify |
emekaokoli19 commentedDec 4, 2025
Hey@ricardoV94, I have added codegen. A lot of tests are failing in |
| ifn_outs==1: | ||
| @numba_basic.numba_njit | ||
| defifelse(cond,*args): | ||
| ifcond: | ||
| res=args[:n_outs] | ||
| else: | ||
| res=args[n_outs:] | ||
| defifelse(cond,x_true,x_false): | ||
| arr=x_trueifcondelsex_false | ||
| returnarrifas_viewelsearr.copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
On second thought I guess we can get rid of the special case and stay with the codegen for every case now
| ifelse_numba=numba_basic.numba_njit(ifelse_py) | ||
| returnres[0] | ||
| cache_version=3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
You probably needed to bump a few times, but for the PR we should only bump once. You can erase your previous cache with pytensor-cache purge for local testing
| cache_version=3 | |
| cache_version=1 |
tests/link/numba/test_compile_ops.py Outdated
| assertr2isnotb | ||
| deftest_ifelse_false_branch(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
You can merge this test with the previous ones. Just eval the function twice, in a way that triggers the different branches.
tests/link/numba/test_compile_ops.py Outdated
| y=pt.vector("y") | ||
| out1,out2=ifelse(x.sum()>0, (x,y), (y,x)) | ||
| fn=function([x,y], [out1,out2],mode=Mode("numba",optimizer=None)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Parametrize this and the single output test to test inplace and not inplace. You can create IfElse inplace manually likeIfElse(as_view=True|False, n_outs=2), and passaccept_inplace=Truetofunction`.
We want to make sure that r1 is a, r2 is b in that case. Right now we are never testing the inplace mode
emekaokoli19 commentedDec 4, 2025
Explanation for failing |
ricardoV94 commentedDec 4, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
Ah you need to set the borrow flags to allow pytensor to pass back inputs unalterated: importnumpyasnpimportpytensorimportpytensor.tensorasptx=pt.vector("x")fn=pytensor.function([pytensor.In(x,borrow=True)],pytensor.Out(x,borrow=True),mode="NUMBA")x_test=np.zeros(5)fn(x_test)isx_test You can check that without that the final function would have a deepcopy Op, using |
Description
This PR fixes an issue in the
IfElsenumba, where outputs were returned as direct references to the input arrays instead of copies. This violated the semantics ofifelse, which guarantees that the returned value is a distinct object, even when both branches reference the same input.The fix ensures that each selected output is explicitly copied. This matches the behavior of the Python linker and prevents unexpected mutations when the NumPy arrays are assumed to be non-shared.
Related Issue
Checklist
Type of change