Skip to content

Commit

Permalink
non-allocating pb2 with FD
Browse files Browse the repository at this point in the history
  • Loading branch information
cortner committed Jun 9, 2024
1 parent 3abdd19 commit dcdd4bb
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 9 deletions.
8 changes: 8 additions & 0 deletions src/ace/sparseprodpool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ function pullback_evaluate!(∂BB, ∂A, basis::PooledSparseProduct{2}, BB::TupM
@assert all(size(∂BB[i], 2) >= size(BB[i], 2) for i = 1:NB)
BB1, BB2 = BB
∂BB1, ∂BB2 = ∂BB

for i = 1:length(∂BB)
fill!(∂BB[i], zero(eltype(∂BB[i])))
end

@inbounds for (iA, ϕ) in enumerate(basis.spec)
∂A_iA = ∂A[iA]
Expand Down Expand Up @@ -258,6 +262,10 @@ function pullback_evaluate!(∂BB, ∂A, basis::PooledSparseProduct{3}, BB::TupM
@assert length(∂BB) == NB
end

for i = 1:length(∂BB)
fill!(∂BB[i], zero(eltype(∂BB[i])))
end

B1 = BB[1]; B2 = BB[2]; B3 = BB[3]
∂B1 = ∂BB[1]; ∂B2 = ∂BB[2]; ∂B3 = ∂BB[3]

Expand Down
93 changes: 93 additions & 0 deletions src/generic_ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# ---------------------------------------------------------------
# general rrules and frules interface for AbstractP4MLBasis


function whatalloc(::typeof(pullback_evaluate!),
∂P, basis::AbstractP4MLBasis, X::AbstractVector)
T∂X = promote_type(_gradtype(basis, X), eltype(∂P))
return (T∂X, length(X))
end

function pullback_evaluate!(∂X,
∂P, basis::AbstractP4MLBasis, X::AbstractVector;
dP = evaluate_ed(basis, X)[2] )
@assert size(∂P) == size(dP) == (length(X), length(basis))
@assert length(∂X) == length(X)
# manual loops to avoid any broadcasting of StrideArrays
# ∂_xa ( ∂P : P ) = ∑_ij ∂_xa ( ∂P_ij * P_ij )
# = ∑_ij ∂P_ij * ∂_xa ( P_ij )
# = ∑_ij ∂P_ij * dP_ij δ_ia
for n = 1:size(dP, 2)
@simd ivdep for a = 1:length(X)
∂X[a] += dP[a, n] * ∂P[a, n]
end
end
return ∂X
end

function rrule(::typeof(evaluate),
basis::AbstractP4MLBasis,
X::AbstractVector)
P = evaluate(basis, X)
# TODO: here we could do evaluate_ed, but need to think about how this
# works with the kwarg trick above...
return P, ∂P -> (NoTangent(), NoTangent(), pullback_evaluate(∂P, basis, X))
end


#=
function whatalloc(::typeof(pb_pb_evaluate!),
∂∂X, ∂P, basis::AbstractP4MLBasis, X::AbstractVector)
Nbasis = length(basis)
Nx = length(X)
@assert ∂∂X isa AbstractVector
@assert length(∂∂X) == Nx
@assert size(∂P) == (Nx, Nbasis)
T∂²P = promote_type(_valtype(basis, X), eltype(∂P), eltype(∂∂X))
T∂²X = promote_type(_gradtype(basis, X), eltype(∂P), eltype(∂∂X))
return (T∂²P, Nx, Nbasis), (T∂²X, Nx)
end
function pb_pb_evaluate!(∂²P, ∂²X, # output
∂∂X, # input / perturbation of ∂X
∂P, basis::AbstractP4MLBasis, # inputs
X::AbstractVector{<: Real})
@no_escape begin
P, dP, ddP = @withalloc evaluate_ed2!(basis, X)
for n = 1:Nbasis
@simd ivdep for a = 1:Nx
∂²P[a, n] = ∂∂X[a] * dP[a, n]
∂²X[a] += ∂∂X[a] * ddP[a, n] * ∂P[a, n]
end
end
end
return ∂²P, ∂²X
end
function rrule(::typeof(pullback_evaluate),
∂P, basis::AbstractP4MLBasis, X::AbstractVector{<: Real})
∂X = pullback_evaluate(∂P, basis, X)
function _pb(∂2)
∂∂P, ∂X = pb_pb_evaluate(∂2, ∂P, basis, X)
return NoTangent(), ∂∂P, NoTangent(), ∂X
end
return ∂X, _pb
end
=#


# -------------------------------------------------------------
# general rrules and frules for AbstractP4MLTensor


function rrule(::typeof(evaluate),
basis::AbstractP4MLTensor,
X)
P = evaluate(basis, X)
return P, ∂P -> (NoTangent(), NoTangent(), pullback_evaluate(∂P, basis, X))
end

56 changes: 47 additions & 9 deletions test/ace/test_sparseprodpool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,24 +165,62 @@ function auto_pb_pb(∂BB, ∂A, basis, BB)
# = (∂BB ⋅ ∇_BB) pullback(∂A, basis, BB)
d = Dual{Float64}(0.0, 1.0)
BB_d = ntuple(i -> BB[i] .+ d .* ∂BB[i], length(BB))
A_d = evaluate(basis, BB_d)
∂BB_d = pullback_evaluate(∂A, basis, BB_d)
∇_∂A = extract_derivative.(Float64, A_d)
∇_BB = ntuple(i -> extract_derivative.(Float64, ∂BB_d[i]), length(∂BB_d))
@no_escape begin
A_d = @withalloc evaluate!(basis, BB_d)
∂BB_d = @withalloc pullback_evaluate!(∂A, basis, BB_d)
∇_∂A = extract_derivative.(Float64, A_d)
∇_BB = ntuple(i -> extract_derivative.(Float64, ∂BB_d[i]), length(∂BB_d))
end
return ∇_∂A, ∇_BB
end

using Bumper, WithAlloc

function auto_pb_pb!(∇_∂A, ∇_BB, ∂BB, ∂A, basis, BB)
@assert all(eltype(BB[i]) == eltype(BB[1]) for i = 2:length(BB))
@no_escape begin
T = eltype(BB[1])
d = Dual{T}(zero(T), one(T))
TD = typeof(d)
B1 = BB[1]
B2 = BB[2]
B1_d = @alloc(TD, size(B1)...)
B2_d = @alloc(TD, size(B2)...)
for t = 1:length(B1)
B1_d[t] = B1[t] + d * ∂BB[1][t]
end
for t = 1:length(B2)
B2_d[t] = B2[t] + d * ∂BB[2][t]
end
BB_d = (B1_d, B2_d)
A_d = @withalloc evaluate!(basis, BB_d)
∂BB_d = @withalloc pullback_evaluate!(∂A, basis, BB_d)
for i = 1:length(A_d)
∇_∂A[i] = extract_derivative(T, A_d[i])
end
for i = 1:length(∂BB_d)
for j = 1:length(∂BB_d[i])
∇_BB[i][j] = extract_derivative(T, ∂BB_d[i][j])
end
end
end
return ∇_∂A, ∇_BB
end

auto_pb_pb(∂2, ∂A, basis, bBB);

∇_∂A, ∇_BB = auto_pb_pb(∂2, ∂A, basis, bBB);
auto_pb_pb!(∇_∂A2, ∇_BB2, ∂2, ∂A, basis, bBB)
@btime auto_pb_pb($∂2, $∂A, $basis, $bBB);
@btime auto_pb_pb!($∇_∂A2, $∇_BB2, $∂2, $∂A, $basis, $bBB);

##

∇_∂A1, ∇_BB1 = auto_pb_pb(∂2, ∂A, basis, bBB)
∇_∂A2, ∇_BB2 = P4ML.pb_pb_evaluate(∂2, ∂A, basis, bBB)

∇_∂A1 ∇_∂A2
all(∇_BB1 ._BB2)

∇_∂A3 = deepcopy(∇_∂A2); ∇_BB3 = deepcopy(∇_BB2)
auto_pb_pb!(∇_∂A3, ∇_BB3, ∂2, ∂A, basis, bBB)
∇_∂A1 _∂A2 ∇_∂A3
all(∇_BB1 .≈ ∇_BB2 .≈ ∇_BB3)

##
@info("Testing pushforward for PooledSparseProduct")
Expand Down

0 comments on commit dcdd4bb

Please sign in to comment.