Skip to content

Commit

Permalink
Handle Beta degen from Inf params
Browse files Browse the repository at this point in the history
  • Loading branch information
quildtide committed Aug 8, 2024
1 parent 5e866ba commit aaab071
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 17 deletions.
80 changes: 69 additions & 11 deletions src/distrs/beta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,28 @@ betapdf(α::Real, β::Real, x::Real) = exp(betalogpdf(α, β, x))

betalogpdf::Real, β::Real, x::Real) = betalogpdf(promote(α, β, x)...)
function betalogpdf::T, β::T, x::T) where {T<:Real}
# Handle degenerate cases
xf = float(typeof(x))
if isinf(α)
if isinf(β)
return float(last(promote(α, β, x,
x == .5 ? convert(xf, NaN) : convert(xf, -Inf)
)))
else
return float(last(promote(α, β, x,
x == 1 ? convert(xf, NaN) : convert(xf, -Inf)
)))
end
elseif (iszero(α) && β > 0) || isinf(β)
return float(last(promote(α, β, x,
x == 0 ? convert(xf, NaN) : convert(xf, -Inf)
)))
elseif iszero(β) && α > 0
return float(last(promote(α, β, x,
x == 1 ? convert(xf, NaN) : convert(xf, -Inf)
)))
end

# we ensure that `log(x)` and `log1p(-x)` do not error
y = clamp(x, 0, 1)
val = xlogy- 1, y) + xlog1py- 1, -y) - logbeta(α, β)
Expand All @@ -28,7 +50,13 @@ end

function betacdf::Real, β::Real, x::Real)
# Handle degenerate cases
if iszero(α) && β > 0
if isinf(α)
if isinf(β)
return float(last(promote(α, β, x, x >= 0.5f0)))
else
return float(last(promote(α, β, x, x >= 1)))
end
elseif (iszero(α) && β > 0) || isinf(β)
return float(last(promote(α, β, x, x >= 0)))
elseif iszero(β) && α > 0
return float(last(promote(α, β, x, x >= 1)))
Expand All @@ -39,7 +67,13 @@ end

function betaccdf::Real, β::Real, x::Real)
# Handle degenerate cases
if iszero(α) && β > 0
if isinf(α)
if isinf(β)
return float(last(promote(α, β, x, x < 0.5f0)))
else
return float(last(promote(α, β, x, x < 1)))
end
elseif (iszero(α) && β > 0) || isinf(β)
return float(last(promote(α, β, x, x < 0)))
elseif iszero(β) && α > 0
return float(last(promote(α, β, x, x < 1)))
Expand All @@ -52,7 +86,13 @@ end
# to an implementation based on the hypergeometric function ₂F₁ to avoid underflow.
function betalogcdf::T, β::T, x::T) where {T<:Real}
# Handle degenerate cases
if iszero(α) && β > 0
if isinf(α)
if isinf(β)
return log(last(promote(x, x >= 0.5f0)))
else
return log(last(promote(x, x >= 1)))
end
elseif (iszero(α) && β > 0) || isinf(β)
return log(last(promote(x, x >= 0)))
elseif iszero(β) && α > 0
return log(last(promote(x, x >= 1)))
Expand All @@ -74,7 +114,13 @@ betalogcdf(α::Real, β::Real, x::Real) = betalogcdf(promote(α, β, x)...)

function betalogccdf::Real, β::Real, x::Real)
# Handle degenerate cases
if iszero(α) && β > 0
if isinf(α)
if isinf(β)
return log(last(promote(α, β, x, x < 0.5f0)))
else
return log(last(promote(α, β, x, x < 1)))
end
elseif (iszero(α) && β > 0) || isinf(β)
return log(last(promote(α, β, x, x < 0)))
elseif iszero(β) && α > 0
return log(last(promote(α, β, x, x < 1)))
Expand All @@ -91,10 +137,16 @@ end
function betainvcdf::Real, β::Real, p::Real)
# Handle degenerate cases
if 0 p 1
if iszero(α) && β > 0
return last(promote(α, β, p, false))
if isinf(α)
if isinf(β)
return last(promote(α, β, p, convert(float(typeof(p)), 0.5)))
else
return last(promote(α, β, p, 1))
end
elseif (iszero(α) && β > 0) || isinf(β)
return last(promote(α, β, p, 0))
elseif iszero(β) && α > 0
return last(promote(α, β, p, p > 0))
return last(promote(α, β, p, 1))
end
end

Expand All @@ -104,12 +156,18 @@ end
function betainvccdf::Real, β::Real, p::Real)
# Handle degenerate cases
if 0 p 1
if iszero(α) && β > 0
return last(promote(α, β, p, p == 0))
if isinf(α)
if isinf(β)
return last(promote(α, β, p, convert(float(typeof(p)), 0.5)))
else
return last(promote(α, β, p, 1))
end
elseif (iszero(α) && β > 0) || isinf(β)
return last(promote(α, β, p, 0))
elseif iszero(β) && α > 0
return last(promote(α, β, p, true))
return last(promote(α, β, p, 1))
end
end

