@@ -11,6 +11,7 @@ def model_c():
11
11
m = pm .Normal ("m" )
12
12
s = pm .LogNormal ("s" )
13
13
pm .Normal ("g" , m , s , shape = 5 )
14
+ pm .Exponential ("e" , scale = s , shape = 7 )
14
15
return mod
15
16
16
17
@@ -20,31 +21,34 @@ def model_nc():
20
21
m = pm .Normal ("m" )
21
22
s = pm .LogNormal ("s" )
22
23
pm .Deterministic ("g" , pm .Normal ("z" , shape = 5 ) * s + m )
24
+ pm .Deterministic ("e" , pm .Exponential ("z_e" , 1 , shape = 7 ) * s )
23
25
return mod
24
26
25
27
26
- def test_reparametrize_created (model_c : pm .Model ):
27
- model_reparam , vip = vip_reparametrize (model_c , ["g" ])
28
- assert "g" in vip .get_lambda ()
29
- assert "g::lam_logit__" in model_reparam .named_vars
30
- assert "g::tau_" in model_reparam .named_vars
28
+ @pytest .mark .parametrize ("var" , ["g" , "e" ])
29
+ def test_reparametrize_created (model_c : pm .Model , var ):
30
+ model_reparam , vip = vip_reparametrize (model_c , [var ])
31
+ assert f"{ var } " in vip .get_lambda ()
32
+ assert f"{ var } ::lam_logit__" in model_reparam .named_vars
33
+ assert f"{ var } ::tau_" in model_reparam .named_vars
31
34
vip .set_all_lambda (1 )
32
- assert ~ np .isfinite (model_reparam ["g ::lam_logit__" ].get_value ()).any ()
35
+ assert ~ np .isfinite (model_reparam [f" { var } ::lam_logit__" ].get_value ()).any ()
33
36
34
37
35
- def test_random_draw (model_c : pm .Model , model_nc ):
38
+ @pytest .mark .parametrize ("var" , ["g" , "e" ])
39
+ def test_random_draw (model_c : pm .Model , model_nc , var ):
36
40
model_c = pm .do (model_c , {"m" : 3 , "s" : 2 })
37
41
model_nc = pm .do (model_nc , {"m" : 3 , "s" : 2 })
38
- model_v , vip = vip_reparametrize (model_c , ["g" ])
39
- assert "g" in [v .name for v in model_v .deterministics ]
40
- c = pm .draw (model_c ["g" ], random_seed = 42 , draws = 1000 )
41
- nc = pm .draw (model_nc ["g" ], random_seed = 42 , draws = 1000 )
42
+ model_v , vip = vip_reparametrize (model_c , [var ])
43
+ assert var in [v .name for v in model_v .deterministics ]
44
+ c = pm .draw (model_c [var ], random_seed = 42 , draws = 1000 )
45
+ nc = pm .draw (model_nc [var ], random_seed = 42 , draws = 1000 )
42
46
vip .set_all_lambda (1 )
43
- v_1 = pm .draw (model_v ["g" ], random_seed = 42 , draws = 1000 )
47
+ v_1 = pm .draw (model_v [var ], random_seed = 42 , draws = 1000 )
44
48
vip .set_all_lambda (0 )
45
- v_0 = pm .draw (model_v ["g" ], random_seed = 42 , draws = 1000 )
49
+ v_0 = pm .draw (model_v [var ], random_seed = 42 , draws = 1000 )
46
50
vip .set_all_lambda (0.5 )
47
- v_05 = pm .draw (model_v ["g" ], random_seed = 42 , draws = 1000 )
51
+ v_05 = pm .draw (model_v [var ], random_seed = 42 , draws = 1000 )
48
52
np .testing .assert_allclose (c .mean (), nc .mean ())
49
53
np .testing .assert_allclose (c .mean (), v_0 .mean ())
50
54
np .testing .assert_allclose (v_05 .mean (), v_1 .mean ())
@@ -57,10 +61,12 @@ def test_random_draw(model_c: pm.Model, model_nc):
57
61
58
62
59
63
def test_reparam_fit (model_c ):
60
- model_v , vip = vip_reparametrize (model_c , ["g" ])
64
+ vars = ["g" , "e" ]
65
+ model_v , vip = vip_reparametrize (model_c , ["g" , "e" ])
61
66
with model_v :
62
- vip .fit (random_seed = 42 )
63
- np .testing .assert_allclose (vip .get_lambda ()["g" ], 0 , atol = 0.01 )
67
+ vip .fit (50000 , random_seed = 42 )
68
+ for var in vars :
69
+ np .testing .assert_allclose (vip .get_lambda ()[var ], 0 , atol = 0.01 )
64
70
65
71
66
72
def test_multilevel ():
0 commit comments