@@ -545,6 +545,27 @@ class LKJ(TransformedDistribution):
545
545
When ``concentration < 1``, the distribution favors samples with small determinent. This is
546
546
useful when we know a priori that some underlying variables are correlated.
547
547
548
+ Sample code for using LKJ in the context of multivariate normal sample::
549
+
550
+ def model(y): # y has dimension N x d
551
+ d = y.shape[1]
552
+ N = y.shape[0]
553
+ # Vector of variances for each of the d variables
554
+ theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d)))
555
+
556
+ concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices
557
+ corr_mat = numpyro.sample("corr_mat", dist.LKJ(d, concentration))
558
+ sigma = jnp.sqrt(theta)
559
+ # we can also use a faster formula `cov_mat = jnp.outer(theta, theta) * corr_mat`
560
+ cov_mat = jnp.matmul(jnp.matmul(jnp.diag(sigma), corr_mat), jnp.diag(sigma))
561
+
562
+ # Vector of expectations
563
+ mu = jnp.zeros(d)
564
+
565
+ with numpyro.plate("observations", N):
566
+ obs = numpyro.sample("obs", dist.MultivariateNormal(mu, covariance_matrix=cov_mat), obs=y)
567
+ return obs
568
+
548
569
:param int dimension: dimension of the matrices
549
570
:param ndarray concentration: concentration/shape parameter of the
550
571
distribution (often referred to as eta)
@@ -606,6 +627,28 @@ class LKJCholesky(Distribution):
606
627
(hence small determinent). This is useful when we know a priori that some underlying
607
628
variables are correlated.
608
629
630
+ Sample code for using LKJCholesky in the context of multivariate normal sample::
631
+
632
+ def model(y): # y has dimension N x d
633
+ d = y.shape[1]
634
+ N = y.shape[0]
635
+ # Vector of variances for each of the d variables
636
+ theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d)))
637
+ # Lower cholesky factor of a correlation matrix
638
+ concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices
639
+ L_omega = numpyro.sample("L_omega", dist.LKJCholesky(d, concentration))
640
+ # Lower cholesky factor of the covariance matrix
641
+ sigma = jnp.sqrt(theta)
642
+ # we can also use a faster formula `L_Omega = sigma[..., None] * L_omega`
643
+ L_Omega = jnp.matmul(jnp.diag(sigma), L_omega)
644
+
645
+ # Vector of expectations
646
+ mu = jnp.zeros(d)
647
+
648
+ with numpyro.plate("observations", N):
649
+ obs = numpyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y)
650
+ return obs
651
+
609
652
:param int dimension: dimension of the matrices
610
653
:param ndarray concentration: concentration/shape parameter of the
611
654
distribution (often referred to as eta)
0 commit comments