Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Depend on ChainRulesCore? #716

Open
mzgubic opened this issue Sep 6, 2021 · 3 comments
Open

Depend on ChainRulesCore? #716

mzgubic opened this issue Sep 6, 2021 · 3 comments

Comments

@mzgubic
Copy link

mzgubic commented Sep 6, 2021

I've recently ran into this error in the wild, see MWE

julia> using Zygote

julia> using StatsBase

julia> gradient(v->sum(AnalyticWeights(v)), rand(3))
ERROR: Need an adjoint for constructor AnalyticWeights{Float64, Float64, Vector{Float64}}. Gradient is of type FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.Jnew{AnalyticWeights{Float64, Float64, Vector{Float64}}, Vector{Any}, false})(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/lib/lib.jl:354
  [3] (::Zygote.var"#1812#back#229"{Zygote.Jnew{AnalyticWeights{Float64, Float64, Vector{Float64}}, Vector{Any}, false}})(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ ~/JuliaEnvs/PortfolioNets.jl/dev/StatsBase/src/weights.jl:13 [inlined]
  [5] (::typeof((AnalyticWeights{Float64, Float64, Vector{Float64}})))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface2.jl:0
  [6] Pullback
    @ ~/JuliaEnvs/PortfolioNets.jl/dev/StatsBase/src/weights.jl:13 [inlined]
  [7] (::typeof((AnalyticWeights)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/JuliaEnvs/PortfolioNets.jl/dev/StatsBase/src/weights.jl:16 [inlined]
  [9] (::typeof((AnalyticWeights)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface2.jl:0
 [10] Pullback
    @ ./REPL[4]:1 [inlined]
 [11] (::typeof((#5)))(Δ::Float64)
    @ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#50#51"{typeof((#5))})(Δ::Float64)
    @ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface.jl:41
 [13] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/JuliaEnvs/PortfolioNets.jl/dev/Zygote/src/compiler/interface.jl:76
 [14] top-level scope
    @ REPL[4]:1

I've fixed it by pirating

function ChainRulesCore.rrule(::Type{StatsBase.AnalyticWeights}, values)
    AnalyticWeights_pullback::AbstractArray) = (NoTangent(), ȳ)
    AnalyticWeights_pullback::Tangent) = (NoTangent(), ȳ.values)
    AnalyticWeights_pullback::AbstractThunk) = (NoTangent(), unthunk(ȳ))
    return AnalyticWeights(values), AnalyticWeights_pullback
end

which solves the problem.

I wonder whether a PR adding this would be welcome? It would need to add a dependency on ChainRulesCore, which is quite lightweight (order of 0.1s precompile time IIRC)

@oxinabox
Copy link
Contributor

oxinabox commented Sep 6, 2021

(order of 0.1s precompile time IIRC)

More importantly about 0.05s load time.
Precompile time is cheap.

I am surprised there is not already a transitive dependency on ChainRulesCore.
But indeed there isn't.


If https://github.com/JuliaLang/Statistics.jl/issues/4 is done then for newer versions of Julia we will be able to define this in ChainRules.jl with the other rules that we have for the Statistics stdlib.


The rule might actually require a little care.
Since the structural Tangent might in theory contain the values or sum or both (and if both they may or may not be consistent).
Depending on what path it has taken to get here.
AnalyticWeights_pullback(ȳ::Tangent) = (NoTangent(), zeros(values) .+ ȳ.values .+ ȳ.sum)

Also i think the unthunk case should either not be unthunking or should be calling AnalyticWeights_pullback after.
(The later of which is a case of needing to be causeful about calling functions defined locally)

@nalimilan
Copy link
Member

Yes I'd rather have ChainRules depend on Statistics once functions are moved there. Otherwise the dependency on ChainRulesCore will make that move impossible.

@devmotion
Copy link
Member

I am surprised there is not already a transitive dependency on ChainRulesCore.
But indeed there isn't.

It is in StatsBase >= 0.33.11 through LogExpFunctions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants