Skip to content

Commit b06c928

Browse files
Use aesara.tensor.atleast_1d in pymc3.aesaraf.change_rv_size
1 parent 0b8bed3 commit b06c928

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

pymc3/aesaraf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def change_rv_size(
160160
if expand:
161161
if rv_node.op.ndim_supp == 0 and at.get_vector_length(size) == 0:
162162
size = rv_node.op._infer_shape(size, dist_params)
163-
new_size = tuple(np.atleast_1d(new_size)) + tuple(size)
163+
new_size = tuple(at.atleast_1d(new_size)) + tuple(size)
164164

165165
# Make sure the new size is a tensor. This helps to not unnecessarily pick
166166
# up a `Cast` in some cases

pymc3/tests/test_aesaraf.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ def test_change_rv_size():
5151
loc = at.as_tensor_variable([1, 2])
5252
rv = normal(loc=loc)
5353
assert rv.ndim == 1
54-
assert rv.eval().shape == (2,)
54+
assert tuple(rv.shape.eval()) == (2,)
5555

5656
rv_new = change_rv_size(rv, new_size=(3,), expand=True)
5757
assert rv_new.ndim == 2
58-
assert rv_new.eval().shape == (3, 2)
58+
assert tuple(rv_new.shape.eval()) == (3, 2)
5959

6060
# Make sure that the shape used to determine the expanded size doesn't
6161
# depend on the old `RandomVariable`.
@@ -65,7 +65,7 @@ def test_change_rv_size():
6565

6666
rv_newer = change_rv_size(rv_new, new_size=(4,), expand=True)
6767
assert rv_newer.ndim == 3
68-
assert rv_newer.eval().shape == (4, 3, 2)
68+
assert tuple(rv_newer.shape.eval()) == (4, 3, 2)
6969

7070
# Make sure we avoid introducing a `Cast` by converting the new size before
7171
# constructing the new `RandomVariable`
@@ -74,7 +74,19 @@ def test_change_rv_size():
7474
rv_newer = change_rv_size(rv, new_size=new_size, expand=False)
7575
assert rv_newer.ndim == 2
7676
assert isinstance(rv_newer.owner.inputs[1], Constant)
77-
assert rv_newer.eval().shape == (4, 3)
77+
assert tuple(rv_newer.shape.eval()) == (4, 3)
78+
79+
rv = normal(0, 1)
80+
new_size = at.as_tensor(np.array([4, 3], dtype="int32"))
81+
rv_newer = change_rv_size(rv, new_size=new_size, expand=True)
82+
assert rv_newer.ndim == 2
83+
assert tuple(rv_newer.shape.eval()) == (4, 3)
84+
85+
rv = normal(0, 1)
86+
new_size = at.as_tensor(2, dtype="int32")
87+
rv_newer = change_rv_size(rv, new_size=new_size, expand=True)
88+
assert rv_newer.ndim == 1
89+
assert tuple(rv_newer.shape.eval()) == (2,)
7890

7991

8092
class TestBroadcasting:

0 commit comments

Comments
 (0)