Skip to content

Commit

Permalink
stop manually unthunking composite
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Oct 15, 2020
1 parent b9c72ce commit 2eb203f
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.9.12"
ChainRulesCore = "0.9.16"
ChainRulesTestUtils = "0.5"
Compat = "3"
FiniteDifferences = "0.11"
Expand Down
4 changes: 2 additions & 2 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger!
function rrule(::typeof(svd), X::AbstractMatrix{<:Real})
F = svd(X)
function svd_pullback::Composite)
# svd_rev does a lot of linear algebra, it is is efficient to unthunk before
∂X = svd_rev(F, unthunk(Ȳ.U), unthunk(Ȳ.S), unthunk(Ȳ.V))
# `getproperty` on `Composite`s ensures we have no thunks.
∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V)
return (NO_FIELDS, ∂X)
end
return F, svd_pullback
Expand Down
7 changes: 5 additions & 2 deletions test/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo

_, dF_unthunked, _ = dF_pullback(Ȳ)

@assert !(getproperty(dF_unthunked, p) isa AbstractThunk)
# helper to let us check how things are stored.
backing_field(c, p) = getproperty(ChainRulesCore.backing(c), p)
@assert !(backing_field(dF_unthunked, p) isa AbstractThunk)

dF_thunked = map(f->Thunk(()->f), dF_unthunked)
@assert getproperty(dF_thunked, p) isa AbstractThunk
@assert backing_field(dF_thunked, p) isa AbstractThunk

dself_thunked, dX_thunked = dX_pullback(dF_thunked)
dself_unthunked, dX_unthunked = dX_pullback(dF_unthunked)
Expand Down

0 comments on commit 2eb203f

Please sign in to comment.