From 7988d0bd023c8ba05a2d97bcbb563a67ed9ed82a Mon Sep 17 00:00:00 2001 From: Max Horn Date: Sat, 27 Feb 2021 15:09:30 +0100 Subject: [PATCH] Keep broadcasting information in make_shared_replacements It seems like broadcasting information gets lost when applying `pm.make_shared_replacements`, leading to problems with the metropolis sampler. Potentially related issues below: - https://github.com/pymc-devs/pymc3/issues/1083 - https://github.com/pymc-devs/pymc3/issues/1304 - https://github.com/pymc-devs/pymc3/issues/1983 This fix was previously suggested in the following issue: - https://github.com/pymc-devs/pymc3/issues/3337 It could be that further adaptations are necessary as indicated in the issue. Strangely, this does not seem to lead to problems when using NUTS. --- RELEASE-NOTES.md | 2 ++ pymc3/aesaraf.py | 7 ++++++- pymc3/tests/test_aesaraf.py | 25 +++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index ce4a7fc2c4..b8f7537b89 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -12,6 +12,8 @@ ### Maintenance - The `pymc3.memoize` module was removed and replaced with `cachetools`. The `hashable` function and `WithMemoization` class were moved to `pymc3.util` (see [#4509](https://github.com/pymc-devs/pymc3/pull/4509)). - Remove float128 dtype support (see [#4514](https://github.com/pymc-devs/pymc3/pull/4514)). +- `pm.make_shared_replacements` now retains broadcasting information which fixes issues with Metropolis samplers (see [#4492](https://github.com/pymc-devs/pymc3/pull/4492)). ++ ... ## PyMC3 3.11.1 (12 February 2021) diff --git a/pymc3/aesaraf.py b/pymc3/aesaraf.py index 87b370e55f..ce9bc93cd5 100644 --- a/pymc3/aesaraf.py +++ b/pymc3/aesaraf.py @@ -238,7 +238,12 @@ def make_shared_replacements(vars, model): Dict of variable -> new shared variable """ othervars = set(model.vars) - set(vars) - return {var: aesara.shared(var.tag.test_value, var.name + "_shared") for var in othervars} + return { + var: aesara.shared( + var.tag.test_value, var.name + "_shared", broadcastable=var.broadcastable + ) + for var in othervars + } def join_nonshared_inputs(xs, vars, shared, make_shared=False): diff --git a/pymc3/tests/test_aesaraf.py b/pymc3/tests/test_aesaraf.py index 1b591e0a85..597d6f3183 100644 --- a/pymc3/tests/test_aesaraf.py +++ b/pymc3/tests/test_aesaraf.py @@ -21,6 +21,8 @@ from aesara.tensor.type import TensorType +import pymc3 as pm + from pymc3.aesaraf import _conversion_map, take_along_axis from pymc3.vartypes import int_types @@ -28,6 +30,29 @@ INTX = str(_conversion_map[FLOATX]) +class TestBroadcasting: + def test_make_shared_replacements(self): + """Check if pm.make_shared_replacements preserves broadcasting.""" + + with pm.Model() as test_model: + test1 = pm.Normal("test1", mu=0.0, sigma=1.0, shape=(1, 10)) + test2 = pm.Normal("test2", mu=0.0, sigma=1.0, shape=(10, 1)) + + # Replace test1 with a shared variable, keep test 2 the same + replacement = pm.make_shared_replacements([test_model.test2], test_model) + assert test_model.test1.broadcastable == replacement[test_model.test1].broadcastable + + def test_metropolis_sampling(self): + """Check if the Metropolis sampler can handle broadcasting.""" + with pm.Model() as test_model: + test1 = pm.Normal("test1", mu=0.0, sigma=1.0, shape=(1, 10)) + test2 = pm.Normal("test2", mu=test1, sigma=1.0, shape=(10, 10)) + + step = pm.Metropolis() + # This should fail immediately if broadcasting does not work. + pm.sample(tune=5, draws=7, cores=1, step=step, compute_convergence_checks=False) + + def _make_along_axis_idx(arr_shape, indices, axis): # compute dimensions to iterate over if str(indices.dtype) not in int_types: