Skip to content

Commit 1809d35

Browse files
committed
Add draft of SkewMultivariateNormal
1 parent 00aa9ef commit 1809d35

File tree

5 files changed

+322
-0
lines changed

5 files changed

+322
-0
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ build
99
*.pyo
1010
/build
1111
/dist
12+
/.hypothesis
1213

1314
# IDE
1415
.idea

numpyro/distributions/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
Pareto,
4040
RelaxedBernoulli,
4141
RelaxedBernoulliLogits,
42+
SkewMultivariateNormal,
4243
SoftLaplace,
4344
StudentT,
4445
Uniform,
@@ -158,6 +159,7 @@
158159
"MultivariateStudentT",
159160
"LowRankMultivariateNormal",
160161
"Normal",
162+
"SkewMultivariateNormal",
161163
"NegativeBinomialProbs",
162164
"NegativeBinomialLogits",
163165
"NegativeBinomial2",

numpyro/distributions/continuous.py

+136
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
2626
# POSSIBILITY OF SUCH DAMAGE.
2727

28+
from typing import Union, cast
29+
2830
import numpy as np
31+
from numpy.typing import NDArray
2932

3033
from jax import lax
3134
from jax.experimental.sparse import BCOO
@@ -1731,6 +1734,139 @@ def variance(self):
17311734
return jnp.broadcast_to(self.scale**2, self.batch_shape)
17321735

17331736

