Skip to content

Commit 79245ce

Browse files
authored
Refactor LogitNormal (#4703)
1 parent a68f571 commit 79245ce

File tree

3 files changed

+56
-48
lines changed

3 files changed

+56
-48
lines changed

pymc3/distributions/continuous.py

+26-36
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from aesara.tensor.var import TensorVariable
4848
from scipy import stats
4949
from scipy.interpolate import InterpolatedUnivariateSpline
50+
from scipy.special import expit
5051

5152
from pymc3.aesaraf import floatX
5253
from pymc3.distributions import logp_transform, transforms
@@ -66,7 +67,7 @@
6667
)
6768
from pymc3.distributions.distribution import Continuous
6869
from pymc3.distributions.special import log_i0
69-
from pymc3.math import invlogit, log1mexp, log1pexp, logdiffexp, logit
70+
from pymc3.math import log1mexp, log1pexp, logdiffexp, logit
7071
from pymc3.util import UNSET
7172

7273
__all__ = [
@@ -3672,6 +3673,21 @@ def logcdf(value, mu, s):
36723673
)
36733674

36743675

3676+
class LogitNormalRV(RandomVariable):
3677+
name = "logit_normal"
3678+
ndim_supp = 0
3679+
ndims_params = [0, 0]
3680+
dtype = "floatX"
3681+
_print_name = ("logitNormal", "\\operatorname{logitNormal}")
3682+
3683+
@classmethod
3684+
def rng_fn(cls, rng, mu, sigma, size=None):
3685+
return expit(stats.norm.rvs(loc=mu, scale=sigma, size=size, random_state=rng))
3686+
3687+
3688+
logit_normal = LogitNormalRV()
3689+
3690+
36753691
class LogitNormal(UnitContinuous):
36763692
r"""
36773693
Logit-Normal log-likelihood.
@@ -3716,44 +3732,22 @@ class LogitNormal(UnitContinuous):
37163732
tau: float
37173733
Scale parameter (tau > 0).
37183734
"""
3735+
rv_op = logit_normal
37193736

3720-
def __init__(self, mu=0, sigma=None, tau=None, sd=None, **kwargs):
3737+
@classmethod
3738+
def dist(cls, mu=0, sigma=None, tau=None, sd=None, **kwargs):
37213739
if sd is not None:
37223740
sigma = sd
3723-
self.mu = mu = at.as_tensor_variable(floatX(mu))
3741+
mu = at.as_tensor_variable(floatX(mu))
37243742
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
3725-
self.sigma = self.sd = at.as_tensor_variable(sigma)
3726-
self.tau = tau = at.as_tensor_variable(tau)
3727-
3728-
self.median = invlogit(mu)
3743+
sigma = sd = at.as_tensor_variable(sigma)
3744+
tau = at.as_tensor_variable(tau)
37293745
assert_negative_support(sigma, "sigma", "LogitNormal")
37303746
assert_negative_support(tau, "tau", "LogitNormal")
37313747

3732-
super().__init__(**kwargs)
3733-
3734-
def random(self, point=None, size=None):
3735-
"""
3736-
Draw random values from LogitNormal distribution.
3737-
3738-
Parameters
3739-
----------
3740-
point: dict, optional
3741-
Dict of variable values on which random values are to be
3742-
conditioned (uses default point if not specified).
3743-
size: int, optional
3744-
Desired size of random sample (returns one sample if not
3745-
specified).
3746-
3747-
Returns
3748-
-------
3749-
array
3750-
"""
3751-
# mu, _, sigma = draw_values([self.mu, self.tau, self.sigma], point=point, size=size)
3752-
# return expit(
3753-
# generate_samples(stats.norm.rvs, loc=mu, scale=sigma, dist_shape=self.shape, size=size)
3754-
# )
3748+
return super().dist([mu, sigma], **kwargs)
37553749

3756-
def logp(self, value):
3750+
def logp(value, mu, sigma):
37573751
"""
37583752
Calculate log-probability of LogitNormal distribution at specified value.
37593753
@@ -3767,8 +3761,7 @@ def logp(self, value):
37673761
-------
37683762
TensorVariable
37693763
"""
3770-
mu = self.mu
3771-
tau = self.tau
3764+
tau, sigma = get_tau_sigma(sigma=sigma)
37723765
return bound(
37733766
-0.5 * tau * (logit(value) - mu) ** 2
37743767
+ 0.5 * at.log(tau / (2.0 * np.pi))
@@ -3778,9 +3771,6 @@ def logp(self, value):
37783771
tau > 0,
37793772
)
37803773

3781-
def _distr_parameters_for_repr(self):
3782-
return ["mu", "sigma"]
3783-
37843774

