From e5184dcc32502fd2eb89168394a0dbf113c3f87b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 15 Jan 2023 01:23:48 +0000 Subject: [PATCH] initial work on addressing perf issues for broadcasted logpdf --- Project.toml | 2 ++ src/DistributionsAD.jl | 1 + src/arraydist.jl | 42 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+) diff --git a/Project.toml b/Project.toml index 185b08a..c4a0420 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] @@ -36,5 +37,6 @@ SpecialFunctions = "0.8, 0.9, 0.10, 1, 2" StaticArrays = "0.12, 1.0" StatsBase = "0.32, 0.33" StatsFuns = "0.9.10, 1" +StructArrays = "0.6" ZygoteRules = "0.2" julia = "1.6" diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index 9be245a..bfe74cb 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -14,6 +14,7 @@ using PDMats, FillArrays, Adapt +using StructArrays: StructArrays using SpecialFunctions: logabsgamma, digamma using LinearAlgebra: copytri!, AbstractTriangular using Distributions: AbstractMvLogNormal, diff --git a/src/arraydist.jl b/src/arraydist.jl index 28e9e2b..0444c31 100644 --- a/src/arraydist.jl +++ b/src/arraydist.jl @@ -64,3 +64,45 @@ function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate) init = reshape(rand(rng, dist.dists[1]), :, 1) return mapreduce(Base.Fix1(rand, rng), hcat, view(dist.dists, 2:length(dist)); init = init) end + +# Lazy array dist +# HACK: Constructor which doesn't enforce the schema. +""" + StructArrayNoSchema(::Type{T}, cols::C) where {T, C<:StructArrays.Tup} + +Construct a `StructArray` without enforcing the schema of `T`. + +This is useful in scenarios where there's a mismatch between the constructor of `T` +and the `fieldnames(T)`. + +# Examples +```jldoctest +julia> using StructArrays, Distributions + +julia> # `Normal` has two fields `μ` and `σ`, but here we only provide `μ`. + StructArrayNoSchema(Normal, (zeros(2),)) +2-element StructArray(::Vector{Float64}) with eltype Normal: + Normal{Float64}(μ=0.0, σ=1.0) + Normal{Float64}(μ=0.0, σ=1.0) + +julia> # This is not allowed by `StructArray`: + StructArray{Normal}((zeros(2),)) +ERROR: NamedTuple names and field types must have matching lengths +[...] +``` +""" +function StructArrayNoSchema(::Type{T}, cols::C) where {T, C<:StructArrays.Tup} + N = isempty(cols) ? 1 : ndims(cols[1]) + StructArrays.StructArray{T, N, typeof(cols)}(cols) +end + + +arraydist(D::Type, args...) = arraydist(D, args) +arraydist(D::Type, args::Tuple) = arraydist(StructArrayNoSchema(D, args)) + +make_logpdf_closure(::Type{D}) where {D} = (x, args...) -> logpdf(D(args...), x) + +function Distributions.logpdf(dist::Product{<:Any,D,<:StructArrays.StructArray}, x::AbstractVector{<:Real}) where {D} + f = make_logpdf_closure(D) + return sum(f.(x, StructArrays.components(dist.v)...)) +end