Skip to content

Commit 03448f7

Browse files
Keep broadcasting information in make_shared_replacements (pymc-devs#4526)
1 parent 9218f33 commit 03448f7

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

RELEASE-NOTES.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
+ ...
1010

1111
### Maintenance
12-
- ⚠ Our memoization mechanism wasn't robust against hash collisions (#4506), sometimes resulting in incorrect values in, for example, posterior predictives. The `pymc3.memoize` module was removed and replaced with `cachetools`. The `hashable` function and `WithMemoization` class were moved to `pymc3.util`.
13-
- ...
12+
- ⚠ Our memoization mechanism wasn't robust against hash collisions (#4506), sometimes resulting in incorrect values in, for example, posterior predictives. The `pymc3.memoize` module was removed and replaced with `cachetools`. The `hashable` function and `WithMemoization` class were moved to `pymc3.util` (see #4525).
13+
- `pm.make_shared_replacements` now retains broadcasting information which fixes issues with Metropolis samplers (see [#4492](https://github.com/pymc-devs/pymc3/pull/4492)).
1414

1515
## PyMC3 3.11.1 (12 February 2021)
1616

pymc3/tests/test_theanof.py

+25
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,38 @@
1919
import theano
2020
import theano.tensor as tt
2121

22+
import pymc3 as pm
23+
2224
from pymc3.theanof import _conversion_map, take_along_axis
2325
from pymc3.vartypes import int_types
2426

2527
FLOATX = str(theano.config.floatX)
2628
INTX = str(_conversion_map[FLOATX])
2729

2830

31+
class TestBroadcasting:
32+
def test_make_shared_replacements(self):
33+
"""Check if pm.make_shared_replacements preserves broadcasting."""
34+
35+
with pm.Model() as test_model:
36+
test1 = pm.Normal("test1", mu=0.0, sigma=1.0, shape=(1, 10))
37+
test2 = pm.Normal("test2", mu=0.0, sigma=1.0, shape=(10, 1))
38+
39+
# Replace test1 with a shared variable, keep test 2 the same
40+
replacement = pm.make_shared_replacements([test_model.test2], test_model)
41+
assert test_model.test1.broadcastable == replacement[test_model.test1].broadcastable
42+
43+
def test_metropolis_sampling(self):
44+
"""Check if the Metropolis sampler can handle broadcasting."""
45+
with pm.Model() as test_model:
46+
test1 = pm.Normal("test1", mu=0.0, sigma=1.0, shape=(1, 10))
47+
test2 = pm.Normal("test2", mu=test1, sigma=1.0, shape=(10, 10))
48+
49+
step = pm.Metropolis()
50+
# This should fail immediately if broadcasting does not work.
51+
pm.sample(tune=5, draws=7, cores=1, step=step, compute_convergence_checks=False)
52+
53+
2954
def _make_along_axis_idx(arr_shape, indices, axis):
3055
# compute dimensions to iterate over
3156
if str(indices.dtype) not in int_types:

pymc3/theanof.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,12 @@ def make_shared_replacements(vars, model):
235235
Dict of variable -> new shared variable
236236
"""
237237
othervars = set(model.vars) - set(vars)
238-
return {var: theano.shared(var.tag.test_value, var.name + "_shared") for var in othervars}
238+
return {
239+
var: theano.shared(
240+
var.tag.test_value, var.name + "_shared", broadcastable=var.broadcastable
241+
)
242+
for var in othervars
243+
}
239244

240245

241246
def join_nonshared_inputs(xs, vars, shared, make_shared=False):

0 commit comments

Comments
 (0)