Skip to content

Commit 3469d23

Browse files
Refactor Student T Distribution (#4694)
* feat: adapt student t * Change default parameterization in terms of sigma and tweak tests Co-authored-by: Ricardo <[email protected]>
1 parent 791a1c4 commit 3469d23

File tree

4 files changed

+66
-61
lines changed

4 files changed

+66
-61
lines changed

pymc3/distributions/continuous.py

+27-44
Original file line numberDiff line numberDiff line change
@@ -1793,6 +1793,21 @@ def logcdf(value, mu, sigma):
17931793
)
17941794

17951795

1796+
class StudentTRV(RandomVariable):
1797+
name = "studentt"
1798+
ndim_supp = 0
1799+
ndims_params = [0, 0, 0]
1800+
dtype = "floatX"
1801+
_print_name = ("StudentT", "\\operatorname{StudentT}")
1802+
1803+
@classmethod
1804+
def rng_fn(cls, rng, nu, mu, sigma, size=None):
1805+
return stats.t.rvs(nu, mu, sigma, size=size, random_state=rng)
1806+
1807+
1808+
studentt = StudentTRV()
1809+
1810+
17961811
class StudentT(Continuous):
17971812
r"""
17981813
Student's T log-likelihood.
@@ -1856,45 +1871,22 @@ class StudentT(Continuous):
18561871
with pm.Model():
18571872
x = pm.StudentT('x', nu=15, mu=0, lam=1/23)
18581873
"""
1874+
rv_op = studentt
18591875

1860-
def __init__(self, nu, mu=0, lam=None, sigma=None, sd=None, *args, **kwargs):
1861-
super().__init__(*args, **kwargs)
1876+
@classmethod
1877+
def dist(cls, nu, mu=0, lam=None, sigma=None, sd=None, *args, **kwargs):
18621878
if sd is not None:
18631879
sigma = sd
1864-
self.nu = nu = at.as_tensor_variable(floatX(nu))
1880+
nu = at.as_tensor_variable(floatX(nu))
18651881
lam, sigma = get_tau_sigma(tau=lam, sigma=sigma)
1866-
self.lam = lam = at.as_tensor_variable(lam)
1867-
self.sigma = self.sd = sigma = at.as_tensor_variable(sigma)
1868-
self.mean = self.median = self.mode = self.mu = mu = at.as_tensor_variable(mu)
1869-
1870-
self.variance = at.switch((nu > 2) * 1, (1 / self.lam) * (nu / (nu - 2)), np.inf)
1882+
sigma = at.as_tensor_variable(sigma)
18711883

1872-
assert_negative_support(lam, "lam (sigma)", "StudentT")
1884+
assert_negative_support(sigma, "sigma (lam)", "StudentT")
18731885
assert_negative_support(nu, "nu", "StudentT")
18741886

1875-
def random(self, point=None, size=None):
1876-
"""
1877-
Draw random values from StudentT distribution.
1878-
1879-
Parameters
1880-
----------
1881-
point: dict, optional
1882-
Dict of variable values on which random values are to be
1883-
conditioned (uses default point if not specified).
1884-
size: int, optional
1885-
Desired size of random sample (returns one sample if not
1886-
specified).
1887+
return super().dist([nu, mu, sigma], **kwargs)
18871888

1888-
Returns
1889-
-------
1890-
array
1891-
"""
1892-
# nu, mu, lam = draw_values([self.nu, self.mu, self.lam], point=point, size=size)
1893-
# return generate_samples(
1894-
# stats.t.rvs, nu, loc=mu, scale=lam ** -0.5, dist_shape=self.shape, size=size
1895-
# )
1896-
1897-
def logp(self, value):
1889+
def logp(value, nu, mu, sigma):
18981890
"""
18991891
Calculate log-probability of StudentT distribution at specified value.
19001892
@@ -1908,11 +1900,7 @@ def logp(self, value):
19081900
-------
19091901
TensorVariable
19101902
"""
1911-
nu = self.nu
1912-
mu = self.mu
1913-
lam = self.lam
1914-
sigma = self.sigma
1915-
1903+
lam, sigma = get_tau_sigma(sigma=sigma)
19161904
return bound(
19171905
gammaln((nu + 1.0) / 2.0)
19181906
+ 0.5 * at.log(lam / (nu * np.pi))
@@ -1923,10 +1911,7 @@ def logp(self, value):
19231911
sigma > 0,
19241912
)
19251913

