Skip to content

Commit afb7695

Browse files
authored
Fix indexing in convolve1d with mode="same" (#1337)
1 parent 3af923b commit afb7695

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

Diff for: pytensor/tensor/signal/conv.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ def convolve1d(
119119
if mode == "same":
120120
# We implement "same" as "valid" with padded `in1`.
121121
in1_batch_shape = tuple(in1.shape)[:-1]
122-
zeros_left = in2.shape[0] // 2
123-
zeros_right = (in2.shape[0] - 1) // 2
122+
zeros_left = in2.shape[-1] // 2
123+
zeros_right = (in2.shape[-1] - 1) // 2
124124
in1 = join(
125125
-1,
126126
zeros((*in1_batch_shape, zeros_left), dtype=in2.dtype),

Diff for: tests/tensor/signal/test_conv.py

+13
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,16 @@ def test_convolve1d_batch():
4747
res_np = np.convolve(x_test[0], y_test[0])
4848
np.testing.assert_allclose(res[0], res_np, rtol=rtol)
4949
np.testing.assert_allclose(res[1], res_np, rtol=rtol)
50+
51+
52+
def test_convolve1d_batch_same():
53+
x = matrix("data")
54+
y = matrix("kernel")
55+
out = convolve1d(x, y, mode="same")
56+
57+
rng = np.random.default_rng(38)
58+
x_test = rng.normal(size=(2, 8)).astype(x.dtype)
59+
y_test = rng.normal(size=(2, 8)).astype(x.dtype)
60+
61+
res = out.eval({x: x_test, y: y_test})
62+
assert res.shape == (2, 8)

0 commit comments

Comments
 (0)