1737+
def skew_delta(skewers_: NDArray[float], cov_: NDArray[float]):
1738+
return (jnp.einsum("...ij,...j->...i", cov_, skewers_)) / jnp.sqrt(
1739+
1
1740+
+ jnp.einsum("...j,...jk,...k->...", skewers_, cov_, skewers_)[..., jnp.newaxis]
1741+
)
1742+
1743+
1744+
# Regularized Multivariate Regression Models with Skew-t Error Distributions
1745+
# https://epublications.marquette.edu/cgi/viewcontent.cgi?article=1225&context=mscs_fac
1746+
class SkewMultivariateNormal(Distribution):
1747+
arg_constraints = {
1748+
"loc": constraints.real_vector,
1749+
"scale_tril": constraints.lower_cholesky,
1750+
"skewers": constraints.real_vector,
1751+
}
1752+
support = constraints.real_vector
1753+
reparametrized_params = ["loc", "scale_tril", "skewers"]
1754+
uv_norm = Normal(0.0, 1.0)
1755+
1756+
@staticmethod
1757+
def mk_big_mv_norm(
1758+
loc: NDArray[float], skewers: NDArray[float], scale_tril: NDArray[float]
1759+
):
1760+
cov = jnp.einsum("...ij,...hj->...ih", scale_tril, scale_tril)
1761+
delta_ = skew_delta(skewers, cov)
1762+
cov_star = jnp.block(
1763+
[
1764+
[
1765+
jnp.ones(skewers.shape[:-1] + (1, 1)),
1766+
jnp.expand_dims(delta_, axis=-2),
1767+
],
1768+
[jnp.expand_dims(delta_, axis=-1), cov],
1769+
]
1770+
)
1771+
1772+
return MultivariateNormal(
1773+
loc=jnp.zeros(loc.shape[-1] + 1), scale_tril=jnp.linalg.cholesky(cov_star)
1774+
)
1775+
1776+
def __init__(
1777+
self,
1778+
loc: Union[NDArray[float], float],
1779+
scale_tril: NDArray[float],
1780+
skewers: NDArray[float],
1781+
validate_args: None = None,
1782+
):
1783+
if jnp.ndim(loc) == 0:
1784+
(loc_,) = promote_shapes(loc, shape=(1,))
1785+
else:
1786+
loc_ = cast(NDArray[float], loc)
1787+
batch_shape = lax.broadcast_shapes(
1788+
jnp.shape(loc_)[:-1], jnp.shape(scale_tril)[:-2], jnp.shape(skewers)[:-1]
1789+
)
1790+
(self.loc,) = promote_shapes(loc_, shape=batch_shape + loc_.shape[-1:])
1791+
(self.skewers,) = promote_shapes(
1792+
skewers, shape=batch_shape + skewers.shape[-1:]
1793+
)
1794+
(self.scale_tril,) = promote_shapes(
1795+
scale_tril, shape=batch_shape + scale_tril.shape[-2:]
1796+
)
1797+
cov_batch = jnp.einsum("...ij,...hj->...ih", self.scale_tril, self.scale_tril)
1798+
self._std_devs = jnp.sqrt(jnp.diagonal(cov_batch, axis1=-2, axis2=-1))
1799+
1800+
# Used for sampling
1801+
self._big_mv_norm = self.mk_big_mv_norm(
1802+
# The blog post just uses unstandardized skewers here but that leads to
1803+
# a discrepancy between sampling and log_prob
1804+
loc=self.loc,
1805+
skewers=skewers / self._std_devs,
1806+
scale_tril=scale_tril,
1807+
)
1808+
# Used for log_prob
1809+
self._mv_norm = MultivariateNormal(loc_, scale_tril=scale_tril)
1810+
1811+
skew_mean = jnp.sqrt(2 / jnp.pi) * skew_delta(
1812+
self.skewers / self._std_devs, cov_batch
1813+
)
1814+
self._mean = self.loc + skew_mean
1815+
# The paper just uses `mean` here but that's definitely not right because
1816+
# it potentially leads to covariance matrices which are not positive semi definite
1817+
self._covariance = cov_batch - jnp.einsum(
1818+
"...i,...j->...ij", skew_mean, skew_mean
1819+
)
1820+
1821+
event_shape = jnp.shape(self.scale_tril)[-1:]
1822+
super().__init__(
1823+
batch_shape=batch_shape,
1824+
event_shape=event_shape,
1825+
validate_args=validate_args,
1826+
)
1827+
1828+
@validate_sample
1829+
def log_prob(self, value: NDArray[float]) -> NDArray[float]:
1830+
return (
1831+
jnp.log(2)
1832+
+ self._mv_norm.log_prob(value)
1833+
+ jnp.log(
1834+
self.uv_norm.cdf(
1835+
jnp.einsum(
1836+
"...k,...k->...",
1837+
(value - self.loc) / self._std_devs,
1838+
self.skewers,
1839+
)
1840+
)
1841+
)
1842+
)
1843+
1844+
@staticmethod
1845+
def infer_shapes(
1846+
loc: NDArray[float], scale_tril: NDArray[float], skewers: NDArray[float]
1847+
):
1848+
event_shape = (scale_tril[-1],)
1849+
batch_shape = lax.broadcast_shapes(loc[:-1], scale_tril[:-2], skewers[:-1])
1850+
return batch_shape, event_shape
1851+
1852+
# https://gregorygundersen.com/blog/2020/12/29/multivariate-skew-normal/
1853+
def sample(
1854+
self, key: random.PRNGKey, sample_shape: tuple[int, ...] = ()
1855+
) -> NDArray[float]:
1856+
assert is_prng_key(key)
1857+
x = self._big_mv_norm.sample(key, sample_shape=sample_shape)
1858+
sign_bit, samples = x[..., 0, jnp.newaxis], x[..., 1:]
1859+
return jnp.where(sign_bit <= 0, -1 * samples, samples) + self.loc
1860+
1861+
@property
1862+
def mean(self):
1863+
return jnp.broadcast_to(self._mean, self.shape())
1864+
1865+
@property
1866+
def covariance_matrix(self):
1867+
return self._covariance
1868+
1869+
17341870
class Pareto(TransformedDistribution):
17351871
arg_constraints = {"scale": constraints.positive, "alpha": constraints.positive}
17361872
reparametrized_params = ["scale", "alpha"]

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
"test": [
5252
"black[jupyter]>=21.8b0",
5353
"flake8",
54+
"hypothesis[numpy]",
5455
"isort>=5.0",
5556
"pytest>=4.1",
5657
"pyro-api>=0.1.1",

test/test_distributions.py

+182
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,15 @@
66
import inspect
77
import math
88
import os
9+
from typing import cast
910

11+
from hypothesis import given, note, settings
12+
import hypothesis.extra.numpy as hnp
13+
import hypothesis.strategies as st
14+
from hypothesis.strategies import DrawFn, SearchStrategy
1015
import numpy as np
1116
from numpy.testing import assert_allclose, assert_array_equal
17+
from numpy.typing import NDArray
1218
import pytest
1319
import scipy
1420
import scipy.stats as osp
@@ -534,6 +540,12 @@ def get_sp_dist(jax_dist):
534540
T(dist.Normal, 0.0, 1.0),
535541
T(dist.Normal, 1.0, np.array([1.0, 2.0])),
536542
T(dist.Normal, np.array([0.0, 1.0]), np.array([[1.0], [2.0]])),
543+
T(
544+
dist.SkewMultivariateNormal,
545+
np.array([2.0, 0.0]),
546+
np.array([[1.0, 0.0], [0.5, 1.0]]),
547+
np.array([0.0, 0.0]),
548+
),
537549
T(dist.Pareto, 1.0, 2.0),
538550
T(dist.Pareto, np.array([1.0, 0.5]), np.array([0.3, 2.0])),
539551
T(dist.Pareto, np.array([[1.0], [3.0]]), np.array([1.0, 0.5])),
@@ -1502,6 +1514,10 @@ def test_mean_var(jax_dist, sp_dist, params):
15021514
dist.TwoSidedTruncatedDistribution,
15031515
):
15041516
pytest.skip("Truncated distributions do not has mean/var implemented")
1517+
if jax_dist is dist.SkewMultivariateNormal:
1518+
pytest.skip(
1519+
"We check SkewMultivariateNormal against MultivariateNormal elsewhere"
1520+
)
15051521
if jax_dist is dist.ProjectedNormal:
15061522
pytest.skip("Mean is defined in submanifold")
15071523

