Skip to content

Commit 3538491

Browse files
farhanreynaldoricardoV94
authored andcommitted
Refactor Rice and Skew Normal distribution (#4705)
* refactor Skewnormal and Rice distribution * add test for rice b and skewnorm tau params * change default parameter to b * Add float32 xfail to `test_rice` Co-authored-by: Farhan Reynaldo <[email protected]> Co-authored-by: Ricardo <[email protected]>
1 parent ba5e978 commit 3538491

File tree

3 files changed

+104
-122
lines changed

3 files changed

+104
-122
lines changed

pymc3/distributions/continuous.py

+51-109
Original file line numberDiff line numberDiff line change
@@ -3102,6 +3102,21 @@ def logp(value, mu, kappa):
31023102
)
31033103

31043104

3105+
class SkewNormalRV(RandomVariable):
3106+
name = "skewnormal"
3107+
ndim_supp = 0
3108+
ndims_params = [0, 0, 0]
3109+
dtype = "floatX"
3110+
_print_name = ("SkewNormal", "\\operatorname{SkewNormal}")
3111+
3112+
@classmethod
3113+
def rng_fn(cls, rng, mu, sigma, alpha, size=None):
3114+
return stats.skewnorm.rvs(a=alpha, loc=mu, scale=sigma, size=size, random_state=rng)
3115+
3116+
3117+
skewnormal = SkewNormalRV()
3118+
3119+
31053120
class SkewNormal(Continuous):
31063121
r"""
31073122
Univariate skew-normal log-likelihood.
@@ -3160,51 +3175,25 @@ class SkewNormal(Continuous):
31603175
approaching plus/minus infinite we get a half-normal distribution.
31613176
31623177
"""
3178+
rv_op = skewnormal
31633179

3164-
def __init__(self, mu=0.0, sigma=None, tau=None, alpha=1, sd=None, *args, **kwargs):
3165-
super().__init__(*args, **kwargs)
3166-
3180+
@classmethod
3181+
def dist(cls, alpha=1, mu=0.0, sigma=None, tau=None, sd=None, *args, **kwargs):
31673182
if sd is not None:
31683183
sigma = sd
31693184

31703185
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
3171-
self.mu = mu = at.as_tensor_variable(floatX(mu))
3172-
self.tau = at.as_tensor_variable(tau)
3173-
self.sigma = self.sd = at.as_tensor_variable(sigma)
3174-
3175-
self.alpha = alpha = at.as_tensor_variable(floatX(alpha))
3176-
3177-
self.mean = mu + self.sigma * (2 / np.pi) ** 0.5 * alpha / (1 + alpha ** 2) ** 0.5
3178-
self.variance = self.sigma ** 2 * (1 - (2 * alpha ** 2) / ((1 + alpha ** 2) * np.pi))
3186+
alpha = at.as_tensor_variable(floatX(alpha))
3187+
mu = at.as_tensor_variable(floatX(mu))
3188+
tau = at.as_tensor_variable(tau)
3189+
sigma = at.as_tensor_variable(sigma)
31793190

31803191
assert_negative_support(tau, "tau", "SkewNormal")
31813192
assert_negative_support(sigma, "sigma", "SkewNormal")
31823193

3183-
def random(self, point=None, size=None):
3184-
"""
3185-
Draw random values from SkewNormal distribution.
3186-
3187-
Parameters
3188-
----------
3189-
point: dict, optional
3190-
Dict of variable values on which random values are to be
3191-
conditioned (uses default point if not specified).
3192-
size: int, optional
3193-
Desired size of random sample (returns one sample if not
3194-
specified).
3195-
3196-
Returns
3197-
-------
3198-
array
3199-
"""
3200-
# mu, tau, _, alpha = draw_values(
3201-
# [self.mu, self.tau, self.sigma, self.alpha], point=point, size=size
3202-
# )
3203-
# return generate_samples(
3204-
# stats.skewnorm.rvs, a=alpha, loc=mu, scale=tau ** -0.5, dist_shape=self.shape, size=size
3205-
# )
3194+
return super().dist([mu, sigma, alpha], *args, **kwargs)
32063195

