Skip to content

Commit 2d3ec8f

Browse files
authored
Add tests for tt.switch related bugs (#4448)
* Add tests for edge cases * Add release-note
1 parent 03d7af5 commit 2d3ec8f

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
### Maintenance
1111
- We upgraded to `Theano-PyMC v1.1.2` which [includes bugfixes](https://github.com/pymc-devs/aesara/compare/rel-1.1.0...rel-1.1.2) for warning floods and compiledir locking (see [#4444](https://github.com/pymc-devs/pymc3/pull/4444))
12+
- `Theano-PyMC v1.1.2` also fixed an important issue in `tt.switch` that affected the behavior of several PyMC distributions, including at least the `Bernoulli` and `TruncatedNormal` (see[#4448](https://github.com/pymc-devs/pymc3/pull/4448))
1213
- `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)).
1314

1415
## PyMC3 3.11.0 (21 January 2021)

pymc3/tests/test_model.py

+28
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,34 @@ def test_tensor_type_conversion(self):
366366

367367
assert m["x2_missing"].type == gf._extra_vars_shared["x2_missing"].type
368368

369+
def test_theano_switch_broadcast_edge_cases(self):
370+
# Tests against two subtle issues related to a previous bug in Theano where tt.switch would not
371+
# always broadcast tensors with single values https://github.com/pymc-devs/aesara/issues/270
372+
373+
# Known issue 1: https://github.com/pymc-devs/pymc3/issues/4389
374+
data = np.zeros(10)
375+
with pm.Model() as m:
376+
p = pm.Beta("p", 1, 1)
377+
obs = pm.Bernoulli("obs", p=p, observed=data)
378+
# Assert logp is correct
379+
npt.assert_allclose(
380+
obs.logp(m.test_point),
381+
np.log(0.5) * 10,
382+
)
383+
384+
# Known issue 2: https://github.com/pymc-devs/pymc3/issues/4417
385+
# fmt: off
386+
data = np.array([
387+
1.35202174, -0.83690274, 1.11175166, 1.29000367, 0.21282749,
388+
0.84430966, 0.24841369, 0.81803141, 0.20550244, -0.45016253,
389+
])
390+
# fmt: on
391+
with pm.Model() as m:
392+
mu = pm.Normal("mu", 0, 5)
393+
obs = pm.TruncatedNormal("obs", mu=mu, sigma=1, lower=-1, upper=2, observed=data)
394+
# Assert dlogp is correct
395+
npt.assert_allclose(m.dlogp([mu])({"mu": 0}), 2.499424682024436, rtol=1e-5)
396+
369397

370398
def test_multiple_observed_rv():
371399
"Test previously buggy MultiObservedRV comparison code."

0 commit comments

Comments
 (0)