From 0ea55020385a11ef372078eab64deb5d0e9b9725 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 26 Sep 2024 01:40:09 +0200 Subject: [PATCH] Decouple `rand` and `eltype` --- src/genericrand.jl | 34 ++++++++---------------- src/multivariate/dirichlet.jl | 9 +++++++ src/multivariate/dirichletmultinomial.jl | 2 ++ src/multivariate/jointorderstatistics.jl | 21 +++++++++++++++ src/multivariate/multinomial.jl | 1 + src/multivariate/mvlogitnormal.jl | 13 +++++++++ src/multivariate/mvlognormal.jl | 11 ++++++++ src/multivariate/mvnormal.jl | 11 ++++++++ src/multivariate/mvnormalcanon.jl | 11 ++++++++ src/multivariate/mvtdist.jl | 15 +++++++++++ src/multivariates.jl | 18 ++++++++++--- src/samplers/multinomial.jl | 6 +++++ src/univariate/continuous/uniform.jl | 2 +- src/univariate/orderstatistic.jl | 3 +-- src/univariates.jl | 4 ++- test/testutils.jl | 14 ++++++++-- test/univariate/continuous/logistic.jl | 17 ++++++++++-- test/univariate/continuous/tdist.jl | 2 +- test/univariate/continuous/uniform.jl | 8 ++++++ 19 files changed, 166 insertions(+), 36 deletions(-) diff --git a/src/genericrand.jl b/src/genericrand.jl index 6b3b213f16..7ec10f12fa 100644 --- a/src/genericrand.jl +++ b/src/genericrand.jl @@ -30,35 +30,23 @@ function rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate}) end # multiple samples -function rand(rng::AbstractRNG, s::Sampleable{Univariate}, dims::Dims) - out = Array{eltype(s)}(undef, dims) - return @inbounds rand!(rng, sampler(s), out) +# we use function barriers since for some distributions `sampler(s)` is not type-stable: +# https://github.com/JuliaStats/Distributions.jl/pull/1281 +function rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate}, dims::Dims) + return _rand(rng, sampler(s), dims) end -function rand( - rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate}, dims::Dims, -) - sz = size(s) - ax = map(Base.OneTo, dims) - out = [Array{eltype(s)}(undef, sz) for _ in Iterators.product(ax...)] - return @inbounds rand!(rng, sampler(s), out, false) +function _rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate}, dims::Dims) + r = rand(rng, s) + out = Array{typeof(r)}(undef, dims) + out[1] = r + rand!(rng, s, @view(out[2:end])) + return out end -# these are workarounds for sampleables that incorrectly base `eltype` on the parameters +# this is a workaround for sampleables that incorrectly base `eltype` on the parameters function rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate,Continuous}) return @inbounds rand!(rng, sampler(s), Array{float(eltype(s))}(undef, size(s))) end -function rand(rng::AbstractRNG, s::Sampleable{Univariate,Continuous}, dims::Dims) - out = Array{float(eltype(s))}(undef, dims) - return @inbounds rand!(rng, sampler(s), out) -end -function rand( - rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate,Continuous}, dims::Dims, -) - sz = size(s) - ax = map(Base.OneTo, dims) - out = [Array{float(eltype(s))}(undef, sz) for _ in Iterators.product(ax...)] - return @inbounds rand!(rng, sampler(s), out, false) -end """ rand!([rng::AbstractRNG,] s::Sampleable, A::AbstractArray) diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index b24980ec98..4c2e6628ba 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -154,6 +154,15 @@ end # sampling +function rand(rng::AbstractRNG, d::Union{Dirichlet,DirichletCanon}) + x = map(αi -> rand(rng, Gamma(αi)), d.alpha) + return lmul!(inv(sum(x)), x) +end +function rand(rng::AbstractRNG, d::Dirichlet{<:Real,<:FillArrays.AbstractFill{<:Real}}) + x = rand(rng, Gamma(FillArrays.getindex_value(d.alpha)), length(d)) + return lmul!(inv(sum(x)), x) +end + function _rand!(rng::AbstractRNG, d::Union{Dirichlet,DirichletCanon}, x::AbstractVector{<:Real}) diff --git a/src/multivariate/dirichletmultinomial.jl b/src/multivariate/dirichletmultinomial.jl index eb15990cb2..b8eb519918 100644 --- a/src/multivariate/dirichletmultinomial.jl +++ b/src/multivariate/dirichletmultinomial.jl @@ -97,6 +97,8 @@ end # Sampling +rand(rng::AbstractRNG, d::DirichletMultinomial) = + multinom_rand(rng, ntrials(d), rand(rng, Dirichlet(d.α))) _rand!(rng::AbstractRNG, d::DirichletMultinomial, x::AbstractVector{<:Real}) = multinom_rand!(rng, ntrials(d), rand(rng, Dirichlet(d.α)), x) diff --git a/src/multivariate/jointorderstatistics.jl b/src/multivariate/jointorderstatistics.jl index 1fbed0d1b6..9c4811bfb7 100644 --- a/src/multivariate/jointorderstatistics.jl +++ b/src/multivariate/jointorderstatistics.jl @@ -125,6 +125,27 @@ function _marginalize_range(dist, i, j, xᵢ, xⱼ, T) return k * T(logdiffcdf(dist, xⱼ, xᵢ)) - loggamma(T(k + 1)) end +function rand(rng::AbstractRNG, d::JointOrderStatistics) + n = d.n + if n == length(d.ranks) # ranks == 1:n + # direct method, slower than inversion method for large `n` and distributions with + # fast quantile function or that use inversion sampling + x = rand(rng, d.dist, n) + sort!(x) + else + # use exponential generation method with inversion, where for gaps in the ranks, we + # use the fact that the sum Y of k IID variables xₘ ~ Exp(1) is Y ~ Gamma(k, 1). + # Lurie, D., and H. O. Hartley. "Machine-generation of order statistics for Monte + # Carlo computations." The American Statistician 26.1 (1972): 26-27. + # this is slow if length(d.ranks) is close to n and quantile for d.dist is expensive, + # but this branch is probably taken when length(d.ranks) is small or much smaller than n. + xi = rand(rng, d.dist) # this is only used to obtain the type of samples from `d.dist` + x = Vector{typeof(xi)}(undef, length(d.ranks)) + _rand!(rng, d, x) + end + return x +end + function _rand!(rng::AbstractRNG, d::JointOrderStatistics, x::AbstractVector{<:Real}) n = d.n if n == length(d.ranks) # ranks == 1:n diff --git a/src/multivariate/multinomial.jl b/src/multivariate/multinomial.jl index c4db44b85b..76d4776e06 100644 --- a/src/multivariate/multinomial.jl +++ b/src/multivariate/multinomial.jl @@ -165,6 +165,7 @@ end # Sampling # if only a single sample is requested, no alias table is created +rand(rng::AbstractRNG, d::Multinomial) = multinom_rand(rng, ntrials(d), probs(d)) _rand!(rng::AbstractRNG, d::Multinomial, x::AbstractVector{<:Real}) = multinom_rand!(rng, ntrials(d), probs(d), x) diff --git a/src/multivariate/mvlogitnormal.jl b/src/multivariate/mvlogitnormal.jl index 0d60ddf654..893dca9733 100644 --- a/src/multivariate/mvlogitnormal.jl +++ b/src/multivariate/mvlogitnormal.jl @@ -88,6 +88,19 @@ kldivergence(p::MvLogitNormal, q::MvLogitNormal) = kldivergence(p.normal, q.norm # Sampling +function rand(rng::AbstractRNG, d::MvLogitNormal) + x = rand(rng, d.normal) + push!(x, zero(eltype(x))) + StatsFuns.softmax!(x) + return x +end +function rand(rng::AbstractRNG, d::MvLogitNormal, n::Int) + r = rand(rng, d.normal, n) + x = vcat(r, zeros(eltype(r), 1, n)) + StatsFuns.softmax!(x; dims=1) + return x +end + function _rand!(rng::AbstractRNG, d::MvLogitNormal, x::AbstractVecOrMat{<:Real}) y = @views _drop1(x) rand!(rng, d.normal, y) diff --git a/src/multivariate/mvlognormal.jl b/src/multivariate/mvlognormal.jl index 1eecd38c2f..eaf11ef4be 100644 --- a/src/multivariate/mvlognormal.jl +++ b/src/multivariate/mvlognormal.jl @@ -232,6 +232,17 @@ var(d::MvLogNormal) = diag(cov(d)) entropy(d::MvLogNormal) = length(d)*(1+log2π)/2 + logdetcov(d.normal)/2 + sum(mean(d.normal)) #See https://en.wikipedia.org/wiki/Log-normal_distribution +function rand(rng::AbstractRNG, d::MvLogNormal) + x = rand(rng, d.normal) + map!(exp, x, x) + return x +end +function rand(rng::AbstractRNG, d::MvLogNormal, n::Int) + xs = rand(rng, d.normal, n) + map!(exp, xs, xs) + return xs +end + function _rand!(rng::AbstractRNG, d::MvLogNormal, x::AbstractVecOrMat{<:Real}) _rand!(rng, d.normal, x) map!(exp, x, x) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 6126c1d8f1..5b6e39b41e 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -273,6 +273,17 @@ gradlogpdf(d::MvNormal, x::AbstractVector{<:Real}) = -(d.Σ \ (x .- d.μ)) # Sampling (for GenericMvNormal) +function rand(rng::AbstractRNG, d::MvNormal) + x = unwhiten!(d.Σ, randn(rng, float(partype(d)), length(d))) + x .+= d.μ + return x +end +function rand(rng::AbstractRNG, d::MvNormal, n::Int) + x = unwhiten!(d.Σ, randn(rng, float(partype(d)), length(d), n)) + x .+= d.μ + return x +end + function _rand!(rng::AbstractRNG, d::MvNormal, x::VecOrMat) unwhiten!(d.Σ, randn!(rng, x)) x .+= d.μ diff --git a/src/multivariate/mvnormalcanon.jl b/src/multivariate/mvnormalcanon.jl index 587b20ba5c..067bdd73e1 100644 --- a/src/multivariate/mvnormalcanon.jl +++ b/src/multivariate/mvnormalcanon.jl @@ -177,6 +177,17 @@ if isdefined(PDMats, :PDSparseMat) unwhiten_winv!(J::PDSparseMat, x::AbstractVecOrMat) = x[:] = J.chol.PtL' \ x end +function rand(rng::AbstractRNG, d::MvNormalCanon) + x = unwhiten_winv!(d.J, randn(rng, float(partype(d)), length(d))) + x .+= d.μ + return x +end +function rand(rng::AbstractRNG, d::MvNormalCanon, n::Int) + x = unwhiten_winv!(d.J, randn(rng, float(partype(d)), length(d), n)) + x .+= d.μ + return x +end + function _rand!(rng::AbstractRNG, d::MvNormalCanon, x::AbstractVector) unwhiten_winv!(d.J, randn!(rng, x)) x .+= d.μ diff --git a/src/multivariate/mvtdist.jl b/src/multivariate/mvtdist.jl index 9076c364b5..7359fa3610 100644 --- a/src/multivariate/mvtdist.jl +++ b/src/multivariate/mvtdist.jl @@ -155,6 +155,21 @@ function gradlogpdf(d::GenericMvTDist, x::AbstractVector{<:Real}) end # Sampling (for GenericMvTDist) +function rand(rng::AbstractRNG, d::GenericMvTDist) + chisqd = Chisq{partype(d)}(d.df) + y = sqrt(rand(rng, chisqd) / d.df) + x = unwhiten!(d.Σ, randn(rng, typeof(y), length(d))) + x .= x ./ y .+ d.μ + x +end +function rand(rng::AbstractRNG, d::GenericMvTDist, n::Int) + chisqd = Chisq{partype(d)}(d.df) + y = rand(rng, chisqd, n) + x = unwhiten!(d.Σ, randn(rng, eltype(y), length(d), n)) + x .= x ./ sqrt.(y' ./ d.df) .+ d.μ + x +end + function _rand!(rng::AbstractRNG, d::GenericMvTDist, x::AbstractVector{<:Real}) chisqd = Chisq{partype(d)}(d.df) y = sqrt(rand(rng, chisqd) / d.df) diff --git a/src/multivariates.jl b/src/multivariates.jl index 56d91233cf..6d6a456786 100644 --- a/src/multivariates.jl +++ b/src/multivariates.jl @@ -18,10 +18,20 @@ size(d::MultivariateDistribution) # multiple multivariate, must allocate matrix # TODO: inconsistency with other `ArrayLikeVariate`s and `rand(s, (n,))` - maybe remove? -rand(rng::AbstractRNG, s::Sampleable{Multivariate}, n::Int) = - @inbounds rand!(rng, sampler(s), Matrix{eltype(s)}(undef, length(s), n)) -rand(rng::AbstractRNG, s::Sampleable{Multivariate,Continuous}, n::Int) = - @inbounds rand!(rng, sampler(s), Matrix{float(eltype(s))}(undef, length(s), n)) +function rand(rng::AbstractRNG, s::Sampleable{Multivariate}, n::Int) + return _rand(rng, sampler(s), n) +end +function _rand(rng, s::Sampleable{Multivariate}, n::Int) + r = rand(rng, s) + out = Matrix{eltype(r)}(undef, length(r), n) + if n > 0 + copyto!(out, r) + if n > 1 + rand!(rng, s, @view(out[:, 2:n])) + end + end + return out +end ## domain diff --git a/src/samplers/multinomial.jl b/src/samplers/multinomial.jl index 7ce8b89730..c09c92acea 100644 --- a/src/samplers/multinomial.jl +++ b/src/samplers/multinomial.jl @@ -1,3 +1,6 @@ +function multinom_rand(rng::AbstractRNG, n::Int, p::AbstractVector{<:Real}) + return multinom_rand!(rng, n, p, Vector{Int}(undef, length(p))) +end function multinom_rand!(rng::AbstractRNG, n::Int, p::AbstractVector{<:Real}, x::AbstractVector{<:Real}) k = length(p) @@ -49,6 +52,9 @@ function MultinomialSampler(n::Int, prob::Vector{<:Real}) return MultinomialSampler(n, prob, AliasTable(prob)) end +function rand(rng::AbstractRNG, s::MultinomialSampler) + return _rand!(rng, s, Vector{Int}(undef, length(s.prob))) +end function _rand!(rng::AbstractRNG, s::MultinomialSampler, x::AbstractVector{<:Real}) n = s.n diff --git a/src/univariate/continuous/uniform.jl b/src/univariate/continuous/uniform.jl index d4beca7d79..3fd4c507de 100644 --- a/src/univariate/continuous/uniform.jl +++ b/src/univariate/continuous/uniform.jl @@ -151,7 +151,7 @@ Base.:*(c::Real, d::Uniform) = Uniform(minmax(c * d.a, c * d.b)...) #### Sampling -rand(rng::AbstractRNG, d::Uniform) = d.a + (d.b - d.a) * rand(rng) +rand(rng::AbstractRNG, d::Uniform{T}) where {T} = d.a + (d.b - d.a) * rand(rng, float(T)) _rand!(rng::AbstractRNG, d::Uniform, A::AbstractArray{<:Real}) = A .= Base.Fix1(quantile, d).(rand!(rng, A)) diff --git a/src/univariate/orderstatistic.jl b/src/univariate/orderstatistic.jl index 1a7055ef91..5424194504 100644 --- a/src/univariate/orderstatistic.jl +++ b/src/univariate/orderstatistic.jl @@ -102,7 +102,6 @@ end function rand(rng::AbstractRNG, d::OrderStatistic) # inverse transform sampling. Since quantile function is Qₓ(Uᵢₙ⁻¹(p)), we draw a random # variable from Uᵢₙ and pass it through the quantile function of `d.dist` - T = eltype(d.dist) b = _uniform_orderstatistic(d) - return T(quantile(d.dist, rand(rng, b))) + return quantile(d.dist, float(partype(d.dist))(rand(rng, b))) end diff --git a/src/univariates.jl b/src/univariates.jl index b60e5a2949..68daebb0cf 100644 --- a/src/univariates.jl +++ b/src/univariates.jl @@ -154,7 +154,9 @@ end Generate a scalar sample from `d`. The general fallback is `quantile(d, rand())`. """ -rand(rng::AbstractRNG, d::UnivariateDistribution) = quantile(d, rand(rng)) +function rand(rng::AbstractRNG, d::UnivariateDistribution) + return quantile(d, rand(rng, float(partype(d)))) +end ## statistics diff --git a/test/testutils.jl b/test/testutils.jl index 8acd378535..9af33584a4 100644 --- a/test/testutils.jl +++ b/test/testutils.jl @@ -165,7 +165,12 @@ function test_samples(s::Sampleable{Univariate, Discrete}, # the sampleable samples3 = [rand(rng3, s) for _ in 1:n] samples4 = [rand(rng4, s) for _ in 1:n] end - @test length(samples) == n + T = typeof(rand(s)) + @test samples isa Vector{T} + @test samples2 isa Vector{T} + @test samples3 isa Vector{T} + @test samples4 isa Vector{T} + @test length(samples) == length(samples2) == length(samples3) == length(samples4) == n @test samples2 == samples @test samples3 == samples4 @@ -289,7 +294,12 @@ function test_samples(s::Sampleable{Univariate, Continuous}, # the sampleable samples3 = [rand(rng3, s) for _ in 1:n] samples4 = [rand(rng4, s) for _ in 1:n] end - @test length(samples) == n + T = typeof(rand(s)) + @test samples isa Vector{T} + @test samples2 isa Vector{T} + @test samples3 isa Vector{T} + @test samples4 isa Vector{T} + @test length(samples) == length(samples2) == length(samples3) == length(samples4) == n @test samples2 == samples @test samples3 == samples4 diff --git a/test/univariate/continuous/logistic.jl b/test/univariate/continuous/logistic.jl index 3eb0b8f2d0..1cd0772fde 100644 --- a/test/univariate/continuous/logistic.jl +++ b/test/univariate/continuous/logistic.jl @@ -1,2 +1,15 @@ -test_cgf(Logistic(0, 1), (-0.99,0.99, 1f-2, -1f-2)) -test_cgf(Logistic(100,10), (-0.099,0.099, 1f-2, -1f-2)) +using Distributions +using Test + +@testset "Logistic" begin + test_cgf(Logistic(0, 1), (-0.99,0.99, 1f-2, -1f-2)) + test_cgf(Logistic(100,10), (-0.099,0.099, 1f-2, -1f-2)) + + # issue 1082 + @testset "rand consistency" begin + for T in (Float32, Float64, BigFloat) + @test @inferred(rand(Logistic(T(0), T(1)))) isa T + @test @inferred(rand(Logistic(T(0), T(1)), 5)) isa Vector{T} + end + end +end diff --git a/test/univariate/continuous/tdist.jl b/test/univariate/continuous/tdist.jl index 127b992434..652af9cbd9 100644 --- a/test/univariate/continuous/tdist.jl +++ b/test/univariate/continuous/tdist.jl @@ -10,10 +10,10 @@ using Test @inferred(rand(TDist(big"1.0"))) end @inferred(rand(TDist(ForwardDiff.Dual(1.0)))) - end for T in (Float32, Float64) @test @inferred(rand(TDist(T(1)))) isa T + @test @inferred(rand(TDist(T(1)), 5)) isa Vector{T} end end diff --git a/test/univariate/continuous/uniform.jl b/test/univariate/continuous/uniform.jl index e3a5d729ed..0d928d0870 100644 --- a/test/univariate/continuous/uniform.jl +++ b/test/univariate/continuous/uniform.jl @@ -114,4 +114,12 @@ using Test end end end + + # issues 1252 and 1783 + @testset "rand consistency" begin + for T in (Float32, Float64, BigFloat) + @test @inferred(rand(Uniform(T(0), T(1)))) isa T + @test @inferred(rand(Uniform(T(0), T(1)), 5)) isa Vector{T} + end + end end