1926-
def _distr_parameters_for_repr(self):
1927-
return ["nu", "mu", "lam"]
1928-
1929-
def logcdf(self, value):
1914+
def logcdf(value, nu, mu, sigma):
19301915
"""
19311916
Compute the log of the cumulative distribution function for Student's T distribution
19321917
at the specified value.
@@ -1946,10 +1931,8 @@ def logcdf(self, value):
19461931
f"StudentT.logcdf expects a scalar value but received a {np.ndim(value)}-dimensional object."
19471932
)
19481933

1949-
nu = self.nu
1950-
mu = self.mu
1951-
sigma = self.sigma
1952-
lam = self.lam
1934+
lam, sigma = get_tau_sigma(sigma=sigma)
1935+
19531936
t = (value - mu) / sigma
19541937
sqrt_t2_nu = at.sqrt(t ** 2 + nu)
19551938
x = (t + sqrt_t2_nu) / (2.0 * sqrt_t2_nu)

pymc3/tests/test_distributions.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -1301,20 +1301,33 @@ def test_lognormal(self):
13011301
n_samples=5, # Just testing alternative parametrization
13021302
)
13031303

1304-
@pytest.mark.xfail(reason="Distribution not refactored yet")
13051304
def test_t(self):
13061305
self.check_logp(
13071306
StudentT,
13081307
R,
13091308
{"nu": Rplus, "mu": R, "lam": Rplus},
13101309
lambda value, nu, mu, lam: sp.t.logpdf(value, nu, mu, lam ** -0.5),
13111310
)
1311+
self.check_logp(
1312+
StudentT,
1313+
R,
1314+
{"nu": Rplus, "mu": R, "sigma": Rplus},
1315+
lambda value, nu, mu, sigma: sp.t.logpdf(value, nu, mu, sigma),
1316+
n_samples=5, # Just testing alternative parametrization
1317+
)
13121318
self.check_logcdf(
13131319
StudentT,
13141320
R,
13151321
{"nu": Rplus, "mu": R, "lam": Rplus},
13161322
lambda value, nu, mu, lam: sp.t.logcdf(value, nu, mu, lam ** -0.5),
1317-
n_samples=10,
1323+
n_samples=10, # relies on slow incomplete beta
1324+
)
1325+
self.check_logcdf(
1326+
StudentT,
1327+
R,
1328+
{"nu": Rplus, "mu": R, "sigma": Rplus},
1329+
lambda value, nu, mu, sigma: sp.t.logcdf(value, nu, mu, sigma),
1330+
n_samples=5, # Just testing alternative parametrization
13181331
)
13191332

13201333
def test_cauchy(self):

pymc3/tests/test_distributions_random.py

+23-13
Original file line numberDiff line numberDiff line change
@@ -283,12 +283,6 @@ class TestAsymmetricLaplace(BaseTestCases.BaseTestCase):
283283
params = {"kappa": 1.0, "b": 1.0, "mu": 0.0}
284284

285285

286-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
287-
class TestStudentT(BaseTestCases.BaseTestCase):
288-
distribution = pm.StudentT
289-
params = {"nu": 5.0, "mu": 0.0, "lam": 1.0}
290-
291-
292286
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
293287
class TestChiSquared(BaseTestCases.BaseTestCase):
294288
distribution = pm.ChiSquared
@@ -481,6 +475,29 @@ class TestGumbel(BaseTestDistribution):
481475
]
482476

483477

478+
class TestStudentT(BaseTestDistribution):
479+
pymc_dist = pm.StudentT
480+
pymc_dist_params = {"nu": 5.0, "mu": -1.0, "sigma": 2.0}
481+
expected_rv_op_params = {"nu": 5.0, "mu": -1.0, "sigma": 2.0}
482+
reference_dist_params = {"df": 5.0, "loc": -1.0, "scale": 2.0}
483+
reference_dist = seeded_scipy_distribution_builder("t")
484+
tests_to_run = [
485+
"check_pymc_params_match_rv_op",
486+
"check_pymc_draws_match_reference",
487+
"check_rv_size",
488+
]
489+
490+
491+
class TestStudentTLam(BaseTestDistribution):
492+
pymc_dist = pm.StudentT
493+
lam, sigma = get_tau_sigma(tau=2.0)
494+
pymc_dist_params = {"nu": 5.0, "mu": -1.0, "lam": lam}
495+
expected_rv_op_params = {"nu": 5.0, "mu": -1.0, "lam": sigma}
496+
reference_dist_params = {"df": 5.0, "loc": -1.0, "scale": sigma}
497+
reference_dist = seeded_scipy_distribution_builder("t")
498+
tests_to_run = ["check_pymc_params_match_rv_op"]
499+
500+
484501
class TestNormal(BaseTestDistribution):
485502
pymc_dist = pm.Normal
486503
pymc_dist_params = {"mu": 5.0, "sigma": 10.0}
@@ -1134,13 +1151,6 @@ def ref_rand(size, kappa, b, mu):
11341151

11351152
pymc3_random(pm.AsymmetricLaplace, {"b": Rplus, "kappa": Rplus, "mu": R}, ref_rand=ref_rand)
11361153

1137-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
1138-
def test_student_t(self):
1139-
def ref_rand(size, nu, mu, lam):
1140-
return st.t.rvs(nu, mu, lam ** -0.5, size=size)
1141-
1142-
pymc3_random(pm.StudentT, {"nu": Rplus, "mu": R, "lam": Rplus}, ref_rand=ref_rand)
1143-
11441154
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
11451155
def test_ex_gaussian(self):
11461156
def ref_rand(size, mu, sigma, nu):

pymc3/tests/test_posteriors.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ class TestNUTSBetaBinomial(sf.NutsFixture, sf.BetaBinomialFixture):
7676
min_n_eff = 400
7777

7878

79-
@pytest.mark.xfail(reason="StudentT not refactored for v4")
8079
class TestNUTSStudentT(sf.NutsFixture, sf.StudentTFixture):
8180
n_samples = 10000
8281
tune = 1000
@@ -98,7 +97,7 @@ class TestNUTSNormalLong(sf.NutsFixture, sf.NormalFixture):
9897
atol = 0.001
9998

10099

101-
@pytest.mark.xfail(reason="StudentT not refactored for v4")
100+
@pytest.mark.xfail(reason="LKJCholeskyCov not refactored for v4")
102101
class TestNUTSLKJCholeskyCov(sf.NutsFixture, sf.LKJCholeskyCovFixture):
103102
n_samples = 2000
104103
tune = 1000

0 commit comments

Comments
 (0)