Skip to content

Commit f59a07f

Browse files
ricardoV94twiecki
authored andcommitted
Refactor Flat and HalfFlat distributions (#4723)
* Refactor Flat and HalfFlat distributions * Re-enable Gumbel logp test * Remove redundant test
1 parent c62c3bf commit f59a07f

6 files changed

+100
-76
lines changed

pymc3/distributions/continuous.py

+46-36
Original file line numberDiff line numberDiff line change
@@ -308,31 +308,36 @@ def logcdf(value, lower, upper):
308308
)
309309

310310

311+
class FlatRV(RandomVariable):
312+
name = "flat"
313+
ndim_supp = 0
314+
ndims_params = []
315+
dtype = "floatX"
316+
_print_name = ("Flat", "\\operatorname{Flat}")
317+
318+
@classmethod
319+
def rng_fn(cls, rng, size):
320+
raise NotImplementedError("Cannot sample from flat variable")
321+
322+
323+
flat = FlatRV()
324+
325+
311326
class Flat(Continuous):
312327
"""
313328
Uninformative log-likelihood that returns 0 regardless of
314329
the passed value.
315330
"""
316331

317-
def __init__(self, *args, **kwargs):
318-
self._default = 0
319-
super().__init__(defaults=("_default",), *args, **kwargs)
320-
321-
def random(self, point=None, size=None):
322-
"""Raises ValueError as it is not possible to sample from Flat distribution
323-
324-
Parameters
325-
----------
326-
point: dict, optional
327-
size: int, optional
332+
rv_op = flat
328333

329-
Raises
330-
-------
331-
ValueError
332-
"""
333-
raise ValueError("Cannot sample from Flat distribution")
334+
@classmethod
335+
def dist(cls, *, size=None, testval=None, **kwargs):
336+
if testval is None:
337+
testval = np.full(size, floatX(0.0))
338+
return super().dist([], size=size, testval=testval, **kwargs)
334339

335-
def logp(self, value):
340+
def logp(value):
336341
"""
337342
Calculate log-probability of Flat distribution at specified value.
338343
@@ -348,7 +353,7 @@ def logp(self, value):
348353
"""
349354
return at.zeros_like(value)
350355

351-
def logcdf(self, value):
356+
def logcdf(value):
352357
"""
353358
Compute the log of the cumulative distribution function for Flat distribution
354359
at the specified value.
@@ -368,28 +373,33 @@ def logcdf(self, value):
368373
)
369374

370375

371-
class HalfFlat(PositiveContinuous):
372-
"""Improper flat prior over the positive reals."""
376+
class HalfFlatRV(RandomVariable):
377+
name = "half_flat"
378+
ndim_supp = 0
379+
ndims_params = []
380+
dtype = "floatX"
381+
_print_name = ("HalfFlat", "\\operatorname{HalfFlat}")
373382

374-
def __init__(self, *args, **kwargs):
375-
self._default = 1
376-
super().__init__(defaults=("_default",), *args, **kwargs)
383+
@classmethod
384+
def rng_fn(cls, rng, size):
385+
raise NotImplementedError("Cannot sample from half_flat variable")
377386

378-
def random(self, point=None, size=None):
379-
"""Raises ValueError as it is not possible to sample from HalfFlat distribution
380387

381-
Parameters
382-
----------
383-
point: dict, optional
384-
size: int, optional
388+
halfflat = HalfFlatRV()
385389

386-
Raises
387-
-------
388-
ValueError
389-
"""
390-
raise ValueError("Cannot sample from HalfFlat distribution")
391390

392-
def logp(self, value):
391+
class HalfFlat(PositiveContinuous):
392+
"""Improper flat prior over the positive reals."""
393+
394+
rv_op = halfflat
395+
396+
@classmethod
397+
def dist(cls, *, size=None, testval=None, **kwargs):
398+
if testval is None:
399+
testval = np.full(size, floatX(1.0))
400+
return super().dist([], size=size, testval=testval, **kwargs)
401+
402+
def logp(value):
393403
"""
394404
Calculate log-probability of HalfFlat distribution at specified value.
395405
@@ -405,7 +415,7 @@ def logp(self, value):
405415
"""
406416
return bound(at.zeros_like(value), value > 0)
407417

408-
def logcdf(self, value):
418+
def logcdf(value):
409419
"""
410420
Compute the log of the cumulative distribution function for HalfFlat distribution
411421
at the specified value.

pymc3/tests/test_distributions.py

+16-21
Original file line numberDiff line numberDiff line change
@@ -961,18 +961,16 @@ def test_discrete_unif(self):
961961
assert logpt(invalid_dist, 0.5).eval() == -np.inf
962962
assert logcdf(invalid_dist, 2).eval() == -np.inf
963963

964-
@pytest.mark.xfail(reason="Distribution not refactored yet")
965964
def test_flat(self):
966965
self.check_logp(Flat, Runif, {}, lambda value: 0)
967966
with Model():
968967
x = Flat("a")
969968
assert_allclose(x.tag.test_value, 0)
970969
self.check_logcdf(Flat, R, {}, lambda value: np.log(0.5))
971970
# Check infinite cases individually.
972-
assert 0.0 == logcdf(Flat.dist(), np.inf).tag.test_value
973-
assert -np.inf == logcdf(Flat.dist(), -np.inf).tag.test_value
971+
assert 0.0 == logcdf(Flat.dist(), np.inf).eval()
972+
assert -np.inf == logcdf(Flat.dist(), -np.inf).eval()
974973

975-
@pytest.mark.xfail(reason="Distribution not refactored yet")
976974
def test_half_flat(self):
977975
self.check_logp(HalfFlat, Rplus, {}, lambda value: 0)
978976
with Model():
@@ -981,8 +979,8 @@ def test_half_flat(self):
981979
assert x.tag.test_value.shape == (2,)
982980
self.check_logcdf(HalfFlat, Rplus, {}, lambda value: -np.inf)
983981
# Check infinite cases individually.
984-
assert 0.0 == logcdf(HalfFlat.dist(), np.inf).tag.test_value
985-
assert -np.inf == logcdf(HalfFlat.dist(), -np.inf).tag.test_value
982+
assert 0.0 == logcdf(HalfFlat.dist(), np.inf).eval()
983+
assert -np.inf == logcdf(HalfFlat.dist(), -np.inf).eval()
986984

987985
def test_normal(self):
988986
self.check_logp(
@@ -2499,17 +2497,19 @@ def test_vonmises(self):
24992497
lambda value, mu, kappa: floatX(sp.vonmises.logpdf(value, kappa, loc=mu)),
25002498
)
25012499

2502-
@pytest.mark.xfail(reason="Distribution not refactored yet")
25032500
def test_gumbel(self):
2504-
def gumbel(value, mu, beta):
2505-
return floatX(sp.gumbel_r.logpdf(value, loc=mu, scale=beta))
2506-
2507-
self.check_logp(Gumbel, R, {"mu": R, "beta": Rplusbig}, gumbel)
2508-
2509-
def gumbellcdf(value, mu, beta):
2510-
return floatX(sp.gumbel_r.logcdf(value, loc=mu, scale=beta))
2511-
2512-
self.check_logcdf(Gumbel, R, {"mu": R, "beta": Rplusbig}, gumbellcdf)
2501+
self.check_logp(
2502+
Gumbel,
2503+
R,
2504+
{"mu": R, "beta": Rplusbig},
2505+
lambda value, mu, beta: sp.gumbel_r.logpdf(value, loc=mu, scale=beta),
2506+
)
2507+
self.check_logcdf(
2508+
Gumbel,
2509+
R,
2510+
{"mu": R, "beta": Rplusbig},
2511+
lambda value, mu, beta: sp.gumbel_r.logcdf(value, loc=mu, scale=beta),
2512+
)
25132513

25142514
def test_logistic(self):
25152515
self.check_logp(
@@ -2538,11 +2538,6 @@ def test_logitnormal(self):
25382538
decimal=select_by_precision(float64=6, float32=1),
25392539
)
25402540

2541-
@pytest.mark.xfail(reason="Distribution not refactored yet")
2542-
def test_multidimensional_beta_construction(self):
2543-
with Model():
2544-
Beta("beta", alpha=1.0, beta=1.0, size=(10, 20))
2545-
25462541
@pytest.mark.xfail(
25472542
condition=(aesara.config.floatX == "float32"),
25482543
reason="Some combinations underflow to -inf in float32 in pymc version",

pymc3/tests/test_distributions_random.py

+35-15
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,11 @@ def check_rv_size(self):
360360
assert actual == expected, f"size={size}, expected={expected}, actual={actual}"
361361

362362
# test multi-parameters sampling for univariate distributions (with univariate inputs)
363-
if self.pymc_dist.rv_op.ndim_supp == 0 and sum(self.pymc_dist.rv_op.ndims_params) == 0:
363+
if (
364+
self.pymc_dist.rv_op.ndim_supp == 0
365+
and self.pymc_dist.rv_op.ndims_params
366+
and sum(self.pymc_dist.rv_op.ndims_params) == 0
367+
):
364368
params = {
365369
k: p * np.ones(self.repeated_params_shape) for k, p in self.pymc_dist_params.items()
366370
}
@@ -394,6 +398,36 @@ def seeded_numpy_distribution_builder(dist_name: str) -> Callable:
394398
)
395399

