Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple rand and eltype #1905

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 35 additions & 44 deletions docs/src/extends.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@ Whereas this package already provides a large collection of common distributions

Generally, you don't have to implement every API method listed in the documentation. This package provides a series of generic functions that turn a small number of internal methods into user-end API methods. What you need to do is to implement this small set of internal methods for your distributions.

By default, `Discrete` sampleables have the support of type `Int` while `Continuous` sampleables have the support of type `Float64`. If this assumption does not hold for your new distribution or sampler, or its `ValueSupport` is neither `Discrete` nor `Continuous`, you should implement the `eltype` method in addition to the other methods listed below.

**Note:** The methods that need to be implemented are different for distributions of different variate forms.


## Create a Sampler

Unlike full-fledged distributions, a sampler, in general, only provides limited functionalities, mainly to support sampling.
Expand All @@ -18,60 +15,48 @@ Unlike full-fledged distributions, a sampler, in general, only provides limited
To implement a univariate sampler, one can define a subtype (say `Spl`) of `Sampleable{Univariate,S}` (where `S` can be `Discrete` or `Continuous`), and provide a `rand` method, as

```julia
function rand(rng::AbstractRNG, s::Spl)
function Base.rand(rng::AbstractRNG, s::Spl)
# ... generate a single sample from s
end
```

The package already implements a vectorized version of `rand!` and `rand` that repeatedly calls the scalar version to generate multiple samples; as wells as a one arg version that uses the default random number generator.

### Multivariate Sampler
The package already implements vectorized versions `rand!(rng::AbstractRNG, s::Spl, dims::Int...)` and `rand(rng::AbstractRNG, s::Spl, dims::Int...)` that repeatedly call the scalar version to generate multiple samples.
Additionally, the package implements versions of these functions without the `rng::AbstractRNG` argument that use the default random number generator.

To implement a multivariate sampler, one can define a subtype of `Sampleable{Multivariate,S}`, and provide both `length` and `_rand!` methods, as
If there is a more efficient method to generate multiple samples, one should provide the following method

```julia
Base.length(s::Spl) = ... # return the length of each sample

function _rand!(rng::AbstractRNG, s::Spl, x::AbstractVector{T}) where T<:Real
# ... generate a single vector sample to x
function Random.rand!(rng::AbstractRNG, s::Spl, x::AbstractArray{<:Real})
# ... generate multiple samples from s in x
end
```

This function can assume that the dimension of `x` is correct, and doesn't need to perform dimension checking.
### Multivariate Sampler

