@@ -24,31 +24,6 @@ def test_leaf_node():
24
24
assert leaf_node .get_idx_right_child () == 12
25
25
26
26
27
- def test_model ():
28
- X = np .linspace (7 , 15 , 100 )
29
- Y = np .sin (np .random .normal (X , 0.2 )) + 3
30
- X = X [:, None ]
31
-
32
- with pm .Model () as model :
33
- sigma = pm .HalfNormal ("sigma" , 1 )
34
- mu = pm .BART ("mu" , X , Y , m = 50 )
35
- y = pm .Normal ("y" , mu , sigma , observed = Y )
36
- idata = pm .sample (chains = 4 )
37
- mean = idata .posterior ["mu" ].stack (samples = ("chain" , "draw" )).mean ("samples" )
38
-
39
- np .testing .assert_allclose (mean , Y , 0.5 )
40
-
41
- Y = np .repeat ([0 , 1 ], 50 )
42
- with pm .Model () as model :
43
- mu_ = pm .BART ("mu_" , X , Y , m = 50 )
44
- mu = pm .Deterministic ("mu" , pm .math .invlogit (mu_ ))
45
- y = pm .Bernoulli ("y" , mu , observed = Y )
46
- idata = pm .sample (chains = 4 )
47
- mean = idata .posterior ["mu" ].stack (samples = ("chain" , "draw" )).mean ("samples" )
48
-
49
- np .testing .assert_allclose (mean , Y , atol = 0.5 )
50
-
51
-
52
27
def test_bart_vi ():
53
28
X = np .random .normal (0 , 1 , size = (3 , 250 )).T
54
29
Y = np .random .normal (0 , 1 , size = 250 )
0 commit comments