Skip to content

Refactoring the ChiSquared distribution #4695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 7, 2021
44 changes: 40 additions & 4 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
BetaRV,
WeibullRV,
cauchy,
chisquare,
exponential,
gamma,
gumbel,
Expand Down Expand Up @@ -2548,7 +2549,7 @@ def logcdf(value, alpha, beta):
)


class ChiSquared(Gamma):
class ChiSquared(PositiveContinuous):
r"""
:math:`\chi^2` log-likelihood.

Expand Down Expand Up @@ -2586,10 +2587,45 @@ class ChiSquared(Gamma):
nu: int
Degrees of freedom (nu > 0).
"""
rv_op = chisquare

def __init__(self, nu, *args, **kwargs):
self.nu = nu = at.as_tensor_variable(floatX(nu))
super().__init__(alpha=nu / 2.0, beta=0.5, *args, **kwargs)
@classmethod
def dist(cls, nu, *args, **kwargs):
nu = at.as_tensor_variable(floatX(nu))
return super().dist([nu], *args, **kwargs)

def logp(value, nu):
"""
Calculate log-probability of ChiSquared distribution at specified value.

Parameters
----------
value: numeric
Value(s) for which log-probability is calculated. If the log probabilities for multiple
values are desired the values must be provided in a numpy array or Aesara tensor

Returns
-------
TensorVariable
"""
return Gamma.logp(value, nu / 2, 2)

def logcdf(value, nu):
"""
Compute the log of the cumulative distribution function for ChiSquared distribution
at the specified value.

Parameters
----------
value: numeric or np.ndarray or `TensorVariable`
Value(s) for which log CDF is calculated. If the log CDF for
multiple values are desired the values must be provided in a numpy
array or `TensorVariable`.
Returns
-------
TensorVariable
"""
return Gamma.logcdf(value, nu / 2, 2)


# TODO: Remove this once logpt for multiplication is working!
Expand Down
11 changes: 8 additions & 3 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,14 +1030,19 @@ def test_half_normal(self):
lambda value, sigma: sp.halfnorm.logcdf(value, scale=sigma),
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_chi_squared(self):
def test_chisquared(self):
self.check_logp(
ChiSquared,
Rplus,
{"nu": Rplusdunif},
{"nu": Rplus},
lambda value, nu: sp.chi2.logpdf(value, df=nu),
)
self.check_logcdf(
ChiSquared,
Rplus,
{"nu": Rplus},
lambda value, nu: sp.chi2.logcdf(value, df=nu),
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_wald_logp(self):
Expand Down
19 changes: 13 additions & 6 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,6 @@ class TestAsymmetricLaplace(BaseTestCases.BaseTestCase):
params = {"kappa": 1.0, "b": 1.0, "mu": 0.0}


@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
class TestChiSquared(BaseTestCases.BaseTestCase):
distribution = pm.ChiSquared
params = {"nu": 2.0}


@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
class TestExGaussian(BaseTestCases.BaseTestCase):
distribution = pm.ExGaussian
Expand Down Expand Up @@ -753,6 +747,19 @@ class TestInverseGammaMuSigma(BaseTestDistribution):
tests_to_run = ["check_pymc_params_match_rv_op"]


class TestChiSquared(BaseTestDistribution):
pymc_dist = pm.ChiSquared
pymc_dist_params = {"nu": 2.0}
expected_rv_op_params = {"nu": 2.0}
reference_dist_params = {"df": 2.0}
reference_dist = seeded_numpy_distribution_builder("chisquare")
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
"check_rv_size",
]


class TestBinomial(BaseTestDistribution):
pymc_dist = pm.Binomial
pymc_dist_params = {"n": 100, "p": 0.33}
Expand Down