3207-
def logp(self, value):
3196+
def logp(value, mu, sigma, alpha):
32083197
"""
32093198
Calculate log-probability of SkewNormal distribution at specified value.
32103199
@@ -3218,20 +3207,14 @@ def logp(self, value):
32183207
-------
32193208
TensorVariable
32203209
"""
3221-
tau = self.tau
3222-
sigma = self.sigma
3223-
mu = self.mu
3224-
alpha = self.alpha
3210+
tau, sigma = get_tau_sigma(sigma=sigma)
32253211
return bound(
32263212
at.log(1 + at.erf(((value - mu) * at.sqrt(tau) * alpha) / at.sqrt(2)))
32273213
+ (-tau * (value - mu) ** 2 + at.log(tau / np.pi / 2.0)) / 2.0,
32283214
tau > 0,
32293215
sigma > 0,
32303216
)
32313217

3232-
def _distr_parameters_for_repr(self):
3233-
return ["mu", "sigma", "alpha"]
3234-
32353218

32363219
class Triangular(BoundedContinuous):
32373220
r"""
@@ -3474,6 +3457,21 @@ def logcdf(
34743457
)
34753458

34763459

3460+
class RiceRV(RandomVariable):
3461+
name = "rice"
3462+
ndim_supp = 0
3463+
ndims_params = [0, 0]
3464+
dtype = "floatX"
3465+
_print_name = ("Rice", "\\operatorname{Rice}")
3466+
3467+
@classmethod
3468+
def rng_fn(cls, rng, b, sigma, size=None):
3469+
return stats.rice.rvs(b=b, scale=sigma, size=size, random_state=rng)
3470+
3471+
3472+
rice = RiceRV()
3473+
3474+
34773475
class Rice(PositiveContinuous):
34783476
r"""
34793477
Rice distribution.
@@ -3533,42 +3531,21 @@ class Rice(PositiveContinuous):
35333531
b = \dfrac{\nu}{\sigma}
35343532
35353533
"""
3534+
rv_op = rice
35363535

3537-
def __init__(self, nu=None, sigma=None, b=None, sd=None, *args, **kwargs):
3538-
super().__init__(*args, **kwargs)
3536+
@classmethod
3537+
def dist(cls, nu=None, sigma=None, b=None, sd=None, *args, **kwargs):
35393538
if sd is not None:
35403539
sigma = sd
35413540

3542-
nu, b, sigma = self.get_nu_b(nu, b, sigma)
3543-
self.nu = nu = at.as_tensor_variable(floatX(nu))
3544-
self.sigma = self.sd = sigma = at.as_tensor_variable(floatX(sigma))
3545-
self.b = b = at.as_tensor_variable(floatX(b))
3546-
3547-
nu_sigma_ratio = -(nu ** 2) / (2 * sigma ** 2)
3548-
self.mean = (
3549-
sigma
3550-
* np.sqrt(np.pi / 2)
3551-
* at.exp(nu_sigma_ratio / 2)
3552-
* (
3553-
(1 - nu_sigma_ratio) * at.i0(-nu_sigma_ratio / 2)
3554-
- nu_sigma_ratio * at.i1(-nu_sigma_ratio / 2)
3555-
)
3556-
)
3557-
self.variance = (
3558-
2 * sigma ** 2
3559-
+ nu ** 2
3560-
- (np.pi * sigma ** 2 / 2)
3561-
* (
3562-
at.exp(nu_sigma_ratio / 2)
3563-
* (
3564-
(1 - nu_sigma_ratio) * at.i0(-nu_sigma_ratio / 2)
3565-
- nu_sigma_ratio * at.i1(-nu_sigma_ratio / 2)
3566-
)
3567-
)
3568-
** 2
3569-
)
3541+
nu, b, sigma = cls.get_nu_b(nu, b, sigma)
3542+
b = at.as_tensor_variable(floatX(b))
3543+
sigma = at.as_tensor_variable(floatX(sigma))
35703544

3571-
def get_nu_b(self, nu, b, sigma):
3545+
return super().dist([b, sigma], *args, **kwargs)
3546+
3547+
@classmethod
3548+
def get_nu_b(cls, nu, b, sigma):
35723549
if sigma is None:
35733550
sigma = 1.0
35743551
if nu is None and b is not None:
@@ -3579,35 +3556,7 @@ def get_nu_b(self, nu, b, sigma):
35793556
return nu, b, sigma
35803557
raise ValueError("Rice distribution must specify either nu" " or b.")
35813558

3582-
def random(self, point=None, size=None):
3583-
"""
3584-
Draw random values from Rice distribution.
3585-
3586-
Parameters
3587-
----------
3588-
point: dict, optional
3589-
Dict of variable values on which random values are to be
3590-
conditioned (uses default point if not specified).
3591-
size: int, optional
3592-
Desired size of random sample (returns one sample if not
3593-
specified).
3594-
3595-
Returns
3596-
-------
3597-
array
3598-
"""
3599-
# nu, sigma = draw_values([self.nu, self.sigma], point=point, size=size)
3600-
# return generate_samples(self._random, nu=nu, sigma=sigma, dist_shape=self.shape, size=size)
3601-
3602-
def _random(self, nu, sigma, size):
3603-
"""Wrapper around stats.rice.rvs that converts Rice's
3604-
parametrization to scipy.rice. All parameter arrays should have
3605-
been broadcasted properly by generate_samples at this point and size is
3606-
the scipy.rvs representation.
3607-
"""
3608-
return stats.rice.rvs(b=nu / sigma, scale=sigma, size=size)
3609-
3610-
def logp(self, value):
3559+
def logp(value, b, sigma):
36113560
"""
36123561
Calculate log-probability of Rice distribution at specified value.
36133562
@@ -3621,20 +3570,13 @@ def logp(self, value):
36213570
-------
36223571
TensorVariable
36233572
"""
3624-
nu = self.nu
3625-
sigma = self.sigma
3626-
b = self.b
36273573
x = value / sigma
36283574
return bound(
36293575
at.log(x * at.exp((-(x - b) * (x - b)) / 2) * i0e(x * b) / sigma),
36303576
sigma >= 0,
3631-
nu >= 0,
36323577
value > 0,
36333578
)
36343579

3635-
def _distr_parameters_for_repr(self):
3636-
return ["nu", "sigma"]
3637-
36383580

36393581
class Logistic(Continuous):
36403582
r"""

