Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

PR #10559 failed to fix einsum (optimize=True) broadcasting bug #10930

Closed
@rsokl

Description

@rsokl

#10559 introduced thefollowing code to prevent dispatchingnumpy.tensordot in a case where einsum was broadcasting over a singleton dimension.

# Handle broadcasting vs BLAS casesifblas:# Checks have already been handledinput_str,results_index=einsum_str.split('->')input_left,input_right=input_str.split(',')if1intmp_operands[0]or1intmp_operands[1]:left_dims= {dim:sizefordim,sizeinzip(input_left,tmp_operands[0].shape)}right_dims= {dim:sizefordim,sizeinzip(input_right,tmp_operands[1].shape)}# If dims do not match we are broadcasting, BLAS offifany(left_dims[ind]!=right_dims[ind]forindinidx_rm):blas=False

However, this checks to see if1 occurs within the operand array itself rather than the shape of the operand. Incidentally, this likely produced a nasty performance regression.

Thus the line

if 1 in tmp_operands[0] or 1 in tmp_operands[1]

should be

if 1 in tmp_operands[0].shape or 1 in tmp_operands[1].shape

This wasn't caught by theunit test because arrays of ones were used 🌌

This leads to the following behavior:

>>>x=np.array([0.,1.,0.])# contains 1, no blas>>>y=np.array([0.0])>>>np.einsum("i,i",x,y,optimize=True)0.
>>>x=np.array([0.,-1.,0.])# doesn't contain 1, yes blas>>>y=np.array([0.0])>>>np.einsum("i,i",x,y,optimize=True)---------------------------------------------------------------------------ValueErrorTraceback (mostrecentcalllast)<ipython-input-184-b0dcea8eedea>in<module>()1x=np.array([0.,-1.,0.])2y=np.array([0.0])---->3np.einsum("i,i",x,y,optimize=True)c:\anaconda\envs\py36\lib\site-packages\numpy\core\einsumfunc.pyineinsum(*operands,**kwargs)11321133# Contract!->1134new_view=tensordot(*tmp_operands,axes=(tuple(left_pos),tuple(right_pos)))11351136# Build a new view if neededc:\anaconda\envs\py36\lib\site-packages\numpy\core\numeric.pyintensordot(a,b,axes)1281axes_b[k]+=ndb1282ifnotequal:->1283raiseValueError("shape-mismatch for sum")12841285# Move the axes to sum over to the end of "a"ValueError:shape-mismatchforsum

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions


      [8]ページ先頭

      ©2009-2025 Movatter.jp