Skip to content

Commit

Permalink
Change order of arguments in expectation (#1420)
Browse files Browse the repository at this point in the history
* Change order of arguments in `expectation`

* Fix deprecation
  • Loading branch information
devmotion authored Nov 12, 2021
1 parent 837d312 commit 8887d7a
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 15 deletions.
7 changes: 4 additions & 3 deletions src/deprecates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ end
@deprecate Wishart(df::Real, S::Matrix, warn::Bool) Wishart(df, S)
@deprecate Wishart(df::Real, S::Cholesky, warn::Bool) Wishart(df, S)

# Deprecate 3 arguments expectation
@deprecate expectation(distr::DiscreteUnivariateDistribution, g::Function, epsilon::Real) expectation(distr, g; epsilon=epsilon) false
@deprecate expectation(distr::ContinuousUnivariateDistribution, g::Function, epsilon::Real) expectation(distr, g) false
# Deprecate 3 arguments expectation and once with function in second place
@deprecate expectation(distr::DiscreteUnivariateDistribution, g::Function, epsilon::Real) expectation(g, distr; epsilon=epsilon) false
@deprecate expectation(distr::ContinuousUnivariateDistribution, g::Function, epsilon::Real) expectation(g, distr) false
@deprecate expectation(distr::Union{UnivariateDistribution,MultivariateDistribution}, g::Function; kwargs...) expectation(g, distr; kwargs...) false
11 changes: 5 additions & 6 deletions src/functionals.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
function expectation(distr::ContinuousUnivariateDistribution, g::Function; kwargs...)
function expectation(g, distr::ContinuousUnivariateDistribution; kwargs...)
return first(quadgk(x -> pdf(distr, x) * g(x), extrema(distr)...; kwargs...))
end

## Assuming that discrete distributions only take integer values.
function expectation(distr::DiscreteUnivariateDistribution, g::Function; epsilon::Real=1e-10)
function expectation(g, distr::DiscreteUnivariateDistribution; epsilon::Real=1e-10)
mindist, maxdist = extrema(distr)
# We want to avoid taking values up to infinity
minval = isfinite(mindist) ? mindist : quantile(distr, epsilon)
maxval = isfinite(maxdist) ? maxdist : quantile(distr, 1 - epsilon)
return sum(x -> pdf(distr, x) * g(x), minval:maxval)
end

function expectation(distr::MultivariateDistribution, g::Function; nsamples::Int=100, rng::AbstractRNG=GLOBAL_RNG)
function expectation(g, distr::MultivariateDistribution; nsamples::Int=100, rng::AbstractRNG=GLOBAL_RNG)
nsamples > 0 || throw(ArgumentError("number of samples should be > 0"))
# We use a function barrier to work around type instability of `sampler(dist)`
return mcexpectation(rng, g, sampler(distr), nsamples)
Expand All @@ -27,9 +27,8 @@ mcexpectation(rng, f, sampler, n) = sum(f, rand(rng, sampler) for _ in 1:n) / n
# end

function kldivergence(P::Distribution{V}, Q::Distribution{V}; kwargs...) where {V<:VariateForm}
function logdiff(x)
return expectation(P; kwargs...) do x
logp = logpdf(P, x)
return (logp > oftype(logp, -Inf)) * (logp - logpdf(Q, x))
end
expectation(P, logdiff; kwargs...)
end
end
4 changes: 2 additions & 2 deletions test/binomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ for (p, n) in [(0.6, 10), (0.8, 6), (0.5, 40), (0.04, 20), (1., 100), (0., 10),
end

# Test calculation of expectation value for Binomial distribution
@test Distributions.expectation(Binomial(6), identity) 3.0
@test Distributions.expectation(Binomial(10, 0.2), x->-x) -2.0
@test Distributions.expectation(identity, Binomial(6)) 3.0
@test Distributions.expectation(x -> -x, Binomial(10, 0.2)) -2.0

# Test mode
@test Distributions.mode(Binomial(100, 0.4)) == 40
Expand Down
11 changes: 8 additions & 3 deletions test/functionals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,18 @@ end
@testset "Expectations" begin
# univariate distributions
for d in (Normal(), Poisson(2.0), Binomial(10, 0.4))
@test Distributions.expectation(d, identity) mean(d) atol=1e-3
@test @test_deprecated(Distributions.expectation(d, identity, 1e-10)) mean(d) atol=1e-3
m = Distributions.expectation(identity, d)
@test m mean(d) atol=1e-3
@test Distributions.expectation(x -> (x - mean(d))^2, d) var(d) atol=1e-3

@test @test_deprecated(Distributions.expectation(d, identity, 1e-10)) == m
@test @test_deprecated(Distributions.expectation(d, identity)) == m
end

# multivariate distribution
d = MvNormal([1.5, -0.5], I)
@test Distributions.expectation(d, identity; nsamples=10_000) mean(d) atol=1e-2
@test Distributions.expectation(identity, d; nsamples=10_000) mean(d) atol=5e-2
@test @test_deprecated(Distributions.expectation(d, identity; nsamples=10_000)) mean(d) atol=5e-2
end

@testset "KL divergences" begin
Expand Down
2 changes: 1 addition & 1 deletion test/loguniform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ import Random
x = rand(rng, dist)
@test cdf(u, log(x)) cdf(dist, x)

@test @inferred(entropy(dist)) Distributions.expectation(dist, x->-logpdf(dist,x))
@test @inferred(entropy(dist)) Distributions.expectation(x->-logpdf(dist,x), dist)
end

@test kldivergence(LogUniform(1,2), LogUniform(1,2)) 0 atol=100eps(Float64)
Expand Down

0 comments on commit 8887d7a

Please sign in to comment.