|
6 | 6 | import inspect
|
7 | 7 | import math
|
8 | 8 | import os
|
| 9 | +from typing import cast |
9 | 10 |
|
| 11 | +from hypothesis import given, note, settings |
| 12 | +import hypothesis.extra.numpy as hnp |
| 13 | +import hypothesis.strategies as st |
| 14 | +from hypothesis.strategies import DrawFn, SearchStrategy |
10 | 15 | import numpy as np
|
11 | 16 | from numpy.testing import assert_allclose, assert_array_equal
|
| 17 | +from numpy.typing import NDArray |
12 | 18 | import pytest
|
13 | 19 | import scipy
|
14 | 20 | import scipy.stats as osp
|
@@ -534,6 +540,12 @@ def get_sp_dist(jax_dist):
|
534 | 540 | T(dist.Normal, 0.0, 1.0),
|
535 | 541 | T(dist.Normal, 1.0, np.array([1.0, 2.0])),
|
536 | 542 | T(dist.Normal, np.array([0.0, 1.0]), np.array([[1.0], [2.0]])),
|
| 543 | + T( |
| 544 | + dist.SkewMultivariateNormal, |
| 545 | + np.array([2.0, 0.0]), |
| 546 | + np.array([[1.0, 0.0], [0.5, 1.0]]), |
| 547 | + np.array([0.0, 0.0]), |
| 548 | + ), |
537 | 549 | T(dist.Pareto, 1.0, 2.0),
|
538 | 550 | T(dist.Pareto, np.array([1.0, 0.5]), np.array([0.3, 2.0])),
|
539 | 551 | T(dist.Pareto, np.array([[1.0], [3.0]]), np.array([1.0, 0.5])),
|
@@ -1502,6 +1514,10 @@ def test_mean_var(jax_dist, sp_dist, params):
|
1502 | 1514 | dist.TwoSidedTruncatedDistribution,
|
1503 | 1515 | ):
|
1504 | 1516 | pytest.skip("Truncated distributions do not has mean/var implemented")
|
| 1517 | + if jax_dist is dist.SkewMultivariateNormal: |
| 1518 | + pytest.skip( |
| 1519 | + "We check SkewMultivariateNormal against MultivariateNormal elsewhere" |
| 1520 | + ) |
1505 | 1521 | if jax_dist is dist.ProjectedNormal:
|
1506 | 1522 | pytest.skip("Mean is defined in submanifold")
|
1507 | 1523 |
|
@@ -2570,3 +2586,169 @@ def sample_binomial_withp0(key):
|
2570 | 2586 | return dist.Binomial(total_count=n, probs=0).sample(key)
|
2571 | 2587 |
|
2572 | 2588 | jax.vmap(sample_binomial_withp0)(random.split(random.PRNGKey(0), 1))
|
| 2589 | + |
| 2590 | + |
| 2591 | +def locs(size: int) -> SearchStrategy[NDArray[float]]: |
| 2592 | + return cast( |
| 2593 | + SearchStrategy[NDArray[float]], |
| 2594 | + hnp.arrays( |
| 2595 | + elements=st.floats( |
| 2596 | + min_value=-1, max_value=1, allow_nan=False, allow_infinity=False |
| 2597 | + ), |
| 2598 | + dtype=np.dtype("float"), |
| 2599 | + shape=size, |
| 2600 | + ), |
| 2601 | + ) |
| 2602 | + |
| 2603 | + |
| 2604 | +def skews(size: int) -> SearchStrategy[NDArray[float]]: |
| 2605 | + return cast( |
| 2606 | + SearchStrategy[NDArray[float]], |
| 2607 | + hnp.arrays( |
| 2608 | + elements=st.floats( |
| 2609 | + min_value=-4, max_value=4, allow_nan=False, allow_infinity=False |
| 2610 | + ), |
| 2611 | + dtype=np.dtype("float"), |
| 2612 | + shape=size, |
| 2613 | + ), |
| 2614 | + ) |
| 2615 | + |
| 2616 | + |
| 2617 | +def variances(size: int) -> SearchStrategy[NDArray[float]]: |
| 2618 | + return cast( |
| 2619 | + SearchStrategy[NDArray[float]], |
| 2620 | + hnp.arrays( |
| 2621 | + # Variances that are too small make it impossible to test t against normal |
| 2622 | + elements=st.floats( |
| 2623 | + min_value=0.1, |
| 2624 | + max_value=3, |
| 2625 | + allow_nan=False, |
| 2626 | + allow_infinity=False, |
| 2627 | + exclude_min=True, |
| 2628 | + ), |
| 2629 | + dtype=np.dtype("float"), |
| 2630 | + shape=size, |
| 2631 | + ), |
| 2632 | + ) |
| 2633 | + |
| 2634 | + |
| 2635 | +def corr_vech_to_matrix(vech: NDArray[float]): |
| 2636 | + width = (math.isqrt(8 * vech.size + 1) + 1) // 2 |
| 2637 | + zeros = np.zeros((width, width)) |
| 2638 | + zeros[np.tril_indices(width, k=-1)] = vech |
| 2639 | + np.fill_diagonal(zeros, 1) |
| 2640 | + return zeros |
| 2641 | + |
| 2642 | + |
| 2643 | +def correlation_chols(size: int) -> SearchStrategy[NDArray[float]]: |
| 2644 | + return hnp.arrays( |
| 2645 | + # Floating point issues mean we sometimes get arrays which aren't positive semi-definite |
| 2646 | + # if we allow correlations of exactly 1 and -1 |
| 2647 | + elements=st.floats( |
| 2648 | + min_value=-0.99, max_value=0.99, allow_nan=False, allow_infinity=False |
| 2649 | + ), |
| 2650 | + dtype=np.dtype("float"), |
| 2651 | + shape=size * (size - 1) // 2, |
| 2652 | + ).map( |
| 2653 | + corr_vech_to_matrix # type: ignore |
| 2654 | + ) |
| 2655 | + |
| 2656 | + |
| 2657 | +@st.composite |
| 2658 | +def loc_and_scale(draw: DrawFn): |
| 2659 | + # Would need to generalize meshgrid to relax this restriction |
| 2660 | + size = 2 |
| 2661 | + corr = draw(correlation_chols(size)) |
| 2662 | + var = draw(variances(size)) |
| 2663 | + return (draw(locs(size)), jnp.sqrt(var)[..., None] * corr) |
| 2664 | + |
| 2665 | + |
| 2666 | +@st.composite |
| 2667 | +def loc_and_scale_and_skewers(draw: DrawFn): |
| 2668 | + # Would need to generalize meshgrid to relax this restriction |
| 2669 | + size = 2 |
| 2670 | + corr = draw(correlation_chols(size)) |
| 2671 | + var = draw(variances(size)) |
| 2672 | + return ( |
| 2673 | + draw(locs(size)), |
| 2674 | + jnp.sqrt(var)[..., None] * corr, |
| 2675 | + draw(skews(size)), |
| 2676 | + ) |
| 2677 | + |
| 2678 | + |
| 2679 | +X, Y = np.meshgrid(np.linspace(-3, 3, 100), np.linspace(-3, 3, 100)) |
| 2680 | +grid = np.dstack((X, Y)) |
| 2681 | +X_wide, Y_wide = np.meshgrid(np.linspace(-6, 6, 50), np.linspace(-6, 6, 50)) |
| 2682 | +grid_wide = np.dstack((X_wide, Y_wide)) |
| 2683 | + |
| 2684 | + |
| 2685 | +@settings(deadline=None) |
| 2686 | +@given(loc_and_scale()) |
| 2687 | +def test_skew_normal_log_prob_generalizes_normal( |
| 2688 | + loc_scale_tril: tuple[NDArray[float], NDArray[float]] |
| 2689 | +): |
| 2690 | + loc, scale_tril = loc_scale_tril |
| 2691 | + mvn = dist.MultivariateNormal(loc=loc, scale_tril=scale_tril) |
| 2692 | + smvn = dist.SkewMultivariateNormal( |
| 2693 | + loc=loc, scale_tril=scale_tril, skewers=np.zeros(scale_tril.shape[-1]) |
| 2694 | + ) |
| 2695 | + assert_allclose(mvn.log_prob(grid), smvn.log_prob(grid), atol=1e-6) |
| 2696 | + |
| 2697 | + |
| 2698 | +@settings(deadline=None) |
| 2699 | +@given(loc_and_scale()) |
| 2700 | +def test_skew_normal_moments_generalize_normal( |
| 2701 | + loc_scale_tril: tuple[NDArray[float], NDArray[float]] |
| 2702 | +): |
| 2703 | + loc, scale_tril = loc_scale_tril |
| 2704 | + mvn = dist.MultivariateNormal(loc=loc, scale_tril=scale_tril) |
| 2705 | + smvn = dist.SkewMultivariateNormal( |
| 2706 | + loc=loc, scale_tril=scale_tril, skewers=np.zeros(scale_tril.shape[-1]) |
| 2707 | + ) |
| 2708 | + assert_allclose(mvn.mean, smvn.mean, atol=1e-30) |
| 2709 | + assert_allclose(mvn.covariance_matrix, smvn.covariance_matrix, atol=1e-30) |
| 2710 | + |
| 2711 | + |
| 2712 | +@settings(deadline=None, max_examples=10) |
| 2713 | +@given(loc_and_scale_and_skewers()) |
| 2714 | +def test_skew_normal_log_prob_vs_samples( |
| 2715 | + loc_scale_tril_skewers: tuple[NDArray[float], NDArray[float], NDArray[float]] |
| 2716 | +): |
| 2717 | + loc, scale_tril, skewers = loc_scale_tril_skewers |
| 2718 | + note(f"Covariance: {scale_tril @ scale_tril.T}") |
| 2719 | + smvn = dist.SkewMultivariateNormal(loc=loc, scale_tril=scale_tril, skewers=skewers) |
| 2720 | + samples = smvn.sample(random.PRNGKey(0), sample_shape=(50_000,)) |
| 2721 | + # gaussian_kde needs a different format |
| 2722 | + grid_ = np.vstack([X_wide.ravel(), Y_wide.ravel()]) |
| 2723 | + lp = jnp.exp(smvn.log_prob(grid_.T)) |
| 2724 | + k = osp.gaussian_kde(samples.T, bw_method="scott")(grid_) |
| 2725 | + |
| 2726 | + lp_normed = (lp - lp.min()) / (lp.max() - lp.min()) |
| 2727 | + k_normed = (k - k.min()) / (k.max() - k.min()) |
| 2728 | + assert_allclose(lp_normed, k_normed, atol=0.07) |
| 2729 | + |
| 2730 | + |
| 2731 | +def split_cov(cov: NDArray[float]) -> tuple[NDArray[float], NDArray[float]]: |
| 2732 | + std_devs = np.sqrt(np.diag(cov)) |
| 2733 | + dinv = np.diag(1 / std_devs) |
| 2734 | + corr = dinv @ cov @ dinv |
| 2735 | + tril_i = np.tril_indices(len(std_devs), k=-1) |
| 2736 | + return (std_devs, corr[tril_i]) |
| 2737 | + |
| 2738 | + |
| 2739 | +@settings(deadline=None) |
| 2740 | +@given(loc_and_scale_and_skewers()) |
| 2741 | +def test_skew_normal_moments_vs_samples( |
| 2742 | + loc_scale_tril_skewers: tuple[NDArray[float], NDArray[float], NDArray[float]] |
| 2743 | +): |
| 2744 | + loc, scale_tril, skewers = loc_scale_tril_skewers |
| 2745 | + note(f"Covariance: {scale_tril @ scale_tril.T}") |
| 2746 | + smvn = dist.SkewMultivariateNormal(loc=loc, scale_tril=scale_tril, skewers=skewers) |
| 2747 | + samples = smvn.sample(random.PRNGKey(0), sample_shape=(500_000,)) |
| 2748 | + assert_allclose(np.mean(samples, axis=0), smvn.mean, rtol=0.005, atol=0.001) |
| 2749 | + |
| 2750 | + std_devs_sample, corr_sample = split_cov(np.cov(samples.T)) |
| 2751 | + std_devs_dist, corr_dist = split_cov(smvn.covariance_matrix) |
| 2752 | + assert_allclose(std_devs_sample, std_devs_dist, rtol=0.003) |
| 2753 | + note(f"Sample corr: {corr_sample}, Distribution corr: {corr_dist}") |
| 2754 | + assert_allclose(corr_sample, corr_dist, atol=0.006) |
0 commit comments