diff --git a/Project.toml b/Project.toml index c82c202..185b08a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DistributionsAD" uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -version = "0.6.42" +version = "0.6.43" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/README.md b/README.md index d10cae4..42b6948 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ AD of `logpdf` is fully supported and tested for the following distributions wrt - `Chi` - `Chisq` - `Cosine` + - `Distributions.AffineDistribution` - `Epanechnikov` - `Erlang` - `Exponential` @@ -38,7 +39,6 @@ AD of `logpdf` is fully supported and tested for the following distributions wrt - `Kolmogorov` - `Laplace` - `Levy` - - `LocationScale` - `Logistic` - `LogitNormal` - `LogNormal` diff --git a/src/common.jl b/src/common.jl index e967b3a..ee094ba 100644 --- a/src/common.jl +++ b/src/common.jl @@ -1,11 +1,13 @@ ## Linear Algebra ## +const CHOLESKY_NoPivot = VERSION >= v"1.8.0-rc1" ? LinearAlgebra.NoPivot() : Val(false) + function turing_chol(A::AbstractMatrix, check) chol = cholesky(A, check=check) (chol.factors, chol.info) end function turing_chol_back(A::AbstractMatrix, check) - C, chol_pullback = rrule(cholesky, A, Val(false), check=check) + C, chol_pullback = rrule(cholesky, A, CHOLESKY_NoPivot; check=check) function back(Δ) Ȳ = Tangent{typeof(C)}(; factors=Δ[1]) ∂C = chol_pullback(Ȳ)[2] @@ -19,7 +21,7 @@ function symm_turing_chol(A::AbstractMatrix, check, uplo) (chol.factors, chol.info) end function symm_turing_chol_back(A::AbstractMatrix, check, uplo) - C, chol_pullback = rrule(cholesky, Symmetric(A,uplo), Val(false), check=check) + C, chol_pullback = rrule(cholesky, Symmetric(A,uplo), CHOLESKY_NoPivot; check=check) function back(Δ) Ȳ = Tangent{typeof(C)}(; factors=Δ[1]) ∂C = chol_pullback(Ȳ)[2] diff --git a/src/flatten.jl b/src/flatten.jl index d175049..a8e1827 100644 --- a/src/flatten.jl +++ b/src/flatten.jl @@ -46,7 +46,7 @@ const flattened_dists = [ Bernoulli, Kolmogorov, Laplace, Levy, - LocationScale, + Distributions.AffineDistribution, Logistic, LogitNormal, LogNormal, diff --git a/src/reversediff.jl b/src/reversediff.jl index c01d09a..54c97ef 100644 --- a/src/reversediff.jl +++ b/src/reversediff.jl @@ -148,14 +148,14 @@ function Gamma(α::T, θ::T; check_args=true) where {T <: TrackedReal} end # Work around to stop TrackedReal of Inf and -Inf from producing NaN in the derivative -function Base.minimum(d::LocationScale{T}) where {T <: TrackedReal} +function Base.minimum(d::Distributions.AffineDistribution{T}) where {T <: TrackedReal} if isfinite(minimum(d.ρ)) return d.μ + d.σ * minimum(d.ρ) else return convert(T, ReverseDiff.@skip(minimum)(d.ρ)) end end -function Base.maximum(d::LocationScale{T}) where {T <: TrackedReal} +function Base.maximum(d::Distributions.AffineDistribution{T}) where {T <: TrackedReal} if isfinite(minimum(d.ρ)) return d.μ + d.σ * maximum(d.ρ) else diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index d3ce2eb..8baa50c 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -143,7 +143,8 @@ DistSpec(Levy, (0.0,), 0.5), DistSpec(Levy, (0.0, 2.0), 0.5), - DistSpec((a, b) -> LocationScale(a, b, Normal()), (1.0, 2.0), 0.5), + # Test AffineDistribution + DistSpec((a, b) -> a + b * Beta(), (1.0, 2.0), 2.0), DistSpec(Logistic, (), 0.5), DistSpec(Logistic, (1.0,), 0.5),