37853775
class Interpolated(BoundedContinuous):
37863776
r"""

pymc3/tests/test_distributions.py

-1
Original file line numberDiff line numberDiff line change
@@ -2527,7 +2527,6 @@ def test_logistic(self):
25272527
decimal=select_by_precision(float64=6, float32=1),
25282528
)
25292529

2530-
@pytest.mark.xfail(reason="Distribution not refactored yet")
25312530
def test_logitnormal(self):
25322531
self.check_logp(
25332532
LogitNormal,

pymc3/tests/test_distributions_random.py

+30-11
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def setup_method(self, *args, **kwargs):
158158
self.model = pm.Model()
159159

160160
def get_random_variable(self, shape, with_vector_params=False, name=None):
161-
""" Creates a RandomVariable of the parametrized distribution. """
161+
"""Creates a RandomVariable of the parametrized distribution."""
162162
if with_vector_params:
163163
params = {
164164
key: value * np.ones(self.shape, dtype=np.dtype(type(value)))
@@ -187,7 +187,7 @@ def get_random_variable(self, shape, with_vector_params=False, name=None):
187187

188188
@staticmethod
189189
def sample_random_variable(random_variable, size):
190-
""" Draws samples from a RandomVariable using its .random() method. """
190+
"""Draws samples from a RandomVariable using its .random() method."""
191191
if size is None:
192192
return random_variable.eval()
193193
else:
@@ -196,7 +196,7 @@ def sample_random_variable(random_variable, size):
196196
@pytest.mark.parametrize("size", [None, (), 1, (1,), 5, (4, 5)], ids=str)
197197
@pytest.mark.parametrize("shape", [None, ()], ids=str)
198198
def test_scalar_distribution_shape(self, shape, size):
199-
""" Draws samples of different [size] from a scalar [shape] RV. """
199+
"""Draws samples of different [size] from a scalar [shape] RV."""
200200
rv = self.get_random_variable(shape)
201201
exp_shape = self.default_shape if shape is None else tuple(np.atleast_1d(shape))
202202
exp_size = self.default_size if size is None else tuple(np.atleast_1d(size))
@@ -216,7 +216,7 @@ def test_scalar_distribution_shape(self, shape, size):
216216
"shape", [None, (), (1,), (1, 1), (1, 2), (10, 11, 1), (9, 10, 2)], ids=str
217217
)
218218
def test_scalar_sample_shape(self, shape, size):
219-
""" Draws samples of scalar [size] from a [shape] RV. """
219+
"""Draws samples of scalar [size] from a [shape] RV."""
220220
rv = self.get_random_variable(shape)
221221
exp_shape = self.default_shape if shape is None else tuple(np.atleast_1d(shape))
222222
exp_size = self.default_size if size is None else tuple(np.atleast_1d(size))
@@ -289,12 +289,6 @@ class TestExGaussian(BaseTestCases.BaseTestCase):
289289
params = {"mu": 0.0, "sigma": 1.0, "nu": 1.0}
290290

291291

292-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
293-
class TestLogitNormal(BaseTestCases.BaseTestCase):
294-
distribution = pm.LogitNormal
295-
params = {"mu": 0.0, "sigma": 1.0}
296-
297-
298292
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
299293
class TestZeroInflatedNegativeBinomial(BaseTestCases.BaseTestCase):
300294
distribution = pm.ZeroInflatedNegativeBinomial
@@ -575,6 +569,32 @@ class TestNormal(BaseTestDistribution):
575569
]
576570

577571

572+
class TestLogitNormal(BaseTestDistribution):
573+
def logit_normal_rng_fn(self, rng, size, loc, scale):
574+
return expit(st.norm.rvs(loc=loc, scale=scale, size=size, random_state=rng))
575+
576+
pymc_dist = pm.LogitNormal
577+
pymc_dist_params = {"mu": 5.0, "sigma": 10.0}
578+
expected_rv_op_params = {"mu": 5.0, "sigma": 10.0}
579+
reference_dist_params = {"loc": 5.0, "scale": 10.0}
580+
reference_dist = lambda self: functools.partial(
581+
self.logit_normal_rng_fn, rng=self.get_random_state()
582+
)
583+
tests_to_run = [
584+
"check_pymc_params_match_rv_op",
585+
"check_pymc_draws_match_reference",
586+
"check_rv_size",
587+
]
588+
589+
590+
class TestLogitNormalTau(BaseTestDistribution):
591+
pymc_dist = pm.LogitNormal
592+
tau, sigma = get_tau_sigma(tau=25.0)
593+
pymc_dist_params = {"mu": 1.0, "tau": tau}
594+
expected_rv_op_params = {"mu": 1.0, "sigma": sigma}
595+
tests_to_run = ["check_pymc_params_match_rv_op"]
596+
597+
578598
class TestNormalTau(BaseTestDistribution):
579599
pymc_dist = pm.Normal
580600
tau, sigma = get_tau_sigma(tau=25.0)
@@ -1443,7 +1463,6 @@ def test_dirichlet_multinomial_dist_ShapeError(self, n, a, shape, expectation):
14431463
with expectation:
14441464
m.random()
14451465

1446-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
14471466
def test_logitnormal(self):
14481467
def ref_rand(size, mu, sigma):
14491468
return expit(st.norm.rvs(loc=mu, scale=sigma, size=size))

0 commit comments

Comments
 (0)