Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit439e7f2

Browse files
Fix GPU expv! allocation in hermitian branch for Real t
Use lmul!(t, F.values) for Real t to avoid allocation in the hermitianbranch. For Complex t, allocation is unavoidable since F.values is Realand multiplying by complex t requires type conversion.🤖 Generated with [Claude Code](https://claude.com/claude-code)Co-Authored-By: Claude <noreply@anthropic.com>
1 parentc8d5d9e commit439e7f2

File tree

1 file changed

+35
-15
lines changed

1 file changed

+35
-15
lines changed

‎src/krylov_phiv.jl‎

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,11 @@ function expv!(w::AbstractVector{Complex{Tw}}, t::Complex{Tt}, Ks::KrylovSubspac
129129
lmul!(beta,mul!(w,@view(V[:,1:m]),compatible_multiplicative_operand(V, expHe)))# exp(A) ≈ norm(b) * V * exp(H)e
130130
end
131131

132-
# Internal GPU implementation shared by Real and Complex t methods
133-
function_expv_gpu_impl!(w::GPUArraysCore.AbstractGPUVector, t, Ks::KrylovSubspace{T, U},
134-
cache, expmethod)where {T, U}
132+
# GPU expv! for Real t (non-allocating in hermitian branch)
133+
function ExponentialUtilities.expv!(w::GPUArraysCore.AbstractGPUVector{Tw},
134+
t::Real, Ks::KrylovSubspace{T, U};
135+
cache=nothing,
136+
expmethod=ExpMethodHigham2005Base())where {Tw, T, U}
135137
m, beta, V, H= Ks.m, Ks.beta,getV(Ks),getH(Ks)
136138
@assertlength(w)==size(V,1)"Dimension mismatch"
137139
ifisnothing(cache)
@@ -149,29 +151,47 @@ function _expv_gpu_impl!(w::GPUArraysCore.AbstractGPUVector, t, Ks::KrylovSubspa
149151
ifishermitian(cache)
150152
# Optimize the case for symtridiagonal H
151153
F=eigen!(SymTridiagonal(cache))
152-
expHe= F.vectors* (exp.(t* F.values).*@view(F.vectors[1, :]))
154+
# Use lmul! to avoid allocation (modifies F.values in place)
155+
expHe= F.vectors* (exp.(lmul!(t, F.values)).*@view(F.vectors[1, :]))
153156
else
154-
expH=exponential!(t* cache, expmethod)
157+
lmul!(t, cache)
158+
expH=exponential!(cache, expmethod)
155159
expHe=@view(expH[:,1])
156160
end
157161

158162
lmul!(beta,mul!(w,@view(V[:,1:m]), Adapt.adapt(parameterless_type(w), expHe)))# exp(A) ≈ norm(b) * V * exp(H)e
159163
end
160164

161-
# GPU expv! for Real t
162-
function ExponentialUtilities.expv!(w::GPUArraysCore.AbstractGPUVector{Tw},
163-
t::Real, Ks::KrylovSubspace{T, U};
164-
cache=nothing,
165-
expmethod=ExpMethodHigham2005Base())where {Tw, T, U}
166-
_expv_gpu_impl!(w, t, Ks, cache, expmethod)
167-
end
168-
169-
# GPU expv! for Complex t
165+
# GPU expv! for Complex t (allocates in hermitian branch due to Real->Complex conversion)
170166
function ExponentialUtilities.expv!(w::GPUArraysCore.AbstractGPUVector{Complex{Tw}},
171167
t::Complex{Tt}, Ks::KrylovSubspace{T, U};
172168
cache=nothing,
173169
expmethod=ExpMethodHigham2005Base())where {Tw, Tt, T, U}
174-
_expv_gpu_impl!(w, t, Ks, cache, expmethod)
170+
m, beta, V, H= Ks.m, Ks.beta,getV(Ks),getH(Ks)
171+
@assertlength(w)==size(V,1)"Dimension mismatch"
172+
ifisnothing(cache)
173+
cache=Matrix{U}(undef, m, m)
174+
elseifisa(cache, ExpvCache)
175+
cache=get_cache(cache, m)
176+
else
177+
throw(ArgumentError("Cache must be an ExpvCache"))
178+
end
179+
ifiszero(Ks.beta)
180+
w .=false
181+
return w
182+
end
183+
copyto!(cache,@view(H[1:m, :]))
184+
ifishermitian(cache)
185+
# Optimize the case for symtridiagonal H
186+
F=eigen!(SymTridiagonal(cache))
187+
# Must allocate here: F.values is Real, t is Complex
188+
expHe= F.vectors* (exp.(t* F.values).*@view(F.vectors[1, :]))
189+
else
190+
expH=exponential!(t* cache, expmethod)
191+
expHe=@view(expH[:,1])
192+
end
193+
194+
lmul!(beta,mul!(w,@view(V[:,1:m]), Adapt.adapt(parameterless_type(w), expHe)))# exp(A) ≈ norm(b) * V * exp(H)e
175195
end
176196

177197
compatible_multiplicative_operand(::AbstractArray, source::AbstractArray)= source

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp