Skip to content

Commit c217727

Browse files
committed
Only add dim_length shared var if new coordinate
1 parent 360cb6e commit c217727

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

pymc/model/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,6 +964,7 @@ def add_coord(
964964
if name in self.coords:
965965
if not np.array_equal(values, self.coords[name]):
966966
raise ValueError(f"Duplicate and incompatible coordinate: {name}.")
967+
return
967968
if length is not None and not isinstance(length, int | Variable):
968969
raise ValueError(
969970
f"The `length` passed for the '{name}' coord must be an int, PyTensor Variable or None."

tests/model/test_core.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import scipy.stats as st
3434

3535
from pytensor.graph import graph_inputs
36+
from pytensor.graph.basic import get_var_by_name
3637
from pytensor.raise_op import Assert
3738
from pytensor.tensor.random.op import RandomVariable
3839
from pytensor.tensor.variable import TensorConstant
@@ -858,6 +859,19 @@ def test_nested_model_coords():
858859
assert set(m2.named_vars_to_dims) < set(m1.named_vars_to_dims)
859860

860861

862+
def test_multiple_add_coords_with_same_name():
863+
coord = {"dim1": ["a", "b", "c"]}
864+
with pm.Model(coords=coord) as m:
865+
a = pm.Normal("a", dims="dim1")
866+
with pm.Model(coords=coord) as nested_m:
867+
b = pm.Normal("b", dims="dim1")
868+
m.add_coords(coord)
869+
c = pm.Normal("c", dims="dim1")
870+
d = pm.Deterministic("d", a + b + c)
871+
variables = get_var_by_name([d], "dim1")
872+
assert len(variables) == 1 and variables[0] is m.dim_lengths["dim1"]
873+
874+
861875
class TestSetUpdateCoords:
862876
def test_shapeerror_from_set_data_dimensionality(self):
863877
with pm.Model() as pmodel:

0 commit comments

Comments
 (0)