Skip to content

Commit d521695

Browse files
authored
Add LKJ (#1066)
* add LKJ * test LKJ * Document LKJ * update integrating constant helpers, fix sign, update docstring * add explicit volume tests, and an importance sampling check * use nifty comprehension to simplify IS test * lil wording change * fix mode(), handle d = 1 edge case * update tests * constant behavior in edge case * short circuit arg check and add little space * test short circuit * Test LKJ logpdf against archived output from Stan * add links to Stan test set
1 parent fb90c3d commit d521695

File tree

7 files changed

+428
-2
lines changed

7 files changed

+428
-2
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
3131
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3232
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
3333
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
34+
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
3435
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3536

3637
[targets]
37-
test = ["Calculus", "Distributed", "FiniteDifferences", "ForwardDiff", "JSON", "StaticArrays", "Test"]
38+
test = ["Calculus", "Distributed", "FiniteDifferences", "ForwardDiff", "JSON",
39+
"StaticArrays", "HypothesisTests", "Test"]

docs/src/matrix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ MatrixNormal
3838
MatrixTDist
3939
MatrixBeta
4040
MatrixFDist
41+
LKJ
4142
```
4243

4344
## Internal Methods (for creating your own matrix-variate distributions)

src/Distributions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ export
106106
KSOneSided,
107107
Laplace,
108108
Levy,
109+
LKJ,
109110
LocationScale,
110111
Logistic,
111112
LogNormal,

src/matrix/lkj.jl

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
"""
2+
LKJ(d, η)
3+
4+
```julia
5+
d::Int dimension
6+
η::Real positive shape
7+
```
8+
The [LKJ](https://doi.org/10.1016/j.jmva.2009.04.008) distribution is a distribution over
9+
``d\\times d`` real correlation matrices (positive-definite matrices with ones on the diagonal).
10+
If ``\\mathbf{R}\\sim \\textrm{LKJ}_{d}(\\eta)``, then its probability density function is
11+
12+
```math
13+
f(\\mathbf{R};\\eta) = \\left[\\prod_{k=1}^{d-1}\\pi^{\\frac{k}{2}}
14+
\\frac{\\Gamma\\left(\\eta+\\frac{d-1-k}{2}\\right)}{\\Gamma\\left(\\eta+\\frac{d-1}{2}\\right)}\\right]^{-1}
15+
|\\mathbf{R}|^{\\eta-1}.
16+
```
17+
18+
If ``\\eta = 1``, then the LKJ distribution is uniform over
19+
[the space of correlation matrices](https://www.jstor.org/stable/2684832).
20+
"""
21+
struct LKJ{T <: Real, D <: Integer} <: ContinuousMatrixDistribution
22+
d::D
23+
η::T
24+
logc0::T
25+
end
26+
27+
# -----------------------------------------------------------------------------
28+
# Constructors
29+
# -----------------------------------------------------------------------------
30+
31+
function LKJ(d::Integer, η::Real; check_args = true)
32+
if check_args
33+
d > 0 || throw(ArgumentError("Matrix dimension must be positive."))
34+
η > 0 || throw(ArgumentError("Shape parameter must be positive."))
35+
end
36+
logc0 = lkj_logc0(d, η)
37+
T = Base.promote_eltype(η, logc0)
38+
LKJ{T, typeof(d)}(d, T(η), T(logc0))
39+
end
40+
41+
# -----------------------------------------------------------------------------
42+
# REPL display
43+
# -----------------------------------------------------------------------------
44+
45+
show(io::IO, d::LKJ) = show_multline(io, d, [(:d, d.d), (, d.η)])
46+
47+
# -----------------------------------------------------------------------------
48+
# Conversion
49+
# -----------------------------------------------------------------------------
50+
51+
function convert(::Type{LKJ{T}}, d::LKJ) where T <: Real
52+
LKJ{T, typeof(d.d)}(d.d, T(d.η), T(d.logc0))
53+
end
54+
55+
function convert(::Type{LKJ{T}}, d::Integer, η, logc0) where T <: Real
56+
LKJ{T, typeof(d)}(d, T(η), T(logc0))
57+
end
58+
59+
# -----------------------------------------------------------------------------
60+
# Properties
61+
# -----------------------------------------------------------------------------
62+
63+
dim(d::LKJ) = d.d
64+
65+
size(d::LKJ) = (dim(d), dim(d))
66+
67+
rank(d::LKJ) = dim(d)
68+
69+
insupport(d::LKJ, R::AbstractMatrix) = isreal(R) && size(R) == size(d) && isone(Diagonal(R)) && isposdef(R)
70+
71+
mean(d::LKJ) = Matrix{partype(d)}(I, dim(d), dim(d))
72+
73+
function mode(d::LKJ; check_args = true)
74+
η = params(d)
75+
if check_args
76+
η > 1 || throw(ArgumentError("mode is defined only when η > 1."))
77+
end
78+
return mean(d)
79+
end
80+
81+
function var(lkj::LKJ)
82+
d = dim(lkj)
83+
d > 1 || return zeros(d, d)
84+
σ² = var(_marginal(lkj))
85+
σ² * (ones(partype(lkj), d, d) - I)
86+
end
87+
88+
params(d::LKJ) = d.η
89+
90+
@inline partype(d::LKJ{T}) where {T <: Real} = T
91+
92+
# -----------------------------------------------------------------------------
93+
# Evaluation
94+
# -----------------------------------------------------------------------------
95+
96+
function lkj_logc0(d::Integer, η::Real)
97+
d > 1 || return zero(η)
98+
if isone(η)
99+
if iseven(d)
100+
logc0 = -lkj_onion_loginvconst_uniform_even(d)
101+
else
102+
logc0 = -lkj_onion_loginvconst_uniform_odd(d)
103+
end
104+
else
105+
logc0 = -lkj_onion_loginvconst(d, η)
106+
end
107+
return logc0
108+
end
109+
110+
logkernel(d::LKJ, R::AbstractMatrix) = (d.η - 1) * logdet(R)
111+
112+
_logpdf(d::LKJ, R::AbstractMatrix) = logkernel(d, R) + d.logc0
113+
114+
# -----------------------------------------------------------------------------
115+
# Sampling
116+
# -----------------------------------------------------------------------------
117+
118+
function _rand!(rng::AbstractRNG, d::LKJ, R::AbstractMatrix)
119+
R .= _lkj_onion_sampler(d.d, d.η, rng)
120+
end
121+
122+
function _lkj_onion_sampler(d::Integer, η::Real, rng::AbstractRNG = Random.GLOBAL_RNG)
123+
# Section 3.2 in LKJ (2009 JMA)
124+
# 1. Initialization
125+
R = ones(typeof(η), d, d)
126+
d > 1 || return R
127+
β = η + 0.5d - 1
128+
u = rand(rng, Beta(β, β))
129+
R[1, 2] = 2u - 1
130+
R[2, 1] = R[1, 2]
131+
# 2.
132+
for k in 2:d - 1
133+
# (a)
134+
β -= 0.5
135+
# (b)
136+
y = rand(rng, Beta(k / 2, β))
137+
# (c)
138+
u = randn(rng, k)
139+
u = u / norm(u)
140+
# (d)
141+
w = sqrt(y) * u
142+
A = cholesky(R[1:k, 1:k]).L
143+
z = A * w
144+
# (e)
145+
R[1:k, k + 1] = z
146+
R[k + 1, 1:k] = z'
147+
end
148+
# 3.
149+
return R
150+
end
151+
152+
# -----------------------------------------------------------------------------
153+
# The free elements of an LKJ matrix each have the same marginal distribution
154+
# -----------------------------------------------------------------------------
155+
156+
function _marginal(lkj::LKJ)
157+
d = lkj.d
158+
η = lkj.η
159+
α = η + 0.5d - 1
160+
LocationScale(-1, 2, Beta(α, α))
161+
end
162+
163+
# -----------------------------------------------------------------------------
164+
# Several redundant implementations of the recipricol integrating constant.
165+
# If f(R; n) = c₀ |R|ⁿ⁻¹, these give log(1 / c₀).
166+
# Every integrating constant formula given in LKJ (2009 JMA) is an expression
167+
# for 1 / c₀, even if they say that it is not.
168+
# -----------------------------------------------------------------------------
169+
170+
function lkj_onion_loginvconst(d::Integer, η::Real)
171+
# Equation (17) in LKJ (2009 JMA)
172+
sumlogs = zero(η)
173+
for k in 2:d - 1
174+
sumlogs += 0.5k*logπ + loggamma+ 0.5(d - 1 - k))
175+
end
176+
α = η + 0.5d - 1
177+
loginvconst = (2η + d - 3)*logtwo + logbeta(α, α) + sumlogs - (d - 2) * loggamma+ 0.5(d - 1))
178+
return loginvconst
179+
end
180+
181+
function lkj_onion_loginvconst_uniform_odd(d::Integer)
182+
# Theorem 5 in LKJ (2009 JMA)
183+
sumlogs = 0.0
184+
for k in 1:div(d - 1, 2)
185+
sumlogs += loggamma(2k)
186+
end
187+
loginvconst = 0.25(d^2 - 1)*logπ + sumlogs - 0.25(d - 1)^2*logtwo - (d - 1)*loggamma(0.5(d + 1))
188+
return loginvconst
189+
end
190+
191+
function lkj_onion_loginvconst_uniform_even(d::Integer)
192+
# Theorem 5 in LKJ (2009 JMA)
193+
sumlogs = 0.0
194+
for k in 1:div(d - 2, 2)
195+
sumlogs += loggamma(2k)
196+
end
197+
loginvconst = 0.25d*(d - 2)*logπ + 0.25(3d^2 - 4d)*logtwo + d*loggamma(0.5d) + sumlogs - (d - 1)*loggamma(d)
198+
end
199+
200+
function lkj_vine_loginvconst(d::Integer, η::Real)
201+
# Equation (16) in LKJ (2009 JMA)
202+
expsum = zero(η)
203+
betasum = zero(η)
204+
for k in 1:d - 1
205+
α = η + 0.5(d - k - 1)
206+
expsum += (2η - 2 + d - k) * (d - k)
207+
betasum += (d - k) * logbeta(α, α)
208+
end
209+
loginvconst = expsum * logtwo + betasum
210+
return loginvconst
211+
end
212+
213+
function lkj_vine_loginvconst_uniform(d::Integer)
214+
# Equation after (16) in LKJ (2009 JMA)
215+
expsum = 0.0
216+
betasum = 0.0
217+
for k in 1:d - 1
218+
α = (k + 1) / 2
219+
expsum += k ^ 2
220+
betasum += k * logbeta(α, α)
221+
end
222+
loginvconst = expsum * logtwo + betasum
223+
return loginvconst
224+
end
225+
226+
function lkj_loginvconst_alt(d::Integer, η::Real)
227+
# Third line in first proof of Section 3.3 in LKJ (2009 JMA)
228+
loginvconst = zero(η)
229+
for k in 1:d - 1
230+
loginvconst += 0.5k*logπ + loggamma+ 0.5(d - 1 - k)) - loggamma+ 0.5(d - 1))
231+
end
232+
return loginvconst
233+
end
234+
235+
function corr_logvolume(n::Integer)
236+
# https://doi.org/10.4169/amer.math.monthly.123.9.909
237+
logvol = 0.0
238+
for k in 1:n - 1
239+
logvol += 0.5k*logπ + k*loggamma((k+1)/2) - k*loggamma((k+2)/2)
240+
end
241+
return logvol
242+
end

src/matrixvariates.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ _logpdf(d::MatrixDistribution, x::AbstractArray)
213213
##### Specific distributions #####
214214

215215
for fname in ["wishart.jl", "inversewishart.jl", "matrixnormal.jl",
216-
"matrixtdist.jl", "matrixbeta.jl", "matrixfdist.jl"]
216+
"matrixtdist.jl", "matrixbeta.jl", "matrixfdist.jl",
217+
"lkj.jl"]
217218
include(joinpath("matrix", fname))
218219
end

0 commit comments

Comments
 (0)