Skip to content

Commit 4da5edf

Browse files
committed
Implement check_icdf helper to test icdf implementations
Note that adding a nan switch to the icdf expression of discrete variables, prevents the returned dtype to be the same as the original distribution. There is no integer nan!
1 parent f043ad9 commit 4da5edf

File tree

7 files changed

+139
-47
lines changed

7 files changed

+139
-47
lines changed

Diff for: pymc/distributions/continuous.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def polyagamma_cdf(*args, **kwargs):
7979
from pymc.distributions import transforms
8080
from pymc.distributions.dist_math import (
8181
SplineWrapper,
82+
check_icdf_parameters,
83+
check_icdf_value,
8284
check_parameters,
8385
clipped_beta_rvs,
8486
i0e,
@@ -532,7 +534,13 @@ def logcdf(value, mu, sigma):
532534
)
533535

534536
def icdf(value, mu, sigma):
535-
return mu + sigma * -np.sqrt(2.0) * at.erfcinv(2 * value)
537+
res = mu + sigma * -np.sqrt(2.0) * at.erfcinv(2 * value)
538+
res = check_icdf_value(res, value)
539+
return check_icdf_parameters(
540+
res,
541+
sigma > 0,
542+
msg="sigma > 0",
543+
)
536544

537545

538546
class TruncatedNormalRV(RandomVariable):

Diff for: pymc/distributions/discrete.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from pymc.distributions.dist_math import (
3737
betaln,
3838
binomln,
39+
check_icdf_parameters,
40+
check_icdf_value,
3941
check_parameters,
4042
factln,
4143
log_diff_normal_cdf,
@@ -820,7 +822,14 @@ def logcdf(value, p):
820822
)
821823

822824
def icdf(value, p):
823-
return at.ceil(at.log1p(-value) / at.log1p(-p)).astype("int64")
825+
res = at.ceil(at.log1p(-value) / at.log1p(-p)).astype("int64")
826+
res = check_icdf_value(res, value)
827+
return check_icdf_parameters(
828+
res,
829+
0 <= p,
830+
p <= 1,
831+
msg="0 <= p <= 1",
832+
)
824833

825834

826835
class HyperGeometric(Discrete):

Diff for: pymc/distributions/dist_math.py

+16
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"""
2020
import warnings
2121

22+
from functools import partial
2223
from typing import Iterable
2324

2425
import numpy as np
@@ -77,6 +78,21 @@ def check_parameters(
7778
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
7879

7980

81+
check_icdf_parameters = partial(check_parameters, can_be_replaced_by_ninf=False)
82+
83+
84+
def check_icdf_value(expr: Variable, value: Variable) -> Variable:
85+
"""Wrap icdf expression in nan switch for value."""
86+
value = at.as_tensor_variable(value)
87+
expr = at.switch(
88+
at.and_(value >= 0, value <= 1),
89+
expr,
90+
np.nan,
91+
)
92+
expr.name = "0 <= value <= 1"
93+
return expr
94+
95+
8096
def logpow(x, m):
8197
"""
8298
Calculates log(x**m) since m*log(x) will fail when m, x = 0.

Diff for: pymc/testing.py

