Skip to content

Commit f8f482a

Browse files
authored
Port LKJ example from Pyro (#1065)
* port LKJ example from Pyro * add code samples in docstring for LKJ and LKJCholesky * delete LKJ example * change sample name for correlation matrix from L_omega to corr_mat * incorporate review comments
1 parent ab18282 commit f8f482a

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

numpyro/distributions/continuous.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,27 @@ class LKJ(TransformedDistribution):
545545
When ``concentration < 1``, the distribution favors samples with small determinent. This is
546546
useful when we know a priori that some underlying variables are correlated.
547547
548+
Sample code for using LKJ in the context of multivariate normal sample::
549+
550+
def model(y): # y has dimension N x d
551+
d = y.shape[1]
552+
N = y.shape[0]
553+
# Vector of variances for each of the d variables
554+
theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d)))
555+
556+
concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices
557+
corr_mat = numpyro.sample("corr_mat", dist.LKJ(d, concentration))
558+
sigma = jnp.sqrt(theta)
559+
# we can also use a faster formula `cov_mat = jnp.outer(theta, theta) * corr_mat`
560+
cov_mat = jnp.matmul(jnp.matmul(jnp.diag(sigma), corr_mat), jnp.diag(sigma))
561+
562+
# Vector of expectations
563+
mu = jnp.zeros(d)
564+
565+
with numpyro.plate("observations", N):
566+
obs = numpyro.sample("obs", dist.MultivariateNormal(mu, covariance_matrix=cov_mat), obs=y)
567+
return obs
568+
548569
:param int dimension: dimension of the matrices
549570
:param ndarray concentration: concentration/shape parameter of the
550571
distribution (often referred to as eta)
@@ -606,6 +627,28 @@ class LKJCholesky(Distribution):
606627
(hence small determinent). This is useful when we know a priori that some underlying
607628
variables are correlated.
608629
630+
Sample code for using LKJCholesky in the context of multivariate normal sample::
631+
632+
def model(y): # y has dimension N x d
633+
d = y.shape[1]
634+
N = y.shape[0]
635+
# Vector of variances for each of the d variables
636+
theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d)))
637+
# Lower cholesky factor of a correlation matrix
638+
concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices
639+
L_omega = numpyro.sample("L_omega", dist.LKJCholesky(d, concentration))
640+
# Lower cholesky factor of the covariance matrix
641+
sigma = jnp.sqrt(theta)
642+
# we can also use a faster formula `L_Omega = sigma[..., None] * L_omega`
643+
L_Omega = jnp.matmul(jnp.diag(sigma), L_omega)
644+
645+
# Vector of expectations
646+
mu = jnp.zeros(d)
647+
648+
with numpyro.plate("observations", N):
649+
obs = numpyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y)
650+
return obs
651+
609652
:param int dimension: dimension of the matrices
610653
:param ndarray concentration: concentration/shape parameter of the
611654
distribution (often referred to as eta)

0 commit comments

Comments
 (0)