12
12
from pytensor .compile import get_mode
13
13
from pytensor .compile .ops import deep_copy_op
14
14
from pytensor .gradient import grad
15
- from pytensor .scalar import float64
15
+ from pytensor .scalar import Composite , float64
16
16
from pytensor .tensor .elemwise import CAReduce , DimShuffle , Elemwise
17
17
from pytensor .tensor .math import All , Any , Max , Min , Prod , ProdWithoutZeros , Sum
18
18
from pytensor .tensor .special import LogSoftmax , Softmax , SoftmaxGrad
@@ -548,7 +548,7 @@ def test_Argmax(x, axes, exc):
548
548
)
549
549
550
550
551
- def test_elemwise_out_type ():
551
+ def test_elemwise_inplace_out_type ():
552
552
# Create a graph with an elemwise
553
553
# Ravel failes if the elemwise output type is reported incorrectly
554
554
x = pt .matrix ()
@@ -563,6 +563,28 @@ def test_elemwise_out_type():
563
563
assert func (x_val ).shape == (18 ,)
564
564
565
565
566
+ def test_elemwise_multiple_inplace_outs ():
567
+ x = pt .vector ()
568
+ y = pt .vector ()
569
+
570
+ x_ = pt .scalar_from_tensor (x [0 ])
571
+ y_ = pt .scalar_from_tensor (y [0 ])
572
+ out_ = x_ + 1 , y_ + 1
573
+
574
+ composite_op = Composite ([x_ , y_ ], out_ )
575
+ elemwise_op = Elemwise (composite_op , inplace_pattern = {0 : 0 , 1 : 1 })
576
+ out = elemwise_op (x , y )
577
+
578
+ fn = function ([x , y ], out , mode = "NUMBA" , accept_inplace = True )
579
+ x_test = np .array ([1 , 2 , 3 ], dtype = config .floatX )
580
+ y_test = np .array ([4 , 5 , 6 ], dtype = config .floatX )
581
+ out1 , out2 = fn (x_test , y_test )
582
+ assert out1 is x_test
583
+ assert out2 is y_test
584
+ np .testing .assert_allclose (out1 , [2 , 3 , 4 ])
585
+ np .testing .assert_allclose (out2 , [5 , 6 , 7 ])
586
+
587
+
566
588
def test_scalar_loop ():
567
589
a = float64 ("a" )
568
590
scalar_loop = pytensor .scalar .ScalarLoop ([a ], [a + a ])
0 commit comments