diff --git a/pymc3/aesaraf.py b/pymc3/aesaraf.py index e30248b841..8cd422aedc 100644 --- a/pymc3/aesaraf.py +++ b/pymc3/aesaraf.py @@ -156,6 +156,10 @@ def change_rv_size( size = rv_node.op._infer_shape(size, dist_params) new_size = tuple(np.atleast_1d(new_size)) + tuple(size) + # Make sure the new size is a tensor. This helps to not unnecessarily pick + # up a `Cast` in some cases + new_size = at.as_tensor(new_size, ndim=1, dtype="int64") + new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params) rv_var = new_rv_node.outputs[-1] rv_var.name = name diff --git a/pymc3/model.py b/pymc3/model.py index 17ad367043..7567b9ac23 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -1116,8 +1116,6 @@ def register_rv(self, rv_var, name, data=None, total_size=None, dims=None, trans ): raise TypeError("Observed data cannot consist of symbolic variables.") - data = pandas_to_array(data) - # `rv_var` is potentially a new variable (e.g. the original # variable could have its size changed to match the data, or be a # new graph that accounts for missing data) diff --git a/pymc3/tests/test_aesaraf.py b/pymc3/tests/test_aesaraf.py index 4e923212a7..d82583fb42 100644 --- a/pymc3/tests/test_aesaraf.py +++ b/pymc3/tests/test_aesaraf.py @@ -23,7 +23,7 @@ import pytest import scipy.sparse as sps -from aesara.graph.basic import Variable, ancestors +from aesara.graph.basic import Constant, Variable, ancestors from aesara.tensor.random.basic import normal, uniform from aesara.tensor.random.op import RandomVariable from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1 @@ -67,6 +67,15 @@ def test_change_rv_size(): assert rv_newer.ndim == 3 assert rv_newer.eval().shape == (4, 3, 2) + # Make sure we avoid introducing a `Cast` by converting the new size before + # constructing the new `RandomVariable` + rv = normal(0, 1) + new_size = np.array([4, 3], dtype="int32") + rv_newer = change_rv_size(rv, new_size=new_size, expand=False) + assert rv_newer.ndim == 2 + assert isinstance(rv_newer.owner.inputs[1], Constant) + assert rv_newer.eval().shape == (4, 3) + class TestBroadcasting: def test_make_shared_replacements(self): diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index 15d7c11237..2759f2d3b5 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -379,14 +379,10 @@ def check_rv_size(self): sizes_to_check = self.sizes_to_check or [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)] sizes_expected = self.sizes_expected or [(), (), (1,), (1,), (5,), (4, 5), (2, 4, 2)] for size, expected in zip(sizes_to_check, sizes_expected): - actual = change_rv_size(self.pymc_rv, size).eval().shape + pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size) + actual = tuple(pymc_rv.shape.eval()) assert actual == expected, f"size={size}, expected={expected}, actual={actual}" - # test negative sizes raise - for size in [-2, (3, -2)]: - with pytest.raises(ValueError): - change_rv_size(self.pymc_rv, size).eval() - # test multi-parameters sampling for univariate distributions (with univariate inputs) if self.pymc_dist.rv_op.ndim_supp == 0 and sum(self.pymc_dist.rv_op.ndims_params) == 0: params = { @@ -400,7 +396,8 @@ def check_rv_size(self): (5, self.repeated_params_shape), ] for size, expected in zip(sizes_to_check, sizes_expected): - actual = change_rv_size(self.pymc_rv, size).eval().shape + pymc_rv = self.pymc_dist.dist(**params, size=size) + actual = tuple(pymc_rv.shape.eval()) assert actual == expected def validate_tests_list(self):