Skip to content

Commit d67c9a3

Browse files
committed
add missing data test
1 parent 64496cd commit d67c9a3

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

Diff for: pymc3/tests/test_bart.py

+12
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,15 @@ def test_bart_random():
6565
assert_almost_equal(pred_first, pred_all[0, :10], decimal=4)
6666
assert pred_all.shape == (2, 50)
6767
assert pred_first.shape == (10,)
68+
69+
70+
def test_missing_data():
71+
X = np.random.normal(0, 1, size=(2, 50)).T
72+
Y = np.random.normal(0, 1, size=50)
73+
X[10:20, 0] = np.nan
74+
75+
with pm.Model() as model:
76+
mu = pm.BART("mu", X, Y, m=10)
77+
sigma = pm.HalfNormal("sigma", 1)
78+
y = pm.Normal("y", mu, sigma, observed=Y)
79+
idata = pm.sample(random_seed=3415)

0 commit comments

Comments
 (0)