diff --git a/pymc/model/core.py b/pymc/model/core.py index 5a7b2651cf..87b4df23f0 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -40,7 +40,7 @@ from pytensor.tensor.variable import TensorConstant, TensorVariable from pymc.blocking import DictToArrayBijection, RaveledVars -from pymc.data import is_valid_observed +from pymc.data import MinibatchOp, is_valid_observed from pymc.exceptions import ( BlockModelAccessError, ImputationWarning, @@ -1241,6 +1241,15 @@ def register_rv( self.add_named_variable(rv_var, dims) self.set_initval(rv_var, initval) else: + if ( + isinstance(observed, TensorVariable) + and observed.owner is not None + and isinstance(observed.owner.op, MinibatchOp) + and total_size is None + ): + warnings.warn( + f"total_size not provided for observed variable `{name}` that uses pm.Minibatch" + ) if not is_valid_observed(observed): raise TypeError( "Variables that depend on other nodes cannot be used for observed data." diff --git a/tests/variational/test_minibatch_rv.py b/tests/variational/test_minibatch_rv.py index 33229e0bb7..6ece053852 100644 --- a/tests/variational/test_minibatch_rv.py +++ b/tests/variational/test_minibatch_rv.py @@ -112,6 +112,13 @@ def test_random(self): assert mx is not x np.testing.assert_array_equal(draw(mx, random_seed=1), draw(x, random_seed=1)) + def test_warning_on_missing_total_size(self): + total_size = 1000 + with pytest.warns(match="total_size not provided"): + with pm.Model() as m: + MB = pm.Minibatch(np.arange(total_size, dtype="float64"), batch_size=100) + pm.Normal("n", observed=MB) + @pytest.mark.filterwarnings("error") def test_minibatch_parameter_and_value(self): rng = np.random.default_rng(161)