Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitee6bb20

Browse files
authored
Fix linearly indexed array math by reshaping arrays (#60164)
This was broken in#59961, as`map` deals with trailing singleton axes differently from broadcasting:```juliajulia> map(+, ones(1), ones(1,1)) |> size(1,)julia> broadcast(+, ones(1), ones(1,1)) |> size(1, 1)```This PR limits the new method to the case where the ndims match, inwhich case there are no trailing axes and the two are equivalent. Thealternate approach suggested in#59961 (comment) isto reshape the arrays, but this adds overhead that nullifies theperformance improvement for small arrays.
1 parentcab2c74 commitee6bb20

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

‎base/arraymath.jl‎

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,19 @@ end
99
# Using map over broadcast enables vectorization for wide matrices with few rows.
1010
# This is because we use linear indexing in `map` as opposed to Cartesian indexing in broadcasting.
1111
# https://github.com/JuliaLang/julia/issues/47873#issuecomment-1352472461
12-
function_broadcast_preserving_zero_d(f, A::Array, B::Array)
13-
map(f, A, B)
12+
function_broadcast_preserving_zero_d(f, A::Array{<:Any,N}, B::Array{<:Any,N}, Cs::Array{<:Any,N}...)where {N}
13+
map(f, A, B, Cs...)
14+
end
15+
16+
function_broadcast_preserving_zero_d(f, A::Array, B::Array, Cs::Array...)
17+
# we already know that the shapes are compatible.
18+
# We just need to select the size corresponding to the higest ndims
19+
# and reshape all the arrays to that size
20+
arrays= (A, B, Cs...)
21+
sz=mapreduce(size, (x,y)->length(x)>length(y)? x: y, arrays)
22+
# Skip reshaping where possible to avoid the overhead
23+
arrays_sameshape=map(x->length(sz)==ndims(x)? x:reshape(x, sz), arrays)
24+
map(f, arrays_sameshape...)
1425
end
1526

1627
function_broadcast_preserving_zero_d(f, A::Array, B::Number)
@@ -28,11 +39,12 @@ for f in (:+, :-)
2839
end
2940
end
3041

31-
function+(A::Array, Bs::Array...)
32-
for Bin Bs
33-
promote_shape(A, B)# check size compatibility
42+
function+(A::Array, B::Array, Cs::Array...)
43+
promote_shape(A, B)
44+
for Cin Cs
45+
promote_shape(A, C)# check size compatibility
3446
end
35-
map(+, A,Bs...)
47+
_broadcast_preserving_zero_d(+, A,B, Cs...)
3648
end
3749

3850
for fin (:/, :\, :*)

‎test/abstractarray.jl‎

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2376,3 +2376,16 @@ end
23762376
show(io, m2)
23772377
@testString(take!(io))=="Any[#= circular reference @-1 =# 3; 2 4;;; 5 7; 6 8]"
23782378
end
2379+
2380+
@testset"size promotion in addition/subtraction"begin
2381+
for Ain Any[ones(),ones(1),ones(1,1,1)]
2382+
@test+(A)== A
2383+
for Bin Any[ones(),ones(1),ones(1,1,1)]
2384+
sz=ndims(A)>ndims(B)?size(A):size(B)
2385+
@test A+ B==fill(2.0,sz)
2386+
@test A- B==zeros(sz)
2387+
@test A+ B+zeros()== A+ B
2388+
@test A- B-zeros()== A- B
2389+
end
2390+
end
2391+
end

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp