Skip to content

Commit a3ece91

Browse files
yiyuezhuotorfjeldedevmotion
authored
Add LKJ bijector (#125)
Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: David Widmann <[email protected]> Co-authored-by: January Desk <[email protected]>
1 parent 3a2b1e4 commit a3ece91

File tree

11 files changed

+440
-4
lines changed

11 files changed

+440
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Bijectors"
22
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
3-
version = "0.8.2"
3+
version = "0.8.3"
44

55
[deps]
66
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

src/bijectors/corr.jl

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
"""
2+
CorrBijector <: Bijector{2}
3+
4+
A bijector implementation of Stan's parametrization method for Correlation matrix:
5+
https://mc-stan.org/docs/2_23/reference-manual/correlation-matrix-transform-section.html
6+
7+
Basically, a unconstrained strictly upper triangular matrix `y` is transformed to
8+
a correlation matrix by following readable but not that efficient form:
9+
10+
```
11+
K = size(y, 1)
12+
z = tanh.(y)
13+
14+
for j=1:K, i=1:K
15+
if i>j
16+
w[i,j] = 0
17+
elseif 1==i==j
18+
w[i,j] = 1
19+
elseif 1<i==j
20+
w[i,j] = prod(sqrt(1 .- z[1:i-1, j].^2))
21+
elseif 1==i<j
22+
w[i,j] = z[i,j]
23+
elseif 1<i<j
24+
w[i,j] = z[i,j] * prod(sqrt(1 .- z[1:i-1, j].^2))
25+
end
26+
end
27+
```
28+
29+
It is easy to see that every column is a unit vector, for example:
30+
31+
```
32+
w3' w3 ==
33+
w[1,3]^2 + w[2,3]^2 + w[3,3]^2 ==
34+
z[1,3]^2 + (z[2,3] * sqrt(1 - z[1,3]^2))^2 + (sqrt(1-z[1,3]^2) * sqrt(1-z[2,3]^2))^2 ==
35+
z[1,3]^2 + z[2,3]^2 * (1-z[1,3]^2) + (1-z[1,3]^2) * (1-z[2,3]^2) ==
36+
z[1,3]^2 + z[2,3]^2 - z[2,3]^2 * z[1,3]^2 + 1 -z[1,3]^2 - z[2,3]^2 + z[1,3]^2 * z[2,3]^2 ==
37+
1
38+
```
39+
40+
And diagonal elements are positive, so `w` is a cholesky factor for a positive matrix.
41+
42+
```
43+
x = w' * w
44+
```
45+
46+
Consider block matrix representation for `x`
47+
48+
```
49+
x = [w1'; w2'; ... wn'] * [w1 w2 ... wn] ==
50+
[w1'w1 w1'w2 ... w1'wn;
51+
w2'w1 w2'w2 ... w2'wn;
52+
...
53+
]
54+
```
55+
56+
The diagonal elements are given by `wk'wk = 1`, thus `x` is a correlation matrix.
57+
58+
Every step is invertible, so this is a bijection(bijector).
59+
60+
Note: The implementation doesn't follow their "manageable expression" directly,
61+
because their equation seems wrong (7/30/2020). Insteadly it follows definition
62+
above the "manageable expression" directly, which is also described in above doc.
63+
"""
64+
struct CorrBijector <: Bijector{2} end
65+
66+
function (b::CorrBijector)(x::AbstractMatrix{<:Real})
67+
w = cholesky(x).U # keep LowerTriangular until here can avoid some computation
68+
r = _link_chol_lkj(w)
69+
return r + zero(x)
70+
# This dense format itself is required by a test, though I can't get the point.
71+
# https://github.com/TuringLang/Bijectors.jl/blob/b0aaa98f90958a167a0b86c8e8eca9b95502c42d/test/transform.jl#L67
72+
end
73+
74+
(b::CorrBijector)(X::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, X)
75+
76+
function (ib::Inverse{<:CorrBijector})(y::AbstractMatrix{<:Real})
77+
w = _inv_link_chol_lkj(y)
78+
return w' * w
79+
end
80+
(ib::Inverse{<:CorrBijector})(Y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(ib, Y)
81+
82+
83+
function logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})
84+
K = LinearAlgebra.checksquare(y)
85+
86+
result = float(zero(eltype(y)))
87+
for j in 2:K, i in 1:(j - 1)
88+
@inbounds abs_y_i_j = abs(y[i, j])
89+
result += (K - i + 1) * (logtwo - (abs_y_i_j + log1pexp(-2 * abs_y_i_j)))
90+
end
91+
92+
return result
93+
end
94+
function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real})
95+
#=
96+
It may be more efficient if we can use un-contraint value to prevent call of b
97+
It's recommended to directly call
98+
`logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})`
99+
if possible.
100+
=#
101+
return -logabsdetjac(inv(b), (b(X)))
102+
end
103+
function logabsdetjac(b::CorrBijector, X::AbstractArray{<:AbstractMatrix{<:Real}})
104+
return mapvcat(X) do x
105+
logabsdetjac(b, x)
106+
end
107+
end
108+
109+
110+
function _inv_link_chol_lkj(y)
111+
K = LinearAlgebra.checksquare(y)
112+
113+
w = similar(y)
114+
115+
@inbounds for j in 1:K
116+
w[1, j] = 1
117+
for i in 2:j
118+
z = tanh(y[i-1, j])
119+
tmp = w[i-1, j]
120+
w[i-1, j] = z * tmp
121+
w[i, j] = tmp * sqrt(1 - z^2)
122+
end
123+
for i in (j+1):K
124+
w[i, j] = 0
125+
end
126+
end
127+
128+
return w
129+
end
130+
131+
"""
132+
function _link_chol_lkj(w)
133+
134+
Link function for cholesky factor.
135+
136+
An alternative and maybe more efficient implementation was considered:
137+
138+
```
139+
for i=2:K, j=(i+1):K
140+
z[i, j] = (w[i, j] / w[i-1, j]) * (z[i-1, j] / sqrt(1 - z[i-1, j]^2))
141+
end
142+
```
143+
144+
But this implementation will not work when w[i-1, j] = 0.
145+
Though it is a zero measure set, unit matrix initialization will not work.
146+
147+
For equivelence, following explanations is given by @torfjelde:
148+
149+
For `(i, j)` in the loop below, we define
150+
151+
z₍ᵢ₋₁, ⱼ₎ = w₍ᵢ₋₁,ⱼ₎ * ∏ₖ₌₁ⁱ⁻² (1 / √(1 - z₍ₖ,ⱼ₎²))
152+
153+
and so
154+
155+
z₍ᵢ,ⱼ₎ = w₍ᵢ,ⱼ₎ * ∏ₖ₌₁ⁱ⁻¹ (1 / √(1 - z₍ₖ,ⱼ₎²))
156+
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (∏ₖ₌₁ⁱ⁻² 1 / √(1 - z₍ₖ,ⱼ₎²))
157+
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (w₍ᵢ₋₁,ⱼ₎ * ∏ₖ₌₁ⁱ⁻² 1 / √(1 - z₍ₖ,ⱼ₎²)) / w₍ᵢ₋₁,ⱼ₎
158+
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (z₍ᵢ₋₁,ⱼ₎ / w₍ᵢ₋₁,ⱼ₎)
159+
= (w₍ᵢ,ⱼ₎ / w₍ᵢ₋₁,ⱼ₎) * (z₍ᵢ₋₁,ⱼ₎ / √(1 - z₍ᵢ₋₁,ⱼ₎²))
160+
161+
which is the above implementation.
162+
"""
163+
function _link_chol_lkj(w)
164+
K = LinearAlgebra.checksquare(w)
165+
166+
z = similar(w) # z is also UpperTriangular.
167+
# Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero.
168+
169+
# This block can't be integrated with loop below, because w[1,1] != 0.
170+
@inbounds z[1, 1] = 0
171+
172+
@inbounds for j=2:K
173+
z[1, j] = atanh(w[1, j])
174+
tmp = sqrt(1 - w[1, j]^2)
175+
for i in 2:(j - 1)
176+
p = w[i, j] / tmp
177+
tmp *= sqrt(1 - p^2)
178+
z[i, j] = atanh(p)
179+
end
180+
z[j, j] = 0
181+
end
182+
183+
return z
184+
end

