Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making rand gradients work for more distributions #123

Open
dcjones opened this issue Oct 28, 2020 · 7 comments
Open

Making rand gradients work for more distributions #123

dcjones opened this issue Oct 28, 2020 · 7 comments

Comments

@dcjones
Copy link

dcjones commented Oct 28, 2020

Trying to compute gradients of the rand function wrt to parameters for certain distributions will produce incorrect results, because some of these functions use branching or iterated algorithms and AD can't take into account how the parameters affect control flow.

A simple demonstration of this is just trying to estimate d/dθ E[x] by estimating E[d/dθ x].

  • Normal of course works: d/dμ E[x] = d/dμ μ = 1 and
julia> mean(gradient-> rand(Normal(μ, 1.0)), 1.0)[1] for _ in 1:10000) # should be ≈ 1.0
1.0

(which works for any values of μ, σ)

Gamma will not return a gradient for some values, and return incorrect results for others. E.g. d/dα E[x] = d/dα αβ = β, yet

julia> mean(gradient-> rand(Gamma(α, 2.0)), 1.01)[1] for _ in 1:10000) # should be ≈ 2.0
2.782440982911109
julia> mean(gradient-> rand(Gamma(α, 2.0)), 1.00)[1] for _ in 1:10000) # should be ≈ 2.0
ERROR: MethodError: no method matching /(::Nothing, ::Int64)

Beta similarly d/dα E[x] = d/dα α/(α+β) = β / (α+β)^2 yet

julia> mean(gradient-> rand(Beta(α, 2.0)), 2.0)[1] for _ in 1:10000) # should be ≈ 0.125
0.14264055366214703
julia> mean(gradient-> rand(Beta(α, 3.0)), 1.0)[1] for _ in 1:10000) # should be ≈ 0.1875
ERROR: MethodError: no method matching /(::Nothing, ::Int64)

It's well known that some distributions (e.g. Gamma, Beta, Dirichlet) don't lend themselves easily to this kind of pathwise gradient which makes them infrequently used as surrogate posteriors for VI, but there have been some papers on trying to work around this using numerical approximations and other techniques. See for example:

Figurnov, Mikhail, Shakir Mohamed, and Andriy Mnih. 2018. “Implicit Reparameterization Gradients.” In Advances in Neural Information Processing Systems 31, edited by S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, 441–52. Curran Associates, Inc.

Jankowiak, Martin, and Fritz Obermeyer. 2018. “Pathwise Derivatives Beyond the Reparameterization Trick.” arXiv [stat.ML]. arXiv. http://arxiv.org/abs/1806.01851.

I'd love to help improve the rand situation, but I'm still getting my bearings with this code, so I was hoping for some pointers.

My vague thought was that there might be a TuringGamma, TuringBeta, etc that implement alternative rand functions that are correctly differentiated. Is there a nicer approach, or is this the best option?

Second, for distributions where there is no viable way to AD rand, is there something better that can be done than report incorrect numbers? Should the remedy be in Distributions?

(Related issue is #113)

@dcjones
Copy link
Author

dcjones commented Oct 29, 2020

Ok, I understand this a bit more and was able to get correct rand(::Gamma) gradients using the Figurnov et al technique by adding a custom Zygote adjoint and writing a version of gamma_inc that works with AD. Then rand(::Beta) comes for free. No new types required!

I'll make a PR is this sounds like something useful for Turing.

@devmotion
Copy link
Member

Great! This sounds definitely useful. It would be even better to add an adjoint for ChainRules instead of Zygote (ChainRules is the new way of defining forward and reverse mode rules for different AD backends and is already used by Zygote).

@dcjones
Copy link
Author

dcjones commented Oct 29, 2020

I think a more general ChainRules adjoint may be blocked by JuliaDiff/ChainRulesCore.jl#68. The adjoint is peculiar and relies on running AD on the incomplete gamma function, and it looks like there's not currently a way of doing that without assuming a specific AD system.

So I think it has to be for a specific package, then it can be generalized once ChainRules supports it.

@devmotion
Copy link
Member

Yeah, I've run into this issue before.

But if it is only since the implementation

relies on running AD on the incomplete gamma function

wouldn't it be even better to add the adjoint for the incomplete gamma function to https://github.com/JuliaDiff/ChainRules.jl/blob/master/src/rulesets/packages/SpecialFunctions.jl instead of relying on a specific AD backend?

@dcjones
Copy link
Author

dcjones commented Oct 29, 2020

Well rand(::Gamma) doesn't come automatically from gamma_inc. The trick is pretty simple (code below). The hard part is that SpecialFunctions.gamma_inc mutates arrays an doesn't work with AD, so I implemented a (probably somewhat inferior) algorithm in _gamma_inc_lower that does.

I'm just learning this stuff, so I'm very open to a better way of handling this.

ZygoteRules.@adjoint function Distributions.rand(rng::AbstractRNG, d::Gamma{T}) where {T<:Real}
    z = rand(rng, d)
    function rand_gamma_pullback(c)
        y = z/d.θ
        ∂α, ∂y = gradient(_gamma_inc_lower, d.α, y)
        return (
            DoesNotExist(),
            (α=(-d.θ*∂α/∂y)*c,
             θ=y*c))
    end
    return z, rand_gamma_pullback
end

@BioTurboNick
Copy link

I don't know if this is exactly the same issue; I was trying to use autodiff in an optimizer that uses an objective function that uses the Gamma distribution, but it chokes at gamma_inc:

ERROR: MethodError: no method matching _gamma_inc(::ForwardDiff.Dual{ForwardDiff.Tag{var"#11#12", Float64}, Float64, …}, ::ForwardDiff.Dual{ForwardDiff.Tag{var"#11#12", Float64}, Float64, …}, ::Int64
Stacktrace:
  [1] gamma_inc(a::ForwardDiff.Dual{ForwardDiff.Tag{var"#11#12", Float64}, Float64, 7}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#11#12", Float64}, Float64, 7}, ind::Int64) (repeats 2 times)
    @ SpecialFunctions C:\Users\nicho\.julia\packages\SpecialFunctions\CQMHW\src\gamma_inc.jl:858
  [2] gammacdf(k::ForwardDiff.Dual{ForwardDiff.Tag{var"#11#12", Float64}, Float64, 7}, θ::ForwardDiff.Dual{ForwardDiff.Tag{var"#11#12", Float64}, Float64, 7}, x::ForwardDiff.Dual{ForwardDiff.Tag{var"#11#12", Float64}, Float64, 7})
    @ StatsFuns C:\Users\nicho\.julia\packages\StatsFuns\6HmgG\src\distrs\gamma.jl:34

Or is this not expected to work at all?

@devmotion
Copy link
Member

It seems this is caused by a call of cdf(Gamma(...), ...) or something similar? Such calls are forwarded to gammacdf in StatsFuns. In StatsFuns >= 1.0.0 we use Julia implementations instead of Rmath implementations there, which call SpecialFunctions.gamma_inc. However, there's no method implemented for ForwardDiff.Dual numbers yet, it would require to fix JuliaDiff/ForwardDiff.jl#424, as outlined in https://github.com/JuliaDiff/ForwardDiff.jl/issues/424#issuecomment-558627378#issuecomment-558627378 (similar to JuliaDiff/ForwardDiff.jl#585).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants