Skip to content

Commit 5fc2cb8

Browse files
committed
Fix bug in Numba inplace vectorize code with multiple outputs
1 parent 789b509 commit 5fc2cb8

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

Diff for: pytensor/link/numba/dispatch/vectorize_codegen.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def codegen(
265265
ctx.nrt.incref(
266266
builder,
267267
sig.return_type.types[inplace_idx],
268-
outputs[inplace_idx]._get_value(),
268+
outputs[inplace_idx]._getvalue(),
269269
)
270270
return ctx.make_tuple(
271271
builder, sig.return_type, [out._getvalue() for out in outputs]

Diff for: tests/link/numba/test_elemwise.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.compile import get_mode
1313
from pytensor.compile.ops import deep_copy_op
1414
from pytensor.gradient import grad
15-
from pytensor.scalar import float64
15+
from pytensor.scalar import Composite, float64
1616
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
1717
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
1818
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@@ -548,7 +548,7 @@ def test_Argmax(x, axes, exc):
548548
)
549549

550550

551-
def test_elemwise_out_type():
551+
def test_elemwise_inplace_out_type():
552552
# Create a graph with an elemwise
553553
# Ravel failes if the elemwise output type is reported incorrectly
554554
x = pt.matrix()
@@ -563,6 +563,28 @@ def test_elemwise_out_type():
563563
assert func(x_val).shape == (18,)
564564

565565

566+
def test_elemwise_multiple_inplace_outs():
567+
x = pt.vector()
568+
y = pt.vector()
569+
570+
x_ = pt.scalar_from_tensor(x[0])
571+
y_ = pt.scalar_from_tensor(y[0])
572+
out_ = x_ + 1, y_ + 1
573+
574+
composite_op = Composite([x_, y_], out_)
575+
elemwise_op = Elemwise(composite_op, inplace_pattern={0: 0, 1: 1})
576+
out = elemwise_op(x, y)
577+
578+
fn = function([x, y], out, mode="NUMBA", accept_inplace=True)
579+
x_test = np.array([1, 2, 3], dtype=config.floatX)
580+
y_test = np.array([4, 5, 6], dtype=config.floatX)
581+
out1, out2 = fn(x_test, y_test)
582+
assert out1 is x_test
583+
assert out2 is y_test
584+
np.testing.assert_allclose(out1, [2, 3, 4])
585+
np.testing.assert_allclose(out2, [5, 6, 7])
586+
587+
566588
def test_scalar_loop():
567589
a = float64("a")
568590
scalar_loop = pytensor.scalar.ScalarLoop([a], [a + a])

0 commit comments

Comments
 (0)