pymc3/tests/test_distributions.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -1496,7 +1496,6 @@ def test_half_studentt(self):
14961496
lambda value, sigma: sp.halfcauchy.logpdf(value, 0, sigma),
14971497
)
14981498

1499-
@pytest.mark.xfail(reason="Distribution not refactored yet")
15001499
def test_skew_normal(self):
15011500
self.check_logp(
15021501
SkewNormal,
@@ -2545,19 +2544,24 @@ def test_multidimensional_beta_construction(self):
25452544
with Model():
25462545
Beta("beta", alpha=1.0, beta=1.0, size=(10, 20))
25472546

2548-
@pytest.mark.xfail(reason="Distribution not refactored yet")
2547+
@pytest.mark.xfail(
2548+
condition=(aesara.config.floatX == "float32"),
2549+
reason="Some combinations underflow to -inf in float32 in pymc version",
2550+
)
25492551
def test_rice(self):
25502552
self.check_logp(
25512553
Rice,
25522554
Rplus,
2553-
{"nu": Rplus, "sigma": Rplusbig},
2554-
lambda value, nu, sigma: sp.rice.logpdf(value, b=nu / sigma, loc=0, scale=sigma),
2555+
{"b": Rplus, "sigma": Rplusbig},
2556+
lambda value, b, sigma: sp.rice.logpdf(value, b=b, loc=0, scale=sigma),
25552557
)
2558+
2559+
def test_rice_nu(self):
25562560
self.check_logp(
25572561
Rice,
25582562
Rplus,
2559-
{"b": Rplus, "sigma": Rplusbig},
2560-
lambda value, b, sigma: sp.rice.logpdf(value, b=b, loc=0, scale=sigma),
2563+
{"nu": Rplus, "sigma": Rplusbig},
2564+
lambda value, nu, sigma: sp.rice.logpdf(value, b=nu / sigma, loc=0, scale=sigma),
25612565
)
25622566

25632567
def test_moyal_logp(self):

pymc3/tests/test_distributions_random.py

+43-7
Original file line numberDiff line numberDiff line change
@@ -265,12 +265,6 @@ class TestTruncatedNormalUpper(BaseTestCases.BaseTestCase):
265265
params = {"mu": 0.0, "tau": 1.0, "upper": 0.5}
266266

267267

268-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
269-
class TestSkewNormal(BaseTestCases.BaseTestCase):
270-
distribution = pm.SkewNormal
271-
params = {"mu": 0.0, "sigma": 1.0, "alpha": 5.0}
272-
273-
274268
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
275269
class TestWald(BaseTestCases.BaseTestCase):
276270
distribution = pm.Wald
@@ -514,6 +508,49 @@ def seeded_kumaraswamy_rng_fn(self):
514508
]
515509

