From 6260c84738349c2c56d99bbc54815fecf3343252 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Mon, 12 Sep 2022 17:34:19 +0300 Subject: [PATCH 01/50] Use None as default value for zerosum_axes --- pymc/distributions/__init__.py | 4 +- pymc/distributions/continuous.py | 187 ++++++++++++++++++++++++++++++- pymc/distributions/transforms.py | 63 +++++++++++ 3 files changed, 250 insertions(+), 4 deletions(-) diff --git a/pymc/distributions/__init__.py b/pymc/distributions/__init__.py index 3240bde379..9467831c34 100644 --- a/pymc/distributions/__init__.py +++ b/pymc/distributions/__init__.py @@ -56,6 +56,7 @@ VonMises, Wald, Weibull, + ZeroSumNormal, ) from pymc.distributions.discrete import ( Bernoulli, @@ -115,8 +116,9 @@ "Uniform", "Flat", "HalfFlat", - "TruncatedNormal", "Normal", + "TruncatedNormal", + "ZeroSumNormal", "Beta", "Kumaraswamy", "Exponential", diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 06d2295450..0b348e46a7 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -69,10 +69,13 @@ def polyagamma_cdf(*args, **kwargs): raise RuntimeError("polyagamma package is not installed!") +from numpy.core.numeric import normalize_axis_tuple from scipy import stats from scipy.interpolate import InterpolatedUnivariateSpline from scipy.special import expit +import pymc as pm + from pymc.aesaraf import floatX from pymc.distributions import transforms from pymc.distributions.dist_math import ( @@ -86,9 +89,20 @@ def polyagamma_cdf(*args, **kwargs): normal_lcdf, zvalue, ) -from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous -from pymc.distributions.shape_utils import rv_size_is_none -from pymc.distributions.transforms import _default_transform +from pymc.distributions.distribution import ( + DIST_PARAMETER_TYPES, + Continuous, + Distribution, + SymbolicRandomVariable, + _moment, +) +from pymc.distributions.logprob import ignore_logprob +from pymc.distributions.shape_utils import ( + _change_dist_size, + convert_dims, + rv_size_is_none, +) +from pymc.distributions.transforms import ZeroSumTransform, _default_transform from pymc.math import invlogit, logdiffexp, logit __all__ = [ @@ -96,6 +110,7 @@ def polyagamma_cdf(*args, **kwargs): "Flat", "HalfFlat", "Normal", + "ZeroSumNormal", "TruncatedNormal", "Beta", "Kumaraswamy", @@ -585,6 +600,172 @@ def logcdf(value, mu, sigma): ) +class ZeroSumNormalRV(SymbolicRandomVariable): + """ZeroSumNormal random variable""" + + _print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}") + zerosum_axes = None + + def __init__(self, *args, zerosum_axes, **kwargs): + self.zerosum_axes = zerosum_axes + super().__init__(*args, **kwargs) + + +class ZeroSumNormal(Distribution): + r""" + ZeroSumNormal distribution, i.e Normal distribution where one or + several axes are constrained to sum to zero. + By default, the last axis is constrained to sum to zero. + See `zerosum_axes` kwarg for more details. + + Parameters + ---------- + sigma : tensor_like of float + Standard deviation (sigma > 0). + Defaults to 1 if not specified. + For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint. + zerosum_axes: list or tuple of strings or integers + Axis (or axes) along which the zero-sum constraint is enforced. + Defaults to [-1], i.e the last axis. + If strings are passed, then ``dims`` is needed. + Otherwise, ``shape`` and ``size`` work as they do for other PyMC distributions. + dims: list or tuple of strings, optional + The dimension names of the axes. + Necessary when ``zerosum_axes`` is specified with strings. + + Warnings + -------- + ``sigma`` has to be a scalar, to ensure the zero-sum constraint. + The ability to specifiy a vector of ``sigma`` may be added in future versions. + + Examples + -------- + .. code-block:: python + COORDS = { + "regions": ["a", "b", "c"], + "answers": ["yes", "no", "whatever", "don't understand question"], + } + with pm.Model(coords=COORDS) as m: + ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes="answers") + + with pm.Model(coords=COORDS) as m: + ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=("regions", "answers")) + + with pm.Model(coords=COORDS) as m: + ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=1) + """ + rv_type = ZeroSumNormalRV + + def __new__(cls, *args, zerosum_axes=None, dims=None, **kwargs): + dims = convert_dims(dims) + if zerosum_axes is None: + zerosum_axes = [-1] + if not isinstance(zerosum_axes, (list, tuple)): + zerosum_axes = [zerosum_axes] + + if isinstance(zerosum_axes[0], str): + if not dims: + raise ValueError("You need to specify dims if zerosum_axes are strings.") + else: + zerosum_axes_ = [] + for axis in zerosum_axes: + zerosum_axes_.append(dims.index(axis)) + zerosum_axes = zerosum_axes_ + + return super().__new__(cls, *args, zerosum_axes=zerosum_axes, dims=dims, **kwargs) + + @classmethod + def dist(cls, sigma=1, zerosum_axes=None, **kwargs): + if zerosum_axes is None: + zerosum_axes = [-1] + + sigma = at.as_tensor_variable(floatX(sigma)) + if sigma.ndim > 0: + raise ValueError("sigma has to be a scalar") + + return super().dist([sigma], zerosum_axes=zerosum_axes, **kwargs) + + # TODO: This is if we want ZeroSum constraint on other dists than Normal + # def dist(cls, dist, lower, upper, **kwargs): + # if not isinstance(dist, TensorVariable) or not isinstance( + # dist.owner.op, (RandomVariable, SymbolicRandomVariable) + # ): + # raise ValueError( + # f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}" + # ) + # if dist.owner.op.ndim_supp > 0: + # raise NotImplementedError( + # "Censoring of multivariate distributions has not been implemented yet" + # ) + # check_dist_not_registered(dist) + # return super().dist([dist, lower, upper], **kwargs) + + @classmethod + def rv_op(cls, sigma, zerosum_axes, size=None): + if size is None: + zerosum_axes_ = np.asarray(zerosum_axes) + # just a placeholder size to infer minimum shape + size = np.ones( + max((max(np.abs(zerosum_axes_) - 1), max(zerosum_axes_))) + 1, dtype=int + ).tolist() + + # check if zerosum_axes is valid + normalize_axis_tuple(zerosum_axes, len(size)) + + normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, size=size)) + normal_dist_, sigma_ = normal_dist.type(), sigma.type() + + # Zerosum-normaling is achieved by substracting the mean along the given zerosum_axes + zerosum_rv_ = normal_dist_ + for axis in zerosum_axes: + zerosum_rv_ -= zerosum_rv_.mean(axis=axis, keepdims=True) + + return ZeroSumNormalRV( + inputs=[normal_dist_, sigma_], + outputs=[zerosum_rv_], + zerosum_axes=zerosum_axes, + ndim_supp=0, + )(normal_dist, sigma) + + +@_logprob.register(ZeroSumNormalRV) +def zerosumnormal_logp(op, values, normal_dist, sigma, **kwargs): + (value,) = values + shape = value.shape + _deg_free_shape = at.inc_subtensor(shape[at.as_tensor_variable(op.zerosum_axes)], -1) + _full_size = at.prod(shape) + _degrees_of_freedom = at.prod(_deg_free_shape) + zerosums = [ + at.all(at.isclose(at.mean(value, axis=axis), 0, atol=1e-9)) for axis in op.zerosum_axes + ] + # out = at.sum( + # pm.logp(dist, value) * _degrees_of_freedom / _full_size, + # axis=op.zerosum_axes, + # ) + # figure out how dimensionality should be handled for logp + # for now, we assume ZSN is a scalar distribut, which is not correct + out = pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size + return check_parameters(out, *zerosums, msg="at.mean(value, axis=zerosum_axes) == 0") + + +@_moment.register(ZeroSumNormalRV) +def zerosumnormal_moment(op, rv, *rv_inputs): + return at.zeros_like(rv) + + +@_change_dist_size.register(ZeroSumNormalRV) +def change_zerosum_size(op, normal_dist, new_size, expand=False): + normal_dist, sigma = normal_dist.owner.inputs + if expand: + new_size = tuple(new_size) + tuple(normal_dist.shape) + return ZeroSumNormal.rv_op(sigma=sigma, zerosum_axes=op.zerosum_axes, size=new_size) + + +@_default_transform.register(ZeroSumNormalRV) +def zerosum_default_transform(op, rv): + return ZeroSumTransform(op.zerosum_axes) + + class TruncatedNormalRV(RandomVariable): name = "truncated_normal" ndim_supp = 0 diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 58082c0c4d..a4d71b8e49 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -39,6 +39,7 @@ "circular", "CholeskyCovPacked", "Chain", + "ZeroSumTransform", ] @@ -266,6 +267,68 @@ def bounds_fn(*rv_inputs): super().__init__(args_fn=bounds_fn) +class ZeroSumTransform(RVTransform): + """ + Constrains the samples of a Normal distribution to sum to zero + along the user-provided ``zerosum_axes``. + By default (``zerosum_axes=[-1]``), the sum-to-zero constraint is imposed + on the last axis. + """ + + name = "zerosum" + + __props__ = ("zerosum_axes",) + + def __init__(self, zerosum_axes): + """ + Parameters + ---------- + zerosum_axes : list of ints + Must be a list of integers (positive or negative). + By default (``zerosum_axes=[-1]``), the sum-to-zero constraint is imposed + on the last axis. + """ + self.zerosum_axes = zerosum_axes + + def forward(self, value, *rv_inputs): + for axis in self.zerosum_axes: + value = extend_axis_rev(value, axis=axis) + return value + + def backward(self, value, *rv_inputs): + for axis in self.zerosum_axes: + value = extend_axis(value, axis=axis) + return value + + def log_jac_det(self, value, *rv_inputs): + return at.constant(0.0) + + +def extend_axis(array, axis): + n = array.shape[axis] + 1 + sum_vals = array.sum(axis, keepdims=True) + norm = sum_vals / (np.sqrt(n) + n) + fill_val = norm - sum_vals / np.sqrt(n) + + out = at.concatenate([array, fill_val], axis=axis) + return out - norm + + +def extend_axis_rev(array, axis): + if axis < 0: + axis = axis % array.ndim + assert axis >= 0 and axis < array.ndim + + n = array.shape[axis] + last = at.take(array, [-1], axis=axis) + + sum_vals = -last * np.sqrt(n) + norm = sum_vals / (np.sqrt(n) + n) + slice_before = (slice(None, None),) * axis + + return array[slice_before + (slice(None, -1),)] + norm + + log_exp_m1 = LogExpM1() log_exp_m1.__doc__ = """ Instantiation of :class:`pymc.distributions.transforms.LogExpM1` From af960162cf6dc6b6e89b68944ed2ea42abfb740b Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Mon, 12 Sep 2022 18:32:49 +0300 Subject: [PATCH 02/50] Add tests for ZSN --- pymc/tests/distributions/test_continuous.py | 40 +++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/pymc/tests/distributions/test_continuous.py b/pymc/tests/distributions/test_continuous.py index a594804978..14c94dec2d 100644 --- a/pymc/tests/distributions/test_continuous.py +++ b/pymc/tests/distributions/test_continuous.py @@ -1802,6 +1802,46 @@ class TestTruncatedNormalUpperArray(BaseTestDistributionRandom): ] +class TestZeroSumNormal(BaseTestDistributionRandom): + COORDS = { + "regions": ["a", "b", "c"], + "answers": ["yes", "no", "whatever", "don't understand question"], + } + with pm.Model(coords=COORDS) as m: + v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes="answers") + s = pm.sample(10) + + assert np.isclose( + s.posterior.v.mean(dim="answers"), 0 + ).all(), "A zerosum_axis is not summing to 0 across all axes." + assert not np.isclose( + s.posterior.v.mean(dim="regions"), 0 + ).all(), "A non zerosum_axis is nonetheless summing to 0 across all samples." + assert s.posterior.v.shape == (4, 10, 3, 4) + + with pm.Model(coords=COORDS) as m: + v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=("regions", "answers")) + s = pm.sample(10) + + assert np.isclose( + s.posterior.v.mean(dim="answers"), 0 + ).all(), "A zerosum_axis is not summing to 0 across all axes." + assert np.isclose( + s.posterior.v.mean(dim="regions"), 0 + ).all(), "A zerosum_axis is not summing to 0 across all axes." + + with pm.Model(coords=COORDS) as m: + v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=1) + s = pm.sample(10) + + assert np.isclose( + s.posterior.v.mean(dim="answers"), 0 + ).all(), "A zerosum_axis is not summing to 0 across all axes." + assert not np.isclose( + s.posterior.v.mean(dim="regions"), 0 + ).all(), "A non zerosum_axis is nonetheless summing to 0 across all samples." + + class TestWald(BaseTestDistributionRandom): pymc_dist = pm.Wald mu, lam, alpha = 1.0, 1.0, 0.0 From 71e5651bc931461977f0813dc44ac0649abde9dc Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Mon, 12 Sep 2022 20:51:10 +0300 Subject: [PATCH 03/50] Reorder dispatched functions --- pymc/distributions/continuous.py | 36 ++++++++++++++++---------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 0b348e46a7..b3479086d6 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -728,6 +728,24 @@ def rv_op(cls, sigma, zerosum_axes, size=None): )(normal_dist, sigma) +@_change_dist_size.register(ZeroSumNormalRV) +def change_zerosum_size(op, normal_dist, new_size, expand=False): + normal_dist, sigma = normal_dist.owner.inputs + if expand: + new_size = tuple(new_size) + tuple(normal_dist.shape) + return ZeroSumNormal.rv_op(sigma=sigma, zerosum_axes=op.zerosum_axes, size=new_size) + + +@_moment.register(ZeroSumNormalRV) +def zerosumnormal_moment(op, rv, *rv_inputs): + return at.zeros_like(rv) + + +@_default_transform.register(ZeroSumNormalRV) +def zerosum_default_transform(op, rv): + return ZeroSumTransform(op.zerosum_axes) + + @_logprob.register(ZeroSumNormalRV) def zerosumnormal_logp(op, values, normal_dist, sigma, **kwargs): (value,) = values @@ -748,24 +766,6 @@ def zerosumnormal_logp(op, values, normal_dist, sigma, **kwargs): return check_parameters(out, *zerosums, msg="at.mean(value, axis=zerosum_axes) == 0") -@_moment.register(ZeroSumNormalRV) -def zerosumnormal_moment(op, rv, *rv_inputs): - return at.zeros_like(rv) - - -@_change_dist_size.register(ZeroSumNormalRV) -def change_zerosum_size(op, normal_dist, new_size, expand=False): - normal_dist, sigma = normal_dist.owner.inputs - if expand: - new_size = tuple(new_size) + tuple(normal_dist.shape) - return ZeroSumNormal.rv_op(sigma=sigma, zerosum_axes=op.zerosum_axes, size=new_size) - - -@_default_transform.register(ZeroSumNormalRV) -def zerosum_default_transform(op, rv): - return ZeroSumTransform(op.zerosum_axes) - - class TruncatedNormalRV(RandomVariable): name = "truncated_normal" ndim_supp = 0 From 3cadb268a3dec73d0734c79bb9776d6521b0bebb Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Mon, 12 Sep 2022 21:13:04 +0300 Subject: [PATCH 04/50] Test pylint --- pymc/distributions/transforms.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index a4d71b8e49..15f1307008 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -319,6 +319,8 @@ def extend_axis_rev(array, axis): axis = axis % array.ndim assert axis >= 0 and axis < array.ndim + # normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] + n = array.shape[axis] last = at.take(array, [-1], axis=axis) From a66c58603d39304de5cb99f0651cf8348afa1ae7 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Mon, 12 Sep 2022 21:24:11 +0300 Subject: [PATCH 05/50] Ignore type check on normalize_axis_tuple --- pymc/distributions/transforms.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 15f1307008..d03845ce54 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -314,19 +314,22 @@ def extend_axis(array, axis): return out - norm +from numpy.core.numeric import normalize_axis_tuple # type: ignore + + def extend_axis_rev(array, axis): - if axis < 0: - axis = axis % array.ndim - assert axis >= 0 and axis < array.ndim + # if axis < 0: + # axis = axis % array.ndim + # assert axis >= 0 and axis < array.ndim - # normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] + normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] - n = array.shape[axis] - last = at.take(array, [-1], axis=axis) + n = array.shape[normalized_axis] + last = at.take(array, [-1], axis=normalized_axis) sum_vals = -last * np.sqrt(n) norm = sum_vals / (np.sqrt(n) + n) - slice_before = (slice(None, None),) * axis + slice_before = (slice(None, None),) * normalized_axis return array[slice_before + (slice(None, -1),)] + norm From e3be4956b5271713faf03c59269c92a5a317dce0 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Mon, 12 Sep 2022 21:28:52 +0300 Subject: [PATCH 06/50] Disable mypy on import of normalize_axis_tuple --- pymc/distributions/transforms.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index d03845ce54..72b0300c1e 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -27,6 +27,10 @@ from aesara.graph import Op from aesara.tensor import TensorVariable +# ignore mypy error because it somehow considers that +# "numpy.core.numeric has no attribute normalize_axis_tuple" +from numpy.core.numeric import normalize_axis_tuple # type: ignore + __all__ = [ "RVTransform", "simplex", @@ -314,14 +318,7 @@ def extend_axis(array, axis): return out - norm -from numpy.core.numeric import normalize_axis_tuple # type: ignore - - def extend_axis_rev(array, axis): - # if axis < 0: - # axis = axis % array.ndim - # assert axis >= 0 and axis < array.ndim - normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] n = array.shape[normalized_axis] From 759de36288da8d68ac75134144f2af2bdf22097d Mon Sep 17 00:00:00 2001 From: Alexandre Andorra Date: Mon, 12 Sep 2022 21:36:01 +0300 Subject: [PATCH 07/50] Remove base class in tests Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/tests/distributions/test_continuous.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/tests/distributions/test_continuous.py b/pymc/tests/distributions/test_continuous.py index 14c94dec2d..60051b3cd8 100644 --- a/pymc/tests/distributions/test_continuous.py +++ b/pymc/tests/distributions/test_continuous.py @@ -1802,7 +1802,7 @@ class TestTruncatedNormalUpperArray(BaseTestDistributionRandom): ] -class TestZeroSumNormal(BaseTestDistributionRandom): +class TestZeroSumNormal: COORDS = { "regions": ["a", "b", "c"], "answers": ["yes", "no", "whatever", "don't understand question"], From a5a1e4595562e06bc7db9dcf5c982b4f6dd9922e Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Thu, 15 Sep 2022 15:28:18 +0300 Subject: [PATCH 08/50] Use pytest parametrize --- pymc/tests/distributions/test_continuous.py | 107 +++++++++++++------- 1 file changed, 70 insertions(+), 37 deletions(-) diff --git a/pymc/tests/distributions/test_continuous.py b/pymc/tests/distributions/test_continuous.py index 60051b3cd8..515f018c90 100644 --- a/pymc/tests/distributions/test_continuous.py +++ b/pymc/tests/distributions/test_continuous.py @@ -24,6 +24,7 @@ from aeppl.logprob import ParameterValueError from aesara.compile.mode import Mode +from numpy import AxisError import pymc as pm @@ -1802,44 +1803,76 @@ class TestTruncatedNormalUpperArray(BaseTestDistributionRandom): ] +COORDS = { + "regions": ["a", "b", "c"], + "answers": ["yes", "no", "whatever", "don't understand question"], +} + + class TestZeroSumNormal: - COORDS = { - "regions": ["a", "b", "c"], - "answers": ["yes", "no", "whatever", "don't understand question"], - } - with pm.Model(coords=COORDS) as m: - v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes="answers") - s = pm.sample(10) - - assert np.isclose( - s.posterior.v.mean(dim="answers"), 0 - ).all(), "A zerosum_axis is not summing to 0 across all axes." - assert not np.isclose( - s.posterior.v.mean(dim="regions"), 0 - ).all(), "A non zerosum_axis is nonetheless summing to 0 across all samples." - assert s.posterior.v.shape == (4, 10, 3, 4) - - with pm.Model(coords=COORDS) as m: - v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=("regions", "answers")) - s = pm.sample(10) - - assert np.isclose( - s.posterior.v.mean(dim="answers"), 0 - ).all(), "A zerosum_axis is not summing to 0 across all axes." - assert np.isclose( - s.posterior.v.mean(dim="regions"), 0 - ).all(), "A zerosum_axis is not summing to 0 across all axes." - - with pm.Model(coords=COORDS) as m: - v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=1) - s = pm.sample(10) - - assert np.isclose( - s.posterior.v.mean(dim="answers"), 0 - ).all(), "A zerosum_axis is not summing to 0 across all axes." - assert not np.isclose( - s.posterior.v.mean(dim="regions"), 0 - ).all(), "A non zerosum_axis is nonetheless summing to 0 across all samples." + @pytest.mark.parametrize( + "dims,zerosum_axes,shape", + [ + (("regions", "answers"), "answers", None), + (("regions", "answers"), ("regions", "answers"), None), + (("regions", "answers"), 0, None), + (("regions", "answers"), -1, None), + (("regions", "answers"), (0, 1), None), + (None, -2, (len(COORDS["regions"]), len(COORDS["answers"]))), + ], + ) + def test_zsn_dims_shape(self, dims, zerosum_axes, shape): + with pm.Model(coords=COORDS) as m: + _ = pm.ZeroSumNormal("v", dims=dims, shape=shape, zerosum_axes=zerosum_axes) + s = pm.sample(10, chains=1, tune=100) + + assert s.posterior.v.shape == (1, 10, len(COORDS["regions"]), len(COORDS["answers"])) + + if not isinstance(zerosum_axes, (list, tuple)): + zerosum_axes = [zerosum_axes] + + if isinstance(zerosum_axes[0], str): + for ax in zerosum_axes: + assert np.isclose( + s.posterior.v.mean(dim=ax), 0 + ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." + + nonzero_ax = list(set(dims).difference(zerosum_axes)) + if nonzero_ax: + assert not np.isclose( + s.posterior.v.mean(dim=nonzero_ax), 0 + ).all(), f"{nonzero_ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." + + else: + for ax in zerosum_axes: + if ax < 0: + assert np.isclose( + s.posterior.v.mean(axis=ax), 0 + ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." + else: + ax = ax + 2 # because 'chain' and 'draw' are added as new axes after sampling + assert np.isclose( + s.posterior.v.mean(axis=ax), 0 + ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." + + @pytest.mark.parametrize( + "dims,zerosum_axes", + [ + (("regions", "answers"), 2), + (("regions", "answers"), (0, -2)), + ], + ) + def test_zsn_fail_axis(self, dims, zerosum_axes): + if isinstance(zerosum_axes, (list, tuple)): + with pytest.raises(ValueError, match="repeated axis"): + with pm.Model(coords=COORDS) as m: + _ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes) + s = pm.sample(10, chains=1, tune=100) + else: + with pytest.raises(AxisError, match="out of bounds"): + with pm.Model(coords=COORDS) as m: + _ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes) + s = pm.sample(10, chains=1, tune=100) class TestWald(BaseTestDistributionRandom): From c9eea6e4acc9c0535f82a4f7738ea60739f1f782 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Thu, 15 Sep 2022 15:52:21 +0300 Subject: [PATCH 09/50] Add pm.draw to tests --- pymc/tests/distributions/test_continuous.py | 35 +++++++++++++++------ 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/pymc/tests/distributions/test_continuous.py b/pymc/tests/distributions/test_continuous.py index 515f018c90..7b90988da2 100644 --- a/pymc/tests/distributions/test_continuous.py +++ b/pymc/tests/distributions/test_continuous.py @@ -1823,9 +1823,15 @@ class TestZeroSumNormal: ) def test_zsn_dims_shape(self, dims, zerosum_axes, shape): with pm.Model(coords=COORDS) as m: - _ = pm.ZeroSumNormal("v", dims=dims, shape=shape, zerosum_axes=zerosum_axes) + v = pm.ZeroSumNormal("v", dims=dims, shape=shape, zerosum_axes=zerosum_axes) s = pm.sample(10, chains=1, tune=100) + # to test forward graph + random_samples = pm.draw( + v, + draws=10, + ) + assert s.posterior.v.shape == (1, 10, len(COORDS["regions"]), len(COORDS["answers"])) if not isinstance(zerosum_axes, (list, tuple)): @@ -1833,15 +1839,24 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape): if isinstance(zerosum_axes[0], str): for ax in zerosum_axes: - assert np.isclose( - s.posterior.v.mean(dim=ax), 0 - ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." - - nonzero_ax = list(set(dims).difference(zerosum_axes)) - if nonzero_ax: - assert not np.isclose( - s.posterior.v.mean(dim=nonzero_ax), 0 - ).all(), f"{nonzero_ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." + for samples in [ + s.posterior.v.mean(dim=ax), + random_samples.mean(axis=dims.index(ax) + 1), + ]: + assert np.isclose( + samples, 0 + ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." + + nonzero_axes = list(set(dims).difference(zerosum_axes)) + if nonzero_axes: + for ax in nonzero_axes: + for samples in [ + s.posterior.v.mean(dim=ax), + random_samples.mean(axis=dims.index(ax) + 1), + ]: + assert not np.isclose( + samples, 0 + ).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." else: for ax in zerosum_axes: From 0582d7c8c890b99a441b5f8ee5b2e761db7f1130 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Thu, 15 Sep 2022 16:07:05 +0300 Subject: [PATCH 10/50] Test moment --- pymc/tests/distributions/test_continuous.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/pymc/tests/distributions/test_continuous.py b/pymc/tests/distributions/test_continuous.py index 7b90988da2..f1a4f71cc2 100644 --- a/pymc/tests/distributions/test_continuous.py +++ b/pymc/tests/distributions/test_continuous.py @@ -964,6 +964,19 @@ def test_normal_moment(self, mu, sigma, size, expected): pm.Normal("x", mu=mu, sigma=sigma, size=size) assert_moment_is_expected(model, expected) + @pytest.mark.parametrize( + "shape, zerosum_axes, expected", + [ + ((2, 5), None, np.zeros((2, 5))), + ((2, 5, 6), None, np.zeros((2, 5, 6))), + ((2, 5, 6), (0, 1), np.zeros((2, 5, 6))), + ], + ) + def test_zerosum_normal_moment(self, shape, zerosum_axes, expected): + with pm.Model() as model: + pm.ZeroSumNormal("x", shape=shape, zerosum_axes=zerosum_axes) + assert_moment_is_expected(model, expected) + @pytest.mark.parametrize( "sigma, size, expected", [ @@ -1811,7 +1824,7 @@ class TestTruncatedNormalUpperArray(BaseTestDistributionRandom): class TestZeroSumNormal: @pytest.mark.parametrize( - "dims,zerosum_axes,shape", + "dims, zerosum_axes, shape", [ (("regions", "answers"), "answers", None), (("regions", "answers"), ("regions", "answers"), None), @@ -1871,7 +1884,7 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape): ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." @pytest.mark.parametrize( - "dims,zerosum_axes", + "dims, zerosum_axes", [ (("regions", "answers"), 2), (("regions", "answers"), (0, -2)), From 0bdcdd7d43f5a31525cd624cd7dae6c6f8892816 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Thu, 15 Sep 2022 16:20:57 +0300 Subject: [PATCH 11/50] Add change size test --- pymc/tests/distributions/test_continuous.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pymc/tests/distributions/test_continuous.py b/pymc/tests/distributions/test_continuous.py index f1a4f71cc2..028ac07771 100644 --- a/pymc/tests/distributions/test_continuous.py +++ b/pymc/tests/distributions/test_continuous.py @@ -32,6 +32,7 @@ from pymc.distributions import logcdf, logp from pymc.distributions.continuous import get_tau_sigma, interpolated from pymc.distributions.dist_math import clipped_beta_rvs +from pymc.distributions.shape_utils import change_dist_size from pymc.tests.distributions.util import ( BaseTestDistributionRandom, Circ, @@ -1895,12 +1896,19 @@ def test_zsn_fail_axis(self, dims, zerosum_axes): with pytest.raises(ValueError, match="repeated axis"): with pm.Model(coords=COORDS) as m: _ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes) - s = pm.sample(10, chains=1, tune=100) else: with pytest.raises(AxisError, match="out of bounds"): with pm.Model(coords=COORDS) as m: _ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes) - s = pm.sample(10, chains=1, tune=100) + + def test_zsn_change_dist_size(self): + base_dist = pm.ZeroSumNormal.dist(shape=(4, 9)) + + new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=False) + assert new_dist.eval().shape == (5, 3) + + new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=True) + assert new_dist.eval().shape == (5, 3, 4, 9) class TestWald(BaseTestDistributionRandom): From 854ef4cd4e6fd092554d594c4b597b5d18b9b02f Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Thu, 15 Sep 2022 16:36:14 +0300 Subject: [PATCH 12/50] Move ZSN to multivariate.py --- pymc/distributions/__init__.py | 4 +- pymc/distributions/continuous.py | 187 +---------------------------- pymc/distributions/multivariate.py | 171 +++++++++++++++++++++++++- 3 files changed, 175 insertions(+), 187 deletions(-) diff --git a/pymc/distributions/__init__.py b/pymc/distributions/__init__.py index 9467831c34..753d33a651 100644 --- a/pymc/distributions/__init__.py +++ b/pymc/distributions/__init__.py @@ -56,7 +56,6 @@ VonMises, Wald, Weibull, - ZeroSumNormal, ) from pymc.distributions.discrete import ( Bernoulli, @@ -100,6 +99,7 @@ StickBreakingWeights, Wishart, WishartBartlett, + ZeroSumNormal, ) from pymc.distributions.simulator import Simulator from pymc.distributions.timeseries import ( @@ -118,7 +118,6 @@ "HalfFlat", "Normal", "TruncatedNormal", - "ZeroSumNormal", "Beta", "Kumaraswamy", "Exponential", @@ -161,6 +160,7 @@ "Continuous", "Discrete", "MvNormal", + "ZeroSumNormal", "MatrixNormal", "KroneckerNormal", "MvStudentT", diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index b3479086d6..06d2295450 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -69,13 +69,10 @@ def polyagamma_cdf(*args, **kwargs): raise RuntimeError("polyagamma package is not installed!") -from numpy.core.numeric import normalize_axis_tuple from scipy import stats from scipy.interpolate import InterpolatedUnivariateSpline from scipy.special import expit -import pymc as pm - from pymc.aesaraf import floatX from pymc.distributions import transforms from pymc.distributions.dist_math import ( @@ -89,20 +86,9 @@ def polyagamma_cdf(*args, **kwargs): normal_lcdf, zvalue, ) -from pymc.distributions.distribution import ( - DIST_PARAMETER_TYPES, - Continuous, - Distribution, - SymbolicRandomVariable, - _moment, -) -from pymc.distributions.logprob import ignore_logprob -from pymc.distributions.shape_utils import ( - _change_dist_size, - convert_dims, - rv_size_is_none, -) -from pymc.distributions.transforms import ZeroSumTransform, _default_transform +from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous +from pymc.distributions.shape_utils import rv_size_is_none +from pymc.distributions.transforms import _default_transform from pymc.math import invlogit, logdiffexp, logit __all__ = [ @@ -110,7 +96,6 @@ def polyagamma_cdf(*args, **kwargs): "Flat", "HalfFlat", "Normal", - "ZeroSumNormal", "TruncatedNormal", "Beta", "Kumaraswamy", @@ -600,172 +585,6 @@ def logcdf(value, mu, sigma): ) -class ZeroSumNormalRV(SymbolicRandomVariable): - """ZeroSumNormal random variable""" - - _print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}") - zerosum_axes = None - - def __init__(self, *args, zerosum_axes, **kwargs): - self.zerosum_axes = zerosum_axes - super().__init__(*args, **kwargs) - - -class ZeroSumNormal(Distribution): - r""" - ZeroSumNormal distribution, i.e Normal distribution where one or - several axes are constrained to sum to zero. - By default, the last axis is constrained to sum to zero. - See `zerosum_axes` kwarg for more details. - - Parameters - ---------- - sigma : tensor_like of float - Standard deviation (sigma > 0). - Defaults to 1 if not specified. - For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint. - zerosum_axes: list or tuple of strings or integers - Axis (or axes) along which the zero-sum constraint is enforced. - Defaults to [-1], i.e the last axis. - If strings are passed, then ``dims`` is needed. - Otherwise, ``shape`` and ``size`` work as they do for other PyMC distributions. - dims: list or tuple of strings, optional - The dimension names of the axes. - Necessary when ``zerosum_axes`` is specified with strings. - - Warnings - -------- - ``sigma`` has to be a scalar, to ensure the zero-sum constraint. - The ability to specifiy a vector of ``sigma`` may be added in future versions. - - Examples - -------- - .. code-block:: python - COORDS = { - "regions": ["a", "b", "c"], - "answers": ["yes", "no", "whatever", "don't understand question"], - } - with pm.Model(coords=COORDS) as m: - ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes="answers") - - with pm.Model(coords=COORDS) as m: - ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=("regions", "answers")) - - with pm.Model(coords=COORDS) as m: - ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=1) - """ - rv_type = ZeroSumNormalRV - - def __new__(cls, *args, zerosum_axes=None, dims=None, **kwargs): - dims = convert_dims(dims) - if zerosum_axes is None: - zerosum_axes = [-1] - if not isinstance(zerosum_axes, (list, tuple)): - zerosum_axes = [zerosum_axes] - - if isinstance(zerosum_axes[0], str): - if not dims: - raise ValueError("You need to specify dims if zerosum_axes are strings.") - else: - zerosum_axes_ = [] - for axis in zerosum_axes: - zerosum_axes_.append(dims.index(axis)) - zerosum_axes = zerosum_axes_ - - return super().__new__(cls, *args, zerosum_axes=zerosum_axes, dims=dims, **kwargs) - - @classmethod - def dist(cls, sigma=1, zerosum_axes=None, **kwargs): - if zerosum_axes is None: - zerosum_axes = [-1] - - sigma = at.as_tensor_variable(floatX(sigma)) - if sigma.ndim > 0: - raise ValueError("sigma has to be a scalar") - - return super().dist([sigma], zerosum_axes=zerosum_axes, **kwargs) - - # TODO: This is if we want ZeroSum constraint on other dists than Normal - # def dist(cls, dist, lower, upper, **kwargs): - # if not isinstance(dist, TensorVariable) or not isinstance( - # dist.owner.op, (RandomVariable, SymbolicRandomVariable) - # ): - # raise ValueError( - # f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}" - # ) - # if dist.owner.op.ndim_supp > 0: - # raise NotImplementedError( - # "Censoring of multivariate distributions has not been implemented yet" - # ) - # check_dist_not_registered(dist) - # return super().dist([dist, lower, upper], **kwargs) - - @classmethod - def rv_op(cls, sigma, zerosum_axes, size=None): - if size is None: - zerosum_axes_ = np.asarray(zerosum_axes) - # just a placeholder size to infer minimum shape - size = np.ones( - max((max(np.abs(zerosum_axes_) - 1), max(zerosum_axes_))) + 1, dtype=int - ).tolist() - - # check if zerosum_axes is valid - normalize_axis_tuple(zerosum_axes, len(size)) - - normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, size=size)) - normal_dist_, sigma_ = normal_dist.type(), sigma.type() - - # Zerosum-normaling is achieved by substracting the mean along the given zerosum_axes - zerosum_rv_ = normal_dist_ - for axis in zerosum_axes: - zerosum_rv_ -= zerosum_rv_.mean(axis=axis, keepdims=True) - - return ZeroSumNormalRV( - inputs=[normal_dist_, sigma_], - outputs=[zerosum_rv_], - zerosum_axes=zerosum_axes, - ndim_supp=0, - )(normal_dist, sigma) - - -@_change_dist_size.register(ZeroSumNormalRV) -def change_zerosum_size(op, normal_dist, new_size, expand=False): - normal_dist, sigma = normal_dist.owner.inputs - if expand: - new_size = tuple(new_size) + tuple(normal_dist.shape) - return ZeroSumNormal.rv_op(sigma=sigma, zerosum_axes=op.zerosum_axes, size=new_size) - - -@_moment.register(ZeroSumNormalRV) -def zerosumnormal_moment(op, rv, *rv_inputs): - return at.zeros_like(rv) - - -@_default_transform.register(ZeroSumNormalRV) -def zerosum_default_transform(op, rv): - return ZeroSumTransform(op.zerosum_axes) - - -@_logprob.register(ZeroSumNormalRV) -def zerosumnormal_logp(op, values, normal_dist, sigma, **kwargs): - (value,) = values - shape = value.shape - _deg_free_shape = at.inc_subtensor(shape[at.as_tensor_variable(op.zerosum_axes)], -1) - _full_size = at.prod(shape) - _degrees_of_freedom = at.prod(_deg_free_shape) - zerosums = [ - at.all(at.isclose(at.mean(value, axis=axis), 0, atol=1e-9)) for axis in op.zerosum_axes - ] - # out = at.sum( - # pm.logp(dist, value) * _degrees_of_freedom / _full_size, - # axis=op.zerosum_axes, - # ) - # figure out how dimensionality should be handled for logp - # for now, we assume ZSN is a scalar distribut, which is not correct - out = pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size - return check_parameters(out, *zerosums, msg="at.mean(value, axis=zerosum_axes) == 0") - - class TruncatedNormalRV(RandomVariable): name = "truncated_normal" ndim_supp = 0 diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index a2dfb9500a..446b1ea9a7 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -36,6 +36,7 @@ from aesara.tensor.random.utils import broadcast_params from aesara.tensor.slinalg import Cholesky, SolveTriangular from aesara.tensor.type import TensorType +from numpy.core.numeric import normalize_axis_tuple from scipy import linalg, stats import pymc as pm @@ -63,15 +64,17 @@ _change_dist_size, broadcast_dist_samples_to, change_dist_size, + convert_dims, rv_size_is_none, to_tuple, ) -from pymc.distributions.transforms import Interval, _default_transform +from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform from pymc.math import kron_diag, kron_dot from pymc.util import check_dist_not_registered __all__ = [ "MvNormal", + "ZeroSumNormal", "MvStudentT", "Dirichlet", "Multinomial", @@ -2380,3 +2383,169 @@ def logp(value, alpha, K): K > 0, msg="alpha > 0, K > 0", ) + + +class ZeroSumNormalRV(SymbolicRandomVariable): + """ZeroSumNormal random variable""" + + _print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}") + zerosum_axes = None + + def __init__(self, *args, zerosum_axes, **kwargs): + self.zerosum_axes = zerosum_axes + super().__init__(*args, **kwargs) + + +class ZeroSumNormal(Distribution): + r""" + ZeroSumNormal distribution, i.e Normal distribution where one or + several axes are constrained to sum to zero. + By default, the last axis is constrained to sum to zero. + See `zerosum_axes` kwarg for more details. + + Parameters + ---------- + sigma : tensor_like of float + Standard deviation (sigma > 0). + Defaults to 1 if not specified. + For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint. + zerosum_axes: list or tuple of strings or integers + Axis (or axes) along which the zero-sum constraint is enforced. + Defaults to [-1], i.e the last axis. + If strings are passed, then ``dims`` is needed. + Otherwise, ``shape`` and ``size`` work as they do for other PyMC distributions. + dims: list or tuple of strings, optional + The dimension names of the axes. + Necessary when ``zerosum_axes`` is specified with strings. + + Warnings + -------- + ``sigma`` has to be a scalar, to ensure the zero-sum constraint. + The ability to specifiy a vector of ``sigma`` may be added in future versions. + + Examples + -------- + .. code-block:: python + COORDS = { + "regions": ["a", "b", "c"], + "answers": ["yes", "no", "whatever", "don't understand question"], + } + with pm.Model(coords=COORDS) as m: + ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes="answers") + + with pm.Model(coords=COORDS) as m: + ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=("regions", "answers")) + + with pm.Model(coords=COORDS) as m: + ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=1) + """ + rv_type = ZeroSumNormalRV + + def __new__(cls, *args, zerosum_axes=None, dims=None, **kwargs): + dims = convert_dims(dims) + if zerosum_axes is None: + zerosum_axes = [-1] + if not isinstance(zerosum_axes, (list, tuple)): + zerosum_axes = [zerosum_axes] + + if isinstance(zerosum_axes[0], str): + if not dims: + raise ValueError("You need to specify dims if zerosum_axes are strings.") + else: + zerosum_axes_ = [] + for axis in zerosum_axes: + zerosum_axes_.append(dims.index(axis)) + zerosum_axes = zerosum_axes_ + + return super().__new__(cls, *args, zerosum_axes=zerosum_axes, dims=dims, **kwargs) + + @classmethod + def dist(cls, sigma=1, zerosum_axes=None, **kwargs): + if zerosum_axes is None: + zerosum_axes = [-1] + + sigma = at.as_tensor_variable(floatX(sigma)) + if sigma.ndim > 0: + raise ValueError("sigma has to be a scalar") + + return super().dist([sigma], zerosum_axes=zerosum_axes, **kwargs) + + # TODO: This is if we want ZeroSum constraint on other dists than Normal + # def dist(cls, dist, lower, upper, **kwargs): + # if not isinstance(dist, TensorVariable) or not isinstance( + # dist.owner.op, (RandomVariable, SymbolicRandomVariable) + # ): + # raise ValueError( + # f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}" + # ) + # if dist.owner.op.ndim_supp > 0: + # raise NotImplementedError( + # "Censoring of multivariate distributions has not been implemented yet" + # ) + # check_dist_not_registered(dist) + # return super().dist([dist, lower, upper], **kwargs) + + @classmethod + def rv_op(cls, sigma, zerosum_axes, size=None): + if size is None: + zerosum_axes_ = np.asarray(zerosum_axes) + # just a placeholder size to infer minimum shape + size = np.ones( + max((max(np.abs(zerosum_axes_) - 1), max(zerosum_axes_))) + 1, dtype=int + ).tolist() + + # check if zerosum_axes is valid + normalize_axis_tuple(zerosum_axes, len(size)) + + normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, size=size)) + normal_dist_, sigma_ = normal_dist.type(), sigma.type() + + # Zerosum-normaling is achieved by substracting the mean along the given zerosum_axes + zerosum_rv_ = normal_dist_ + for axis in zerosum_axes: + zerosum_rv_ -= zerosum_rv_.mean(axis=axis, keepdims=True) + + return ZeroSumNormalRV( + inputs=[normal_dist_, sigma_], + outputs=[zerosum_rv_], + zerosum_axes=zerosum_axes, + ndim_supp=0, + )(normal_dist, sigma) + + +@_change_dist_size.register(ZeroSumNormalRV) +def change_zerosum_size(op, normal_dist, new_size, expand=False): + normal_dist, sigma = normal_dist.owner.inputs + if expand: + new_size = tuple(new_size) + tuple(normal_dist.shape) + return ZeroSumNormal.rv_op(sigma=sigma, zerosum_axes=op.zerosum_axes, size=new_size) + + +@_moment.register(ZeroSumNormalRV) +def zerosumnormal_moment(op, rv, *rv_inputs): + return at.zeros_like(rv) + + +@_default_transform.register(ZeroSumNormalRV) +def zerosum_default_transform(op, rv): + return ZeroSumTransform(op.zerosum_axes) + + +@_logprob.register(ZeroSumNormalRV) +def zerosumnormal_logp(op, values, normal_dist, sigma, **kwargs): + (value,) = values + shape = value.shape + _deg_free_shape = at.inc_subtensor(shape[at.as_tensor_variable(op.zerosum_axes)], -1) + _full_size = at.prod(shape) + _degrees_of_freedom = at.prod(_deg_free_shape) + zerosums = [ + at.all(at.isclose(at.mean(value, axis=axis), 0, atol=1e-9)) for axis in op.zerosum_axes + ] + # out = at.sum( + # pm.logp(dist, value) * _degrees_of_freedom / _full_size, + # axis=op.zerosum_axes, + # ) + # figure out how dimensionality should be handled for logp + # for now, we assume ZSN is a scalar distribut, which is not correct + out = pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size + return check_parameters(out, *zerosums, msg="at.mean(value, axis=zerosum_axes) == 0") From fd3aefa8a61da3574b3f85c576dc72b6cd018c49 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Thu, 15 Sep 2022 16:44:31 +0300 Subject: [PATCH 13/50] Move ZSN tests to test_multivariate.py --- pymc/tests/distributions/test_continuous.py | 109 ---------------- pymc/tests/distributions/test_multivariate.py | 122 +++++++++++++++++- 2 files changed, 120 insertions(+), 111 deletions(-) diff --git a/pymc/tests/distributions/test_continuous.py b/pymc/tests/distributions/test_continuous.py index 028ac07771..a594804978 100644 --- a/pymc/tests/distributions/test_continuous.py +++ b/pymc/tests/distributions/test_continuous.py @@ -24,7 +24,6 @@ from aeppl.logprob import ParameterValueError from aesara.compile.mode import Mode -from numpy import AxisError import pymc as pm @@ -32,7 +31,6 @@ from pymc.distributions import logcdf, logp from pymc.distributions.continuous import get_tau_sigma, interpolated from pymc.distributions.dist_math import clipped_beta_rvs -from pymc.distributions.shape_utils import change_dist_size from pymc.tests.distributions.util import ( BaseTestDistributionRandom, Circ, @@ -965,19 +963,6 @@ def test_normal_moment(self, mu, sigma, size, expected): pm.Normal("x", mu=mu, sigma=sigma, size=size) assert_moment_is_expected(model, expected) - @pytest.mark.parametrize( - "shape, zerosum_axes, expected", - [ - ((2, 5), None, np.zeros((2, 5))), - ((2, 5, 6), None, np.zeros((2, 5, 6))), - ((2, 5, 6), (0, 1), np.zeros((2, 5, 6))), - ], - ) - def test_zerosum_normal_moment(self, shape, zerosum_axes, expected): - with pm.Model() as model: - pm.ZeroSumNormal("x", shape=shape, zerosum_axes=zerosum_axes) - assert_moment_is_expected(model, expected) - @pytest.mark.parametrize( "sigma, size, expected", [ @@ -1817,100 +1802,6 @@ class TestTruncatedNormalUpperArray(BaseTestDistributionRandom): ] -COORDS = { - "regions": ["a", "b", "c"], - "answers": ["yes", "no", "whatever", "don't understand question"], -} - - -class TestZeroSumNormal: - @pytest.mark.parametrize( - "dims, zerosum_axes, shape", - [ - (("regions", "answers"), "answers", None), - (("regions", "answers"), ("regions", "answers"), None), - (("regions", "answers"), 0, None), - (("regions", "answers"), -1, None), - (("regions", "answers"), (0, 1), None), - (None, -2, (len(COORDS["regions"]), len(COORDS["answers"]))), - ], - ) - def test_zsn_dims_shape(self, dims, zerosum_axes, shape): - with pm.Model(coords=COORDS) as m: - v = pm.ZeroSumNormal("v", dims=dims, shape=shape, zerosum_axes=zerosum_axes) - s = pm.sample(10, chains=1, tune=100) - - # to test forward graph - random_samples = pm.draw( - v, - draws=10, - ) - - assert s.posterior.v.shape == (1, 10, len(COORDS["regions"]), len(COORDS["answers"])) - - if not isinstance(zerosum_axes, (list, tuple)): - zerosum_axes = [zerosum_axes] - - if isinstance(zerosum_axes[0], str): - for ax in zerosum_axes: - for samples in [ - s.posterior.v.mean(dim=ax), - random_samples.mean(axis=dims.index(ax) + 1), - ]: - assert np.isclose( - samples, 0 - ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." - - nonzero_axes = list(set(dims).difference(zerosum_axes)) - if nonzero_axes: - for ax in nonzero_axes: - for samples in [ - s.posterior.v.mean(dim=ax), - random_samples.mean(axis=dims.index(ax) + 1), - ]: - assert not np.isclose( - samples, 0 - ).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." - - else: - for ax in zerosum_axes: - if ax < 0: - assert np.isclose( - s.posterior.v.mean(axis=ax), 0 - ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." - else: - ax = ax + 2 # because 'chain' and 'draw' are added as new axes after sampling - assert np.isclose( - s.posterior.v.mean(axis=ax), 0 - ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." - - @pytest.mark.parametrize( - "dims, zerosum_axes", - [ - (("regions", "answers"), 2), - (("regions", "answers"), (0, -2)), - ], - ) - def test_zsn_fail_axis(self, dims, zerosum_axes): - if isinstance(zerosum_axes, (list, tuple)): - with pytest.raises(ValueError, match="repeated axis"): - with pm.Model(coords=COORDS) as m: - _ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes) - else: - with pytest.raises(AxisError, match="out of bounds"): - with pm.Model(coords=COORDS) as m: - _ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes) - - def test_zsn_change_dist_size(self): - base_dist = pm.ZeroSumNormal.dist(shape=(4, 9)) - - new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=False) - assert new_dist.eval().shape == (5, 3) - - new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=True) - assert new_dist.eval().shape == (5, 3, 4, 9) - - class TestWald(BaseTestDistributionRandom): pymc_dist = pm.Wald mu, lam, alpha = 1.0, 1.0, 0.0 diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index 76052513df..92b14aded7 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -28,6 +28,7 @@ from aeppl.logprob import ParameterValueError from aesara.tensor import TensorVariable from aesara.tensor.random.utils import broadcast_params +from numpy import AxisError import pymc as pm @@ -754,7 +755,12 @@ def test_car_logp(self, sparse, size): # d x d adjacency matrix for a square (d=4) of rook-adjacent sites W = np.array( - [[0.0, 1.0, 1.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0]] + [ + [0.0, 1.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 1.0], + [0.0, 1.0, 1.0, 0.0], + ] ) tau = 2 @@ -1007,6 +1013,19 @@ def test_mv_normal_moment(self, mu, cov, size, expected): # MvNormal logp is only implemented for up to 2D variables assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3) + @pytest.mark.parametrize( + "shape, zerosum_axes, expected", + [ + ((2, 5), None, np.zeros((2, 5))), + ((2, 5, 6), None, np.zeros((2, 5, 6))), + ((2, 5, 6), (0, 1), np.zeros((2, 5, 6))), + ], + ) + def test_zerosum_normal_moment(self, shape, zerosum_axes, expected): + with pm.Model() as model: + pm.ZeroSumNormal("x", shape=shape, zerosum_axes=zerosum_axes) + assert_moment_is_expected(model, expected) + @pytest.mark.parametrize( "mu, size, expected", [ @@ -1026,7 +1045,12 @@ def test_mv_normal_moment(self, mu, cov, size, expected): ) def test_car_moment(self, mu, size, expected): W = np.array( - [[0.0, 1.0, 1.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0]] + [ + [0.0, 1.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 1.0], + [0.0, 1.0, 1.0, 0.0], + ] ) tau = 2 alpha = 0.5 @@ -1367,6 +1391,100 @@ def test_issue_3706(self): assert prior_pred["X"].shape == (1, N, 2) +COORDS = { + "regions": ["a", "b", "c"], + "answers": ["yes", "no", "whatever", "don't understand question"], +} + + +class TestZeroSumNormal: + @pytest.mark.parametrize( + "dims, zerosum_axes, shape", + [ + (("regions", "answers"), "answers", None), + (("regions", "answers"), ("regions", "answers"), None), + (("regions", "answers"), 0, None), + (("regions", "answers"), -1, None), + (("regions", "answers"), (0, 1), None), + (None, -2, (len(COORDS["regions"]), len(COORDS["answers"]))), + ], + ) + def test_zsn_dims_shape(self, dims, zerosum_axes, shape): + with pm.Model(coords=COORDS) as m: + v = pm.ZeroSumNormal("v", dims=dims, shape=shape, zerosum_axes=zerosum_axes) + s = pm.sample(10, chains=1, tune=100) + + # to test forward graph + random_samples = pm.draw( + v, + draws=10, + ) + + assert s.posterior.v.shape == (1, 10, len(COORDS["regions"]), len(COORDS["answers"])) + + if not isinstance(zerosum_axes, (list, tuple)): + zerosum_axes = [zerosum_axes] + + if isinstance(zerosum_axes[0], str): + for ax in zerosum_axes: + for samples in [ + s.posterior.v.mean(dim=ax), + random_samples.mean(axis=dims.index(ax) + 1), + ]: + assert np.isclose( + samples, 0 + ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." + + nonzero_axes = list(set(dims).difference(zerosum_axes)) + if nonzero_axes: + for ax in nonzero_axes: + for samples in [ + s.posterior.v.mean(dim=ax), + random_samples.mean(axis=dims.index(ax) + 1), + ]: + assert not np.isclose( + samples, 0 + ).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." + + else: + for ax in zerosum_axes: + if ax < 0: + assert np.isclose( + s.posterior.v.mean(axis=ax), 0 + ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." + else: + ax = ax + 2 # because 'chain' and 'draw' are added as new axes after sampling + assert np.isclose( + s.posterior.v.mean(axis=ax), 0 + ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." + + @pytest.mark.parametrize( + "dims, zerosum_axes", + [ + (("regions", "answers"), 2), + (("regions", "answers"), (0, -2)), + ], + ) + def test_zsn_fail_axis(self, dims, zerosum_axes): + if isinstance(zerosum_axes, (list, tuple)): + with pytest.raises(ValueError, match="repeated axis"): + with pm.Model(coords=COORDS) as m: + _ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes) + else: + with pytest.raises(AxisError, match="out of bounds"): + with pm.Model(coords=COORDS) as m: + _ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes) + + def test_zsn_change_dist_size(self): + base_dist = pm.ZeroSumNormal.dist(shape=(4, 9)) + + new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=False) + assert new_dist.eval().shape == (5, 3) + + new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=True) + assert new_dist.eval().shape == (5, 3, 4, 9) + + class TestMvStudentTCov(BaseTestDistributionRandom): def mvstudentt_rng_fn(self, size, nu, mu, cov, rng): mv_samples = rng.multivariate_normal(np.zeros_like(mu), cov, size=size) From e94e4f1a1fb35811f29aa74193f75e63d03c6543 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Sat, 17 Sep 2022 13:36:18 +0300 Subject: [PATCH 14/50] Add check if zerosum_axes is iterable in dist method --- pymc/distributions/multivariate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 446b1ea9a7..ea7c881f96 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2463,6 +2463,8 @@ def __new__(cls, *args, zerosum_axes=None, dims=None, **kwargs): def dist(cls, sigma=1, zerosum_axes=None, **kwargs): if zerosum_axes is None: zerosum_axes = [-1] + if not isinstance(zerosum_axes, (list, tuple)): + zerosum_axes = [zerosum_axes] sigma = at.as_tensor_variable(floatX(sigma)) if sigma.ndim > 0: From dec4a9fa8dd8a9e2fff0dcc44bf21d70589bf598 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Sat, 17 Sep 2022 13:38:16 +0300 Subject: [PATCH 15/50] Improve test_zsn_change_dist_size --- pymc/tests/distributions/test_multivariate.py | 34 +++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index 92b14aded7..6e2bec2e08 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -1415,10 +1415,7 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape): s = pm.sample(10, chains=1, tune=100) # to test forward graph - random_samples = pm.draw( - v, - draws=10, - ) + random_samples = pm.draw(v, draws=10) assert s.posterior.v.shape == (1, 10, len(COORDS["regions"]), len(COORDS["answers"])) @@ -1475,14 +1472,39 @@ def test_zsn_fail_axis(self, dims, zerosum_axes): with pm.Model(coords=COORDS) as m: _ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes) - def test_zsn_change_dist_size(self): - base_dist = pm.ZeroSumNormal.dist(shape=(4, 9)) + @pytest.mark.parametrize( + "zerosum_axes", + [(-1), (-2), (1), ((0, 1)), ((-2, -1))], + ) + def test_zsn_change_dist_size(self, zerosum_axes): + base_dist = pm.ZeroSumNormal.dist(shape=(4, 9), zerosum_axes=zerosum_axes) + random_samples = pm.draw(base_dist, draws=100) + + if not isinstance(zerosum_axes, (list, tuple)): + zerosum_axes = [zerosum_axes] + self.assert_zerosum_axes(random_samples, zerosum_axes) new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=False) assert new_dist.eval().shape == (5, 3) + random_samples = pm.draw(new_dist, draws=100) + self.assert_zerosum_axes(random_samples, zerosum_axes) new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=True) assert new_dist.eval().shape == (5, 3, 4, 9) + random_samples = pm.draw(new_dist, draws=100) + self.assert_zerosum_axes(random_samples, zerosum_axes) + + def assert_zerosum_axes(self, random_samples, zerosum_axes): + for ax in zerosum_axes: + if ax < 0: + assert np.isclose( + random_samples.mean(axis=ax), 0 + ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." + else: + ax = ax + 1 + assert np.isclose( + random_samples.mean(axis=ax), 0 + ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." class TestMvStudentTCov(BaseTestDistributionRandom): From f7a55c5e4cd4a13fb7912560449affe8295f8d62 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Sun, 18 Sep 2022 14:19:30 +0300 Subject: [PATCH 16/50] Improve docstrings --- pymc/distributions/multivariate.py | 8 +++++++- pymc/distributions/transforms.py | 3 +-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index ea7c881f96..1ed6264511 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2403,10 +2403,16 @@ class ZeroSumNormal(Distribution): By default, the last axis is constrained to sum to zero. See `zerosum_axes` kwarg for more details. + .. math: + + ZSN(\sigma) = N(0, \sigma^2 (I - \tfrac{1}{n}J) + where $J_{ij} = 1$ + Parameters ---------- sigma : tensor_like of float - Standard deviation (sigma > 0). + Scale parameter (sigma > 0). + It's actually the standard deviation of the underlying, unconstrained Normal distribution. Defaults to 1 if not specified. For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint. zerosum_axes: list or tuple of strings or integers diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 72b0300c1e..05cf48090d 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -273,8 +273,7 @@ def bounds_fn(*rv_inputs): class ZeroSumTransform(RVTransform): """ - Constrains the samples of a Normal distribution to sum to zero - along the user-provided ``zerosum_axes``. + Constrains any random samples to sum to zero along the user-provided ``zerosum_axes``. By default (``zerosum_axes=[-1]``), the sum-to-zero constraint is imposed on the last axis. """ From da6eaab1a22efe71f16b3ffa423204f5ca53b1c9 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Tue, 27 Sep 2022 16:25:05 +0200 Subject: [PATCH 17/50] Refactor get_steps to work with multivariate support shapes --- pymc/distributions/shape_utils.py | 103 ++++++++++++++++++- pymc/distributions/timeseries.py | 87 +++------------- pymc/tests/distributions/test_shape_utils.py | 65 +++++++++++- pymc/tests/distributions/test_timeseries.py | 62 +---------- 4 files changed, 183 insertions(+), 134 deletions(-) diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 74a96ba4b1..89824bf582 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -20,7 +20,7 @@ import warnings from functools import singledispatch -from typing import Optional, Sequence, Tuple, Union +from typing import Any, Optional, Sequence, Tuple, Union import numpy as np @@ -28,11 +28,15 @@ from aesara import tensor as at from aesara.graph.basic import Variable from aesara.graph.op import Op, compute_test_value +from aesara.raise_op import Assert from aesara.tensor.random.op import RandomVariable from aesara.tensor.shape import SpecifyShape from aesara.tensor.var import TensorVariable from typing_extensions import TypeAlias +from pymc.aesaraf import convert_observed_data +from pymc.model import modelcontext + __all__ = [ "to_tuple", "shapes_broadcasting", @@ -666,3 +670,100 @@ def change_specify_shape_size(op, ss, new_size, expand) -> TensorVariable: # specify_shape has a wrong signature https://github.com/aesara-devs/aesara/issues/1164 return at.specify_shape(new_var, new_shapes) # type: ignore + + +def get_support_shape( + support_shape: Optional[Sequence[Union[int, np.ndarray, TensorVariable]]], + *, + shape: Optional[Shape] = None, + dims: Optional[Dims] = None, + observed: Optional[Any] = None, + support_shape_offset: Sequence[int] = None, + ndim_supp: int = 1, +): + """Extract length of support shapes from shape / dims / observed information + + Parameters + ---------- + support_shape: + User-specified support shape for multivariate distribution + shape: + User-specified shape for multivariate distribution + dims: + User-specified dims for multivariate distribution + observed: + User-specified observed data from multivariate distribution + support_shape_offset: + Difference between last shape dimensions and the length of explicit support shapes in multivariate distribution, defaults to 0. + For timeseries, this is shape[-1] = support_shape[-1] + 1 + ndim_supp: + Number of support dimensions of the given multivariate distribution, defaults to 1 + + Returns + ------- + support_shape + Support shape, if specified directly by user, or inferred from the last dimensions of + shape / dims / observed. When two sources of support shape information are provided, + a symbolic Assert is added to ensure they are consistent. + """ + if support_shape_offset is None: + support_shape_offset = [0] * ndim_supp + inferred_support_shape = None + + if shape is not None: + shape = to_tuple(shape) + assert isinstance(shape, tuple) + inferred_support_shape = at.stack( + [shape[-i - 1] - support_shape_offset[-i - 1] for i in range(ndim_supp)] + ) + + if inferred_support_shape is None and dims is not None: + dims = convert_dims(dims) + assert isinstance(dims, tuple) + model = modelcontext(None) + inferred_support_shape = at.stack( + [ + model.dim_lengths[dims[-i - 1]] - support_shape_offset[-i - 1] # type: ignore + for i in range(ndim_supp) + ] + ) + + if inferred_support_shape is None and observed is not None: + observed = convert_observed_data(observed) + inferred_support_shape = at.stack( + [observed.shape[-i - 1] - support_shape_offset[-i - 1] for i in range(ndim_supp)] + ) + + if inferred_support_shape is None: + inferred_support_shape = support_shape + # If there are two sources of information for the support shapes, assert they are consistent: + elif support_shape is not None: + inferred_support_shape = Assert(msg="Steps do not match last shape dimension")( + inferred_support_shape, at.all(at.eq(inferred_support_shape, support_shape)) + ) + return inferred_support_shape + + +def get_support_shape_1d( + support_shape: Optional[Union[int, np.ndarray, TensorVariable]], + *, + shape: Optional[Shape] = None, + dims: Optional[Dims] = None, + observed: Optional[Any] = None, + support_shape_offset: int = 0, +): + """Helper function for cases when you just care about one dimension.""" + if support_shape is not None: + support_shape_tuple = (support_shape,) + + support_shape_tuple = get_support_shape( + support_shape_tuple, + shape=shape, + dims=dims, + observed=observed, + support_shape_offset=(support_shape_offset,), + ) + if support_shape_tuple is not None: + (support_shape,) = support_shape_tuple + + return support_shape diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index bb23cce574..942a2ed49c 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings -from typing import Any, Optional, Union +from typing import Optional import aesara import aesara.tensor as at @@ -24,12 +24,11 @@ from aesara import scan from aesara.graph import FunctionGraph, rewrite_graph from aesara.graph.basic import Node, clone_replace -from aesara.raise_op import Assert from aesara.tensor import TensorVariable from aesara.tensor.random.op import RandomVariable from aesara.tensor.rewriting.basic import ShapeFeature, topo_constant_folding -from pymc.aesaraf import convert_observed_data, floatX, intX +from pymc.aesaraf import floatX, intX from pymc.distributions import distribution, multivariate from pymc.distributions.continuous import Flat, Normal, get_tau_sigma from pymc.distributions.distribution import ( @@ -40,14 +39,11 @@ ) from pymc.distributions.logprob import ignore_logprob, logp from pymc.distributions.shape_utils import ( - Dims, - Shape, _change_dist_size, change_dist_size, - convert_dims, + get_support_shape_1d, to_tuple, ) -from pymc.model import modelcontext from pymc.util import check_dist_not_registered __all__ = [ @@ -61,61 +57,6 @@ ] -def get_steps( - steps: Optional[Union[int, np.ndarray, TensorVariable]], - *, - shape: Optional[Shape] = None, - dims: Optional[Dims] = None, - observed: Optional[Any] = None, - step_shape_offset: int = 0, -): - """Extract number of steps from shape / dims / observed information - - Parameters - ---------- - steps: - User specified steps for timeseries distribution - shape: - User specified shape for timeseries distribution - dims: - User specified dims for timeseries distribution - observed: - User specified observed data from timeseries distribution - step_shape_offset: - Difference between last shape dimension and number of steps in timeseries - distribution, defaults to 0 - - Returns - ------- - steps - Steps, if specified directly by user, or inferred from the last dimension of - shape / dims / observed. When two sources of step information are provided, - a symbolic Assert is added to ensure they are consistent. - """ - inferred_steps = None - if shape is not None: - shape = to_tuple(shape) - inferred_steps = shape[-1] - step_shape_offset - - if inferred_steps is None and dims is not None: - dims = convert_dims(dims) - model = modelcontext(None) - inferred_steps = model.dim_lengths[dims[-1]] - step_shape_offset - - if inferred_steps is None and observed is not None: - observed = convert_observed_data(observed) - inferred_steps = observed.shape[-1] - step_shape_offset - - if inferred_steps is None: - inferred_steps = steps - # If there are two sources of information for the steps, assert they are consistent - elif steps is not None: - inferred_steps = Assert(msg="Steps do not match last shape dimension")( - inferred_steps, at.eq(inferred_steps, steps) - ) - return inferred_steps - - class RandomWalkRV(SymbolicRandomVariable): """RandomWalk Variable""" @@ -132,21 +73,21 @@ class RandomWalk(Distribution): rv_type = RandomWalkRV def __new__(cls, *args, steps=None, **kwargs): - steps = get_steps( - steps=steps, + steps = get_support_shape_1d( + support_shape=steps, shape=None, # Shape will be checked in `cls.dist` dims=kwargs.get("dims", None), observed=kwargs.get("observed", None), - step_shape_offset=1, + support_shape_offset=1, ) return super().__new__(cls, *args, steps=steps, **kwargs) @classmethod def dist(cls, init_dist, innovation_dist, steps=None, **kwargs) -> at.TensorVariable: - steps = get_steps( - steps=steps, + steps = get_support_shape_1d( + support_shape=steps, shape=kwargs.get("shape"), - step_shape_offset=1, + support_shape_offset=1, ) if steps is None: raise ValueError("Must specify steps or shape parameter") @@ -391,12 +332,12 @@ class AR(Distribution): def __new__(cls, name, rho, *args, steps=None, constant=False, ar_order=None, **kwargs): rhos = at.atleast_1d(at.as_tensor_variable(floatX(rho))) ar_order = cls._get_ar_order(rhos=rhos, constant=constant, ar_order=ar_order) - steps = get_steps( - steps=steps, + steps = get_support_shape_1d( + support_shape=steps, shape=None, # Shape will be checked in `cls.dist` dims=kwargs.get("dims", None), observed=kwargs.get("observed", None), - step_shape_offset=ar_order, + support_shape_offset=ar_order, ) return super().__new__( cls, name, rhos, *args, steps=steps, constant=constant, ar_order=ar_order, **kwargs @@ -427,7 +368,9 @@ def dist( init_dist = kwargs.pop("init") ar_order = cls._get_ar_order(rhos=rhos, constant=constant, ar_order=ar_order) - steps = get_steps(steps=steps, shape=kwargs.get("shape", None), step_shape_offset=ar_order) + steps = get_support_shape_1d( + support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=ar_order + ) if steps is None: raise ValueError("Must specify steps or shape parameter") steps = at.as_tensor_variable(intX(steps), ndim=0) diff --git a/pymc/tests/distributions/test_shape_utils.py b/pymc/tests/distributions/test_shape_utils.py index 8dad510a85..b1d2dc9568 100644 --- a/pymc/tests/distributions/test_shape_utils.py +++ b/pymc/tests/distributions/test_shape_utils.py @@ -18,9 +18,10 @@ import numpy as np import pytest -from aesara import Mode from aesara import tensor as at +from aesara.compile.mode import Mode from aesara.graph import Constant, ancestors +from aesara.tensor import TensorVariable from aesara.tensor.random import normal from aesara.tensor.shape import SpecifyShape @@ -36,10 +37,12 @@ convert_shape, convert_size, get_broadcastable_dist_samples, + get_support_shape_1d, rv_size_is_none, shapes_broadcasting, to_tuple, ) +from pymc.model import Model test_shapes = [ (tuple(), (1,), (4,), (5, 4)), @@ -622,3 +625,63 @@ def test_change_specify_shape_size_multivariate(): new_x.eval({batch: 5, supp: 3}).shape == (10, 5, 5, 3) with pytest.raises(AssertionError, match=re.escape("expected (None, None, 5, 3)")): new_x.eval({batch: 6, supp: 3}).shape == (10, 5, 5, 3) + + +@pytest.mark.parametrize( + "steps, shape, step_shape_offset, expected_steps, consistent", + [ + (10, None, 0, 10, True), + (10, None, 1, 10, True), + (None, (10,), 0, 10, True), + (None, (10,), 1, 9, True), + (None, (10, 5), 0, 5, True), + (None, None, 0, None, True), + (10, (10,), 0, 10, True), + (10, (11,), 1, 10, True), + (10, (5, 5), 0, 5, False), + (10, (5, 10), 1, 9, False), + ], +) +@pytest.mark.parametrize("info_source", ("shape", "dims", "observed")) +def test_get_support_shape_1d( + info_source, steps, shape, step_shape_offset, expected_steps, consistent +): + if info_source == "shape": + inferred_steps = get_support_shape_1d( + support_shape=steps, shape=shape, support_shape_offset=step_shape_offset + ) + + elif info_source == "dims": + if shape is None: + dims = None + coords = {} + else: + dims = tuple(str(i) for i, shape in enumerate(shape)) + coords = {str(i): range(shape) for i, shape in enumerate(shape)} + with Model(coords=coords): + inferred_steps = get_support_shape_1d( + support_shape=steps, dims=dims, support_shape_offset=step_shape_offset + ) + + elif info_source == "observed": + if shape is None: + observed = None + else: + observed = np.zeros(shape) + inferred_steps = get_support_shape_1d( + support_shape=steps, observed=observed, support_shape_offset=step_shape_offset + ) + + if not isinstance(inferred_steps, TensorVariable): + assert inferred_steps == expected_steps + else: + if consistent: + assert inferred_steps.eval() == expected_steps + else: + # check that inferred steps is still correct by ignoring the assert + f = aesara.function( + [], inferred_steps, mode=Mode().including("local_remove_all_assert") + ) + assert f() == expected_steps + with pytest.raises(AssertionError, match="Steps do not match"): + inferred_steps.eval() diff --git a/pymc/tests/distributions/test_timeseries.py b/pymc/tests/distributions/test_timeseries.py index 545a45a20f..77fb9c2646 100644 --- a/pymc/tests/distributions/test_timeseries.py +++ b/pymc/tests/distributions/test_timeseries.py @@ -19,8 +19,6 @@ import pytest import scipy.stats as st -from aesara.tensor import TensorVariable - import pymc as pm from pymc.aesaraf import floatX @@ -28,14 +26,8 @@ from pymc.distributions.discrete import DiracDelta from pymc.distributions.logprob import logp from pymc.distributions.multivariate import Dirichlet -from pymc.distributions.shape_utils import change_dist_size -from pymc.distributions.timeseries import ( - AR, - GARCH11, - EulerMaruyama, - GaussianRandomWalk, - get_steps, -) +from pymc.distributions.shape_utils import change_dist_size, to_tuple +from pymc.distributions.timeseries import AR, GARCH11, EulerMaruyama, GaussianRandomWalk from pymc.model import Model from pymc.sampling import draw, sample, sample_posterior_predictive from pymc.tests.distributions.util import ( @@ -48,56 +40,6 @@ from pymc.tests.helpers import SeededTest, select_by_precision -@pytest.mark.parametrize( - "steps, shape, step_shape_offset, expected_steps, consistent", - [ - (10, None, 0, 10, True), - (10, None, 1, 10, True), - (None, (10,), 0, 10, True), - (None, (10,), 1, 9, True), - (None, (10, 5), 0, 5, True), - (None, None, 0, None, True), - (10, (10,), 0, 10, True), - (10, (11,), 1, 10, True), - (10, (5, 5), 0, 5, False), - (10, (5, 10), 1, 9, False), - ], -) -@pytest.mark.parametrize("info_source", ("shape", "dims", "observed")) -def test_get_steps(info_source, steps, shape, step_shape_offset, expected_steps, consistent): - if info_source == "shape": - inferred_steps = get_steps(steps=steps, shape=shape, step_shape_offset=step_shape_offset) - - elif info_source == "dims": - if shape is None: - dims = None - coords = {} - else: - dims = tuple(str(i) for i, shape in enumerate(shape)) - coords = {str(i): range(shape) for i, shape in enumerate(shape)} - with Model(coords=coords): - inferred_steps = get_steps(steps=steps, dims=dims, step_shape_offset=step_shape_offset) - - elif info_source == "observed": - if shape is None: - observed = None - else: - observed = np.zeros(shape) - inferred_steps = get_steps( - steps=steps, observed=observed, step_shape_offset=step_shape_offset - ) - - if not isinstance(inferred_steps, TensorVariable): - assert inferred_steps == expected_steps - else: - if consistent: - assert inferred_steps.eval() == expected_steps - else: - assert inferred_steps.owner.inputs[0].eval() == expected_steps - with pytest.raises(AssertionError, match="Steps do not match"): - inferred_steps.eval() - - class TestGaussianRandomWalk: def test_logp(self): def ref_logp(value, mu, sigma): From a5ed1f0f3c5548d5f61a63bc4307a9f40182c9c7 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Tue, 27 Sep 2022 18:52:05 +0200 Subject: [PATCH 18/50] Refactor ZSN dist and logp for rightmost zerosum_axes --- pymc/distributions/multivariate.py | 150 ++++++++++++++++++----------- pymc/distributions/shape_utils.py | 2 + 2 files changed, 95 insertions(+), 57 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 1ed6264511..74986ec4af 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -36,7 +36,8 @@ from aesara.tensor.random.utils import broadcast_params from aesara.tensor.slinalg import Cholesky, SolveTriangular from aesara.tensor.type import TensorType -from numpy.core.numeric import normalize_axis_tuple + +# from numpy.core.numeric import normalize_axis_tuple from scipy import linalg, stats import pymc as pm @@ -64,7 +65,7 @@ _change_dist_size, broadcast_dist_samples_to, change_dist_size, - convert_dims, + get_support_shape, rv_size_is_none, to_tuple, ) @@ -2389,11 +2390,7 @@ class ZeroSumNormalRV(SymbolicRandomVariable): """ZeroSumNormal random variable""" _print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}") - zerosum_axes = None - - def __init__(self, *args, zerosum_axes, **kwargs): - self.zerosum_axes = zerosum_axes - super().__init__(*args, **kwargs) + default_output = 0 class ZeroSumNormal(Distribution): @@ -2447,36 +2444,57 @@ class ZeroSumNormal(Distribution): """ rv_type = ZeroSumNormalRV - def __new__(cls, *args, zerosum_axes=None, dims=None, **kwargs): - dims = convert_dims(dims) - if zerosum_axes is None: - zerosum_axes = [-1] - if not isinstance(zerosum_axes, (list, tuple)): - zerosum_axes = [zerosum_axes] + # def __new__(cls, *args, zerosum_axes=None, dims=None, **kwargs): + # dims = convert_dims(dims) + # if zerosum_axes is None: + # zerosum_axes = [-1] + # if not isinstance(zerosum_axes, (list, tuple)): + # zerosum_axes = [zerosum_axes] - if isinstance(zerosum_axes[0], str): - if not dims: - raise ValueError("You need to specify dims if zerosum_axes are strings.") - else: - zerosum_axes_ = [] - for axis in zerosum_axes: - zerosum_axes_.append(dims.index(axis)) - zerosum_axes = zerosum_axes_ + # if isinstance(zerosum_axes[0], str): + # if not dims: + # raise ValueError("You need to specify dims if zerosum_axes are strings.") + # else: + # zerosum_axes_ = [] + # for axis in zerosum_axes: + # zerosum_axes_.append(dims.index(axis)) + # zerosum_axes = zerosum_axes_ - return super().__new__(cls, *args, zerosum_axes=zerosum_axes, dims=dims, **kwargs) + # return super().__new__(cls, *args, zerosum_axes=zerosum_axes, dims=dims, **kwargs) @classmethod - def dist(cls, sigma=1, zerosum_axes=None, **kwargs): + def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs): if zerosum_axes is None: - zerosum_axes = [-1] - if not isinstance(zerosum_axes, (list, tuple)): - zerosum_axes = [zerosum_axes] + zerosum_axes = 1 + if not isinstance(zerosum_axes, int): + raise TypeError("zerosum_axes has to be an integer") + if not zerosum_axes > 0: + raise ValueError("zerosum_axes has to be > 0") sigma = at.as_tensor_variable(floatX(sigma)) if sigma.ndim > 0: raise ValueError("sigma has to be a scalar") - return super().dist([sigma], zerosum_axes=zerosum_axes, **kwargs) + support_shape = get_support_shape( + support_shape=support_shape, + shape=kwargs.get("shape"), + ndim_supp=zerosum_axes, + ) + if support_shape is None: + if zerosum_axes > 0: + raise ValueError("You must specify shape or support_shape parameter") + # edge case doesn't work for now, because at.stack in get_support_shape fails + # else: + # support_shape = () # because it's just a Normal in that case + support_shape = at.as_tensor_variable(intX(support_shape)) + + assert zerosum_axes == at.get_vector_length( + support_shape + ), "support_shape has to be as long as zerosum_axes" + + return super().dist( + [sigma], zerosum_axes=zerosum_axes, support_shape=support_shape, **kwargs + ) # TODO: This is if we want ZeroSum constraint on other dists than Normal # def dist(cls, dist, lower, upper, **kwargs): @@ -2494,39 +2512,55 @@ def dist(cls, sigma=1, zerosum_axes=None, **kwargs): # return super().dist([dist, lower, upper], **kwargs) @classmethod - def rv_op(cls, sigma, zerosum_axes, size=None): - if size is None: - zerosum_axes_ = np.asarray(zerosum_axes) - # just a placeholder size to infer minimum shape - size = np.ones( - max((max(np.abs(zerosum_axes_) - 1), max(zerosum_axes_))) + 1, dtype=int - ).tolist() + def rv_op(cls, sigma, zerosum_axes, support_shape, size=None): + # if size is None: + # zerosum_axes_ = np.asarray(zerosum_axes) + # # just a placeholder size to infer minimum shape + # size = np.ones( + # max((max(np.abs(zerosum_axes_) - 1), max(zerosum_axes_))) + 1, dtype=int + # ).tolist() # check if zerosum_axes is valid - normalize_axis_tuple(zerosum_axes, len(size)) - - normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, size=size)) - normal_dist_, sigma_ = normal_dist.type(), sigma.type() + # normalize_axis_tuple(zerosum_axes, len(size)) + + shape = to_tuple(size) + tuple(support_shape) + normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, shape=shape)) + normal_dist_, sigma_, support_shape_ = ( + normal_dist.type(), + sigma.type(), + support_shape.type(), + ) # Zerosum-normaling is achieved by substracting the mean along the given zerosum_axes zerosum_rv_ = normal_dist_ - for axis in zerosum_axes: - zerosum_rv_ -= zerosum_rv_.mean(axis=axis, keepdims=True) + for axis in range(zerosum_axes): + zerosum_rv_ -= zerosum_rv_.mean(axis=-axis - 1, keepdims=True) return ZeroSumNormalRV( - inputs=[normal_dist_, sigma_], - outputs=[zerosum_rv_], - zerosum_axes=zerosum_axes, - ndim_supp=0, - )(normal_dist, sigma) + inputs=[normal_dist_, sigma_, support_shape_], + outputs=[zerosum_rv_, support_shape_], + ndim_supp=zerosum_axes, + )(normal_dist, sigma, support_shape) + + # TODO: + # write __new__ + # refactor ZSN tests + # test get_support_shape with 2D + # test ZSN logp + # test ZSN variance + # fix failing Ubuntu test @_change_dist_size.register(ZeroSumNormalRV) def change_zerosum_size(op, normal_dist, new_size, expand=False): - normal_dist, sigma = normal_dist.owner.inputs + normal_dist, sigma, support_shape = normal_dist.owner.inputs if expand: - new_size = tuple(new_size) + tuple(normal_dist.shape) - return ZeroSumNormal.rv_op(sigma=sigma, zerosum_axes=op.zerosum_axes, size=new_size) + original_shape = tuple(normal_dist.shape) + old_size = original_shape[len(original_shape) - op.ndim_supp :] + new_size = tuple(new_size) + old_size + return ZeroSumNormal.rv_op( + sigma=sigma, zerosum_axes=op.ndim_supp, support_shape=support_shape, size=new_size + ) @_moment.register(ZeroSumNormalRV) @@ -2540,20 +2574,22 @@ def zerosum_default_transform(op, rv): @_logprob.register(ZeroSumNormalRV) -def zerosumnormal_logp(op, values, normal_dist, sigma, **kwargs): +def zerosumnormal_logp(op, values, normal_dist, sigma, support_shape, **kwargs): (value,) = values shape = value.shape - _deg_free_shape = at.inc_subtensor(shape[at.as_tensor_variable(op.zerosum_axes)], -1) + zerosum_axes = op.ndim_supp + _deg_free_support_shape = at.inc_subtensor(shape[-zerosum_axes:], -1) _full_size = at.prod(shape) - _degrees_of_freedom = at.prod(_deg_free_shape) + _degrees_of_freedom = at.prod(_deg_free_support_shape) zerosums = [ - at.all(at.isclose(at.mean(value, axis=axis), 0, atol=1e-9)) for axis in op.zerosum_axes + at.all(at.isclose(at.mean(value, axis=-axis - 1), 0, atol=1e-9)) + for axis in range(zerosum_axes) ] - # out = at.sum( - # pm.logp(dist, value) * _degrees_of_freedom / _full_size, - # axis=op.zerosum_axes, - # ) + out = at.sum( + pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size, + axis=tuple(np.arange(-zerosum_axes, 0)), + ) # figure out how dimensionality should be handled for logp # for now, we assume ZSN is a scalar distribut, which is not correct - out = pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size + # out = pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size return check_parameters(out, *zerosums, msg="at.mean(value, axis=zerosum_axes) == 0") diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 89824bf582..118b818df6 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -706,6 +706,8 @@ def get_support_shape( shape / dims / observed. When two sources of support shape information are provided, a symbolic Assert is added to ensure they are consistent. """ + if ndim_supp < 1: + raise NotImplementedError("ndim_supp must be bigger than 0") if support_shape_offset is None: support_shape_offset = [0] * ndim_supp inferred_support_shape = None From 126e76bfbb548c7464263184a8aa51bbdf2e76b3 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Wed, 28 Sep 2022 16:14:34 +0200 Subject: [PATCH 19/50] Start writing __new__ method --- pymc/distributions/multivariate.py | 79 ++++++++++++++++++------------ pymc/distributions/shape_utils.py | 8 +-- 2 files changed, 54 insertions(+), 33 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 74986ec4af..539044a62d 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -18,6 +18,7 @@ import warnings from functools import reduce +from typing import Optional import aesara import aesara.tensor as at @@ -36,8 +37,6 @@ from aesara.tensor.random.utils import broadcast_params from aesara.tensor.slinalg import Cholesky, SolveTriangular from aesara.tensor.type import TensorType - -# from numpy.core.numeric import normalize_axis_tuple from scipy import linalg, stats import pymc as pm @@ -2412,20 +2411,24 @@ class ZeroSumNormal(Distribution): It's actually the standard deviation of the underlying, unconstrained Normal distribution. Defaults to 1 if not specified. For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint. - zerosum_axes: list or tuple of strings or integers - Axis (or axes) along which the zero-sum constraint is enforced. - Defaults to [-1], i.e the last axis. - If strings are passed, then ``dims`` is needed. - Otherwise, ``shape`` and ``size`` work as they do for other PyMC distributions. - dims: list or tuple of strings, optional - The dimension names of the axes. - Necessary when ``zerosum_axes`` is specified with strings. + zerosum_axes: int, defaults to 1 + Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position. + Defaults to 1, i.e the rightmost axis. + dims: sequence of strings, optional + Dimension names of the distribution. Works the same as for other PyMC distributions. + Necessary if ``shape`` is not passed. + shape: tuple of integers, optional + Shape of the distribution. Works the same as for other PyMC distributions. + Necessary if ``dims`` is not passed. Warnings -------- ``sigma`` has to be a scalar, to ensure the zero-sum constraint. The ability to specifiy a vector of ``sigma`` may be added in future versions. + ``zerosum_axes`` has to be > 0. If you want the behavior of ``zerosum_axes = 0``, + just use ``pm.Normal``. + Examples -------- .. code-block:: python @@ -2444,23 +2447,21 @@ class ZeroSumNormal(Distribution): """ rv_type = ZeroSumNormalRV - # def __new__(cls, *args, zerosum_axes=None, dims=None, **kwargs): - # dims = convert_dims(dims) - # if zerosum_axes is None: - # zerosum_axes = [-1] - # if not isinstance(zerosum_axes, (list, tuple)): - # zerosum_axes = [zerosum_axes] + def __new__(cls, *args, zerosum_axes=None, support_shape=None, dims=None, **kwargs): + if dims is not None or kwargs.get("observed") is not None: + zerosum_axes = cls.check_zerosum_axes(zerosum_axes) - # if isinstance(zerosum_axes[0], str): - # if not dims: - # raise ValueError("You need to specify dims if zerosum_axes are strings.") - # else: - # zerosum_axes_ = [] - # for axis in zerosum_axes: - # zerosum_axes_.append(dims.index(axis)) - # zerosum_axes = zerosum_axes_ + support_shape = get_support_shape( + support_shape=support_shape, + shape=None, # Shape will be checked in `cls.dist` + dims=dims, + observed=kwargs.get("observed", None), + ndim_supp=zerosum_axes, + ) - # return super().__new__(cls, *args, zerosum_axes=zerosum_axes, dims=dims, **kwargs) + return super().__new__( + cls, *args, zerosum_axes=zerosum_axes, support_shape=support_shape, dims=dims, **kwargs + ) @classmethod def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs): @@ -2480,10 +2481,13 @@ def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs): shape=kwargs.get("shape"), ndim_supp=zerosum_axes, ) + + # print(f"{support_shape.eval() = }") + if support_shape is None: if zerosum_axes > 0: raise ValueError("You must specify shape or support_shape parameter") - # edge case doesn't work for now, because at.stack in get_support_shape fails + # edge-case doesn't work for now, because at.stack in get_support_shape fails # else: # support_shape = () # because it's just a Normal in that case support_shape = at.as_tensor_variable(intX(support_shape)) @@ -2511,6 +2515,16 @@ def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs): # check_dist_not_registered(dist) # return super().dist([dist, lower, upper], **kwargs) + @classmethod + def check_zerosum_axes(cls, zerosum_axes: Optional[int]) -> int: + if zerosum_axes is None: + zerosum_axes = 1 + if not isinstance(zerosum_axes, int): + raise TypeError("zerosum_axes has to be an integer") + if not zerosum_axes > 0: + raise ValueError("zerosum_axes has to be > 0") + return zerosum_axes + @classmethod def rv_op(cls, sigma, zerosum_axes, support_shape, size=None): # if size is None: @@ -2553,11 +2567,14 @@ def rv_op(cls, sigma, zerosum_axes, support_shape, size=None): @_change_dist_size.register(ZeroSumNormalRV) def change_zerosum_size(op, normal_dist, new_size, expand=False): + normal_dist, sigma, support_shape = normal_dist.owner.inputs + if expand: original_shape = tuple(normal_dist.shape) old_size = original_shape[len(original_shape) - op.ndim_supp :] new_size = tuple(new_size) + old_size + return ZeroSumNormal.rv_op( sigma=sigma, zerosum_axes=op.ndim_supp, support_shape=support_shape, size=new_size ) @@ -2570,7 +2587,8 @@ def zerosumnormal_moment(op, rv, *rv_inputs): @_default_transform.register(ZeroSumNormalRV) def zerosum_default_transform(op, rv): - return ZeroSumTransform(op.zerosum_axes) + zerosum_axes = tuple(np.arange(-op.ndim_supp, 0)) + return ZeroSumTransform(zerosum_axes) @_logprob.register(ZeroSumNormalRV) @@ -2578,18 +2596,19 @@ def zerosumnormal_logp(op, values, normal_dist, sigma, support_shape, **kwargs): (value,) = values shape = value.shape zerosum_axes = op.ndim_supp + _deg_free_support_shape = at.inc_subtensor(shape[-zerosum_axes:], -1) _full_size = at.prod(shape) _degrees_of_freedom = at.prod(_deg_free_support_shape) + zerosums = [ at.all(at.isclose(at.mean(value, axis=-axis - 1), 0, atol=1e-9)) for axis in range(zerosum_axes) ] + out = at.sum( pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size, axis=tuple(np.arange(-zerosum_axes, 0)), ) - # figure out how dimensionality should be handled for logp - # for now, we assume ZSN is a scalar distribut, which is not correct - # out = pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size + return check_parameters(out, *zerosums, msg="at.mean(value, axis=zerosum_axes) == 0") diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 118b818df6..5ad08d7098 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -681,7 +681,7 @@ def get_support_shape( support_shape_offset: Sequence[int] = None, ndim_supp: int = 1, ): - """Extract length of support shapes from shape / dims / observed information + """Extract the support shapes from shape / dims / observed information Parameters ---------- @@ -694,7 +694,8 @@ def get_support_shape( observed: User-specified observed data from multivariate distribution support_shape_offset: - Difference between last shape dimensions and the length of explicit support shapes in multivariate distribution, defaults to 0. + Difference between last shape dimensions and the length of + explicit support shapes in multivariate distribution, defaults to 0. For timeseries, this is shape[-1] = support_shape[-1] + 1 ndim_supp: Number of support dimensions of the given multivariate distribution, defaults to 1 @@ -740,9 +741,10 @@ def get_support_shape( inferred_support_shape = support_shape # If there are two sources of information for the support shapes, assert they are consistent: elif support_shape is not None: - inferred_support_shape = Assert(msg="Steps do not match last shape dimension")( + inferred_support_shape = Assert(msg="support_shape does not match last shape dimension")( inferred_support_shape, at.all(at.eq(inferred_support_shape, support_shape)) ) + return inferred_support_shape From 3a8d8982b3833820e6d4c16a04d7835b8f05c186 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Wed, 28 Sep 2022 16:41:06 +0200 Subject: [PATCH 20/50] Handle single output and fix transform --- pymc/distributions/distribution.py | 8 ++++++++ pymc/distributions/multivariate.py | 2 -- pymc/distributions/transforms.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 772e1a6ce6..1515899be0 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -415,6 +415,14 @@ def dist( @_get_measurable_outputs.register(SymbolicRandomVariable) def _get_measurable_outputs_symbolic_random_variable(op, node): # This tells Aeppl that any non RandomType outputs are measurable + + # Assume that if there is one default_output, that's the only one that is measurable + # In the rare case this is not what one wants, a specialized _get_measuarable_outputs + # can dispatch for a subclassed Op + if op.default_output is not None: + return [node.default_output()] + + # Otherwise assume that any outputs that are not of RandomType are measurable return [out for out in node.outputs if not isinstance(out.type, RandomType)] diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 539044a62d..004dceb0d8 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2482,8 +2482,6 @@ def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs): ndim_supp=zerosum_axes, ) - # print(f"{support_shape.eval() = }") - if support_shape is None: if zerosum_axes > 0: raise ValueError("You must specify shape or support_shape parameter") diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 05cf48090d..58046dd34a 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -291,7 +291,7 @@ def __init__(self, zerosum_axes): By default (``zerosum_axes=[-1]``), the sum-to-zero constraint is imposed on the last axis. """ - self.zerosum_axes = zerosum_axes + self.zerosum_axes = tuple(int(axis) for axis in zerosum_axes) def forward(self, value, *rv_inputs): for axis in self.zerosum_axes: From 4c5273757e2e18bf453fe3c032fbe1c7f4320ed4 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Wed, 28 Sep 2022 18:33:26 +0200 Subject: [PATCH 21/50] Fix indexing of at.stack in get_support_shape --- pymc/distributions/shape_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 5ad08d7098..8ef5acc112 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -717,7 +717,7 @@ def get_support_shape( shape = to_tuple(shape) assert isinstance(shape, tuple) inferred_support_shape = at.stack( - [shape[-i - 1] - support_shape_offset[-i - 1] for i in range(ndim_supp)] + [shape[i] - support_shape_offset[i] for i in np.arange(-ndim_supp, 0)] ) if inferred_support_shape is None and dims is not None: @@ -726,15 +726,15 @@ def get_support_shape( model = modelcontext(None) inferred_support_shape = at.stack( [ - model.dim_lengths[dims[-i - 1]] - support_shape_offset[-i - 1] # type: ignore - for i in range(ndim_supp) + model.dim_lengths[dims[i]] - support_shape_offset[i] # type: ignore + for i in np.arange(-ndim_supp, 0) ] ) if inferred_support_shape is None and observed is not None: observed = convert_observed_data(observed) inferred_support_shape = at.stack( - [observed.shape[-i - 1] - support_shape_offset[-i - 1] for i in range(ndim_supp)] + [observed.shape[i] - support_shape_offset[i] for i in np.arange(-ndim_supp, 0)] ) if inferred_support_shape is None: From 7e4ed0a3a4d94c556ea0d01c0fa94ce133f8450c Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Wed, 28 Sep 2022 18:37:26 +0200 Subject: [PATCH 22/50] Fix examples in ZSN docstrings --- pymc/distributions/multivariate.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 004dceb0d8..832c4ebef3 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2437,13 +2437,16 @@ class ZeroSumNormal(Distribution): "answers": ["yes", "no", "whatever", "don't understand question"], } with pm.Model(coords=COORDS) as m: - ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes="answers") + # the zero sum axis will be 'answers' + ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers")) with pm.Model(coords=COORDS) as m: - ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=("regions", "answers")) + # the zero sum axes will be 'answers' and 'regions' + ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=2) with pm.Model(coords=COORDS) as m: - ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=1) + # the zero sum axes will be the last two + ...: v = pm.ZeroSumNormal("v", shape=(3, 4, 5), zerosum_axes=2) """ rv_type = ZeroSumNormalRV @@ -2525,18 +2528,13 @@ def check_zerosum_axes(cls, zerosum_axes: Optional[int]) -> int: @classmethod def rv_op(cls, sigma, zerosum_axes, support_shape, size=None): - # if size is None: - # zerosum_axes_ = np.asarray(zerosum_axes) - # # just a placeholder size to infer minimum shape - # size = np.ones( - # max((max(np.abs(zerosum_axes_) - 1), max(zerosum_axes_))) + 1, dtype=int - # ).tolist() - - # check if zerosum_axes is valid - # normalize_axis_tuple(zerosum_axes, len(size)) shape = to_tuple(size) + tuple(support_shape) normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, shape=shape)) + + if zerosum_axes > normal_dist.ndim: + raise ValueError("Shape of distribution is too small for the number of zerosum axes") + normal_dist_, sigma_, support_shape_ = ( normal_dist.type(), sigma.type(), @@ -2555,7 +2553,6 @@ def rv_op(cls, sigma, zerosum_axes, support_shape, size=None): )(normal_dist, sigma, support_shape) # TODO: - # write __new__ # refactor ZSN tests # test get_support_shape with 2D # test ZSN logp From 44b5b91a1749bc31b5b3fb23acabbfd024f23355 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Wed, 28 Sep 2022 18:38:10 +0200 Subject: [PATCH 23/50] Refactor test_zsn_dims_shape --- pymc/tests/distributions/test_multivariate.py | 60 +++++++------------ 1 file changed, 23 insertions(+), 37 deletions(-) diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index 6e2bec2e08..bd26b2eab5 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -1401,12 +1401,12 @@ class TestZeroSumNormal: @pytest.mark.parametrize( "dims, zerosum_axes, shape", [ - (("regions", "answers"), "answers", None), - (("regions", "answers"), ("regions", "answers"), None), - (("regions", "answers"), 0, None), - (("regions", "answers"), -1, None), - (("regions", "answers"), (0, 1), None), - (None, -2, (len(COORDS["regions"]), len(COORDS["answers"]))), + (("regions", "answers"), None, None), + (("regions", "answers"), 1, None), + (("regions", "answers"), 2, None), + (None, None, (len(COORDS["regions"]), len(COORDS["answers"]))), + (None, 1, (len(COORDS["regions"]), len(COORDS["answers"]))), + (None, 2, (len(COORDS["regions"]), len(COORDS["answers"]))), ], ) def test_zsn_dims_shape(self, dims, zerosum_axes, shape): @@ -1419,41 +1419,27 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape): assert s.posterior.v.shape == (1, 10, len(COORDS["regions"]), len(COORDS["answers"])) - if not isinstance(zerosum_axes, (list, tuple)): - zerosum_axes = [zerosum_axes] + zerosum_axes = np.arange(-v.owner.op.ndim_supp, 0) + nonzero_axes = np.arange(v.ndim - v.owner.op.ndim_supp) + + for ax in zerosum_axes: + for samples in [ + s.posterior.v.mean(axis=ax), + random_samples.mean(axis=ax), + ]: + assert np.isclose( + samples, 0 + ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." - if isinstance(zerosum_axes[0], str): - for ax in zerosum_axes: + if nonzero_axes: + for ax in nonzero_axes: for samples in [ - s.posterior.v.mean(dim=ax), - random_samples.mean(axis=dims.index(ax) + 1), + s.posterior.v.mean(axis=ax), + random_samples.mean(axis=ax), ]: - assert np.isclose( + assert not np.isclose( samples, 0 - ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." - - nonzero_axes = list(set(dims).difference(zerosum_axes)) - if nonzero_axes: - for ax in nonzero_axes: - for samples in [ - s.posterior.v.mean(dim=ax), - random_samples.mean(axis=dims.index(ax) + 1), - ]: - assert not np.isclose( - samples, 0 - ).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." - - else: - for ax in zerosum_axes: - if ax < 0: - assert np.isclose( - s.posterior.v.mean(axis=ax), 0 - ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." - else: - ax = ax + 2 # because 'chain' and 'draw' are added as new axes after sampling - assert np.isclose( - s.posterior.v.mean(axis=ax), 0 - ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." + ).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." @pytest.mark.parametrize( "dims, zerosum_axes", From 99dbb38e5792b4480e644ac6bf073f3de3601be2 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Wed, 28 Sep 2022 19:17:51 +0200 Subject: [PATCH 24/50] Refactor test_zsn_fail_axis --- pymc/tests/distributions/test_multivariate.py | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index bd26b2eab5..2d7b53d0da 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -28,7 +28,6 @@ from aeppl.logprob import ParameterValueError from aesara.tensor import TensorVariable from aesara.tensor.random.utils import broadcast_params -from numpy import AxisError import pymc as pm @@ -1442,21 +1441,29 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape): ).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." @pytest.mark.parametrize( - "dims, zerosum_axes", + "error, match, shape, support_shape, zerosum_axes", [ - (("regions", "answers"), 2), - (("regions", "answers"), (0, -2)), + (IndexError, "index out of range", (3, 4, 5), None, 4), + (AssertionError, "does not match", (3, 4), 3, None), # support_shape should be 4 + ( + AssertionError, + "does not match", + (3, 4), + (3, 4), + None, + ), # doesn't work because zerosum_axes = 1 ], ) - def test_zsn_fail_axis(self, dims, zerosum_axes): - if isinstance(zerosum_axes, (list, tuple)): - with pytest.raises(ValueError, match="repeated axis"): - with pm.Model(coords=COORDS) as m: - _ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes) - else: - with pytest.raises(AxisError, match="out of bounds"): - with pm.Model(coords=COORDS) as m: - _ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes) + def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes): + with pytest.raises(error, match=match): + with pm.Model() as m: + _ = pm.ZeroSumNormal( + "v", shape=shape, support_shape=support_shape, zerosum_axes=zerosum_axes + ) + + # v = pm.ZeroSumNormal("v", support_shape=(3, 4), zerosum_axes=2) # should work + + # v = pm.ZeroSumNormal("v", shape=(3, 4), support_shape=(3, 4), zerosum_axes=2) this should work but doesn't @pytest.mark.parametrize( "zerosum_axes", From e3dc1d46846d7915336ebe29b0ff54bf2a96ebd1 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Thu, 29 Sep 2022 15:19:01 +0200 Subject: [PATCH 25/50] Refactor test_zsn_change_dist_size --- pymc/distributions/multivariate.py | 2 +- pymc/tests/distributions/test_multivariate.py | 22 ++++++++----------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 832c4ebef3..2770a26783 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2567,7 +2567,7 @@ def change_zerosum_size(op, normal_dist, new_size, expand=False): if expand: original_shape = tuple(normal_dist.shape) - old_size = original_shape[len(original_shape) - op.ndim_supp :] + old_size = original_shape[: len(original_shape) - op.ndim_supp] new_size = tuple(new_size) + old_size return ZeroSumNormal.rv_op( diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index 2d7b53d0da..3f0cf3c2db 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -1467,18 +1467,19 @@ def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes): @pytest.mark.parametrize( "zerosum_axes", - [(-1), (-2), (1), ((0, 1)), ((-2, -1))], + [1, 2], ) def test_zsn_change_dist_size(self, zerosum_axes): base_dist = pm.ZeroSumNormal.dist(shape=(4, 9), zerosum_axes=zerosum_axes) random_samples = pm.draw(base_dist, draws=100) - if not isinstance(zerosum_axes, (list, tuple)): - zerosum_axes = [zerosum_axes] self.assert_zerosum_axes(random_samples, zerosum_axes) new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=False) - assert new_dist.eval().shape == (5, 3) + if zerosum_axes == 1: + assert new_dist.eval().shape == (5, 3, 9) + elif zerosum_axes == 2: + assert new_dist.eval().shape == (5, 3, 4, 9) random_samples = pm.draw(new_dist, draws=100) self.assert_zerosum_axes(random_samples, zerosum_axes) @@ -1488,16 +1489,11 @@ def test_zsn_change_dist_size(self, zerosum_axes): self.assert_zerosum_axes(random_samples, zerosum_axes) def assert_zerosum_axes(self, random_samples, zerosum_axes): + zerosum_axes = np.arange(-zerosum_axes, 0) for ax in zerosum_axes: - if ax < 0: - assert np.isclose( - random_samples.mean(axis=ax), 0 - ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." - else: - ax = ax + 1 - assert np.isclose( - random_samples.mean(axis=ax), 0 - ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." + assert np.isclose( + random_samples.mean(axis=ax), 0 + ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." class TestMvStudentTCov(BaseTestDistributionRandom): From 09f0d91b11564a5acedb6773cf4a74cfe5f7cc8e Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Thu, 29 Sep 2022 15:40:37 +0200 Subject: [PATCH 26/50] Simplify test_zsn_dims_shape --- pymc/tests/distributions/test_multivariate.py | 48 ++++++++----------- 1 file changed, 21 insertions(+), 27 deletions(-) diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index 3f0cf3c2db..e351e8d424 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -1418,27 +1418,15 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape): assert s.posterior.v.shape == (1, 10, len(COORDS["regions"]), len(COORDS["answers"])) - zerosum_axes = np.arange(-v.owner.op.ndim_supp, 0) - nonzero_axes = np.arange(v.ndim - v.owner.op.ndim_supp) - - for ax in zerosum_axes: - for samples in [ - s.posterior.v.mean(axis=ax), - random_samples.mean(axis=ax), - ]: - assert np.isclose( - samples, 0 - ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." - - if nonzero_axes: - for ax in nonzero_axes: - for samples in [ - s.posterior.v.mean(axis=ax), - random_samples.mean(axis=ax), - ]: - assert not np.isclose( - samples, 0 - ).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." + ndim_supp = v.owner.op.ndim_supp + zerosum_axes = np.arange(-ndim_supp, 0) + nonzero_axes = np.arange(v.ndim - ndim_supp) + for samples in [ + s.posterior.v, + random_samples, + ]: + self.assert_zerosum_axes(samples, zerosum_axes) + self.assert_zerosum_axes(samples, nonzero_axes, check_zerosum_axes=False) @pytest.mark.parametrize( "error, match, shape, support_shape, zerosum_axes", @@ -1473,6 +1461,7 @@ def test_zsn_change_dist_size(self, zerosum_axes): base_dist = pm.ZeroSumNormal.dist(shape=(4, 9), zerosum_axes=zerosum_axes) random_samples = pm.draw(base_dist, draws=100) + zerosum_axes = np.arange(-zerosum_axes, 0) self.assert_zerosum_axes(random_samples, zerosum_axes) new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=False) @@ -1488,12 +1477,17 @@ def test_zsn_change_dist_size(self, zerosum_axes): random_samples = pm.draw(new_dist, draws=100) self.assert_zerosum_axes(random_samples, zerosum_axes) - def assert_zerosum_axes(self, random_samples, zerosum_axes): - zerosum_axes = np.arange(-zerosum_axes, 0) - for ax in zerosum_axes: - assert np.isclose( - random_samples.mean(axis=ax), 0 - ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." + def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes=True): + if check_zerosum_axes: + for ax in axes_to_check: + assert np.isclose( + random_samples.mean(axis=ax), 0 + ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." + else: + for ax in axes_to_check: + assert not np.isclose( + random_samples.mean(axis=ax), 0 + ).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." class TestMvStudentTCov(BaseTestDistributionRandom): From cf5b384271c60a747d29d01e79cbee64bac54c49 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Thu, 29 Sep 2022 15:42:39 +0200 Subject: [PATCH 27/50] Refactor test_zsn_dims_shape --- pymc/tests/distributions/test_multivariate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index e351e8d424..f87ff8bfb0 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -1016,8 +1016,8 @@ def test_mv_normal_moment(self, mu, cov, size, expected): "shape, zerosum_axes, expected", [ ((2, 5), None, np.zeros((2, 5))), - ((2, 5, 6), None, np.zeros((2, 5, 6))), - ((2, 5, 6), (0, 1), np.zeros((2, 5, 6))), + ((2, 5, 6), 2, np.zeros((2, 5, 6))), + ((2, 5, 6), 3, np.zeros((2, 5, 6))), ], ) def test_zerosum_normal_moment(self, shape, zerosum_axes, expected): From 3e86a3ea405c7ea6e03e6a58c4c1f81e03383855 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Thu, 29 Sep 2022 16:35:36 +0200 Subject: [PATCH 28/50] Fix get_support_shape --- pymc/distributions/multivariate.py | 3 +-- pymc/distributions/shape_utils.py | 9 +++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 2770a26783..49a80f19c4 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2487,7 +2487,7 @@ def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs): if support_shape is None: if zerosum_axes > 0: - raise ValueError("You must specify shape or support_shape parameter") + raise ValueError("You must specify dims, shape or support_shape parameter") # edge-case doesn't work for now, because at.stack in get_support_shape fails # else: # support_shape = () # because it's just a Normal in that case @@ -2553,7 +2553,6 @@ def rv_op(cls, sigma, zerosum_axes, support_shape, size=None): )(normal_dist, sigma, support_shape) # TODO: - # refactor ZSN tests # test get_support_shape with 2D # test ZSN logp # test ZSN variance diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 8ef5acc112..4abf0b4eb9 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -741,8 +741,13 @@ def get_support_shape( inferred_support_shape = support_shape # If there are two sources of information for the support shapes, assert they are consistent: elif support_shape is not None: - inferred_support_shape = Assert(msg="support_shape does not match last shape dimension")( - inferred_support_shape, at.all(at.eq(inferred_support_shape, support_shape)) + inferred_support_shape = at.stack( + [ + Assert(msg="support_shape does not match last shape dimension")( + inferred, at.eq(inferred, explicit) + ) + for inferred, explicit in zip(inferred_support_shape, support_shape) + ] ) return inferred_support_shape From ce68f020ed30299f185bb9898c61007f5909cb25 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Thu, 29 Sep 2022 16:36:30 +0200 Subject: [PATCH 29/50] Test support_shape handling --- pymc/tests/distributions/test_multivariate.py | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index f87ff8bfb0..6e0de958a1 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -1432,14 +1432,14 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape): "error, match, shape, support_shape, zerosum_axes", [ (IndexError, "index out of range", (3, 4, 5), None, 4), - (AssertionError, "does not match", (3, 4), 3, None), # support_shape should be 4 + (AssertionError, "does not match", (3, 4), (3,), None), # support_shape should be 4 ( AssertionError, "does not match", (3, 4), (3, 4), None, - ), # doesn't work because zerosum_axes = 1 + ), # doesn't work because zerosum_axes = 1 by default ], ) def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes): @@ -1449,9 +1449,20 @@ def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes): "v", shape=shape, support_shape=support_shape, zerosum_axes=zerosum_axes ) - # v = pm.ZeroSumNormal("v", support_shape=(3, 4), zerosum_axes=2) # should work + @pytest.mark.parametrize( + "shape, support_shape", + [ + (None, (3, 4)), + ((3, 4), (3, 4)), + ], + ) + def test_zsn_support_shape(self, shape, support_shape): + with pm.Model() as m: + v = pm.ZeroSumNormal("v", shape=shape, support_shape=support_shape, zerosum_axes=2) - # v = pm.ZeroSumNormal("v", shape=(3, 4), support_shape=(3, 4), zerosum_axes=2) this should work but doesn't + random_samples = pm.draw(v, draws=10) + zerosum_axes = np.arange(-2, 0) + self.assert_zerosum_axes(random_samples, zerosum_axes) @pytest.mark.parametrize( "zerosum_axes", @@ -1465,9 +1476,9 @@ def test_zsn_change_dist_size(self, zerosum_axes): self.assert_zerosum_axes(random_samples, zerosum_axes) new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=False) - if zerosum_axes == 1: + try: assert new_dist.eval().shape == (5, 3, 9) - elif zerosum_axes == 2: + except AssertionError: assert new_dist.eval().shape == (5, 3, 4, 9) random_samples = pm.draw(new_dist, draws=100) self.assert_zerosum_axes(random_samples, zerosum_axes) From b50909e99e5af32ab1b5bfa95725653d1e777e90 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Thu, 29 Sep 2022 17:07:19 +0200 Subject: [PATCH 30/50] Remove TODO list comment --- pymc/distributions/multivariate.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 49a80f19c4..2e63109476 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2488,7 +2488,7 @@ def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs): if support_shape is None: if zerosum_axes > 0: raise ValueError("You must specify dims, shape or support_shape parameter") - # edge-case doesn't work for now, because at.stack in get_support_shape fails + # TODO: edge-case doesn't work for now, because at.stack in get_support_shape fails # else: # support_shape = () # because it's just a Normal in that case support_shape = at.as_tensor_variable(intX(support_shape)) @@ -2552,12 +2552,6 @@ def rv_op(cls, sigma, zerosum_axes, support_shape, size=None): ndim_supp=zerosum_axes, )(normal_dist, sigma, support_shape) - # TODO: - # test get_support_shape with 2D - # test ZSN logp - # test ZSN variance - # fix failing Ubuntu test - @_change_dist_size.register(ZeroSumNormalRV) def change_zerosum_size(op, normal_dist, new_size, expand=False): From 7ba1d0ff3a4912f7857119a3ff48d1bb4db18b13 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Thu, 29 Sep 2022 19:17:05 +0200 Subject: [PATCH 31/50] Add test of ZSN variance --- pymc/tests/distributions/test_multivariate.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index 6e8bd6f89a..fc7de70fd0 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -1501,6 +1501,23 @@ def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes= random_samples.mean(axis=ax), 0 ).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." + @pytest.mark.parametrize( + "sigma, n", + [ + (5, 3), + (2, 6), + ], + ) + def test_zsn_variance(self, sigma, n): + + dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=n) + random_samples = pm.draw(dist, draws=100_000) + + empirical_var = random_samples.var(axis=0) + theoretical_var = sigma**2 * (n - 1) / n + + np.testing.assert_allclose(empirical_var, theoretical_var, rtol=1e-02) + class TestMvStudentTCov(BaseTestDistributionRandom): def mvstudentt_rng_fn(self, size, nu, mu, cov, rng): From 5ee950a51f611c659bad015c40f3bbc252af4942 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Fri, 30 Sep 2022 14:04:13 +0200 Subject: [PATCH 32/50] Remove unused imports --- pymc/distributions/timeseries.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 0a790e7ee4..ab6f82c932 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -25,7 +25,7 @@ from aesara.tensor import TensorVariable from aesara.tensor.random.op import RandomVariable -from pymc.aesaraf import constant_fold, convert_observed_data, floatX, intX +from pymc.aesaraf import constant_fold, floatX, intX from pymc.distributions import distribution, multivariate from pymc.distributions.continuous import Flat, Normal, get_tau_sigma from pymc.distributions.distribution import ( @@ -42,7 +42,6 @@ to_tuple, ) from pymc.exceptions import NotConstantValueError -from pymc.model import modelcontext from pymc.util import check_dist_not_registered __all__ = [ From 13a54e6d817fb712b9770b469a3f8f10b036c3a4 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Fri, 30 Sep 2022 14:08:59 +0200 Subject: [PATCH 33/50] Replace get_steps by get_support_shape_1d in timeseries.py --- pymc/distributions/timeseries.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 41b21f5f57..376479d5f0 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -583,18 +583,20 @@ class GARCH11(Distribution): rv_type = GARCH11RV def __new__(cls, *args, steps=None, **kwargs): - steps = get_steps( - steps=steps, + steps = get_support_shape_1d( + support_shape=steps, shape=None, # Shape will be checked in `cls.dist` dims=kwargs.get("dims", None), observed=kwargs.get("observed", None), - step_shape_offset=1, + support_shape_offset=1, ) return super().__new__(cls, *args, steps=steps, **kwargs) @classmethod def dist(cls, omega, alpha_1, beta_1, initial_vol, *, steps=None, **kwargs): - steps = get_steps(steps=steps, shape=kwargs.get("shape", None), step_shape_offset=1) + steps = get_support_shape_1d( + support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=1 + ) if steps is None: raise ValueError("Must specify steps or shape parameter") steps = at.as_tensor_variable(intX(steps), ndim=0) From ca655bca7110014105b666f49373cfe8ac4c2fc1 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Fri, 30 Sep 2022 14:40:41 +0200 Subject: [PATCH 34/50] Split dims and shape test --- pymc/tests/distributions/test_multivariate.py | 43 +++++++++++++++---- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index fc7de70fd0..aef01ec26c 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -1399,19 +1399,44 @@ def test_issue_3706(self): class TestZeroSumNormal: @pytest.mark.parametrize( - "dims, zerosum_axes, shape", + "dims, zerosum_axes", [ - (("regions", "answers"), None, None), - (("regions", "answers"), 1, None), - (("regions", "answers"), 2, None), - (None, None, (len(COORDS["regions"]), len(COORDS["answers"]))), - (None, 1, (len(COORDS["regions"]), len(COORDS["answers"]))), - (None, 2, (len(COORDS["regions"]), len(COORDS["answers"]))), + (("regions", "answers"), None), + (("regions", "answers"), 1), + (("regions", "answers"), 2), ], ) - def test_zsn_dims_shape(self, dims, zerosum_axes, shape): + def test_zsn_dims(self, dims, zerosum_axes): with pm.Model(coords=COORDS) as m: - v = pm.ZeroSumNormal("v", dims=dims, shape=shape, zerosum_axes=zerosum_axes) + v = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes) + s = pm.sample(10, chains=1, tune=100) + + # to test forward graph + random_samples = pm.draw(v, draws=10) + + assert s.posterior.v.shape == (1, 10, len(COORDS["regions"]), len(COORDS["answers"])) + + ndim_supp = v.owner.op.ndim_supp + zerosum_axes = np.arange(-ndim_supp, 0) + nonzero_axes = np.arange(v.ndim - ndim_supp) + for samples in [ + s.posterior.v, + random_samples, + ]: + self.assert_zerosum_axes(samples, zerosum_axes) + self.assert_zerosum_axes(samples, nonzero_axes, check_zerosum_axes=False) + + @pytest.mark.parametrize( + "zerosum_axes, shape", + [ + (None, (len(COORDS["regions"]), len(COORDS["answers"]))), + (1, (len(COORDS["regions"]), len(COORDS["answers"]))), + (2, (len(COORDS["regions"]), len(COORDS["answers"]))), + ], + ) + def test_zsn_shape(self, shape, zerosum_axes): + with pm.Model(coords=COORDS) as m: + v = pm.ZeroSumNormal("v", shape=shape, zerosum_axes=zerosum_axes) s = pm.sample(10, chains=1, tune=100) # to test forward graph From 9d419eff39b34a09acdffdcd4b7a165ff2db1243 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Fri, 30 Sep 2022 15:42:34 +0200 Subject: [PATCH 35/50] Fix test_get_support_shape_1d --- pymc/distributions/shape_utils.py | 2 ++ pymc/tests/distributions/test_shape_utils.py | 34 +++++++++++--------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 9d4ed23cde..b9b45c1238 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -755,6 +755,8 @@ def get_support_shape_1d( """Helper function for cases when you just care about one dimension.""" if support_shape is not None: support_shape_tuple = (support_shape,) + else: + support_shape_tuple = None support_shape_tuple = get_support_shape( support_shape_tuple, diff --git a/pymc/tests/distributions/test_shape_utils.py b/pymc/tests/distributions/test_shape_utils.py index a52232413c..3c60ec1b28 100644 --- a/pymc/tests/distributions/test_shape_utils.py +++ b/pymc/tests/distributions/test_shape_utils.py @@ -605,7 +605,7 @@ def test_change_specify_shape_size_multivariate(): @pytest.mark.parametrize( - "steps, shape, step_shape_offset, expected_steps, consistent", + "support_shape, shape, support_shape_offset, expected_support_shape, consistent", [ (10, None, 0, 10, True), (10, None, 1, 10, True), @@ -621,11 +621,11 @@ def test_change_specify_shape_size_multivariate(): ) @pytest.mark.parametrize("info_source", ("shape", "dims", "observed")) def test_get_support_shape_1d( - info_source, steps, shape, step_shape_offset, expected_steps, consistent + info_source, support_shape, shape, support_shape_offset, expected_support_shape, consistent ): if info_source == "shape": - inferred_steps = get_support_shape_1d( - support_shape=steps, shape=shape, support_shape_offset=step_shape_offset + inferred_support_shape = get_support_shape_1d( + support_shape=support_shape, shape=shape, support_shape_offset=support_shape_offset ) elif info_source == "dims": @@ -633,11 +633,11 @@ def test_get_support_shape_1d( dims = None coords = {} else: - dims = tuple(str(i) for i, shape in enumerate(shape)) + dims = tuple(str(i) for i, _ in enumerate(shape)) coords = {str(i): range(shape) for i, shape in enumerate(shape)} with Model(coords=coords): - inferred_steps = get_support_shape_1d( - support_shape=steps, dims=dims, support_shape_offset=step_shape_offset + inferred_support_shape = get_support_shape_1d( + support_shape=support_shape, dims=dims, support_shape_offset=support_shape_offset ) elif info_source == "observed": @@ -645,20 +645,22 @@ def test_get_support_shape_1d( observed = None else: observed = np.zeros(shape) - inferred_steps = get_support_shape_1d( - support_shape=steps, observed=observed, support_shape_offset=step_shape_offset + inferred_support_shape = get_support_shape_1d( + support_shape=support_shape, + observed=observed, + support_shape_offset=support_shape_offset, ) - if not isinstance(inferred_steps, TensorVariable): - assert inferred_steps == expected_steps + if not isinstance(inferred_support_shape, TensorVariable): + assert inferred_support_shape == expected_support_shape else: if consistent: - assert inferred_steps.eval() == expected_steps + assert inferred_support_shape.eval() == expected_support_shape else: # check that inferred steps is still correct by ignoring the assert f = aesara.function( - [], inferred_steps, mode=Mode().including("local_remove_all_assert") + [], inferred_support_shape, mode=Mode().including("local_remove_all_assert") ) - assert f() == expected_steps - with pytest.raises(AssertionError, match="Steps do not match"): - inferred_steps.eval() + assert f() == expected_support_shape + with pytest.raises(AssertionError, match="support_shape does not match"): + inferred_support_shape.eval() From 85da56cf021c06826820481cf8c8f7ee0968cf12 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Fri, 30 Sep 2022 16:29:36 +0200 Subject: [PATCH 36/50] Add test_get_support_shape --- pymc/tests/distributions/test_shape_utils.py | 77 ++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/pymc/tests/distributions/test_shape_utils.py b/pymc/tests/distributions/test_shape_utils.py index 3c60ec1b28..3a9f4bb8de 100644 --- a/pymc/tests/distributions/test_shape_utils.py +++ b/pymc/tests/distributions/test_shape_utils.py @@ -37,6 +37,7 @@ convert_shape, convert_size, get_broadcastable_dist_samples, + get_support_shape, get_support_shape_1d, rv_size_is_none, shapes_broadcasting, @@ -664,3 +665,79 @@ def test_get_support_shape_1d( assert f() == expected_support_shape with pytest.raises(AssertionError, match="support_shape does not match"): inferred_support_shape.eval() + + +@pytest.mark.parametrize( + "support_shape, shape, support_shape_offset, expected_support_shape, ndim_supp, consistent", + [ + ((10, 5), None, (0,), (10, 5), 1, True), + ((10, 5), None, (1, 1), (10, 5), 1, True), + (None, (10, 5), (0,), 5, 1, True), + (None, (10, 5), (1,), 4, 1, True), + (None, (10, 5, 2), (0,), 2, 1, True), + (None, None, None, None, 1, True), + ((10, 5), (10, 5), None, (10, 5), 2, True), + ((10, 5), (11, 10, 5), None, (10, 5), 2, True), + (None, (11, 10, 5), (0, 1, 0), (11, 9, 5), 3, True), + ((10, 5), (10, 5, 5), (0,), (5,), 1, False), + ((10, 5), (10, 5), (1, 1), (9, 4), 2, False), + ], +) +@pytest.mark.parametrize("info_source", ("shape", "dims", "observed")) +def test_get_support_shape( + info_source, + support_shape, + shape, + support_shape_offset, + expected_support_shape, + ndim_supp, + consistent, +): + if info_source == "shape": + inferred_support_shape = get_support_shape( + support_shape=support_shape, + shape=shape, + support_shape_offset=support_shape_offset, + ndim_supp=ndim_supp, + ) + + elif info_source == "dims": + if shape is None: + dims = None + coords = {} + else: + dims = tuple(str(i) for i, _ in enumerate(shape)) + coords = {str(i): range(shape) for i, shape in enumerate(shape)} + with Model(coords=coords): + inferred_support_shape = get_support_shape( + support_shape=support_shape, + dims=dims, + support_shape_offset=support_shape_offset, + ndim_supp=ndim_supp, + ) + + elif info_source == "observed": + if shape is None: + observed = None + else: + observed = np.zeros(shape) + inferred_support_shape = get_support_shape( + support_shape=support_shape, + observed=observed, + support_shape_offset=support_shape_offset, + ndim_supp=ndim_supp, + ) + + if not isinstance(inferred_support_shape, TensorVariable): + assert inferred_support_shape == expected_support_shape + else: + if consistent: + assert (inferred_support_shape.eval() == expected_support_shape).all() + else: + # check that inferred support shape is still correct by ignoring the assert + f = aesara.function( + [], inferred_support_shape, mode=Mode().including("local_remove_all_assert") + ) + assert (f() == expected_support_shape).all() + with pytest.raises(AssertionError, match="support_shape does not match"): + inferred_support_shape.eval() From f363118078d908ffdd4c526053dff0587a2714f7 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Wed, 5 Oct 2022 18:25:39 +0200 Subject: [PATCH 37/50] Add ZSN logp test --- pymc/tests/distributions/test_multivariate.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index aef01ec26c..dda35d8427 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -1543,6 +1543,52 @@ def test_zsn_variance(self, sigma, n): np.testing.assert_allclose(empirical_var, theoretical_var, rtol=1e-02) + @pytest.mark.parametrize( + "sigma, shape, zerosum_axes, mvn_axes", + [ + (5, 3, None, [-1]), + (2, 6, None, [-1]), + (5, (7, 3), None, [-1]), + (5, (2, 7, 3), 2, [1, 2]), + ], + ) + def test_zsn_logp(self, sigma, shape, zerosum_axes, mvn_axes): + + zsn_dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=shape, zerosum_axes=zerosum_axes) + zsn_logp = pm.logp(zsn_dist, value=np.zeros(shape)).eval() + mvn_logp = self.logp_norm(value=np.zeros(shape), sigma=sigma, axes=mvn_axes) + + np.testing.assert_allclose(zsn_logp, mvn_logp) + + def logp_norm(self, value, sigma, axes): + """ + Special case of the MvNormal, that's equivalent to the ZSN. + Only to test the ZSN logp + """ + axes = [ax if ax >= 0 else value.ndim + ax for ax in axes] + if len(set(axes)) < len(axes): + raise ValueError("Must specify unique zero sum axes") + other_axes = [ax for ax in range(value.ndim) if ax not in axes] + new_order = other_axes + axes + reshaped_value = np.reshape( + np.transpose(value, new_order), [value.shape[ax] for ax in other_axes] + [-1] + ) + + degrees_of_freedom = np.prod([value.shape[ax] - 1 for ax in axes]) + full_size = np.prod([value.shape[ax] for ax in axes]) + + ns = value.shape[-1] + psdet = (0.5 * np.log(2 * np.pi) + np.log(sigma)) * degrees_of_freedom / full_size + exp = 0.5 * (reshaped_value / sigma) ** 2 + inds = np.ones_like(value, dtype="bool") + for ax in axes: + inds = np.logical_and(inds, np.abs(np.mean(value, axis=ax, keepdims=True)) < 1e-9) + inds = np.reshape( + np.transpose(inds, new_order), [value.shape[ax] for ax in other_axes] + [-1] + )[..., 0] + + return np.where(inds, np.sum(-psdet - exp, axis=-1), -np.inf) + class TestMvStudentTCov(BaseTestDistributionRandom): def mvstudentt_rng_fn(self, size, nu, mu, cov, rng): From 64eca5cd52cf824646a258307b1f451ed4e4e2f9 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Wed, 5 Oct 2022 18:33:18 +0200 Subject: [PATCH 38/50] Fix test_inconsistent_steps_and_shape --- pymc/tests/distributions/test_timeseries.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pymc/tests/distributions/test_timeseries.py b/pymc/tests/distributions/test_timeseries.py index 947720ed4a..3af42bcb57 100644 --- a/pymc/tests/distributions/test_timeseries.py +++ b/pymc/tests/distributions/test_timeseries.py @@ -122,7 +122,9 @@ def test_missing_steps(self): GaussianRandomWalk.dist(shape=None, init_dist=Normal.dist(0, 100)) def test_inconsistent_steps_and_shape(self): - with pytest.raises(AssertionError, match="Steps do not match last shape dimension"): + with pytest.raises( + AssertionError, match="support_shape does not match last shape dimension" + ): x = GaussianRandomWalk.dist(steps=12, shape=45, init_dist=Normal.dist(0, 100)) def test_inferred_steps_from_dims(self): From c5e76c94f5bbe81f8370d69b5d5654acd7b2b124 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Wed, 5 Oct 2022 18:44:37 +0200 Subject: [PATCH 39/50] Integrate review comments --- pymc/distributions/multivariate.py | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 6405b557ad..16bd5cc413 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2468,12 +2468,7 @@ def __new__(cls, *args, zerosum_axes=None, support_shape=None, dims=None, **kwar @classmethod def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs): - if zerosum_axes is None: - zerosum_axes = 1 - if not isinstance(zerosum_axes, int): - raise TypeError("zerosum_axes has to be an integer") - if not zerosum_axes > 0: - raise ValueError("zerosum_axes has to be > 0") + zerosum_axes = cls.check_zerosum_axes(zerosum_axes) sigma = at.as_tensor_variable(floatX(sigma)) if sigma.ndim > 0: @@ -2501,21 +2496,6 @@ def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs): [sigma], zerosum_axes=zerosum_axes, support_shape=support_shape, **kwargs ) - # TODO: This is if we want ZeroSum constraint on other dists than Normal - # def dist(cls, dist, lower, upper, **kwargs): - # if not isinstance(dist, TensorVariable) or not isinstance( - # dist.owner.op, (RandomVariable, SymbolicRandomVariable) - # ): - # raise ValueError( - # f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}" - # ) - # if dist.owner.op.ndim_supp > 0: - # raise NotImplementedError( - # "Censoring of multivariate distributions has not been implemented yet" - # ) - # check_dist_not_registered(dist) - # return super().dist([dist, lower, upper], **kwargs) - @classmethod def check_zerosum_axes(cls, zerosum_axes: Optional[int]) -> int: if zerosum_axes is None: From 08c9df0f0e6511f194faa006f7f59a83a3641688 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Wed, 5 Oct 2022 18:51:42 +0200 Subject: [PATCH 40/50] Solve freaking pre-commit issues --- pymc/tests/distributions/test_multivariate.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index dda35d8427..3ef3537f42 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -755,12 +755,7 @@ def test_car_logp(self, sparse, size): # d x d adjacency matrix for a square (d=4) of rook-adjacent sites W = np.array( - [ - [0.0, 1.0, 1.0, 0.0], - [1.0, 0.0, 0.0, 1.0], - [1.0, 0.0, 0.0, 1.0], - [0.0, 1.0, 1.0, 0.0], - ] + [[0.0, 1.0, 1.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0]] ) tau = 2 @@ -1045,12 +1040,7 @@ def test_zerosum_normal_moment(self, shape, zerosum_axes, expected): ) def test_car_moment(self, mu, size, expected): W = np.array( - [ - [0.0, 1.0, 1.0, 0.0], - [1.0, 0.0, 0.0, 1.0], - [1.0, 0.0, 0.0, 1.0], - [0.0, 1.0, 1.0, 0.0], - ] + [[0.0, 1.0, 1.0, 0.0], [1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0]] ) tau = 2 alpha = 0.5 From c120f7e1b26310cf3eecfe581e59dc18676f9a9e Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Wed, 5 Oct 2022 18:55:35 +0200 Subject: [PATCH 41/50] Put assert_zerosum_axes at top of test class --- pymc/tests/distributions/test_multivariate.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index 3ef3537f42..b6410af855 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -1388,6 +1388,18 @@ def test_issue_3706(self): class TestZeroSumNormal: + def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes=True): + if check_zerosum_axes: + for ax in axes_to_check: + assert np.isclose( + random_samples.mean(axis=ax), 0 + ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." + else: + for ax in axes_to_check: + assert not np.isclose( + random_samples.mean(axis=ax), 0 + ).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." + @pytest.mark.parametrize( "dims, zerosum_axes", [ @@ -1504,18 +1516,6 @@ def test_zsn_change_dist_size(self, zerosum_axes): random_samples = pm.draw(new_dist, draws=100) self.assert_zerosum_axes(random_samples, zerosum_axes) - def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes=True): - if check_zerosum_axes: - for ax in axes_to_check: - assert np.isclose( - random_samples.mean(axis=ax), 0 - ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." - else: - for ax in axes_to_check: - assert not np.isclose( - random_samples.mean(axis=ax), 0 - ).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." - @pytest.mark.parametrize( "sigma, n", [ From ba5f3a141f1e1818122ebd8d9f182bc3dd0a3302 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Wed, 5 Oct 2022 19:35:13 +0200 Subject: [PATCH 42/50] Improve error message of get_support_shape --- pymc/distributions/shape_utils.py | 2 +- pymc/tests/distributions/test_timeseries.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index b9b45c1238..2bcc85a89a 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -734,7 +734,7 @@ def get_support_shape( elif support_shape is not None: inferred_support_shape = at.stack( [ - Assert(msg="support_shape does not match last shape dimension")( + Assert(msg="support_shape does not match respective shape dimension")( inferred, at.eq(inferred, explicit) ) for inferred, explicit in zip(inferred_support_shape, support_shape) diff --git a/pymc/tests/distributions/test_timeseries.py b/pymc/tests/distributions/test_timeseries.py index 3af42bcb57..f7f6a7227f 100644 --- a/pymc/tests/distributions/test_timeseries.py +++ b/pymc/tests/distributions/test_timeseries.py @@ -123,7 +123,7 @@ def test_missing_steps(self): def test_inconsistent_steps_and_shape(self): with pytest.raises( - AssertionError, match="support_shape does not match last shape dimension" + AssertionError, match="support_shape does not match respective shape dimension" ): x = GaussianRandomWalk.dist(steps=12, shape=45, init_dist=Normal.dist(0, 100)) From 48dafe99c3484fa5d7a304c75099515568d16cab Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Wed, 5 Oct 2022 19:39:25 +0200 Subject: [PATCH 43/50] Nicer format for ZSN logp test --- pymc/tests/distributions/test_multivariate.py | 58 +++++++++---------- 1 file changed, 28 insertions(+), 30 deletions(-) diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index b6410af855..64bcfe3edd 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -1543,41 +1543,39 @@ def test_zsn_variance(self, sigma, n): ], ) def test_zsn_logp(self, sigma, shape, zerosum_axes, mvn_axes): + def logp_norm(value, sigma, axes): + """ + Special case of the MvNormal, that's equivalent to the ZSN. + Only to test the ZSN logp + """ + axes = [ax if ax >= 0 else value.ndim + ax for ax in axes] + if len(set(axes)) < len(axes): + raise ValueError("Must specify unique zero sum axes") + other_axes = [ax for ax in range(value.ndim) if ax not in axes] + new_order = other_axes + axes + reshaped_value = np.reshape( + np.transpose(value, new_order), [value.shape[ax] for ax in other_axes] + [-1] + ) - zsn_dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=shape, zerosum_axes=zerosum_axes) - zsn_logp = pm.logp(zsn_dist, value=np.zeros(shape)).eval() - mvn_logp = self.logp_norm(value=np.zeros(shape), sigma=sigma, axes=mvn_axes) + degrees_of_freedom = np.prod([value.shape[ax] - 1 for ax in axes]) + full_size = np.prod([value.shape[ax] for ax in axes]) - np.testing.assert_allclose(zsn_logp, mvn_logp) + psdet = (0.5 * np.log(2 * np.pi) + np.log(sigma)) * degrees_of_freedom / full_size + exp = 0.5 * (reshaped_value / sigma) ** 2 + inds = np.ones_like(value, dtype="bool") + for ax in axes: + inds = np.logical_and(inds, np.abs(np.mean(value, axis=ax, keepdims=True)) < 1e-9) + inds = np.reshape( + np.transpose(inds, new_order), [value.shape[ax] for ax in other_axes] + [-1] + )[..., 0] - def logp_norm(self, value, sigma, axes): - """ - Special case of the MvNormal, that's equivalent to the ZSN. - Only to test the ZSN logp - """ - axes = [ax if ax >= 0 else value.ndim + ax for ax in axes] - if len(set(axes)) < len(axes): - raise ValueError("Must specify unique zero sum axes") - other_axes = [ax for ax in range(value.ndim) if ax not in axes] - new_order = other_axes + axes - reshaped_value = np.reshape( - np.transpose(value, new_order), [value.shape[ax] for ax in other_axes] + [-1] - ) - - degrees_of_freedom = np.prod([value.shape[ax] - 1 for ax in axes]) - full_size = np.prod([value.shape[ax] for ax in axes]) + return np.where(inds, np.sum(-psdet - exp, axis=-1), -np.inf) - ns = value.shape[-1] - psdet = (0.5 * np.log(2 * np.pi) + np.log(sigma)) * degrees_of_freedom / full_size - exp = 0.5 * (reshaped_value / sigma) ** 2 - inds = np.ones_like(value, dtype="bool") - for ax in axes: - inds = np.logical_and(inds, np.abs(np.mean(value, axis=ax, keepdims=True)) < 1e-9) - inds = np.reshape( - np.transpose(inds, new_order), [value.shape[ax] for ax in other_axes] + [-1] - )[..., 0] + zsn_dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=shape, zerosum_axes=zerosum_axes) + zsn_logp = pm.logp(zsn_dist, value=np.zeros(shape)).eval() + mvn_logp = logp_norm(value=np.zeros(shape), sigma=sigma, axes=mvn_axes) - return np.where(inds, np.sum(-psdet - exp, axis=-1), -np.inf) + np.testing.assert_allclose(zsn_logp, mvn_logp) class TestMvStudentTCov(BaseTestDistributionRandom): From 6612a24a2391f1a52fc07d96b2dc16bdcd965c0b Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Thu, 6 Oct 2022 10:13:33 +0200 Subject: [PATCH 44/50] Increase tolerance for test_zsn_variance --- pymc/distributions/multivariate.py | 6 +-- pymc/tests/distributions/test_multivariate.py | 47 +++++++++++-------- 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 16bd5cc413..5bf640590a 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2438,15 +2438,15 @@ class ZeroSumNormal(Distribution): } with pm.Model(coords=COORDS) as m: # the zero sum axis will be 'answers' - ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers")) + v = pm.ZeroSumNormal("v", dims=("regions", "answers")) with pm.Model(coords=COORDS) as m: # the zero sum axes will be 'answers' and 'regions' - ...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=2) + v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=2) with pm.Model(coords=COORDS) as m: # the zero sum axes will be the last two - ...: v = pm.ZeroSumNormal("v", shape=(3, 4, 5), zerosum_axes=2) + v = pm.ZeroSumNormal("v", shape=(3, 4, 5), zerosum_axes=2) """ rv_type = ZeroSumNormalRV diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index 64bcfe3edd..263b720595 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -1381,13 +1381,12 @@ def test_issue_3706(self): assert prior_pred["X"].shape == (1, N, 2) -COORDS = { - "regions": ["a", "b", "c"], - "answers": ["yes", "no", "whatever", "don't understand question"], -} - - class TestZeroSumNormal: + coords = { + "regions": ["a", "b", "c"], + "answers": ["yes", "no", "whatever", "don't understand question"], + } + def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes=True): if check_zerosum_axes: for ax in axes_to_check: @@ -1409,14 +1408,19 @@ def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes= ], ) def test_zsn_dims(self, dims, zerosum_axes): - with pm.Model(coords=COORDS) as m: + with pm.Model(coords=self.coords) as m: v = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes) s = pm.sample(10, chains=1, tune=100) # to test forward graph random_samples = pm.draw(v, draws=10) - assert s.posterior.v.shape == (1, 10, len(COORDS["regions"]), len(COORDS["answers"])) + assert s.posterior.v.shape == ( + 1, + 10, + len(self.coords["regions"]), + len(self.coords["answers"]), + ) ndim_supp = v.owner.op.ndim_supp zerosum_axes = np.arange(-ndim_supp, 0) @@ -1429,22 +1433,25 @@ def test_zsn_dims(self, dims, zerosum_axes): self.assert_zerosum_axes(samples, nonzero_axes, check_zerosum_axes=False) @pytest.mark.parametrize( - "zerosum_axes, shape", - [ - (None, (len(COORDS["regions"]), len(COORDS["answers"]))), - (1, (len(COORDS["regions"]), len(COORDS["answers"]))), - (2, (len(COORDS["regions"]), len(COORDS["answers"]))), - ], + "zerosum_axes", + (None, 1, 2), ) - def test_zsn_shape(self, shape, zerosum_axes): - with pm.Model(coords=COORDS) as m: + def test_zsn_shape(self, zerosum_axes): + shape = (len(self.coords["regions"]), len(self.coords["answers"])) + + with pm.Model(coords=self.coords) as m: v = pm.ZeroSumNormal("v", shape=shape, zerosum_axes=zerosum_axes) s = pm.sample(10, chains=1, tune=100) # to test forward graph random_samples = pm.draw(v, draws=10) - assert s.posterior.v.shape == (1, 10, len(COORDS["regions"]), len(COORDS["answers"])) + assert s.posterior.v.shape == ( + 1, + 10, + len(self.coords["regions"]), + len(self.coords["answers"]), + ) ndim_supp = v.owner.op.ndim_supp zerosum_axes = np.arange(-ndim_supp, 0) @@ -1525,13 +1532,13 @@ def test_zsn_change_dist_size(self, zerosum_axes): ) def test_zsn_variance(self, sigma, n): - dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=n) - random_samples = pm.draw(dist, draws=100_000) + dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=(100_000, n)) + random_samples = pm.draw(dist) empirical_var = random_samples.var(axis=0) theoretical_var = sigma**2 * (n - 1) / n - np.testing.assert_allclose(empirical_var, theoretical_var, rtol=1e-02) + np.testing.assert_allclose(empirical_var, theoretical_var, atol=0.4) @pytest.mark.parametrize( "sigma, shape, zerosum_axes, mvn_axes", From 6b07a2a79b7def081628744e22dddb5515f62675 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Thu, 6 Oct 2022 18:42:30 +0200 Subject: [PATCH 45/50] Add ZSN to docs --- docs/source/api/distributions/multivariate.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/api/distributions/multivariate.rst b/docs/source/api/distributions/multivariate.rst index 1156f04a25..ac401b9944 100644 --- a/docs/source/api/distributions/multivariate.rst +++ b/docs/source/api/distributions/multivariate.rst @@ -8,6 +8,7 @@ Multivariate MvNormal MvStudentT + ZeroSumNormal Dirichlet Multinomial DirichletMultinomial From 135ed473351b8b1837791c9c20e4097af8a9ba78 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Thu, 6 Oct 2022 21:22:41 +0200 Subject: [PATCH 46/50] Refactor ZSN docs --- pymc/distributions/multivariate.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 5bf640590a..81b4c40916 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2399,7 +2399,7 @@ class ZeroSumNormal(Distribution): By default, the last axis is constrained to sum to zero. See `zerosum_axes` kwarg for more details. - .. math: + .. math:: ZSN(\sigma) = N(0, \sigma^2 (I - \tfrac{1}{n}J) where $J_{ij} = 1$ @@ -2431,7 +2431,9 @@ class ZeroSumNormal(Distribution): Examples -------- - .. code-block:: python + Define a `ZeroSumNormal` variable, with `sigma=1` and + `zerosum_axes=1` by default:: + COORDS = { "regions": ["a", "b", "c"], "answers": ["yes", "no", "whatever", "don't understand question"], From cba01877eb919af0660b3facdd38270790d7c5d7 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Thu, 6 Oct 2022 22:24:17 +0200 Subject: [PATCH 47/50] Better latex in ZSN docs --- pymc/distributions/multivariate.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 81b4c40916..d758adb3b0 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2401,8 +2401,11 @@ class ZeroSumNormal(Distribution): .. math:: - ZSN(\sigma) = N(0, \sigma^2 (I - \tfrac{1}{n}J) - where $J_{ij} = 1$ + \begin{align*} + ZSN(\sigma) = N \Big( 0, \sigma^2 (I - \tfrac{1}{n}J) \Big) \\ + \text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\ + n = \text{nbr of zero-sum axes} + \end{align*} Parameters ---------- From 566f3086cfbe00fd41d1a69f4cf5b30742b20c90 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Fri, 7 Oct 2022 12:58:50 +0200 Subject: [PATCH 48/50] Add ZeroSumTransform to docs --- docs/source/api/distributions/transforms.rst | 1 + pymc/distributions/multivariate.py | 2 +- pymc/distributions/transforms.py | 15 +++++++-------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/source/api/distributions/transforms.rst b/docs/source/api/distributions/transforms.rst index 904ee19ea5..434e2065c5 100644 --- a/docs/source/api/distributions/transforms.rst +++ b/docs/source/api/distributions/transforms.rst @@ -33,6 +33,7 @@ Specific Transform Classes LogExpM1 Ordered SumTo1 + ZeroSumTransform Transform Composition Classes diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index d758adb3b0..e92a479d15 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2422,7 +2422,7 @@ class ZeroSumNormal(Distribution): Necessary if ``shape`` is not passed. shape: tuple of integers, optional Shape of the distribution. Works the same as for other PyMC distributions. - Necessary if ``dims`` is not passed. + Necessary if ``dims`` or ``observed`` is not passed. Warnings -------- diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 0b10460ee0..d895dd7a7b 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -276,6 +276,13 @@ class ZeroSumTransform(RVTransform): Constrains any random samples to sum to zero along the user-provided ``zerosum_axes``. By default (``zerosum_axes=[-1]``), the sum-to-zero constraint is imposed on the last axis. + + Parameters + ---------- + zerosum_axes : list of ints + Must be a list of integers (positive or negative). + By default (``zerosum_axes=[-1]``), the sum-to-zero constraint is imposed + on the last axis. """ name = "zerosum" @@ -283,14 +290,6 @@ class ZeroSumTransform(RVTransform): __props__ = ("zerosum_axes",) def __init__(self, zerosum_axes): - """ - Parameters - ---------- - zerosum_axes : list of ints - Must be a list of integers (positive or negative). - By default (``zerosum_axes=[-1]``), the sum-to-zero constraint is imposed - on the last axis. - """ self.zerosum_axes = tuple(int(axis) for axis in zerosum_axes) def forward(self, value, *rv_inputs): From 5954e65b07cb079d6b10845dbbe0c40d4eab824e Mon Sep 17 00:00:00 2001 From: Alexandre Andorra Date: Fri, 7 Oct 2022 15:04:39 +0200 Subject: [PATCH 49/50] Remove mention of default value in ZS transform docs Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/distributions/transforms.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index d895dd7a7b..408080d58c 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -281,8 +281,6 @@ class ZeroSumTransform(RVTransform): ---------- zerosum_axes : list of ints Must be a list of integers (positive or negative). - By default (``zerosum_axes=[-1]``), the sum-to-zero constraint is imposed - on the last axis. """ name = "zerosum" From 3e729221b444f6b68e0ba450342a133dca9d350d Mon Sep 17 00:00:00 2001 From: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> Date: Fri, 7 Oct 2022 16:21:55 +0200 Subject: [PATCH 50/50] Update pymc/distributions/transforms.py --- pymc/distributions/transforms.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 408080d58c..ee142b46fe 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -274,8 +274,6 @@ def bounds_fn(*rv_inputs): class ZeroSumTransform(RVTransform): """ Constrains any random samples to sum to zero along the user-provided ``zerosum_axes``. - By default (``zerosum_axes=[-1]``), the sum-to-zero constraint is imposed - on the last axis. Parameters ----------