Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use pairwise summation in loglikelihood #1409

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/matrixvariates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,17 @@ rows and `size(d, 2)` columns, or an array of matrices of size `size(d)`.
"""
loglikelihood(d::MatrixDistribution, X::AbstractMatrix{<:Real}) = logpdf(d, X)
function loglikelihood(d::MatrixDistribution, X::AbstractArray{<:Real,3})
(size(X, 1), size(X, 2)) == size(d) || throw(DimensionMismatch("Inconsistent array dimensions."))
return sum(i -> _logpdf(d, view(X, :, :, i)), axes(X, 3))
(size(X, 1), size(X, 2)) == size(d) || throw(DimensionMismatch("inconsistent array dimensions"))
# we use pairwise summation (https://github.com/JuliaLang/julia/pull/31020)
mschauer marked this conversation as resolved.
Show resolved Hide resolved
# to compute `sum(Base.Fix1(_logpdf, d), eachslice(X; dims=3))`
broadcasted = Broadcast.broadcasted(_logpdf, (d,), (view(X, :, :, i) for i in axes(X, 3)))
return sum(Broadcast.instantiate(broadcasted))
end
function loglikelihood(d::MatrixDistribution, X::AbstractArray{<:AbstractMatrix{<:Real}})
return sum(x -> logpdf(d, x), X)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this call already use pairwise summation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I thought it doesn't but seems you're right (was it always the case or did it change at some point?):

julia> A = vcat([Float32(1E0)], fill(Float32(1E-8), 10^8));

julia> sum(identity, A)
1.9999989f0

julia> sum(x -> x, A)
1.9999989f0

julia> sum(Broadcast.instantiate(Broadcast.broadcasted(x -> x, A)))
1.9999989f0

In any case, I'll close this issue in favour of #1391 which will generalize logpdf and loglikelihood.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea. That said, I get slightly different results in some cases:

julia> x = vcat([Float32(1E0)], fill(Float32(1E-8), 10^8));

julia> sum(sin, x) - sum(sin.(x))
-1.41859055f-5

# we use pairwise summation (https://github.com/JuliaLang/julia/pull/31020)
# to compute `sum(Base.Fix1(logpdf, d), X)`
broadcasted = Broadcast.broadcasted(logpdf, (d,), X)
mschauer marked this conversation as resolved.
Show resolved Hide resolved
return sum(Broadcast.instantiate(broadcasted))
end

# for testing
Expand Down
10 changes: 7 additions & 3 deletions src/multivariates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,15 @@ vectors of length `dim(d)`.
"""
loglikelihood(d::MultivariateDistribution, X::AbstractVector{<:Real}) = logpdf(d, X)
function loglikelihood(d::MultivariateDistribution, X::AbstractMatrix{<:Real})
size(X, 1) == length(d) || throw(DimensionMismatch("Inconsistent array dimensions."))
return sum(i -> _logpdf(d, view(X, :, i)), 1:size(X, 2))
size(X, 1) == length(d) || throw(DimensionMismatch("inconsistent array dimensions"))
# we use pairwise summation (https://github.com/JuliaLang/julia/pull/31020)
broadcasted = Broadcast.broadcasted(_logpdf, (d,), (view(X, :, i) for i in axes(X, 2)))
return sum(Broadcast.instantiate(broadcasted))
end
function loglikelihood(d::MultivariateDistribution, X::AbstractArray{<:AbstractVector})
return sum(x -> logpdf(d, x), X)
# we use pairwise summation (https://github.com/JuliaLang/julia/pull/31020)
broadcasted = Broadcast.broadcasted(logpdf, (d,), X)
return sum(Broadcast.instantiate(broadcasted))
end

##### Specific distributions #####
Expand Down
16 changes: 11 additions & 5 deletions src/univariates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,11 @@ The log-likelihood of distribution `d` with respect to all samples contained in

Here `x` can be a single scalar sample or an array of samples.
"""
loglikelihood(d::UnivariateDistribution, X::AbstractArray) = sum(x -> logpdf(d, x), X)
function loglikelihood(d::UnivariateDistribution, X::AbstractArray)
# we use pairwise summation (https://github.com/JuliaLang/julia/pull/31020)
broadcasted = Broadcast.broadcasted(logpdf, (d,), X)
return sum(Broadcast.instantiate(broadcasted))
end
loglikelihood(d::UnivariateDistribution, x::Real) = logpdf(d, x)

### special definitions for distributions with integer-valued support
Expand Down Expand Up @@ -550,11 +554,12 @@ function integerunitrange_cdf(d::DiscreteUnivariateDistribution, x::Integer)
minimum_d, maximum_d = extrema(d)
isfinite(minimum_d) || isfinite(maximum_d) || error("support is unbounded")

# we use pairwise summation (https://github.com/JuliaLang/julia/pull/31020)
result = if isfinite(minimum_d) && !(isfinite(maximum_d) && x >= div(minimum_d + maximum_d, 2))
c = sum(Base.Fix1(pdf, d), minimum_d:(max(x, minimum_d)))
c = sum(Broadcast.instantiate(Broadcast.broadcasted(pdf, (d,), minimum_d:(max(x, minimum_d)))))
x < minimum_d ? zero(c) : c
else
c = 1 - sum(Base.Fix1(pdf, d), (min(x + 1, maximum_d)):maximum_d)
c = 1 - sum(Broadcast.instantiate(Broadcast.broadcasted(pdf, (d,), min(x + 1, maximum_d):maximum_d)))
x >= maximum_d ? one(c) : c
end

Expand All @@ -565,11 +570,12 @@ function integerunitrange_ccdf(d::DiscreteUnivariateDistribution, x::Integer)
minimum_d, maximum_d = extrema(d)
isfinite(minimum_d) || isfinite(maximum_d) || error("support is unbounded")

# we use pairwise summation (https://github.com/JuliaLang/julia/pull/31020)
result = if isfinite(minimum_d) && !(isfinite(maximum_d) && x >= div(minimum_d + maximum_d, 2))
c = 1 - sum(Base.Fix1(pdf, d), minimum_d:(max(x, minimum_d)))
c = 1 - sum(Broadcast.instantiate(Broadcast.broadcasted(pdf, (d,), minimum_d:(max(x, minimum_d)))))
x < minimum_d ? one(c) : c
else
c = sum(Base.Fix1(pdf, d), (min(x + 1, maximum_d)):maximum_d)
c = sum(Broadcast.instantiate(Broadcast.broadcasted(pdf, (d,), min(x + 1, maximum_d):maximum_d)))
x >= maximum_d ? zero(c) : c
end

Expand Down