Skip to content

Commit d50742d

Browse files
authored
Add Exponential distribution to model/transforms/autoreparam.py (#365)
1 parent 99170df commit d50742d

File tree

2 files changed

+61
-17
lines changed

2 files changed

+61
-17
lines changed

Diff for: pymc_experimental/model/transforms/autoreparam.py

+38
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,44 @@ def _(
246246
return vip_rep
247247

248248

249+
@_vip_reparam_node.register
250+
def _(
251+
op: pm.Exponential,
252+
node: Apply,
253+
name: str,
254+
dims: List[Variable],
255+
transform: Optional[Transform],
256+
lam: pt.TensorVariable,
257+
) -> ModelDeterministic:
258+
rng, size, scale = node.inputs
259+
scale_centered = scale**lam
260+
scale_noncentered = scale ** (1 - lam)
261+
vip_rv_ = pm.Exponential.dist(
262+
scale=scale_centered,
263+
size=size,
264+
rng=rng,
265+
)
266+
vip_rv_value_ = vip_rv_.clone()
267+
vip_rv_.name = f"{name}::tau_"
268+
if transform is not None:
269+
vip_rv_value_.name = f"{vip_rv_.name}_{transform.name}__"
270+
else:
271+
vip_rv_value_.name = vip_rv_.name
272+
vip_rv = model_free_rv(
273+
vip_rv_,
274+
vip_rv_value_,
275+
transform,
276+
*dims,
277+
)
278+
279+
vip_rep_ = scale_noncentered * vip_rv
280+
281+
vip_rep_.name = name
282+
283+
vip_rep = model_deterministic(vip_rep_, *dims)
284+
return vip_rep
285+
286+
249287
def vip_reparametrize(
250288
model: pm.Model,
251289
var_names: Sequence[str],

Diff for: tests/model/transforms/test_autoreparam.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def model_c():
1111
m = pm.Normal("m")
1212
s = pm.LogNormal("s")
1313
pm.Normal("g", m, s, shape=5)
14+
pm.Exponential("e", scale=s, shape=7)
1415
return mod
1516

1617

@@ -20,31 +21,34 @@ def model_nc():
2021
m = pm.Normal("m")
2122
s = pm.LogNormal("s")
2223
pm.Deterministic("g", pm.Normal("z", shape=5) * s + m)
24+
pm.Deterministic("e", pm.Exponential("z_e", 1, shape=7) * s)
2325
return mod
2426

2527

26-
def test_reparametrize_created(model_c: pm.Model):
27-
model_reparam, vip = vip_reparametrize(model_c, ["g"])
28-
assert "g" in vip.get_lambda()
29-
assert "g::lam_logit__" in model_reparam.named_vars
30-
assert "g::tau_" in model_reparam.named_vars
28+
@pytest.mark.parametrize("var", ["g", "e"])
29+
def test_reparametrize_created(model_c: pm.Model, var):
30+
model_reparam, vip = vip_reparametrize(model_c, [var])
31+
assert f"{var}" in vip.get_lambda()
32+
assert f"{var}::lam_logit__" in model_reparam.named_vars
33+
assert f"{var}::tau_" in model_reparam.named_vars
3134
vip.set_all_lambda(1)
32-
assert ~np.isfinite(model_reparam["g::lam_logit__"].get_value()).any()
35+
assert ~np.isfinite(model_reparam[f"{var}::lam_logit__"].get_value()).any()
3336

3437

35-
def test_random_draw(model_c: pm.Model, model_nc):
38+
@pytest.mark.parametrize("var", ["g", "e"])
39+
def test_random_draw(model_c: pm.Model, model_nc, var):
3640
model_c = pm.do(model_c, {"m": 3, "s": 2})
3741
model_nc = pm.do(model_nc, {"m": 3, "s": 2})
38-
model_v, vip = vip_reparametrize(model_c, ["g"])
39-
assert "g" in [v.name for v in model_v.deterministics]
40-
c = pm.draw(model_c["g"], random_seed=42, draws=1000)
41-
nc = pm.draw(model_nc["g"], random_seed=42, draws=1000)
42+
model_v, vip = vip_reparametrize(model_c, [var])
43+
assert var in [v.name for v in model_v.deterministics]
44+
c = pm.draw(model_c[var], random_seed=42, draws=1000)
45+
nc = pm.draw(model_nc[var], random_seed=42, draws=1000)
4246
vip.set_all_lambda(1)
43-
v_1 = pm.draw(model_v["g"], random_seed=42, draws=1000)
47+
v_1 = pm.draw(model_v[var], random_seed=42, draws=1000)
4448
vip.set_all_lambda(0)
45-
v_0 = pm.draw(model_v["g"], random_seed=42, draws=1000)
49+
v_0 = pm.draw(model_v[var], random_seed=42, draws=1000)
4650
vip.set_all_lambda(0.5)
47-
v_05 = pm.draw(model_v["g"], random_seed=42, draws=1000)
51+
v_05 = pm.draw(model_v[var], random_seed=42, draws=1000)
4852
np.testing.assert_allclose(c.mean(), nc.mean())
4953
np.testing.assert_allclose(c.mean(), v_0.mean())
5054
np.testing.assert_allclose(v_05.mean(), v_1.mean())
@@ -57,10 +61,12 @@ def test_random_draw(model_c: pm.Model, model_nc):
5761

5862

5963
def test_reparam_fit(model_c):
60-
model_v, vip = vip_reparametrize(model_c, ["g"])
64+
vars = ["g", "e"]
65+
model_v, vip = vip_reparametrize(model_c, ["g", "e"])
6166
with model_v:
62-
vip.fit(random_seed=42)
63-
np.testing.assert_allclose(vip.get_lambda()["g"], 0, atol=0.01)
67+
vip.fit(50000, random_seed=42)
68+
for var in vars:
69+
np.testing.assert_allclose(vip.get_lambda()[var], 0, atol=0.01)
6470

6571

6672
def test_multilevel():

0 commit comments

Comments
 (0)