|
21 | 21 | import arviz as az
|
22 | 22 | import numpy as np
|
23 | 23 | import numpy.testing as npt
|
| 24 | +import pandas as pd |
24 | 25 | import pytest
|
25 | 26 | import theano
|
26 | 27 | import theano.tensor as tt
|
@@ -56,7 +57,7 @@ def test_parallel_sample_does_not_reuse_seed(self):
|
56 | 57 | random_numbers = []
|
57 | 58 | draws = []
|
58 | 59 | for _ in range(2):
|
59 |
| - np.random.seed(1) # seeds in other processes don't effect main process |
| 60 | + np.random.seed(1) # seeds in other processes don't affect main process |
60 | 61 | with self.model:
|
61 | 62 | trace = pm.sample(100, tune=0, cores=cores, return_inferencedata=False)
|
62 | 63 | # numpy thread mentioned race condition. might as well check none are equal
|
@@ -1108,6 +1109,38 @@ def test_potentials_warning(self):
|
1108 | 1109 | pm.sample_prior_predictive(samples=5)
|
1109 | 1110 |
|
1110 | 1111 |
|
| 1112 | +def test_prior_sampling_mixture(): |
| 1113 | + """ |
| 1114 | + Added this test because the NormalMixture distribution did not support |
| 1115 | + component shape identification, causing prior predictive sampling to error out. |
| 1116 | + """ |
| 1117 | + old_faithful_df = pd.read_csv(pm.get_data("old_faithful.csv")) |
| 1118 | + old_faithful_df["std_waiting"] = ( |
| 1119 | + old_faithful_df.waiting - old_faithful_df.waiting.mean() |
| 1120 | + ) / old_faithful_df.waiting.std() |
| 1121 | + N = old_faithful_df.shape[0] |
| 1122 | + K = 30 |
| 1123 | + |
| 1124 | + def stick_breaking(beta): |
| 1125 | + portion_remaining = tt.concatenate([[1], tt.extra_ops.cumprod(1 - beta)[:-1]]) |
| 1126 | + result = beta * portion_remaining |
| 1127 | + return result / tt.sum(result, axis=-1, keepdims=True) |
| 1128 | + |
| 1129 | + with pm.Model() as model: |
| 1130 | + alpha = pm.Gamma("alpha", 1.0, 1.0) |
| 1131 | + beta = pm.Beta("beta", 1.0, alpha, shape=K) |
| 1132 | + w = pm.Deterministic("w", stick_breaking(beta)) |
| 1133 | + |
| 1134 | + tau = pm.Gamma("tau", 1.0, 1.0, shape=K) |
| 1135 | + lambda_ = pm.Gamma("lambda_", 10.0, 1.0, shape=K) |
| 1136 | + mu = pm.Normal("mu", 0, tau=lambda_ * tau, shape=K) |
| 1137 | + obs = pm.NormalMixture( |
| 1138 | + "obs", w, mu, tau=lambda_ * tau, observed=old_faithful_df.std_waiting.values |
| 1139 | + ) |
| 1140 | + |
| 1141 | + pm.sample_prior_predictive() |
| 1142 | + |
| 1143 | + |
1111 | 1144 | class TestSamplePosteriorPredictive:
|
1112 | 1145 | def test_point_list_arg_bug_fspp(self, point_list_arg_bug_fixture):
|
1113 | 1146 | pmodel, trace = point_list_arg_bug_fixture
|
|
0 commit comments