The package implements both `rand` and `rand!` as follows (which you don't need to implement in general):
To implement a multivariate sampler, one can define a subtype of `Sampleable{Multivariate,S}`, and provide `length`, `rand`, and `rand!` methods, as

```julia
function _rand!(rng::AbstractRNG, s::Sampleable{Multivariate}, A::DenseMatrix)
for i = 1:size(A,2)
_rand!(rng, s, view(A,:,i))
end
return A
end
Base.length(s::Spl) = ... # return the length of each sample

function rand!(rng::AbstractRNG, s::Sampleable{Multivariate}, A::AbstractVector)
length(A) == length(s) ||
throw(DimensionMismatch("Output size inconsistent with sample length."))
_rand!(rng, s, A)
function Base.rand(rng::AbstractRNG, s::Spl)
# ... generate a single vector sample from s
end

function rand!(rng::AbstractRNG, s::Sampleable{Multivariate}, A::DenseMatrix)
size(A,1) == length(s) ||
throw(DimensionMismatch("Output size inconsistent with sample length."))
_rand!(rng, s, A)
@inline function Random.rand!(rng::AbstractRNG, s::Spl, x::AbstractVector{<:Real})
# `@inline` + `@boundscheck` allows users to skip bound checks by calling `@inbounds rand!(...)`
# Ref https://docs.julialang.org/en/v1/devdocs/boundscheck/#Eliding-bounds-checks
@boundscheck # ... check size (and possibly indices) of `x`
# ... generate a single vector sample from s in x
end

rand(rng::AbstractRNG, s::Sampleable{Multivariate,S}) where {S<:ValueSupport} =
_rand!(rng, s, Vector{eltype(S)}(length(s)))

rand(rng::AbstractRNG, s::Sampleable{Multivariate,S}, n::Int) where {S<:ValueSupport} =
_rand!(rng, s, Matrix{eltype(S)}(length(s), n))
```

If there is a more efficient method to generate multiple vector samples in a batch, one should provide the following method

```julia
function _rand!(rng::AbstractRNG, s::Spl, A::DenseMatrix{T}) where T<:Real
@inline function Random.rand!(rng::AbstractRNG, s::Spl, A::AbstractMatrix{<:Real})
# `@inline` + `@boundscheck` allows users to skip bound checks by calling `@inbounds rand!(...)`
# Ref https://docs.julialang.org/en/v1/devdocs/boundscheck/#Eliding-bounds-checks
@boundscheck # ... check size (and possibly indices) of `x`
# ... generate multiple vector samples in batch
end
```
Expand All @@ -80,17 +65,22 @@ Remember that each *column* of A is a sample.

### Matrix-variate Sampler

To implement a multivariate sampler, one can define a subtype of `Sampleable{Multivariate,S}`, and provide both `size` and `_rand!` methods, as
To implement a multivariate sampler, one can define a subtype of `Sampleable{Multivariate,S}`, and provide `size`, `rand`, and `rand!` methods, as

```julia
Base.size(s::Spl) = ... # the size of each matrix sample

function _rand!(rng::AbstractRNG, s::Spl, x::DenseMatrix{T}) where T<:Real
# ... generate a single matrix sample to x
function Base.rand(rng::AbstractRNG, s::Spl)
# ... generate a single matrix sample from s
end
```

Note that you can assume `x` has correct dimensions in `_rand!` and don't have to perform dimension checking, the generic `rand` and `rand!` will do dimension checking and array allocation for you.
@inline function Random.rand!(rng::AbstractRNG, s::Spl, x::AbstractMatrix{<:Real})
# `@inline` + `@boundscheck` allows users to skip bound checks by calling `@inbounds rand!(...)`
# Ref https://docs.julialang.org/en/v1/devdocs/boundscheck/#Eliding-bounds-checks
@boundscheck # ... check size (and possibly indices) of `x`
# ... generate a single matrix sample from s in x
end
```

## Create a Distribution

Expand All @@ -106,7 +96,7 @@ A univariate distribution type should be defined as a subtype of `DiscreteUnivar

The following methods need to be implemented for each univariate distribution type:

- [`rand(::AbstractRNG, d::UnivariateDistribution)`](@ref)
- [`Base.rand(::AbstractRNG, d::UnivariateDistribution)`](@ref)
- [`sampler(d::Distribution)`](@ref)
- [`logpdf(d::UnivariateDistribution, x::Real)`](@ref)
- [`cdf(d::UnivariateDistribution, x::Real)`](@ref)
Expand Down Expand Up @@ -138,8 +128,8 @@ The following methods need to be implemented for each multivariate distribution

- [`length(d::MultivariateDistribution)`](@ref)
- [`sampler(d::Distribution)`](@ref)
- [`eltype(d::Distribution)`](@ref)
- [`Distributions._rand!(::AbstractRNG, d::MultivariateDistribution, x::AbstractArray)`](@ref)
- [`Base.rand(::AbstractRNG, d::MultivariateDistribution)`](@ref)
- [`Random.rand!(::AbstractRNG, d::MultivariateDistribution, x::AbstractVector{<:Real})`](@ref)
- [`Distributions._logpdf(d::MultivariateDistribution, x::AbstractArray)`](@ref)

Note that if there exist faster methods for batch evaluation, one should override `_logpdf!` and `_pdf!`.
Expand All @@ -161,6 +151,7 @@ A matrix-variate distribution type should be defined as a subtype of `DiscreteMa
The following methods need to be implemented for each matrix-variate distribution type:

- [`size(d::MatrixDistribution)`](@ref)
- [`Distributions._rand!(rng::AbstractRNG, d::MatrixDistribution, A::AbstractMatrix)`](@ref)
- [`Base.rand(rng::AbstractRNG, d::MatrixDistribution)`](@ref)
- [`Random.rand!(rng::AbstractRNG, d::MatrixDistribution, A::AbstractMatrix{<:Real})`](@ref)
- [`sampler(d::MatrixDistribution)`](@ref)
- [`Distributions._logpdf(d::MatrixDistribution, x::AbstractArray)`](@ref)
1 change: 0 additions & 1 deletion docs/src/multivariate.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ The methods listed below are implemented for each multivariate distribution, whi
```@docs
length(::MultivariateDistribution)
size(::MultivariateDistribution)
eltype(::Type{MultivariateDistribution})
mean(::MultivariateDistribution)
var(::MultivariateDistribution)
cov(::MultivariateDistribution)
Expand Down
1 change: 0 additions & 1 deletion docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ The basic functionalities that a sampleable object provides are to *retrieve inf
length(::Sampleable)
size(::Sampleable)
nsamples(::Type{Sampleable}, ::Any)
eltype(::Type{Sampleable})
rand(::AbstractRNG, ::Sampleable)
rand!(::AbstractRNG, ::Sampleable, ::AbstractArray)
```
Expand Down
2 changes: 0 additions & 2 deletions src/censored.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@ function partype(d::Censored{<:UnivariateDistribution,<:ValueSupport,T}) where {
return promote_type(partype(d.uncensored), T)
end

Base.eltype(::Type{<:Censored{D,S,T}}) where {D,S,T} = promote_type(T, eltype(D))

#### Range and Support

isupperbounded(d::LeftCensored) = isupperbounded(d.uncensored)
Expand Down
2 changes: 0 additions & 2 deletions src/cholesky/lkjcholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ end
# Properties
# -----------------------------------------------------------------------------

Base.eltype(::Type{LKJCholesky{T}}) where {T} = T

function Base.size(d::LKJCholesky)
p = d.d
return (p, p)
Expand Down
37 changes: 29 additions & 8 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,35 @@ Base.size(s::Sampleable{Univariate}) = ()
Base.size(s::Sampleable{Multivariate}) = (length(s),)

"""
eltype(::Type{Sampleable})

The default element type of a sample. This is the type of elements of the samples generated
by the `rand` method. However, one can provide an array of different element types to
store the samples using `rand!`.
"""
Base.eltype(::Type{<:Sampleable{F,Discrete}}) where {F} = Int
Base.eltype(::Type{<:Sampleable{F,Continuous}}) where {F} = Float64
eltype(::Type{S}) where {S<:Distributions.Sampleable}

The default element type of a sample from a sampler of type `S`.

This is the type of elements of the samples generated by the `rand` method.
However, one can provide an array of different element types to store the samples using `rand!`.

!!! warn
This method is deprecated and will be removed in an upcoming breaking release.

"""
function Base.eltype(::Type{S}) where {VF<:VariateForm,VS<:Union{Discrete,Continuous},S<:Sampleable{VF,VS}}
Base.depwarn("`eltype(::Type{<:Distributions.Sampleable})` is deprecated and will be removed", :eltype)
T = Base.promote_op(rand, S)
if T === Union{}
# `T` can be `Union{}` if
# - `rand(::S)` is not defined and/or
# - `rand(::S)` calls `eltype(::S)` or `eltype(::Type{S})` but they are not specialized and fall back to the generic definition here
# (the latter case happens e.g. for non-univariate samplers that only implement `_rand!` and rely on the generic fallback for `rand`)
# In all these cases we return 1) `Int` for `Discrete` samplers and 2) `Float64` for `Continuous` samplers
if VS === Discrete
return Int
elseif VS === Continuous
return Float64
end
else
return eltype(T)
end
end

"""
nsamples(s::Sampleable)
Expand Down
34 changes: 11 additions & 23 deletions src/genericrand.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,35 +30,23 @@ function rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate})
end

# multiple samples
function rand(rng::AbstractRNG, s::Sampleable{Univariate}, dims::Dims)
out = Array{eltype(s)}(undef, dims)
return @inbounds rand!(rng, sampler(s), out)
# we use function barriers since for some distributions `sampler(s)` is not type-stable:
# https://github.com/JuliaStats/Distributions.jl/pull/1281
function rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate}, dims::Dims)
return _rand(rng, sampler(s), dims)
end
function rand(
rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate}, dims::Dims,
)
sz = size(s)
ax = map(Base.OneTo, dims)
out = [Array{eltype(s)}(undef, sz) for _ in Iterators.product(ax...)]
return @inbounds rand!(rng, sampler(s), out, false)
function _rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate}, dims::Dims)
r = rand(rng, s)
out = Array{typeof(r)}(undef, dims)
out[1] = r
rand!(rng, s, @view(out[2:end]))
return out
end

# these are workarounds for sampleables that incorrectly base `eltype` on the parameters
# this is a workaround for sampleables that incorrectly base `eltype` on the parameters
function rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate,Continuous})
return @inbounds rand!(rng, sampler(s), Array{float(eltype(s))}(undef, size(s)))
end
function rand(rng::AbstractRNG, s::Sampleable{Univariate,Continuous}, dims::Dims)
out = Array{float(eltype(s))}(undef, dims)
return @inbounds rand!(rng, sampler(s), out)
end
function rand(
rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate,Continuous}, dims::Dims,
)
sz = size(s)
ax = map(Base.OneTo, dims)
out = [Array{float(eltype(s))}(undef, sz) for _ in Iterators.product(ax...)]
return @inbounds rand!(rng, sampler(s), out, false)
end

"""
rand!([rng::AbstractRNG,] s::Sampleable, A::AbstractArray)
Expand Down
11 changes: 9 additions & 2 deletions src/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ end

length(d::DirichletCanon) = length(d.alpha)

Base.eltype(::Type{<:Dirichlet{T}}) where {T} = T

#### Conversions
convert(::Type{Dirichlet{T}}, cf::DirichletCanon) where {T<:Real} =
Dirichlet(convert(AbstractVector{T}, cf.alpha))
Expand Down Expand Up @@ -154,6 +152,15 @@ end

# sampling

function rand(rng::AbstractRNG, d::Union{Dirichlet,DirichletCanon})
x = map(αi -> rand(rng, Gamma(αi)), d.alpha)
return lmul!(inv(sum(x)), x)
end
function rand(rng::AbstractRNG, d::Dirichlet{<:Real,<:FillArrays.AbstractFill{<:Real}})
x = rand(rng, Gamma(FillArrays.getindex_value(d.alpha)), length(d))
return lmul!(inv(sum(x)), x)
end

function _rand!(rng::AbstractRNG,
d::Union{Dirichlet,DirichletCanon},
x::AbstractVector{<:Real})
Expand Down
2 changes: 2 additions & 0 deletions src/multivariate/dirichletmultinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ end


# Sampling
rand(rng::AbstractRNG, d::DirichletMultinomial) =
multinom_rand(rng, ntrials(d), rand(rng, Dirichlet(d.α)))
_rand!(rng::AbstractRNG, d::DirichletMultinomial, x::AbstractVector{<:Real}) =
multinom_rand!(rng, ntrials(d), rand(rng, Dirichlet(d.α)), x)

Expand Down
23 changes: 21 additions & 2 deletions src/multivariate/jointorderstatistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ maximum(d::JointOrderStatistics) = Fill(maximum(d.dist), length(d))

params(d::JointOrderStatistics) = tuple(params(d.dist)..., d.n, d.ranks)
partype(d::JointOrderStatistics) = partype(d.dist)
Base.eltype(::Type{<:JointOrderStatistics{D}}) where {D} = Base.eltype(D)
Base.eltype(d::JointOrderStatistics) = eltype(d.dist)

function logpdf(d::JointOrderStatistics, x::AbstractVector{<:Real})
n = d.n
Expand Down Expand Up @@ -125,6 +123,27 @@ function _marginalize_range(dist, i, j, xᵢ, xⱼ, T)
return k * T(logdiffcdf(dist, xⱼ, xᵢ)) - loggamma(T(k + 1))
end

function rand(rng::AbstractRNG, d::JointOrderStatistics)
n = d.n
if n == length(d.ranks) # ranks == 1:n
# direct method, slower than inversion method for large `n` and distributions with
# fast quantile function or that use inversion sampling
x = rand(rng, d.dist, n)
sort!(x)
else
# use exponential generation method with inversion, where for gaps in the ranks, we
# use the fact that the sum Y of k IID variables xₘ ~ Exp(1) is Y ~ Gamma(k, 1).
# Lurie, D., and H. O. Hartley. "Machine-generation of order statistics for Monte
# Carlo computations." The American Statistician 26.1 (1972): 26-27.
# this is slow if length(d.ranks) is close to n and quantile for d.dist is expensive,
# but this branch is probably taken when length(d.ranks) is small or much smaller than n.
xi = rand(rng, d.dist) # this is only used to obtain the type of samples from `d.dist`
x = Vector{typeof(xi)}(undef, length(d.ranks))
_rand!(rng, d, x)
end
return x
end

function _rand!(rng::AbstractRNG, d::JointOrderStatistics, x::AbstractVector{<:Real})
n = d.n
if n == length(d.ranks) # ranks == 1:n
Expand Down
1 change: 1 addition & 0 deletions src/multivariate/multinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ end
# Sampling

# if only a single sample is requested, no alias table is created
rand(rng::AbstractRNG, d::Multinomial) = multinom_rand(rng, ntrials(d), probs(d))
_rand!(rng::AbstractRNG, d::Multinomial, x::AbstractVector{<:Real}) =
multinom_rand!(rng, ntrials(d), probs(d), x)

Expand Down
15 changes: 13 additions & 2 deletions src/multivariate/mvlogitnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ canonform(d::MvLogitNormal{<:MvNormal}) = MvLogitNormal(canonform(d.normal))
# Properties

length(d::MvLogitNormal) = length(d.normal) + 1
Base.eltype(::Type{<:MvLogitNormal{D}}) where {D} = eltype(D)
Base.eltype(d::MvLogitNormal) = eltype(d.normal)
params(d::MvLogitNormal) = params(d.normal)
@inline partype(d::MvLogitNormal) = partype(d.normal)

Expand Down Expand Up @@ -88,6 +86,19 @@ kldivergence(p::MvLogitNormal, q::MvLogitNormal) = kldivergence(p.normal, q.norm

# Sampling

function rand(rng::AbstractRNG, d::MvLogitNormal)
x = rand(rng, d.normal)
push!(x, zero(eltype(x)))
StatsFuns.softmax!(x)
return x
end
function rand(rng::AbstractRNG, d::MvLogitNormal, n::Int)
r = rand(rng, d.normal, n)
x = vcat(r, zeros(eltype(r), 1, n))
StatsFuns.softmax!(x; dims=1)
return x
end

function _rand!(rng::AbstractRNG, d::MvLogitNormal, x::AbstractVecOrMat{<:Real})
y = @views _drop1(x)
rand!(rng, d.normal, y)
Expand Down
Loading
Loading