@@ -2570,3 +2586,169 @@ def sample_binomial_withp0(key):
25702586
return dist.Binomial(total_count=n, probs=0).sample(key)
25712587

25722588
jax.vmap(sample_binomial_withp0)(random.split(random.PRNGKey(0), 1))
2589+
2590+
2591+
def locs(size: int) -> SearchStrategy[NDArray[float]]:
2592+
return cast(
2593+
SearchStrategy[NDArray[float]],
2594+
hnp.arrays(
2595+
elements=st.floats(
2596+
min_value=-1, max_value=1, allow_nan=False, allow_infinity=False
2597+
),
2598+
dtype=np.dtype("float"),
2599+
shape=size,
2600+
),
2601+
)
2602+
2603+
2604+
def skews(size: int) -> SearchStrategy[NDArray[float]]:
2605+
return cast(
2606+
SearchStrategy[NDArray[float]],
2607+
hnp.arrays(
2608+
elements=st.floats(
2609+
min_value=-4, max_value=4, allow_nan=False, allow_infinity=False
2610+
),
2611+
dtype=np.dtype("float"),
2612+
shape=size,
2613+
),
2614+
)
2615+
2616+
2617+
def variances(size: int) -> SearchStrategy[NDArray[float]]:
2618+
return cast(
2619+
SearchStrategy[NDArray[float]],
2620+
hnp.arrays(
2621+
# Variances that are too small make it impossible to test t against normal
2622+
elements=st.floats(
2623+
min_value=0.1,
2624+
max_value=3,
2625+
allow_nan=False,
2626+
allow_infinity=False,
2627+
exclude_min=True,
2628+
),
2629+
dtype=np.dtype("float"),
2630+
shape=size,
2631+
),
2632+
)
2633+
2634+
2635+
def corr_vech_to_matrix(vech: NDArray[float]):
2636+
width = (math.isqrt(8 * vech.size + 1) + 1) // 2
2637+
zeros = np.zeros((width, width))
2638+
zeros[np.tril_indices(width, k=-1)] = vech
2639+
np.fill_diagonal(zeros, 1)
2640+
return zeros
2641+
2642+
2643+
def correlation_chols(size: int) -> SearchStrategy[NDArray[float]]:
2644+
return hnp.arrays(
2645+
# Floating point issues mean we sometimes get arrays which aren't positive semi-definite
2646+
# if we allow correlations of exactly 1 and -1
2647+
elements=st.floats(
2648+
min_value=-0.99, max_value=0.99, allow_nan=False, allow_infinity=False
2649+
),
2650+
dtype=np.dtype("float"),
2651+
shape=size * (size - 1) // 2,
2652+
).map(
2653+
corr_vech_to_matrix # type: ignore
2654+
)
2655+
2656+
2657+
@st.composite
2658+
def loc_and_scale(draw: DrawFn):
2659+
# Would need to generalize meshgrid to relax this restriction
2660+
size = 2
2661+
corr = draw(correlation_chols(size))
2662+
var = draw(variances(size))
2663+
return (draw(locs(size)), jnp.sqrt(var)[..., None] * corr)
2664+
2665+
2666+
@st.composite
2667+
def loc_and_scale_and_skewers(draw: DrawFn):
2668+
# Would need to generalize meshgrid to relax this restriction
2669+
size = 2
2670+
corr = draw(correlation_chols(size))
2671+
var = draw(variances(size))
2672+
return (
2673+
draw(locs(size)),
2674+
jnp.sqrt(var)[..., None] * corr,
2675+
draw(skews(size)),
2676+
)
2677+
2678+
2679+
X, Y = np.meshgrid(np.linspace(-3, 3, 100), np.linspace(-3, 3, 100))
2680+
grid = np.dstack((X, Y))
2681+
X_wide, Y_wide = np.meshgrid(np.linspace(-6, 6, 50), np.linspace(-6, 6, 50))
2682+
grid_wide = np.dstack((X_wide, Y_wide))
2683+
2684+
2685+
@settings(deadline=None)
2686+
@given(loc_and_scale())
2687+
def test_skew_normal_log_prob_generalizes_normal(
2688+
loc_scale_tril: tuple[NDArray[float], NDArray[float]]
2689+
):
2690+
loc, scale_tril = loc_scale_tril
2691+
mvn = dist.MultivariateNormal(loc=loc, scale_tril=scale_tril)
2692+
smvn = dist.SkewMultivariateNormal(
2693+
loc=loc, scale_tril=scale_tril, skewers=np.zeros(scale_tril.shape[-1])
2694+
)
2695+
assert_allclose(mvn.log_prob(grid), smvn.log_prob(grid), atol=1e-6)
2696+
2697+
2698+
@settings(deadline=None)
2699+
@given(loc_and_scale())
2700+
def test_skew_normal_moments_generalize_normal(
2701+
loc_scale_tril: tuple[NDArray[float], NDArray[float]]
2702+
):
2703+
loc, scale_tril = loc_scale_tril
2704+
mvn = dist.MultivariateNormal(loc=loc, scale_tril=scale_tril)
2705+
smvn = dist.SkewMultivariateNormal(
2706+
loc=loc, scale_tril=scale_tril, skewers=np.zeros(scale_tril.shape[-1])
2707+
)
2708+
assert_allclose(mvn.mean, smvn.mean, atol=1e-30)
2709+
assert_allclose(mvn.covariance_matrix, smvn.covariance_matrix, atol=1e-30)
2710+
2711+
2712+
@settings(deadline=None, max_examples=10)
2713+
@given(loc_and_scale_and_skewers())
2714+
def test_skew_normal_log_prob_vs_samples(
2715+
loc_scale_tril_skewers: tuple[NDArray[float], NDArray[float], NDArray[float]]
2716+
):
2717+
loc, scale_tril, skewers = loc_scale_tril_skewers
2718+
note(f"Covariance: {scale_tril @ scale_tril.T}")
2719+
smvn = dist.SkewMultivariateNormal(loc=loc, scale_tril=scale_tril, skewers=skewers)
2720+
samples = smvn.sample(random.PRNGKey(0), sample_shape=(50_000,))
2721+
# gaussian_kde needs a different format
2722+
grid_ = np.vstack([X_wide.ravel(), Y_wide.ravel()])
2723+
lp = jnp.exp(smvn.log_prob(grid_.T))
2724+
k = osp.gaussian_kde(samples.T, bw_method="scott")(grid_)
2725+
2726+
lp_normed = (lp - lp.min()) / (lp.max() - lp.min())
2727+
k_normed = (k - k.min()) / (k.max() - k.min())
2728+
assert_allclose(lp_normed, k_normed, atol=0.07)
2729+
2730+
2731+
def split_cov(cov: NDArray[float]) -> tuple[NDArray[float], NDArray[float]]:
2732+
std_devs = np.sqrt(np.diag(cov))
2733+
dinv = np.diag(1 / std_devs)
2734+
corr = dinv @ cov @ dinv
2735+
tril_i = np.tril_indices(len(std_devs), k=-1)
2736+
return (std_devs, corr[tril_i])
2737+
2738+
2739+
@settings(deadline=None)
2740+
@given(loc_and_scale_and_skewers())
2741+
def test_skew_normal_moments_vs_samples(
2742+
loc_scale_tril_skewers: tuple[NDArray[float], NDArray[float], NDArray[float]]
2743+
):
2744+
loc, scale_tril, skewers = loc_scale_tril_skewers
2745+
note(f"Covariance: {scale_tril @ scale_tril.T}")
2746+
smvn = dist.SkewMultivariateNormal(loc=loc, scale_tril=scale_tril, skewers=skewers)
2747+
samples = smvn.sample(random.PRNGKey(0), sample_shape=(500_000,))
2748+
assert_allclose(np.mean(samples, axis=0), smvn.mean, rtol=0.005, atol=0.001)
2749+
2750+
std_devs_sample, corr_sample = split_cov(np.cov(samples.T))
2751+
std_devs_dist, corr_dist = split_cov(smvn.covariance_matrix)
2752+
assert_allclose(std_devs_sample, std_devs_dist, rtol=0.003)
2753+
note(f"Sample corr: {corr_sample}, Distribution corr: {corr_dist}")
2754+
assert_allclose(corr_sample, corr_dist, atol=0.006)

0 commit comments

Comments
 (0)