src/compat/forwarddiff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@ function jacobian(
1515
x::AbstractVector{<:Real}
1616
)
1717
return ForwardDiff.jacobian(b, x)
18-
end
18+
end

src/compat/reversediff.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ using ..Bijectors: Log, SimplexBijector, maphcat, simplex_link_jacobian,
88
simplex_invlink_jacobian, simplex_logabsdetjac_gradient, ADBijector,
99
ReverseDiffAD, Inverse
1010
import ..Bijectors: _eps, logabsdetjac, _logabsdetjac_scale, _simplex_bijector,
11-
_simplex_inv_bijector, replace_diag, jacobian, getpd, lower
11+
_simplex_inv_bijector, replace_diag, jacobian, getpd, lower,
12+
_inv_link_chol_lkj, _link_chol_lkj
1213

1314
using Compat: eachcol
1415
using Distributions: LocationScale
@@ -180,4 +181,5 @@ lower(A::TrackedMatrix) = track(lower, A)
180181
return lower(Ad), Δ -> (lower(Δ),)
181182
end
182183

184+
183185
end

src/compat/tracker.jl

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,3 +440,105 @@ lower(A::TrackedMatrix) = track(lower, A)
440440
Ad = data(A)
441441
return lower(Ad), Δ -> (lower(Δ),)
442442
end
443+
444+
_inv_link_chol_lkj(y::TrackedMatrix) = track(_inv_link_chol_lkj, y)
445+
@grad function _inv_link_chol_lkj(y_tracked)
446+
y = data(y_tracked)
447+
448+
K = LinearAlgebra.checksquare(y)
449+
450+
w = similar(y)
451+
452+
z_mat = similar(y) # cache for adjoint
453+
tmp_mat = similar(y)
454+
455+
@inbounds for j in 1:K
456+
w[1, j] = 1
457+
for i in 2:j
458+
z = tanh(y[i-1, j])
459+
tmp = w[i-1, j]
460+
461+
z_mat[i, j] = z
462+
tmp_mat[i, j] = tmp
463+
464+
w[i-1, j] = z * tmp
465+
w[i, j] = tmp * sqrt(1 - z^2)
466+
end
467+
for i in (j+1):K
468+
w[i, j] = 0
469+
end
470+
end
471+
472+
function pullback_inv_link_chol_lkj(Δw)
473+
LinearAlgebra.checksquare(Δw)
474+
475+
Δy = zero(y)
476+
477+
@inbounds for j in 1:K
478+
Δtmp = Δw[j,j]
479+
for i in j:-1:2
480+
Δz = Δw[i-1, j] * tmp_mat[i, j] - Δtmp * tmp_mat[i, j] / sqrt(1 - z_mat[i, j]^2) * z_mat[i, j]
481+
Δy[i-1, j] = Δz / cosh(y[i-1, j])^2
482+
Δtmp = Δw[i-1, j] * z_mat[i, j] + Δtmp * sqrt(1 - z_mat[i, j]^2)
483+
end
484+
end
485+
486+
return (Δy,)
487+
end
488+
489+
return w, pullback_inv_link_chol_lkj
490+
end
491+
492+
_link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w)
493+
@grad function _link_chol_lkj(w_tracked)
494+
w = data(w_tracked)
495+
496+
K = LinearAlgebra.checksquare(w)
497+
498+
z = similar(w)
499+
500+
@inbounds z[1, 1] = 0
501+
502+
tmp_mat = similar(w) # cache for pullback.
503+
504+
@inbounds for j=2:K
505+
z[1, j] = atanh(w[1, j])
506+
tmp = sqrt(1 - w[1, j]^2)
507+
tmp_mat[1, j] = tmp
508+
for i in 2:(j - 1)
509+
p = w[i, j] / tmp
510+
tmp *= sqrt(1 - p^2)
511+
tmp_mat[i, j] = tmp
512+
z[i, j] = atanh(p)
513+
end
514+
z[j, j] = 0
515+
end
516+
517+
function pullback_link_chol_lkj(Δz)
518+
LinearAlgebra.checksquare(Δz)
519+
520+
Δw = similar(w)
521+
522+
@inbounds Δw[1,1] = zero(eltype(Δz))
523+
524+
@inbounds for j=2:K
525+
Δw[j, j] = 0
526+
Δtmp = zero(eltype(Δz)) # Δtmp_mat[j-1,j]
527+
for i in (j-1):-1:2
528+
p = w[i, j] / tmp_mat[i-1, j]
529+
ftmp = sqrt(1 - p^2)
530+
d_ftmp_p = -p / ftmp
531+
d_p_tmp = -w[i,j] / tmp_mat[i-1, j]^2
532+
533+
Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p
534+
Δw[i, j] = Δp / tmp_mat[i-1, j]
535+
Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp
536+
end
537+
Δw[1, j] = Δz[1, j] / (1-w[1,j]^2) - Δtmp / sqrt(1 - w[1,j]^2) * w[1,j]
538+
end
539+
540+
return (Δw,)
541+
end
542+
543+
return z, pullback_link_chol_lkj
544+
end

0 commit comments

Comments
 (0)