Skip to content

Commit ec18b5d

Browse files
committed
Make nested models share coords with parents
1 parent b29124b commit ec18b5d

File tree

3 files changed

+22
-4
lines changed

3 files changed

+22
-4
lines changed

Diff for: RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
120120
- New named dimensions can be introduced to the model via `pm.Data(..., dims=...)`. For mutable data variables (see above) the lengths of these dimensions are symbolic, so they can be re-sized via `pm.set_data()`.
121121
- `pm.Data` now passes additional kwargs to `aesara.shared`/`at.as_tensor`. [#5098](https://github.com/pymc-devs/pymc/pull/5098).
122122
- Univariate censored distributions are now available via `pm.Censored`. [#5169](https://github.com/pymc-devs/pymc/pull/5169)
123+
- Nested models now inherit the parent model's coordinates. [#5344](https://github.com/pymc-devs/pymc/pull/5344)
123124
- ...
124125

125126

Diff for: pymc/model.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -628,10 +628,6 @@ def __init__(
628628
rng_seeder: Optional[Union[int, np.random.RandomState]] = None,
629629
):
630630
self.name = name
631-
self._coords = {}
632-
self._RV_dims = {}
633-
self._dim_lengths = {}
634-
self.add_coords(coords)
635631
self.check_bounds = check_bounds
636632

637633
if rng_seeder is None:
@@ -654,6 +650,9 @@ def __init__(
654650
self.auto_deterministics = treelist(parent=self.parent.auto_deterministics)
655651
self.deterministics = treelist(parent=self.parent.deterministics)
656652
self.potentials = treelist(parent=self.parent.potentials)
653+
self._coords = self.parent._coords
654+
self._RV_dims = treedict(parent=self.parent._RV_dims)
655+
self._dim_lengths = self.parent._dim_lengths
657656
else:
658657
self.named_vars = treedict()
659658
self.values_to_rvs = treedict()
@@ -663,6 +662,10 @@ def __init__(
663662
self.auto_deterministics = treelist()
664663
self.deterministics = treelist()
665664
self.potentials = treelist()
665+
self._coords = {}
666+
self._RV_dims = treedict()
667+
self._dim_lengths = {}
668+
self.add_coords(coords)
666669

667670
from pymc.printing import str_for_model
668671

Diff for: pymc/tests/test_model.py

+14
Original file line numberDiff line numberDiff line change
@@ -651,3 +651,17 @@ def test_datalogpt_multiple_shapes():
651651
# This would raise a TypeError, see #4803 and #4804
652652
x_val = m.rvs_to_values[x]
653653
m.datalogpt.eval({x_val: 0})
654+
655+
656+
def test_nested_model_coords():
657+
COORDS = {"dim": range(10)}
658+
with pm.Model(name="m1", coords=COORDS) as m1:
659+
a = pm.Normal("a")
660+
with pm.Model(name="m2") as m2:
661+
b = pm.Normal("b")
662+
c = pm.HalfNormal("c")
663+
d = pm.Normal("d", b, c, dims="dim")
664+
e = pm.Normal("e", a + d, dims="dim")
665+
assert m1.coords is m2.coords
666+
assert m1.dim_lengths is m2.dim_lengths
667+
assert set(m2.RV_dims) < set(m1.RV_dims)

0 commit comments

Comments
 (0)