Skip to content

Commit

Permalink
Allow linsolv to have const args (#1121)
Browse files Browse the repository at this point in the history
* Allow linsolv to have const args

* Add tests
  • Loading branch information
wsmoses authored and michel2323 committed Nov 7, 2023
1 parent bcce97a commit 8cf4c50
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 18 deletions.
50 changes: 32 additions & 18 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,18 +337,6 @@ function EnzymeRules.augmented_primal(config, func::Const{typeof(\)}, ::Type{RT}
else
res
end

dAs = if EnzymeRules.width(config) == 1
(A.dval,)
else
A.dval
end

dbs = if EnzymeRules.width(config) == 1
(b.dval,)
else
b.dval
end

cache_b = if EnzymeRules.overwritten(config)[3]
copy(b.val)
Expand Down Expand Up @@ -394,21 +382,47 @@ function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache,
end

dAs = if EnzymeRules.width(config) == 1
(A.dval,)
if typeof(A) <: Const
(nothing,)
else
(A.dval,)
end
else
A.dval
if typeof(A) <: Const
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
nothing
end
else
A.dval
end
end

dbs = if EnzymeRules.width(config) == 1
(b.dval,)
if typeof(b) <: Const
(nothing,)
else
(b.dval,)
end
else
b.dval
if typeof(b) <: Const
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
nothing
end
else
b.dval
end
end

for (dA, db, dy) in zip(dAs, dbs, dys)
z = transpose(cache_A) \ dy
dA .-= z * transpose(y)
db .+= z
if !(typeof(A) <: Const)
dA .-= z * transpose(y)
end
if !(typeof(b) <: Const)
db .+= z
end
dy .= eltype(dy)(0)
end

Expand Down
32 changes: 32 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2519,4 +2519,36 @@ end
y = A \ b
@test dA (-z * transpose(y))
@test db z

db = zero(b)

forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Const{typeof(A)}, Duplicated{typeof(b)})

tape, primal, shadow = forward(Const(\), Const(A), Duplicated(b, db))

dy = Float64[17, 19]
copyto!(shadow, dy)

pullback(Const(\), Const(A), Duplicated(b, db), tape)

z = transpose(A) \ dy

y = A \ b
@test db z

dA = zero(A)

forward, pullback = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(\)}, Duplicated, Duplicated{typeof(A)}, Const{typeof(b)})

tape, primal, shadow = forward(Const(\), Duplicated(A, dA), Const(b))

dy = Float64[17, 19]
copyto!(shadow, dy)

pullback(Const(\), Duplicated(A, dA), Const(b), tape)

z = transpose(A) \ dy

y = A \ b
@test dA (-z * transpose(y))
end

0 comments on commit 8cf4c50

Please sign in to comment.