Skip to content

Commit 3d58304

Browse files
committed
Numerical improvements to LKJCholesky forward transform
1 parent 8a525f1 commit 3d58304

File tree

3 files changed

+50
-20
lines changed

3 files changed

+50
-20
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Bijectors"
22
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
3-
version = "0.15.6"
3+
version = "0.15.7"
44

55
[deps]
66
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

src/bijectors/corr.jl

+14-13
Original file line numberDiff line numberDiff line change
@@ -293,15 +293,15 @@ which is the above implementation.
293293
function _link_chol_lkj(W::AbstractMatrix)
294294
K = LinearAlgebra.checksquare(W)
295295

296-
y = similar(W) # z is also UpperTriangular.
296+
y = similar(W) # W is upper triangular.
297297
# Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero.
298298

299299
@inbounds for j in 1:K
300-
remainder_sq = one(eltype(W))
301-
for i in 1:(j - 1)
300+
remainder_sq = W[j, j]^2
301+
for i in (j - 1):-1:1
302302
z = W[i, j] / sqrt(remainder_sq)
303-
y[i, j] = atanh(z)
304-
remainder_sq -= W[i, j]^2
303+
y[i, j] = asinh(z)
304+
remainder_sq += W[i, j]^2
305305
end
306306
for i in j:K
307307
y[i, j] = 0
@@ -317,17 +317,18 @@ function _link_chol_lkj_from_upper(W::AbstractMatrix)
317317

318318
y = similar(W, N)
319319

320-
idx = 1
320+
starting_idx = 1
321321
@inbounds for j in 2:K
322-
y[idx] = atanh(W[1, j])
323-
idx += 1
324-
remainder_sq = 1 - W[1, j]^2
325-
for i in 2:(j - 1)
322+
y[starting_idx] = atanh(W[1, j])
323+
starting_idx += 1
324+
remainder_sq = W[j, j]^2
325+
for i in (j - 1):-1:2
326+
idx = starting_idx + i - 2
326327
z = W[i, j] / sqrt(remainder_sq)
327-
y[idx] = atanh(z)
328-
remainder_sq -= W[i, j]^2
329-
idx += 1
328+
y[idx] = asinh(z)
329+
remainder_sq += W[i, j]^2
330330
end
331+
starting_idx += length((j - 1):-1:2)
331332
end
332333

333334
return y

test/transform.jl

+35-6
Original file line numberDiff line numberDiff line change
@@ -237,18 +237,47 @@ end
237237
end
238238

239239
@testset "LKJCholesky" begin
240+
# Convert Cholesky factor to its free parameters, i.e. its off-diagonal elements
241+
function chol_3by3_to_free_params(x::Cholesky)
242+
if x.uplo == :U
243+
return [x.U[1, 2], x.U[1, 3], x.U[2, 3]]
244+
else
245+
return [x.L[2, 1], x.L[3, 1], x.L[3, 2]]
246+
end
247+
# TODO: Generalise to arbitrary dimension using this code:
248+
# inds = [
249+
# LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if
250+
# (uplo === :L && I[2] < I[1]) || (uplo === :U && I[2] > I[1])
251+
# ]
252+
end
253+
254+
# Reconstruct Cholesky factor from its free parameters
255+
# Note that x[i, i] is always positive so we don't need to worry about the sign
256+
function free_params_to_chol_3by3(free_params::AbstractVector, uplo::Symbol)
257+
x = UpperTriangular(zeros(eltype(free_params), 3, 3))
258+
x[1, 1] = 1
259+
x[1, 2] = free_params[1]
260+
x[1, 3] = free_params[2]
261+
x[2, 2] = sqrt(1 - free_params[1]^2)
262+
x[2, 3] = free_params[3]
263+
x[3, 3] = sqrt(1 - free_params[2]^2 - free_params[3]^2)
264+
if uplo == :U
265+
return Cholesky(x)
266+
else
267+
return Cholesky(transpose(x))
268+
end
269+
end
270+
240271
@testset "uplo: $uplo" for uplo in [:L, :U]
241272
dist = LKJCholesky(3, 1, uplo)
242273
single_sample_tests(dist)
243274

244275
x = rand(dist)
245276

246-
inds = [
247-
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if
248-
(uplo === :L && I[2] < I[1]) || (uplo === :U && I[2] > I[1])
249-
]
250-
J = ForwardDiff.jacobian(z -> link(dist, Cholesky(z, x.uplo, x.info)), x.UL)
251-
J = J[:, inds]
277+
# Here, we need to pass ForwardDiff only the free parameters of the
278+
# Cholesky factor so that we get a square Jacobian matrix
279+
free_params = chol_3by3_to_free_params(x)
280+
J = ForwardDiff.jacobian(z -> link(dist, free_params_to_chol_3by3(z, uplo)), free_params)
252281
logpdf_turing = logpdf_with_trans(dist, x, true)
253282
@test logpdf(dist, x) - _logabsdet(J) logpdf_turing
254283
end

0 commit comments

Comments
 (0)