Skip to content

Refactoring the ChiSquared distribution #4695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 7, 2021
2 changes: 1 addition & 1 deletion docs/source/api/distributions/continuous.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Continuous
InverseGamma
Weibull
Lognormal
ChiSquared
ChiSquare
Wald
Pareto
ExGaussian
Expand Down
4 changes: 2 additions & 2 deletions pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
AsymmetricLaplace,
Beta,
Cauchy,
ChiSquared,
ChiSquare,
ExGaussian,
Exponential,
Flat,
Expand Down Expand Up @@ -125,7 +125,7 @@
"Bound",
"Lognormal",
"HalfStudentT",
"ChiSquared",
"ChiSquare",
"HalfNormal",
"Wald",
"Pareto",
Expand Down
44 changes: 40 additions & 4 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
BetaRV,
WeibullRV,
cauchy,
chisquare,
exponential,
gamma,
gumbel,
Expand Down Expand Up @@ -87,7 +88,7 @@
"Weibull",
"HalfStudentT",
"Lognormal",
"ChiSquared",
"ChiSquare",
"HalfNormal",
"Wald",
"Pareto",
Expand Down Expand Up @@ -2548,7 +2549,7 @@ def logcdf(value, alpha, beta):
)


class ChiSquared(Gamma):
class ChiSquare(Gamma):
r"""
:math:`\chi^2` log-likelihood.

Expand Down Expand Up @@ -2586,11 +2587,46 @@ class ChiSquared(Gamma):
nu: int
Degrees of freedom (nu > 0).
"""
rv_op = chisquare

def __init__(self, nu, *args, **kwargs):
self.nu = nu = at.as_tensor_variable(floatX(nu))
@classmethod
def dist(cls, nu, *args, **kwargs):
nu = at.as_tensor_variable(floatX(nu))
super().__init__(alpha=nu / 2.0, beta=0.5, *args, **kwargs)

def logp(value, nu):
"""
Calculate log-probability of ChiSquare distribution at specified value.

Parameters
----------
value: numeric
Value(s) for which log-probability is calculated. If the log probabilities for multiple
values are desired the values must be provided in a numpy array or Aesara tensor

Returns
-------
TensorVariable
"""
return Gamma.logpdf(value, nu/2, 2)

def logcdf(value, nu):
"""
Compute the log of the cumulative distribution function for ChiSquare distribution
at the specified value.

Parameters
----------
value: numeric or np.ndarray or `TensorVariable`
Value(s) for which log CDF is calculated. If the log CDF for
multiple values are desired the values must be provided in a numpy
array or `TensorVariable`.
Returns
-------
TensorVariable
"""
return Gamma.logcdf(value, nu/2, 2)


# TODO: Remove this once logpt for multiplication is working!
class WeibullBetaRV(WeibullRV):
Expand Down
4 changes: 2 additions & 2 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

from pymc3.aesaraf import floatX, intX
from pymc3.distributions import transforms
from pymc3.distributions.continuous import ChiSquared, Normal
from pymc3.distributions.continuous import ChiSquare, Normal
from pymc3.distributions.dist_math import bound, factln, logpow
from pymc3.distributions.distribution import Continuous, Discrete
from pymc3.distributions.special import gammaln, multigammaln
Expand Down Expand Up @@ -905,7 +905,7 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, testv
tril_testval = None

c = at.sqrt(
ChiSquared("%s_c" % name, nu - np.arange(2, 2 + n_diag), shape=n_diag, testval=diag_testval)
ChiSquare("%s_c" % name, nu - np.arange(2, 2 + n_diag), shape=n_diag, testval=diag_testval)
)
pm._log.info("Added new variable %s_c to model diagonal of Wishart." % name)
z = Normal("%s_z" % name, 0.0, 1.0, shape=n_tril, testval=tril_testval)
Expand Down
2 changes: 1 addition & 1 deletion pymc3/gp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def _build_prior(self, name, X, reparameterize=True, **kwargs):
cov = stabilize(self.cov_func(X))
shape = infer_shape(X, kwargs.pop("shape", None))
if reparameterize:
chi2 = pm.ChiSquared(name + "_chi2_", self.nu)
chi2 = pm.ChiSquare(name + "_chi2_", self.nu)
v = pm.Normal(name + "_rotated_", mu=0.0, sigma=1.0, size=shape, **kwargs)
f = pm.Deterministic(name, (at.sqrt(self.nu) / chi2) * (mu + cholesky(cov).dot(v)))
else:
Expand Down
5 changes: 2 additions & 3 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
Bound,
Categorical,
Cauchy,
ChiSquared,
ChiSquare,
Constant,
DensityDist,
Dirichlet,
Expand Down Expand Up @@ -1030,10 +1030,9 @@ def test_half_normal(self):
lambda value, sigma: sp.halfnorm.logcdf(value, scale=sigma),
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_chi_squared(self):
self.check_logp(
ChiSquared,
ChiSquare,
Rplus,
{"nu": Rplusdunif},
lambda value, nu: sp.chi2.logpdf(value, df=nu),
Expand Down
13 changes: 10 additions & 3 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,17 @@ class TestAsymmetricLaplace(BaseTestCases.BaseTestCase):
params = {"kappa": 1.0, "b": 1.0, "mu": 0.0}


@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
class TestChiSquared(BaseTestCases.BaseTestCase):
distribution = pm.ChiSquared
class TestChiSquare(BaseTestCases.BaseTestCase):
distribution = pm.ChiSquare
params = {"nu": 2.0}
expected_rv_op_params = {"nu": 2.0}
reference_dist_params = {"df": 2.0}
reference_dist = seeded_numpy_distribution_builder("chisquare")
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
"check_rv_size",
]


@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
Expand Down