+92
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from pymc.distributions.shape_utils import change_dist_size
3838
from pymc.initial_point import make_initial_point_fn
3939
from pymc.logprob import joint_logp
40+
from pymc.logprob.abstract import icdf
4041
from pymc.logprob.utils import ParameterValueError
4142
from pymc.pytensorf import (
4243
compile_pymc,
@@ -520,6 +521,97 @@ def check_logcdf(
520521
)
521522

522523

524+
def check_icdf(
525+
pymc_dist: Distribution,
526+
paramdomains: Dict[str, Domain],
527+
scipy_icdf: Callable,
528+
decimal: Optional[int] = None,
529+
n_samples: int = 100,
530+
) -> None:
531+
"""
532+
Generic test for PyMC icdf methods
533+
534+
The following tests are performed by default:
535+
1. Test PyMC icdf and equivalent scipy icdf (ppf) methods give similar
536+
results for parameters inside the supported edges.
537+
Edges are excluded by default, but can be artificially included by
538+
creating a domain with repeated values (e.g., `Domain([0, 0, .5, 1, 1]`)
539+
2. Test PyMC icdf method raises for invalid parameter values
540+
outside the supported edges.
541+
3. Test PyMC icdf method returns np.nan for values below 0 or above 1,
542+
when using valid parameters.
543+
544+
Parameters
545+
----------
546+
pymc_dist: PyMC distribution
547+
paramdomains : Dictionary of Parameter : Domain pairs
548+
Supported domains of distribution parameters
549+
scipy_icdf : Scipy icdf method
550+
Scipy icdf (ppp) method of equivalent pymc_dist distribution
551+
decimal : int, optional
552+
Level of precision with which pymc_dist and scipy_icdf are compared.
553+
Defaults to 6 for float64 and 3 for float32
554+
n_samples : int
555+
Upper limit on the number of valid domain and value combinations that
556+
are compared between pymc and scipy methods. If n_samples is below the
557+
total number of combinations, a random subset is evaluated. Setting
558+
n_samples = -1, will return all possible combinations. Defaults to 100
559+
560+
"""
561+
if decimal is None:
562+
decimal = select_by_precision(float64=6, float32=3)
563+
564+
dist = create_dist_from_paramdomains(pymc_dist, paramdomains)
565+
q = pt.scalar(dtype="float64", name="q")
566+
dist_icdf = icdf(dist, q)
567+
pymc_icdf = pytensor.function(list(inputvars(dist_icdf)), dist_icdf)
568+
569+
# Test pymc and scipy distributions match for values and parameters
570+
# within the supported domain edges (excluding edges)
571+
domains = paramdomains.copy()
572+
domain = Domain([0, 0.1, 0.5, 0.75, 0.95, 0.99, 1]) # Values we test the icdf at
573+
domains["q"] = domain
574+
575+
for point in product(domains, n_samples=n_samples):
576+
point = dict(point)
577+
npt.assert_almost_equal(
578+
pymc_icdf(**point),
579+
scipy_icdf(**point),
580+
decimal=decimal,
581+
err_msg=str(point),
582+
)
583+
584+
valid_value = domain.vals[0]
585+
valid_params = {param: paramdomain.vals[0] for param, paramdomain in paramdomains.items()}
586+
valid_params["q"] = valid_value
587+
588+
# Test pymc distribution raises ParameterValueError for parameters outside the
589+
# supported domain edges (excluding edges)
590+
invalid_params = find_invalid_scalar_params(paramdomains)
591+
for invalid_param, invalid_edges in invalid_params.items():
592+
for invalid_edge in invalid_edges:
593+
if invalid_edge is None:
594+
continue
595+
596+
point = valid_params.copy()
597+
point[invalid_param] = invalid_edge
598+
with pytest.raises(ParameterValueError):
599+
pymc_icdf(**point)
600+
pytest.fail(f"test_params={point}")
601+
602+
# Test that values below 0 or above 1 evaluate to nan
603+
invalid_values = find_invalid_scalar_params({"q": domain})["q"]
604+
for invalid_value in invalid_values:
605+
if invalid_value is not None:
606+
point = valid_params.copy()
607+
point["q"] = invalid_value
608+
npt.assert_equal(
609+
pymc_icdf(**point),
610+
np.nan,
611+
err_msg=str(point),
612+
)
613+
614+
523615
def check_selfconsistency_discrete_logcdf(
524616
distribution: Distribution,
525617
domain: Domain,

Diff for: tests/distributions/test_continuous.py

+6-18
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
Runif,
4444
Unit,
4545
assert_moment_is_expected,
46+
check_icdf,
4647
check_logcdf,
4748
check_logp,
4849
continuous_random_tester,
@@ -270,6 +271,11 @@ def test_normal(self):
270271
lambda value, mu, sigma: st.norm.logcdf(value, mu, sigma),
271272
decimal=select_by_precision(float64=6, float32=1),
272273
)
274+
check_icdf(
275+
pm.Normal,
276+
{"mu": R, "sigma": Rplus},
277+
lambda q, mu, sigma: st.norm.ppf(q, mu, sigma),
278+
)
273279

274280
def test_half_normal(self):
275281
check_logp(
@@ -2269,21 +2275,3 @@ def dist(cls, **kwargs):
22692275
extra_args={"rng": pytensor.shared(rng)},
22702276
ref_rand=ref_rand,
22712277
)
2272-
2273-
2274-
class TestICDF:
2275-
@pytest.mark.parametrize(
2276-
"dist_params, obs, size",
2277-
[
2278-
((0, 1), np.array([-0.5, 0, 0.3, 0.5, 1, 1.5], dtype=np.float64), ()),
2279-
((-1, 20), np.array([-0.5, 0, 0.3, 0.5, 1, 1.5], dtype=np.float64), ()),
2280-
((-1, 20), np.array([-0.5, 0, 0.3, 0.5, 1, 1.5], dtype=np.float64), (2, 3)),
2281-
],
2282-
)
2283-
def test_normal_icdf(self, dist_params, obs, size):
2284-
dist_params_at, obs_at, size_at = create_pytensor_params(dist_params, obs, size)
2285-
dist_params = dict(zip(dist_params_at, dist_params))
2286-
2287-
x = Normal.dist(*dist_params_at, size=size_at)
2288-
2289-
scipy_logprob_tester(x, obs, dist_params, test_fn=st.norm.ppf, test="icdf")

Diff for: tests/distributions/test_discrete.py

+6-26
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
UnitSortedVector,
5252
Vector,
5353
assert_moment_is_expected,
54+
check_icdf,
5455
check_logcdf,
5556
check_logp,
5657
check_selfconsistency_discrete_logcdf,
@@ -143,6 +144,11 @@ def test_geometric(self):
143144
Nat,
144145
{"p": Unit},
145146
)
147+
check_icdf(
148+
pm.Geometric,
149+
{"p": Unit},
150+
st.geom.ppf,
151+
)
146152