return last(beta_inc_inv(β, α, p))
end
end
45 changes: 39 additions & 6 deletions test/rmath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,28 +191,61 @@ end
# Beta(α, 0) is a Dirac distribution at x=1
α = β = 1//2

for x in 0f0:0.01f0:1f0
for x in -1f0:0.05f0:1f0
# Check betapdf
@test @inferred(betapdf(0, β, x)) === (x == 0 ? NaN32 : 0f0)
@test @inferred(betapdf(α, 0, x)) === (x == 1 ? NaN32 : 0f0)
@test @inferred(betapdf(Inf32, β, x)) === (x == 1 ? NaN32 : 0f0)
@test @inferred(betapdf(α, Inf32, x)) === (x == 0 ? NaN32 : 0f0)
@test @inferred(betapdf(Inf32, Inf32, x)) === (x === 0.5f0 ? NaN32 : 0f0)

# Check betalogpdf
@test @inferred(betalogpdf(0, β, x)) === (x == 0 ? NaN32 : -Inf32)
@test @inferred(betalogpdf(α, 0, x)) === (x == 1 ? NaN32 : -Inf32)
@test @inferred(betalogpdf(Inf32, β, x)) === (x == 1 ? NaN32 : -Inf32)
@test @inferred(betalogpdf(α, Inf32, x)) === (x == 0 ? NaN32 : -Inf32)
@test @inferred(betalogpdf(Inf32, Inf32, x)) === (x === 0.5f0 ? NaN32 : -Inf32)

# Check betacdf
@test @inferred(betacdf(0, β, x)) === 1f0
@test @inferred(betacdf(0, β, x)) === (x < 0 ? 0f0 : 1f0)
@test @inferred(betacdf(α, 0, x)) === (x < 1 ? 0f0 : 1f0)

@test @inferred(betacdf(Inf32, β, x)) === (x < 1 ? 0f0 : 1f0)
@test @inferred(betacdf(α, Inf32, x)) === (x < 0 ? 0f0 : 1f0)
@test @inferred(betacdf(Inf32, Inf32, x)) === (x < .5 ? 0f0 : 1f0)

# Check betaccdf, betalogcdf, and betalogccdf based on betacdf
@test @inferred(betaccdf(0, β, x)) === 1 - betacdf(0, β, x)
@test @inferred(betaccdf(α, 0, x)) === 1 - betacdf(α, 0, x)
@test @inferred(betaccdf(Inf32, β, x)) === 1 - betacdf(Inf32, β, x)
@test @inferred(betaccdf(α, Inf32, x)) === 1 - betacdf(α, Inf32, x)
@test @inferred(betaccdf(Inf32, Inf32, x)) === 1 - betacdf(Inf32, Inf32, x)

@test @inferred(betalogcdf(0, β, x)) === log(betacdf(0, β, x))
@test @inferred(betalogcdf(α, 0, x)) === log(betacdf(α, 0, x))
@test @inferred(betalogcdf(Inf32, β, x)) === log(betacdf(Inf32, β, x))
@test @inferred(betalogcdf(α, Inf32, x)) === log(betacdf(α, Inf32, x))
@test @inferred(betalogcdf(Inf32, Inf32, x)) === log(betacdf(Inf32, Inf32, x))

@test @inferred(betalogccdf(0, β, x)) === log(betaccdf(0, β, x))
@test @inferred(betalogccdf(α, 0, x)) === log(betaccdf(α, 0, x))
@test @inferred(betalogccdf(Inf32, β, x)) === log(betaccdf(Inf32, β, x))
@test @inferred(betalogccdf(α, Inf32, x)) === log(betaccdf(α, Inf32, x))
@test @inferred(betalogccdf(Inf32, Inf32, x)) === log(betaccdf(Inf32, Inf32, x))
end

for p in 0f0:0.01f0:1f0
for p in 0f0:0.05f0:1f0
# Check betainvcdf
@test @inferred(betainvcdf(0, β, p)) === 0f0
@test @inferred(betainvcdf(α, 0, p)) === (p > 0 ? 1f0 : 0f0)
@test @inferred(betainvcdf(α, 0, p)) === 1f0
@test @inferred(betainvcdf(Inf32, β, p)) === 1f0
@test @inferred(betainvcdf(α, Inf32, p)) === 0f0
@test @inferred(betainvcdf(Inf32, Inf32, p)) === 0.5f0

# Check betainvccdf
@test @inferred(betainvccdf(0, β, p)) === (p > 0 ? 0f0 : 1f0)
@test @inferred(betainvccdf(0, β, p)) === 0f0
@test @inferred(betainvccdf(α, 0, p)) === 1f0
@test @inferred(betainvccdf(Inf32, β, p)) === 1f0
@test @inferred(betainvccdf(Inf32, Inf32, p)) === 0.5f0
end
end

Expand Down

0 comments on commit aaab071

Please sign in to comment.