From 81bedf24c5f227e89c19c6f7407d02b4f79eaf19 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 30 Nov 2024 18:40:17 +0000 Subject: [PATCH] Use logcosh Co-authored-by: David Widmann --- src/bijectors/corr.jl | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index af46b1b9..cc90523a 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -353,7 +353,7 @@ function _inv_link_chol_lkj(Y::AbstractMatrix) for i in 1:(j - 1) z = tanh(Y[i, j]) W[i, j] = z * exp(log_remainder) - log_remainder += IrrationalConstants.logtwo + Y[i, j] - LogExpFunctions.log1pexp(2 * Y[i, j]) + log_remainder -= LogExpFunctions.logcosh(Y[i, j]) logJ += log_remainder end logJ += log_remainder @@ -380,7 +380,7 @@ function _inv_link_chol_lkj(y::AbstractVector) for i in 1:(j - 1) z = tanh(y[idx]) W[i, j] = z * exp(log_remainder) - log_remainder += IrrationalConstants.logtwo + y[idx] - LogExpFunctions.log1pexp(2 * y[idx]) + log_remainder -= LogExpFunctions.logcosh(y[idx]) logJ += log_remainder idx += 1 end @@ -460,13 +460,8 @@ function _logabsdetjac_inv_corr(Y::AbstractMatrix) K = LinearAlgebra.checksquare(Y) result = float(zero(eltype(Y))) - for j in 2:K, i in 1:(j - 1) - @inbounds abs_y_i_j = abs(Y[i, j]) - result += - (K - i + 1) * ( - IrrationalConstants.logtwo - - (abs_y_i_j + LogExpFunctions.log1pexp(-2 * abs_y_i_j)) - ) + @inbounds for j in 2:K, i in 1:(j - 1) + result += (K - i + 1) * (-LogExpFunctions.logcosh(Y[i, j])) end return result end @@ -495,7 +490,7 @@ function _logabsdetjac_inv_chol(y::AbstractVector) @inbounds for j in 2:K tmp = zero(result) for _ in 1:(j - 1) - logz = 2 * (IrrationalConstants.logtwo + y[idx] - LogExpFunctions.log1pexp(2 * y[idx])) + logz = -2 * LogExpFunctions.logcosh(y[idx]) result += logz + (tmp / 2) tmp += logz idx += 1