|
| 1 | +""" |
| 2 | + LKJ(d, η) |
| 3 | +
|
| 4 | +```julia |
| 5 | +d::Int dimension |
| 6 | +η::Real positive shape |
| 7 | +``` |
| 8 | +The [LKJ](https://doi.org/10.1016/j.jmva.2009.04.008) distribution is a distribution over |
| 9 | +``d\\times d`` real correlation matrices (positive-definite matrices with ones on the diagonal). |
| 10 | +If ``\\mathbf{R}\\sim \\textrm{LKJ}_{d}(\\eta)``, then its probability density function is |
| 11 | +
|
| 12 | +```math |
| 13 | +f(\\mathbf{R};\\eta) = \\left[\\prod_{k=1}^{d-1}\\pi^{\\frac{k}{2}} |
| 14 | +\\frac{\\Gamma\\left(\\eta+\\frac{d-1-k}{2}\\right)}{\\Gamma\\left(\\eta+\\frac{d-1}{2}\\right)}\\right]^{-1} |
| 15 | +|\\mathbf{R}|^{\\eta-1}. |
| 16 | +``` |
| 17 | +
|
| 18 | +If ``\\eta = 1``, then the LKJ distribution is uniform over |
| 19 | +[the space of correlation matrices](https://www.jstor.org/stable/2684832). |
| 20 | +""" |
| 21 | +struct LKJ{T <: Real, D <: Integer} <: ContinuousMatrixDistribution |
| 22 | + d::D |
| 23 | + η::T |
| 24 | + logc0::T |
| 25 | +end |
| 26 | + |
| 27 | +# ----------------------------------------------------------------------------- |
| 28 | +# Constructors |
| 29 | +# ----------------------------------------------------------------------------- |
| 30 | + |
| 31 | +function LKJ(d::Integer, η::Real; check_args = true) |
| 32 | + if check_args |
| 33 | + d > 0 || throw(ArgumentError("Matrix dimension must be positive.")) |
| 34 | + η > 0 || throw(ArgumentError("Shape parameter must be positive.")) |
| 35 | + end |
| 36 | + logc0 = lkj_logc0(d, η) |
| 37 | + T = Base.promote_eltype(η, logc0) |
| 38 | + LKJ{T, typeof(d)}(d, T(η), T(logc0)) |
| 39 | +end |
| 40 | + |
| 41 | +# ----------------------------------------------------------------------------- |
| 42 | +# REPL display |
| 43 | +# ----------------------------------------------------------------------------- |
| 44 | + |
| 45 | +show(io::IO, d::LKJ) = show_multline(io, d, [(:d, d.d), (:η, d.η)]) |
| 46 | + |
| 47 | +# ----------------------------------------------------------------------------- |
| 48 | +# Conversion |
| 49 | +# ----------------------------------------------------------------------------- |
| 50 | + |
| 51 | +function convert(::Type{LKJ{T}}, d::LKJ) where T <: Real |
| 52 | + LKJ{T, typeof(d.d)}(d.d, T(d.η), T(d.logc0)) |
| 53 | +end |
| 54 | + |
| 55 | +function convert(::Type{LKJ{T}}, d::Integer, η, logc0) where T <: Real |
| 56 | + LKJ{T, typeof(d)}(d, T(η), T(logc0)) |
| 57 | +end |
| 58 | + |
| 59 | +# ----------------------------------------------------------------------------- |
| 60 | +# Properties |
| 61 | +# ----------------------------------------------------------------------------- |
| 62 | + |
| 63 | +dim(d::LKJ) = d.d |
| 64 | + |
| 65 | +size(d::LKJ) = (dim(d), dim(d)) |
| 66 | + |
| 67 | +rank(d::LKJ) = dim(d) |
| 68 | + |
| 69 | +insupport(d::LKJ, R::AbstractMatrix) = isreal(R) && size(R) == size(d) && isone(Diagonal(R)) && isposdef(R) |
| 70 | + |
| 71 | +mean(d::LKJ) = Matrix{partype(d)}(I, dim(d), dim(d)) |
| 72 | + |
| 73 | +function mode(d::LKJ; check_args = true) |
| 74 | + η = params(d) |
| 75 | + if check_args |
| 76 | + η > 1 || throw(ArgumentError("mode is defined only when η > 1.")) |
| 77 | + end |
| 78 | + return mean(d) |
| 79 | +end |
| 80 | + |
| 81 | +function var(lkj::LKJ) |
| 82 | + d = dim(lkj) |
| 83 | + d > 1 || return zeros(d, d) |
| 84 | + σ² = var(_marginal(lkj)) |
| 85 | + σ² * (ones(partype(lkj), d, d) - I) |
| 86 | +end |
| 87 | + |
| 88 | +params(d::LKJ) = d.η |
| 89 | + |
| 90 | +@inline partype(d::LKJ{T}) where {T <: Real} = T |
| 91 | + |
| 92 | +# ----------------------------------------------------------------------------- |
| 93 | +# Evaluation |
| 94 | +# ----------------------------------------------------------------------------- |
| 95 | + |
| 96 | +function lkj_logc0(d::Integer, η::Real) |
| 97 | + d > 1 || return zero(η) |
| 98 | + if isone(η) |
| 99 | + if iseven(d) |
| 100 | + logc0 = -lkj_onion_loginvconst_uniform_even(d) |
| 101 | + else |
| 102 | + logc0 = -lkj_onion_loginvconst_uniform_odd(d) |
| 103 | + end |
| 104 | + else |
| 105 | + logc0 = -lkj_onion_loginvconst(d, η) |
| 106 | + end |
| 107 | + return logc0 |
| 108 | +end |
| 109 | + |
| 110 | +logkernel(d::LKJ, R::AbstractMatrix) = (d.η - 1) * logdet(R) |
| 111 | + |
| 112 | +_logpdf(d::LKJ, R::AbstractMatrix) = logkernel(d, R) + d.logc0 |
| 113 | + |
| 114 | +# ----------------------------------------------------------------------------- |
| 115 | +# Sampling |
| 116 | +# ----------------------------------------------------------------------------- |
| 117 | + |
| 118 | +function _rand!(rng::AbstractRNG, d::LKJ, R::AbstractMatrix) |
| 119 | + R .= _lkj_onion_sampler(d.d, d.η, rng) |
| 120 | +end |
| 121 | + |
| 122 | +function _lkj_onion_sampler(d::Integer, η::Real, rng::AbstractRNG = Random.GLOBAL_RNG) |
| 123 | + # Section 3.2 in LKJ (2009 JMA) |
| 124 | + # 1. Initialization |
| 125 | + R = ones(typeof(η), d, d) |
| 126 | + d > 1 || return R |
| 127 | + β = η + 0.5d - 1 |
| 128 | + u = rand(rng, Beta(β, β)) |
| 129 | + R[1, 2] = 2u - 1 |
| 130 | + R[2, 1] = R[1, 2] |
| 131 | + # 2. |
| 132 | + for k in 2:d - 1 |
| 133 | + # (a) |
| 134 | + β -= 0.5 |
| 135 | + # (b) |
| 136 | + y = rand(rng, Beta(k / 2, β)) |
| 137 | + # (c) |
| 138 | + u = randn(rng, k) |
| 139 | + u = u / norm(u) |
| 140 | + # (d) |
| 141 | + w = sqrt(y) * u |
| 142 | + A = cholesky(R[1:k, 1:k]).L |
| 143 | + z = A * w |
| 144 | + # (e) |
| 145 | + R[1:k, k + 1] = z |
| 146 | + R[k + 1, 1:k] = z' |
| 147 | + end |
| 148 | + # 3. |
| 149 | + return R |
| 150 | +end |
| 151 | + |
| 152 | +# ----------------------------------------------------------------------------- |
| 153 | +# The free elements of an LKJ matrix each have the same marginal distribution |
| 154 | +# ----------------------------------------------------------------------------- |
| 155 | + |
| 156 | +function _marginal(lkj::LKJ) |
| 157 | + d = lkj.d |
| 158 | + η = lkj.η |
| 159 | + α = η + 0.5d - 1 |
| 160 | + LocationScale(-1, 2, Beta(α, α)) |
| 161 | +end |
| 162 | + |
| 163 | +# ----------------------------------------------------------------------------- |
| 164 | +# Several redundant implementations of the recipricol integrating constant. |
| 165 | +# If f(R; n) = c₀ |R|ⁿ⁻¹, these give log(1 / c₀). |
| 166 | +# Every integrating constant formula given in LKJ (2009 JMA) is an expression |
| 167 | +# for 1 / c₀, even if they say that it is not. |
| 168 | +# ----------------------------------------------------------------------------- |
| 169 | + |
| 170 | +function lkj_onion_loginvconst(d::Integer, η::Real) |
| 171 | + # Equation (17) in LKJ (2009 JMA) |
| 172 | + sumlogs = zero(η) |
| 173 | + for k in 2:d - 1 |
| 174 | + sumlogs += 0.5k*logπ + loggamma(η + 0.5(d - 1 - k)) |
| 175 | + end |
| 176 | + α = η + 0.5d - 1 |
| 177 | + loginvconst = (2η + d - 3)*logtwo + logbeta(α, α) + sumlogs - (d - 2) * loggamma(η + 0.5(d - 1)) |
| 178 | + return loginvconst |
| 179 | +end |
| 180 | + |
| 181 | +function lkj_onion_loginvconst_uniform_odd(d::Integer) |
| 182 | + # Theorem 5 in LKJ (2009 JMA) |
| 183 | + sumlogs = 0.0 |
| 184 | + for k in 1:div(d - 1, 2) |
| 185 | + sumlogs += loggamma(2k) |
| 186 | + end |
| 187 | + loginvconst = 0.25(d^2 - 1)*logπ + sumlogs - 0.25(d - 1)^2*logtwo - (d - 1)*loggamma(0.5(d + 1)) |
| 188 | + return loginvconst |
| 189 | +end |
| 190 | + |
| 191 | +function lkj_onion_loginvconst_uniform_even(d::Integer) |
| 192 | + # Theorem 5 in LKJ (2009 JMA) |
| 193 | + sumlogs = 0.0 |
| 194 | + for k in 1:div(d - 2, 2) |
| 195 | + sumlogs += loggamma(2k) |
| 196 | + end |
| 197 | + loginvconst = 0.25d*(d - 2)*logπ + 0.25(3d^2 - 4d)*logtwo + d*loggamma(0.5d) + sumlogs - (d - 1)*loggamma(d) |
| 198 | +end |
| 199 | + |
| 200 | +function lkj_vine_loginvconst(d::Integer, η::Real) |
| 201 | + # Equation (16) in LKJ (2009 JMA) |
| 202 | + expsum = zero(η) |
| 203 | + betasum = zero(η) |
| 204 | + for k in 1:d - 1 |
| 205 | + α = η + 0.5(d - k - 1) |
| 206 | + expsum += (2η - 2 + d - k) * (d - k) |
| 207 | + betasum += (d - k) * logbeta(α, α) |
| 208 | + end |
| 209 | + loginvconst = expsum * logtwo + betasum |
| 210 | + return loginvconst |
| 211 | +end |
| 212 | + |
| 213 | +function lkj_vine_loginvconst_uniform(d::Integer) |
| 214 | + # Equation after (16) in LKJ (2009 JMA) |
| 215 | + expsum = 0.0 |
| 216 | + betasum = 0.0 |
| 217 | + for k in 1:d - 1 |
| 218 | + α = (k + 1) / 2 |
| 219 | + expsum += k ^ 2 |
| 220 | + betasum += k * logbeta(α, α) |
| 221 | + end |
| 222 | + loginvconst = expsum * logtwo + betasum |
| 223 | + return loginvconst |
| 224 | +end |
| 225 | + |
| 226 | +function lkj_loginvconst_alt(d::Integer, η::Real) |
| 227 | + # Third line in first proof of Section 3.3 in LKJ (2009 JMA) |
| 228 | + loginvconst = zero(η) |
| 229 | + for k in 1:d - 1 |
| 230 | + loginvconst += 0.5k*logπ + loggamma(η + 0.5(d - 1 - k)) - loggamma(η + 0.5(d - 1)) |
| 231 | + end |
| 232 | + return loginvconst |
| 233 | +end |
| 234 | + |
| 235 | +function corr_logvolume(n::Integer) |
| 236 | + # https://doi.org/10.4169/amer.math.monthly.123.9.909 |
| 237 | + logvol = 0.0 |
| 238 | + for k in 1:n - 1 |
| 239 | + logvol += 0.5k*logπ + k*loggamma((k+1)/2) - k*loggamma((k+2)/2) |
| 240 | + end |
| 241 | + return logvol |
| 242 | +end |
0 commit comments