Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweak bookkeeping in LOBPCG #980

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions src/eigen/lobpcg_hyper_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,21 +298,17 @@ end
end


function final_retval(X, AX, BX, resid_history, niter, n_matvec)
λ = @views [real((X[:, n]'*AX[:, n]) / (X[:, n]'BX[:, n])) for n=1:size(X, 2)]
λ_device = oftype(X[:, 1], λ) # Offload to GPU if needed
residuals = AX .- BX .* λ_device'
function final_retval(X, AX, BX, λ, resid_history, niter, n_matvec)
if !issorted(λ)
p = sortperm(λ)
λ = λ[p]
residuals = residuals[:, p]
X = X[:, p]
AX = AX[:, p]
BX = BX[:, p]
resid_history = resid_history[p, :]
end
(; λ=λ_device, X, AX, BX,
residual_norms=norm.(eachcol(residuals)),
(; λ=λ, X, AX, BX,
residual_norms=resid_history[:, niter+1],
residual_history=resid_history[:, 1:niter+1], n_matvec)
end

Expand Down Expand Up @@ -368,14 +364,15 @@ end
nlocked = 0
niter = 0 # the first iteration is fake
λs = @views [real((X[:, n]'*AX[:, n]) / (X[:, n]'BX[:, n])) for n=1:M]
λs = oftype(X[:, 1], λs) # Offload to GPU if needed
λs = oftype(real(X[:, 1]), λs) # Offload to GPU if needed
new_X = X
new_AX = AX
new_BX = BX
# The full_ arrays contain all the vectors, the others only get the active ones
full_X = X
full_AX = AX
full_BX = BX
full_λs = λs

while true
if niter > 0 # first iteration is just to compute the residuals (no X update)
Expand All @@ -393,7 +390,8 @@ end
AY = LazyHcat(AX, AR)
BY = LazyHcat(BX, BR) # data shared with (X, R) in non-general case
end
cX, λs = rayleigh_ritz(Y, AY, M-nlocked)
cX, λs_RR = rayleigh_ritz(Y, AY, M-nlocked)
λs .= λs_RR

# Update X. By contrast to some other implementations, we
# wait on updating P because we have to know which vectors
Expand Down Expand Up @@ -446,7 +444,7 @@ end
if nlocked >= n_conv_check # Converged!
X .= new_X # Update the part of X which is still active
AX .= new_AX
return final_retval(full_X, full_AX, full_BX, resid_history, niter, n_matvec)
return final_retval(full_X, full_AX, full_BX, full_λs, resid_history, niter, n_matvec)
end
newly_locked = nlocked - prev_nlocked
active = newly_locked+1:size(X,2) # newly active vectors
Expand Down Expand Up @@ -531,9 +529,9 @@ end
B_ortho!(R, BR)
end

niter < maxiter || break
niter >= maxiter && break
niter = niter + 1
end

final_retval(full_X, full_AX, full_BX, resid_history, maxiter, n_matvec)
final_retval(full_X, full_AX, full_BX, full_λs, resid_history, maxiter, n_matvec)
end
Loading