@@ -451,6 +451,80 @@ def step(seq1, seq2, mitsot1, mitsot2, sitsot1):
451
451
benchmark (numba_fn , * test .values ())
452
452
453
453
454
+ @pytest .mark .parametrize ("n_steps_constant" , (True , False ))
455
+ def test_inplace_taps (n_steps_constant ):
456
+ """Test that numba will inplace in the inner_function of the oldest sit-sot, mit-sot taps."""
457
+ n_steps = 10 if n_steps_constant else scalar ("n_steps" , dtype = int )
458
+ a = scalar ("a" )
459
+ x0 = scalar ("x0" )
460
+ y0 = vector ("y0" , shape = (2 ,))
461
+ z0 = vector ("z0" , shape = (3 ,))
462
+
463
+ def step (ztm3 , ztm1 , xtm1 , ytm1 , ytm2 , a ):
464
+ z = ztm1 + 1 + ztm3 + a
465
+ x = xtm1 + 1
466
+ y = ytm1 + 1 + ytm2 + a
467
+ return z , x , z + x + y , y
468
+
469
+ [zs , xs , ws , ys ], _ = scan (
470
+ fn = step ,
471
+ outputs_info = [
472
+ dict (initial = z0 , taps = [- 3 , - 1 ]),
473
+ dict (initial = x0 , taps = [- 1 ]),
474
+ None ,
475
+ dict (initial = y0 , taps = [- 1 , - 2 ]),
476
+ ],
477
+ non_sequences = [a ],
478
+ n_steps = n_steps ,
479
+ )
480
+ numba_fn , _ = compare_numba_and_py (
481
+ [n_steps ] * (not n_steps_constant ) + [a , x0 , y0 , z0 ],
482
+ [zs [- 1 ], xs [- 1 ], ws [- 1 ], ys [- 1 ]],
483
+ [10 ] * (not n_steps_constant ) + [np .pi , np .e , [1 , np .euler_gamma ], [0 , 1 , 2 ]],
484
+ numba_mode = "NUMBA" ,
485
+ eval_obj_mode = False ,
486
+ )
487
+ [scan_op ] = [
488
+ node .op
489
+ for node in numba_fn .maker .fgraph .toposort ()
490
+ if isinstance (node .op , Scan )
491
+ ]
492
+
493
+ # Scan reorders inputs internally, so we need to check its ordering
494
+ inner_inps = scan_op .fgraph .inputs
495
+ mit_sot_inps = scan_op .inner_mitsot (inner_inps )
496
+ oldest_mit_sot_inps = [
497
+ # Implicitly assume that the first mit-sot input is the one with 3 taps
498
+ # This is not a required behavior and the test can change if we need to change Scan.
499
+ mit_sot_inps [:2 ][scan_op .info .mit_sot_in_slices [0 ].index (- 3 )],
500
+ mit_sot_inps [2 :][scan_op .info .mit_sot_in_slices [1 ].index (- 2 )],
501
+ ]
502
+ [sit_sot_inp ] = scan_op .inner_sitsot (inner_inps )
503
+
504
+ inner_outs = scan_op .fgraph .outputs
505
+ mit_sot_outs = scan_op .inner_mitsot_outs (inner_outs )
506
+ [sit_sot_out ] = scan_op .inner_sitsot_outs (inner_outs )
507
+ [nit_sot_out ] = scan_op .inner_nitsot_outs (inner_outs )
508
+
509
+ if n_steps_constant :
510
+ assert mit_sot_outs [0 ].owner .op .destroy_map == {
511
+ 0 : [mit_sot_outs [0 ].owner .inputs .index (oldest_mit_sot_inps [0 ])]
512
+ }
513
+ assert mit_sot_outs [1 ].owner .op .destroy_map == {
514
+ 0 : [mit_sot_outs [1 ].owner .inputs .index (oldest_mit_sot_inps [1 ])]
515
+ }
516
+ assert sit_sot_out .owner .op .destroy_map == {
517
+ 0 : [sit_sot_out .owner .inputs .index (sit_sot_inp )]
518
+ }
519
+ else :
520
+ # This is not a feature, but a current limitation
521
+ # https://github.com/pymc-devs/pytensor/issues/1283
522
+ assert mit_sot_outs [0 ].owner .op .destroy_map == {}
523
+ assert mit_sot_outs [1 ].owner .op .destroy_map == {}
524
+ assert sit_sot_out .owner .op .destroy_map == {}
525
+ assert nit_sot_out .owner .op .destroy_map == {}
526
+
527
+
454
528
@pytest .mark .parametrize (
455
529
"buffer_size" , ("unit" , "aligned" , "misaligned" , "whole" , "whole+init" )
456
530
)
0 commit comments