Uh oh!
There was an error while loading.Please reload this page.
- Notifications
You must be signed in to change notification settings - Fork11.9k
Closed
Milestone
Description
#10559 introduced thefollowing code to prevent dispatchingnumpy.tensordot in a case where einsum was broadcasting over a singleton dimension.
# Handle broadcasting vs BLAS casesifblas:# Checks have already been handledinput_str,results_index=einsum_str.split('->')input_left,input_right=input_str.split(',')if1intmp_operands[0]or1intmp_operands[1]:left_dims= {dim:sizefordim,sizeinzip(input_left,tmp_operands[0].shape)}right_dims= {dim:sizefordim,sizeinzip(input_right,tmp_operands[1].shape)}# If dims do not match we are broadcasting, BLAS offifany(left_dims[ind]!=right_dims[ind]forindinidx_rm):blas=False
However, this checks to see if1 occurs within the operand array itself rather than the shape of the operand. Incidentally, this likely produced a nasty performance regression.
Thus the line
if 1 in tmp_operands[0] or 1 in tmp_operands[1]
should be
if 1 in tmp_operands[0].shape or 1 in tmp_operands[1].shape
This wasn't caught by theunit test because arrays of ones were used 🌌
This leads to the following behavior:
>>>x=np.array([0.,1.,0.])# contains 1, no blas>>>y=np.array([0.0])>>>np.einsum("i,i",x,y,optimize=True)0.
>>>x=np.array([0.,-1.,0.])# doesn't contain 1, yes blas>>>y=np.array([0.0])>>>np.einsum("i,i",x,y,optimize=True)---------------------------------------------------------------------------ValueErrorTraceback (mostrecentcalllast)<ipython-input-184-b0dcea8eedea>in<module>()1x=np.array([0.,-1.,0.])2y=np.array([0.0])---->3np.einsum("i,i",x,y,optimize=True)c:\anaconda\envs\py36\lib\site-packages\numpy\core\einsumfunc.pyineinsum(*operands,**kwargs)11321133# Contract!->1134new_view=tensordot(*tmp_operands,axes=(tuple(left_pos),tuple(right_pos)))11351136# Build a new view if neededc:\anaconda\envs\py36\lib\site-packages\numpy\core\numeric.pyintensordot(a,b,axes)1281axes_b[k]+=ndb1282ifnotequal:->1283raiseValueError("shape-mismatch for sum")12841285# Move the axes to sum over to the end of "a"ValueError:shape-mismatchforsum
Metadata
Metadata
Assignees
Labels
No labels