516510

511+
class TestSkewNormal(BaseTestDistribution):
512+
pymc_dist = pm.SkewNormal
513+
pymc_dist_params = {"mu": 0.0, "sigma": 1.0, "alpha": 5.0}
514+
expected_rv_op_params = {"mu": 0.0, "sigma": 1.0, "alpha": 5.0}
515+
reference_dist_params = {"loc": 0.0, "scale": 1.0, "a": 5.0}
516+
reference_dist = seeded_scipy_distribution_builder("skewnorm")
517+
tests_to_run = [
518+
"check_pymc_params_match_rv_op",
519+
"check_pymc_draws_match_reference",
520+
"check_rv_size",
521+
]
522+
523+
524+
class TestSkewNormalTau(BaseTestDistribution):
525+
pymc_dist = pm.SkewNormal
526+
tau, sigma = get_tau_sigma(tau=2.0)
527+
pymc_dist_params = {"mu": 0.0, "tau": tau, "alpha": 5.0}
528+
expected_rv_op_params = {"mu": 0.0, "sigma": sigma, "alpha": 5.0}
529+
tests_to_run = ["check_pymc_params_match_rv_op"]
530+
531+
532+
class TestRice(BaseTestDistribution):
533+
pymc_dist = pm.Rice
534+
b, sigma = 1, 2
535+
pymc_dist_params = {"b": b, "sigma": sigma}
536+
expected_rv_op_params = {"b": b, "sigma": sigma}
537+
reference_dist_params = {"b": b, "scale": sigma}
538+
reference_dist = seeded_scipy_distribution_builder("rice")
539+
tests_to_run = [
540+
"check_pymc_params_match_rv_op",
541+
"check_pymc_draws_match_reference",
542+
"check_rv_size",
543+
]
544+
545+
546+
class TestRiceNu(BaseTestDistribution):
547+
pymc_dist = pm.Rice
548+
nu = sigma = 2
549+
pymc_dist_params = {"nu": nu, "sigma": sigma}
550+
expected_rv_op_params = {"b": nu / sigma, "sigma": sigma}
551+
tests_to_run = ["check_pymc_params_match_rv_op"]
552+
553+
517554
class TestStudentTLam(BaseTestDistribution):
518555
pymc_dist = pm.StudentT
519556
lam, sigma = get_tau_sigma(tau=2.0)
@@ -1145,7 +1182,6 @@ def ref_rand(size, mu, sigma, upper):
11451182
pm.TruncatedNormal, {"mu": R, "sigma": Rplusbig, "upper": Rplusbig}, ref_rand=ref_rand
11461183
)
11471184

1148-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
11491185
def test_skew_normal(self):
11501186
def ref_rand(size, alpha, mu, sigma):
11511187
return st.skewnorm.rvs(size=size, a=alpha, loc=mu, scale=sigma)

0 commit comments

Comments
 (0)