396400

401+
class TestFlat(BaseTestDistribution):
402+
pymc_dist = pm.Flat
403+
pymc_dist_params = {}
404+
expected_rv_op_params = {}
405+
tests_to_run = [
406+
"check_pymc_params_match_rv_op",
407+
"check_rv_size",
408+
"check_not_implemented",
409+
]
410+
411+
def check_not_implemented(self):
412+
with pytest.raises(NotImplementedError):
413+
self.pymc_rv.eval()
414+
415+
416+
class TestHalfFlat(BaseTestDistribution):
417+
pymc_dist = pm.HalfFlat
418+
pymc_dist_params = {}
419+
expected_rv_op_params = {}
420+
tests_to_run = [
421+
"check_pymc_params_match_rv_op",
422+
"check_rv_size",
423+
"check_not_implemented",
424+
]
425+
426+
def check_not_implemented(self):
427+
with pytest.raises(NotImplementedError):
428+
self.pymc_rv.eval()
429+
430+
397431
class TestDiscreteWeibull(BaseTestDistribution):
398432
def discrete_weibul_rng_fn(self, size, q, beta, uniform_rng_fct):
399433
return np.ceil(np.power(np.log(1 - uniform_rng_fct(size=size)) / np.log(q), 1.0 / beta)) - 1
@@ -1240,20 +1274,6 @@ def ref_rand(size, mu, sigma, nu):
12401274

12411275
pymc3_random(pm.ExGaussian, {"mu": R, "sigma": Rplus, "nu": Rplus}, ref_rand=ref_rand)
12421276

1243-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
1244-
def test_flat(self):
1245-
with pm.Model():
1246-
f = pm.Flat("f")
1247-
with pytest.raises(ValueError):
1248-
f.random(1)
1249-
1250-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
1251-
def test_half_flat(self):
1252-
with pm.Model():
1253-
f = pm.HalfFlat("f")
1254-
with pytest.raises(ValueError):
1255-
f.random(1)
1256-
12571277
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
12581278
def test_matrix_normal(self):
12591279
def ref_rand(size, mu, rowcov, colcov):

pymc3/tests/test_distributions_timeseries.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pymc3.tests.helpers import select_by_precision
2424

2525
# pytestmark = pytest.mark.usefixtures("seeded_test")
26-
pytestmark = pytest.mark.xfail(reason="This test relies on the deprecated Distribution interface")
26+
pytestmark = pytest.mark.xfail(reason="Timeseries not refactored")
2727

2828

2929
def test_AR():

pymc3/tests/test_sampling.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,6 @@ def test_sum_normal(self):
603603
_, pval = stats.kstest(ppc["b"], stats.norm(scale=scale).cdf)
604604
assert pval > 0.001
605605

606-
@pytest.mark.xfail(reason="HalfFlat not refactored for v4")
607606
def test_model_not_drawable_prior(self):
608607
data = np.random.poisson(lam=10, size=200)
609608
model = pm.Model()
@@ -613,7 +612,7 @@ def test_model_not_drawable_prior(self):
613612
trace = pm.sample(tune=1000)
614613

615614
with model:
616-
with pytest.raises(ValueError) as excinfo:
615+
with pytest.raises(NotImplementedError) as excinfo:
617616
pm.sample_prior_predictive(50)
618617
assert "Cannot sample" in str(excinfo.value)
619618
samples = pm.sample_posterior_predictive(trace, 40)

pymc3/tests/test_step.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,7 @@ def test_step_categorical(self):
621621
trace = sample(8000, tune=0, step=step, start=start, model=model, random_seed=1)
622622
self.check_stat(check, trace, step.__class__.__name__)
623623

624-
@pytest.mark.xfail(reason="Flat not refactored for v4")
624+
@pytest.mark.xfail(reason="EllipticalSlice not refactored for v4")
625625
def test_step_elliptical_slice(self):
626626
start, model, (K, L, mu, std, noise) = mv_prior_simple()
627627
unc = noise ** 0.5

0 commit comments

Comments
 (0)