147153
def test_hypergeometric(self):
148154
def modified_scipy_hypergeom_logcdf(value, N, k, n):
@@ -1148,29 +1154,3 @@ def test_shape_inputs(self, eta, cutpoints, sigma, expected):
11481154
)
11491155
p = categorical.owner.inputs[3].eval()
11501156
assert p.shape == expected
1151-
1152-
1153-
class TestICDF:
1154-
@pytest.mark.parametrize(
1155-
"dist_params, obs, size",
1156-
[
1157-
((0.1,), np.array([-0.5, 0, 0.1, 0.5, 0.9, 1.0, 1.5], dtype=np.int64), ()),
1158-
((0.5,), np.array([-0.5, 0, 0.1, 0.5, 0.9, 1.0, 1.5], dtype=np.int64), (3, 2)),
1159-
(
1160-
(np.array([0.0, 0.2, 0.5, 1.0]),),
1161-
np.array([0.7, 0.7, 0.7, 0.7], dtype=np.int64),
1162-
(),
1163-
),
1164-
],
1165-
)
1166-
def test_geometric_icdf(self, dist_params, obs, size):
1167-
dist_params_at, obs_at, size_at = create_pytensor_params(dist_params, obs, size)
1168-
dist_params = dict(zip(dist_params_at, dist_params))
1169-
1170-
x = Geometric.dist(*dist_params_at, size=size_at)
1171-
1172-
def scipy_geom_icdf(value, p):
1173-
# Scipy ppf returns floats
1174-
return st.geom.ppf(value, p).astype(value.dtype)
1175-
1176-
scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_geom_icdf, test="icdf")

Diff for: tests/distributions/test_truncated.py

-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,6 @@ def test_truncation_discrete_random(op_type, lower, upper):
174174
x = geometric_op(p, name="x", size=500)
175175
xt = Truncated.dist(x, lower=lower, upper=upper)
176176
assert isinstance(xt.owner.op, TruncatedRV)
177-
assert xt.type.dtype == x.type.dtype
178177

179178
xt_draws = draw(xt)
180179
assert np.all(xt_draws >= lower)

0 commit comments

Comments
 (0)