|
19 | 19 | import theano
|
20 | 20 | import theano.tensor as tt
|
21 | 21 |
|
| 22 | +import pymc3 as pm |
| 23 | + |
22 | 24 | from pymc3.theanof import _conversion_map, take_along_axis
|
23 | 25 | from pymc3.vartypes import int_types
|
24 | 26 |
|
25 | 27 | FLOATX = str(theano.config.floatX)
|
26 | 28 | INTX = str(_conversion_map[FLOATX])
|
27 | 29 |
|
28 | 30 |
|
| 31 | +class TestBroadcasting: |
| 32 | + def test_make_shared_replacements(self): |
| 33 | + """Check if pm.make_shared_replacements preserves broadcasting.""" |
| 34 | + |
| 35 | + with pm.Model() as test_model: |
| 36 | + test1 = pm.Normal("test1", mu=0.0, sigma=1.0, shape=(1, 10)) |
| 37 | + test2 = pm.Normal("test2", mu=0.0, sigma=1.0, shape=(10, 1)) |
| 38 | + |
| 39 | + # Replace test1 with a shared variable, keep test 2 the same |
| 40 | + replacement = pm.make_shared_replacements([test_model.test2], test_model) |
| 41 | + assert test_model.test1.broadcastable == replacement[test_model.test1].broadcastable |
| 42 | + |
| 43 | + def test_metropolis_sampling(self): |
| 44 | + """Check if the Metropolis sampler can handle broadcasting.""" |
| 45 | + with pm.Model() as test_model: |
| 46 | + test1 = pm.Normal("test1", mu=0.0, sigma=1.0, shape=(1, 10)) |
| 47 | + test2 = pm.Normal("test2", mu=test1, sigma=1.0, shape=(10, 10)) |
| 48 | + |
| 49 | + step = pm.Metropolis() |
| 50 | + # This should fail immediately if broadcasting does not work. |
| 51 | + pm.sample(tune=5, draws=7, cores=1, step=step, compute_convergence_checks=False) |
| 52 | + |
| 53 | + |
29 | 54 | def _make_along_axis_idx(arr_shape, indices, axis):
|
30 | 55 | # compute dimensions to iterate over
|
31 | 56 | if str(indices.dtype) not in int_types:
|
|
0 commit comments