Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle degenerate Beta distribution cases #1881

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 111 additions & 16 deletions src/univariate/continuous/beta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ Beta() = Beta{Float64}(1.0, 1.0)

@distr_support Beta 0.0 1.0



#### Conversions
function convert(::Type{Beta{T}}, α::Real, β::Real) where T<:Real
Beta(T(α), T(β))
Expand All @@ -59,39 +61,120 @@ Base.convert(::Type{Beta{T}}, d::Beta{T}) where {T<:Real} = d
params(d::Beta) = (d.α, d.β)
@inline partype(d::Beta{T}) where {T<:Real} = T

#### Support

function minimum(d::Beta)
(α, β) = params(d)

if isinf(α)
return isinf(β) ? 0.5 : 1.0
end

return 0.0
end

function maximum(d::Beta)
(α, β) = params(d)

if isinf(β)
return isinf(α) ? 0.5 : 0.0
end

return 1.0
end

minimum(d::Type{Beta}) = 0.0
maximum(d::Type{Beta}) = 1.0

#### Statistics

mean(d::Beta) = ((α, β) = params(d); α / (α + β))
function mean(d::Beta{T})::float(T) where T
(α, β) = params(d)

function mode(d::Beta; check_args::Bool=true)
if !isinf(α)
return α / (α + β)
else # α is infinite
return isinf(β) ? 0.5 : 1.0
end
end

function mode(d::Beta{T}; check_args::Bool=true)::float(T) where T
α, β = params(d)
@check_args(
Beta,
(α, α > 1, "mode is defined only when α > 1."),
(β, β > 1, "mode is defined only when β > 1."),
)
return (α - 1) / (α + β - 2)

if !isinf(α)
if !isinf(β)
# non-degenerate case
@check_args(
Beta,
(α, α > 1, "mode is defined only when α > 1 for non-degenerate cases."),
(β, β > 1, "mode is defined only when β > 1 for non-degenerate cases."),
)
return (α - 1) / (α + β - 2)
else # β is infinite
return 0.0
end
else # α is infinite
return isinf(β) ? 0.5 : 1.0
end
end

modes(d::Beta) = [mode(d)]

function var(d::Beta)
(α, β) = params(d)

if isinf(α) || isinf(β)
# degenerate cases
return zero(α)
end

s = α + β
return (α * β) / (abs2(s) * (s + 1))
end

meanlogx(d::Beta) = ((α, β) = params(d); digamma(α) - digamma(α + β))
function meanlogx(d::Beta{T})::float(T) where T
(α, β) = params(d)

if !isinf(α)
if !isinf(β)
# non-degenerate case
return digamma(α) - digamma(α + β)
else # β is infinite
return log(0.0) # this is -Inf in Julia
end
else # α is infinite
if isinf(β)
return log(0.5)
else
return 0 # log(1)
end
end
end

function varlogx(d::Beta)
(α, β) = params(d)

if isinf(α) || isinf(β)
# degenerate cases
return zero(α)
end

trigamma(α) - trigamma(α + β)
end

varlogx(d::Beta) = ((α, β) = params(d); trigamma(α) - trigamma(α + β))
stdlogx(d::Beta) = sqrt(varlogx(d))

function skewness(d::Beta)
function skewness(d::Beta{T})::float(T) where T
(α, β) = params(d)
if α == β

if isinf(α) || isinf(β)
# degenerate cases
return NaN
elseif α == β
# symmetric non-degenerate
return zero(α)
else
# asymmetric non-degenerate
s = α + β
(2(β - α) * sqrt(s + 1)) / ((s + 2) * sqrt(α * β))
end
Expand All @@ -106,6 +189,9 @@ end

function entropy(d::Beta)
α, β = params(d)
if isinf(α) || isinf(β)
return zero(α)
end
s = α + β
logbeta(α, β) - (α - 1) * digamma(α) - (β - 1) * digamma(β) +
(s - 2) * digamma(s)
Expand Down Expand Up @@ -151,14 +237,19 @@ function sampler(d::Beta{T}) where T
end

# From Knuth
function rand(rng::AbstractRNG, s::BetaSampler)
function rand(rng::AbstractRNG, s::BetaSampler{T, S1, S2})::float(T) where {T, S1, S2}
iα = s.iα
iβ = s.iβ
if s.γ
if iszero(iα)
return iszero(iβ) ? 0.5 : 1.0
elseif iszero(iβ)
return 0
end
g1 = rand(rng, s.s1)
g2 = rand(rng, s.s2)
return g1 / (g1 + g2)
else
iα = s.iα
iβ = s.iβ
while true
u = rand(rng) # the Uniform sampler just calls rand()
v = rand(rng)
Expand All @@ -180,7 +271,7 @@ function rand(rng::AbstractRNG, s::BetaSampler)
end
end

function rand(rng::AbstractRNG, d::Beta{T}) where T
function rand(rng::AbstractRNG, d::Beta{T})::float(T) where T
(α, β) = params(d)
if (α ≤ 1.0) && (β ≤ 1.0)
while true
Expand All @@ -201,6 +292,10 @@ function rand(rng::AbstractRNG, d::Beta{T}) where T
end
end
end
elseif isinf(α)
return isinf(β) ? 0.5 : 1.0
elseif isinf(β)
return 0
else
g1 = rand(rng, Gamma(α, one(T)))
g2 = rand(rng, Gamma(β, one(T)))
Expand Down
64 changes: 64 additions & 0 deletions test/degenerate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using Test, Distributions, StatsBase

@testset "Degenerate Beta" begin
d1 = Beta(Inf, .5)
d2 = Beta(Inf, Inf)
d3 = Beta(14, Inf)

@test minimum(d1) == 1
@test minimum(d2) == .5
@test minimum(d3) == 0

@test maximum(d1) == 1
@test maximum(d2) == .5
@test maximum(d3) == 0

@test mean(d1) == 1
@test mean(d2) == .5
@test mean(d3) == 0

# Currently hangs due to StatsFuns
# @test median(d1) == 1
# @test median(d2) == .5
# @test median(d3) == 0

@test mode(d1) == 1
@test mode(d2) == .5
@test mode(d3) == 0

@test var(d1) == 0
@test var(d2) == 0
@test var(d3) == 0

@test std(d3) == 0

@test isnan(skewness(d1))
@test isnan(skewness(d2))
@test isnan(skewness(d3))

@test isnan(kurtosis(d1))
@test isnan(kurtosis(d2))
@test isnan(kurtosis(d3))

@test entropy(d1) == 0
@test entropy(d2) == 0
@test entropy(d3) == 0

@test meanlogx(d1) == 0
@test meanlogx(d2) == log(.5)
@test meanlogx(d3) == log(0)

@test varlogx(d1) == 0
@test varlogx(d2) == 0
@test varlogx(d3) == 0

@test stdlogx(d1) == 0

@test rand(d1) == 1
@test rand(d2) == .5
@test rand(d3) == 0

@test rand(sampler(d1)) == 1
@test rand(sampler(d2)) == .5
@test rand(sampler(d3)) == 0
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ const tests = [
"eachvariate",
"univariate/continuous/triangular",
"statsapi",
"degenerate", # extra file compared to /src

### missing files compared to /src:
# "common",
Expand Down
Loading