|
| 1 | +""" |
| 2 | + CorrBijector <: Bijector{2} |
| 3 | +
|
| 4 | +A bijector implementation of Stan's parametrization method for Correlation matrix: |
| 5 | +https://mc-stan.org/docs/2_23/reference-manual/correlation-matrix-transform-section.html |
| 6 | +
|
| 7 | +Basically, a unconstrained strictly upper triangular matrix `y` is transformed to |
| 8 | +a correlation matrix by following readable but not that efficient form: |
| 9 | +
|
| 10 | +``` |
| 11 | +K = size(y, 1) |
| 12 | +z = tanh.(y) |
| 13 | +
|
| 14 | +for j=1:K, i=1:K |
| 15 | + if i>j |
| 16 | + w[i,j] = 0 |
| 17 | + elseif 1==i==j |
| 18 | + w[i,j] = 1 |
| 19 | + elseif 1<i==j |
| 20 | + w[i,j] = prod(sqrt(1 .- z[1:i-1, j].^2)) |
| 21 | + elseif 1==i<j |
| 22 | + w[i,j] = z[i,j] |
| 23 | + elseif 1<i<j |
| 24 | + w[i,j] = z[i,j] * prod(sqrt(1 .- z[1:i-1, j].^2)) |
| 25 | + end |
| 26 | +end |
| 27 | +``` |
| 28 | +
|
| 29 | +It is easy to see that every column is a unit vector, for example: |
| 30 | +
|
| 31 | +``` |
| 32 | +w3' w3 == |
| 33 | +w[1,3]^2 + w[2,3]^2 + w[3,3]^2 == |
| 34 | +z[1,3]^2 + (z[2,3] * sqrt(1 - z[1,3]^2))^2 + (sqrt(1-z[1,3]^2) * sqrt(1-z[2,3]^2))^2 == |
| 35 | +z[1,3]^2 + z[2,3]^2 * (1-z[1,3]^2) + (1-z[1,3]^2) * (1-z[2,3]^2) == |
| 36 | +z[1,3]^2 + z[2,3]^2 - z[2,3]^2 * z[1,3]^2 + 1 -z[1,3]^2 - z[2,3]^2 + z[1,3]^2 * z[2,3]^2 == |
| 37 | +1 |
| 38 | +``` |
| 39 | +
|
| 40 | +And diagonal elements are positive, so `w` is a cholesky factor for a positive matrix. |
| 41 | +
|
| 42 | +``` |
| 43 | +x = w' * w |
| 44 | +``` |
| 45 | +
|
| 46 | +Consider block matrix representation for `x` |
| 47 | +
|
| 48 | +``` |
| 49 | +x = [w1'; w2'; ... wn'] * [w1 w2 ... wn] == |
| 50 | +[w1'w1 w1'w2 ... w1'wn; |
| 51 | + w2'w1 w2'w2 ... w2'wn; |
| 52 | + ... |
| 53 | +] |
| 54 | +``` |
| 55 | +
|
| 56 | +The diagonal elements are given by `wk'wk = 1`, thus `x` is a correlation matrix. |
| 57 | +
|
| 58 | +Every step is invertible, so this is a bijection(bijector). |
| 59 | +
|
| 60 | +Note: The implementation doesn't follow their "manageable expression" directly, |
| 61 | +because their equation seems wrong (7/30/2020). Insteadly it follows definition |
| 62 | +above the "manageable expression" directly, which is also described in above doc. |
| 63 | +""" |
| 64 | +struct CorrBijector <: Bijector{2} end |
| 65 | + |
| 66 | +function (b::CorrBijector)(x::AbstractMatrix{<:Real}) |
| 67 | + w = cholesky(x).U # keep LowerTriangular until here can avoid some computation |
| 68 | + r = _link_chol_lkj(w) |
| 69 | + return r + zero(x) |
| 70 | + # This dense format itself is required by a test, though I can't get the point. |
| 71 | + # https://github.com/TuringLang/Bijectors.jl/blob/b0aaa98f90958a167a0b86c8e8eca9b95502c42d/test/transform.jl#L67 |
| 72 | +end |
| 73 | + |
| 74 | +(b::CorrBijector)(X::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, X) |
| 75 | + |
| 76 | +function (ib::Inverse{<:CorrBijector})(y::AbstractMatrix{<:Real}) |
| 77 | + w = _inv_link_chol_lkj(y) |
| 78 | + return w' * w |
| 79 | +end |
| 80 | +(ib::Inverse{<:CorrBijector})(Y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(ib, Y) |
| 81 | + |
| 82 | + |
| 83 | +function logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) |
| 84 | + K = LinearAlgebra.checksquare(y) |
| 85 | + |
| 86 | + result = float(zero(eltype(y))) |
| 87 | + for j in 2:K, i in 1:(j - 1) |
| 88 | + @inbounds abs_y_i_j = abs(y[i, j]) |
| 89 | + result += (K - i + 1) * (logtwo - (abs_y_i_j + log1pexp(-2 * abs_y_i_j))) |
| 90 | + end |
| 91 | + |
| 92 | + return result |
| 93 | +end |
| 94 | +function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real}) |
| 95 | + #= |
| 96 | + It may be more efficient if we can use un-contraint value to prevent call of b |
| 97 | + It's recommended to directly call |
| 98 | + `logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})` |
| 99 | + if possible. |
| 100 | + =# |
| 101 | + return -logabsdetjac(inv(b), (b(X))) |
| 102 | +end |
| 103 | +function logabsdetjac(b::CorrBijector, X::AbstractArray{<:AbstractMatrix{<:Real}}) |
| 104 | + return mapvcat(X) do x |
| 105 | + logabsdetjac(b, x) |
| 106 | + end |
| 107 | +end |
| 108 | + |
| 109 | + |
| 110 | +function _inv_link_chol_lkj(y) |
| 111 | + K = LinearAlgebra.checksquare(y) |
| 112 | + |
| 113 | + w = similar(y) |
| 114 | + |
| 115 | + @inbounds for j in 1:K |
| 116 | + w[1, j] = 1 |
| 117 | + for i in 2:j |
| 118 | + z = tanh(y[i-1, j]) |
| 119 | + tmp = w[i-1, j] |
| 120 | + w[i-1, j] = z * tmp |
| 121 | + w[i, j] = tmp * sqrt(1 - z^2) |
| 122 | + end |
| 123 | + for i in (j+1):K |
| 124 | + w[i, j] = 0 |
| 125 | + end |
| 126 | + end |
| 127 | + |
| 128 | + return w |
| 129 | +end |
| 130 | + |
| 131 | +""" |
| 132 | + function _link_chol_lkj(w) |
| 133 | +
|
| 134 | +Link function for cholesky factor. |
| 135 | +
|
| 136 | +An alternative and maybe more efficient implementation was considered: |
| 137 | +
|
| 138 | +``` |
| 139 | +for i=2:K, j=(i+1):K |
| 140 | + z[i, j] = (w[i, j] / w[i-1, j]) * (z[i-1, j] / sqrt(1 - z[i-1, j]^2)) |
| 141 | +end |
| 142 | +``` |
| 143 | +
|
| 144 | +But this implementation will not work when w[i-1, j] = 0. |
| 145 | +Though it is a zero measure set, unit matrix initialization will not work. |
| 146 | +
|
| 147 | +For equivelence, following explanations is given by @torfjelde: |
| 148 | +
|
| 149 | +For `(i, j)` in the loop below, we define |
| 150 | +
|
| 151 | + z₍ᵢ₋₁, ⱼ₎ = w₍ᵢ₋₁,ⱼ₎ * ∏ₖ₌₁ⁱ⁻² (1 / √(1 - z₍ₖ,ⱼ₎²)) |
| 152 | +
|
| 153 | +and so |
| 154 | +
|
| 155 | + z₍ᵢ,ⱼ₎ = w₍ᵢ,ⱼ₎ * ∏ₖ₌₁ⁱ⁻¹ (1 / √(1 - z₍ₖ,ⱼ₎²)) |
| 156 | + = (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (∏ₖ₌₁ⁱ⁻² 1 / √(1 - z₍ₖ,ⱼ₎²)) |
| 157 | + = (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (w₍ᵢ₋₁,ⱼ₎ * ∏ₖ₌₁ⁱ⁻² 1 / √(1 - z₍ₖ,ⱼ₎²)) / w₍ᵢ₋₁,ⱼ₎ |
| 158 | + = (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (z₍ᵢ₋₁,ⱼ₎ / w₍ᵢ₋₁,ⱼ₎) |
| 159 | + = (w₍ᵢ,ⱼ₎ / w₍ᵢ₋₁,ⱼ₎) * (z₍ᵢ₋₁,ⱼ₎ / √(1 - z₍ᵢ₋₁,ⱼ₎²)) |
| 160 | +
|
| 161 | +which is the above implementation. |
| 162 | +""" |
| 163 | +function _link_chol_lkj(w) |
| 164 | + K = LinearAlgebra.checksquare(w) |
| 165 | + |
| 166 | + z = similar(w) # z is also UpperTriangular. |
| 167 | + # Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero. |
| 168 | + |
| 169 | + # This block can't be integrated with loop below, because w[1,1] != 0. |
| 170 | + @inbounds z[1, 1] = 0 |
| 171 | + |
| 172 | + @inbounds for j=2:K |
| 173 | + z[1, j] = atanh(w[1, j]) |
| 174 | + tmp = sqrt(1 - w[1, j]^2) |
| 175 | + for i in 2:(j - 1) |
| 176 | + p = w[i, j] / tmp |
| 177 | + tmp *= sqrt(1 - p^2) |
| 178 | + z[i, j] = atanh(p) |
| 179 | + end |
| 180 | + z[j, j] = 0 |
| 181 | + end |
| 182 | + |
| 183 | + return z |
| 184 | +end |
0 commit comments