Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for Float16 #74

Merged
merged 5 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/LogExpFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ export xlogx, xlogy, xlog1py, xexpx, xexpy, logistic, logit, log1psq, log1pexp,
softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, logsumexp!, softmax,
softmax!, logcosh, cloglog, cexpexp

# expm1(::Float16) is not defined in older Julia versions,
# hence for better Float16 support we use an internal function instead
# https://github.com/JuliaLang/julia/pull/40867
if VERSION < v"1.7.0-DEV.1172"
_expm1(x) = expm1(x)
_expm1(x::Float16) = Float16(expm1(Float32(x)))
else
const _expm1 = expm1
end

include("basicfuns.jl")
include("logsumexp.jl")

Expand Down
8 changes: 4 additions & 4 deletions src/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,22 +219,22 @@ See:

Note: different than Maechler (2012), no negation inside parentheses
"""
log1mexp(x::Real) = x < IrrationalConstants.loghalf ? log1p(-exp(x)) : log(-expm1(x))
log1mexp(x::Real) = x < IrrationalConstants.loghalf ? log1p(-exp(x)) : log(-_expm1(x))

"""
$(SIGNATURES)

Return `log(2 - exp(x))` evaluated as `log1p(-expm1(x))`
"""
log2mexp(x::Real) = log1p(-expm1(x))
log2mexp(x::Real) = log1p(-_expm1(x))

"""
$(SIGNATURES)

Return `log(exp(x) - 1)` or the “invsoftplus” function. It is the inverse of
[`log1pexp`](@ref) (aka “softplus”).
"""
logexpm1(x::Real) = x <= 18.0 ? log(expm1(x)) : x <= 33.3 ? x - exp(-x) : oftype(exp(-x), x)
logexpm1(x::Real) = x <= 18.0 ? log(_expm1(x)) : x <= 33.3 ? x - exp(-x) : oftype(exp(-x), x)
logexpm1(x::Float32) = x <= 9f0 ? log(expm1(x)) : x <= 16f0 ? x - exp(-x) : oftype(exp(-x), x)

const softplus = log1pexp
Expand Down Expand Up @@ -420,4 +420,4 @@ $(SIGNATURES)

Compute the complementary double exponential, `1 - exp(-exp(x))`.
"""
cexpexp(x) = -expm1(-exp(x))
cexpexp(x) = -_expm1(-exp(x))
37 changes: 23 additions & 14 deletions test/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,23 +154,28 @@ end
end

@testset "log1mexp" begin
@test log1mexp(-1.0) ≈ log1p(- exp(-1.0))
@test log1mexp(-10.0) ≈ log1p(- exp(-10.0))
for T in (Float64, Float32, Float16)
@test @inferred(log1mexp(-T(1))) isa T
@test log1mexp(-T(1)) ≈ log1p(- exp(-T(1)))
@test log1mexp(-T(10)) ≈ log1p(- exp(-T(10)))
end
end

@testset "log2mexp" begin
@test log2mexp(0.0) ≈ 0.0
@test log2mexp(-1.0) ≈ log(2.0 - exp(-1.0))
for T in (Float64, Float32, Float16)
@test @inferred(log2mexp(T(0))) isa T
@test iszero(log2mexp(T(0)))
@test log2mexp(-T(1)) ≈ log(2 - exp(-T(1)))
end
end

@testset "logexpm1" begin
@test logexpm1(2.0) ≈ log(exp(2.0) - 1.0)
@test logexpm1(log1pexp(2.0)) ≈ 2.0
@test logexpm1(log1pexp(-2.0)) ≈ -2.0

@test logexpm1(2f0) ≈ log(exp(2f0) - 1f0)
@test logexpm1(log1pexp(2f0)) ≈ 2f0
@test logexpm1(log1pexp(-2f0)) ≈ -2f0
for T in (Float64, Float32, Float16)
@test @inferred(logexpm1(T(2))) isa T
@test logexpm1(T(2)) ≈ log(exp(T(2)) - 1)
@test logexpm1(log1pexp(T(2))) ≈ T(2)
@test logexpm1(log1pexp(-T(2))) ≈ -T(2)
end
end

@testset "log1pmx" begin
Expand Down Expand Up @@ -428,9 +433,13 @@ end
cloglog_big(x::T) where {T} = T(log(-log(1 - BigFloat(x))))
cexpexp_big(x::T) where {T} = 1 - exp(-exp(BigFloat(x)))

for x in 0.1:0.1:0.9
@test cloglog(x) ≈ cloglog_big(x)
@test cexpexp(x) ≈ cexpexp_big(x)
for T in (Float64, Float32, Float16)
@test @inferred(cloglog(T(1//2))) isa T
@test @inferred(cexpexp(T(0))) isa T
for x in 0.1:0.1:0.9
@test cloglog(T(x)) ≈ cloglog_big(T(x))
@test cexpexp(T(x)) ≈ cexpexp_big(T(x))
end
end
for _ in 1:10
randf = rand(Float64)
Expand Down
Loading