@@ -55,7 +55,7 @@ def range_arr(x):
55
55
56
56
57
57
@numba_funcify .register (Scan )
58
- def numba_funcify_Scan (op , node , ** kwargs ):
58
+ def numba_funcify_Scan (op : Scan , node , ** kwargs ):
59
59
# Apply inner rewrites
60
60
# TODO: Not sure this is the right place to do this, should we have a rewrite that
61
61
# explicitly triggers the optimization of the inner graphs of Scan?
@@ -67,9 +67,32 @@ def numba_funcify_Scan(op, node, **kwargs):
67
67
.optimizer
68
68
)
69
69
fgraph = op .fgraph
70
+ # When the buffer can only hold one SITSOT or as as many MITSOT as there are taps,
71
+ # We must always discard the oldest tap, so it's safe to destroy it in the inner function.
72
+ # TODO: Allow inplace for MITMOT
73
+ destroyable_sitsot = [
74
+ inner_sitsot
75
+ for outer_sitsot , inner_sitsot in zip (
76
+ op .outer_sitsot (node .inputs ), op .inner_sitsot (fgraph .inputs ), strict = True
77
+ )
78
+ if outer_sitsot .type .shape [0 ] == 1
79
+ ]
80
+ destroyable_mitsot = [
81
+ oldest_inner_mitmot
82
+ for outer_mitsot , oldest_inner_mitmot , taps in zip (
83
+ op .outer_mitsot (node .inputs ),
84
+ op .oldest_inner_mitsot (fgraph .inputs ),
85
+ op .info .mit_sot_in_slices ,
86
+ strict = True ,
87
+ )
88
+ if outer_mitsot .type .shape [0 ] == abs (min (taps ))
89
+ ]
90
+ destroyable = {* destroyable_sitsot , * destroyable_mitsot }
70
91
add_supervisor_to_fgraph (
71
92
fgraph = fgraph ,
72
- input_specs = [In (x , borrow = True , mutable = False ) for x in fgraph .inputs ],
93
+ input_specs = [
94
+ In (x , borrow = True , mutable = x in destroyable ) for x in fgraph .inputs
95
+ ],
73
96
accept_inplace = True ,
74
97
)
75
98
rewriter (fgraph )
0 commit comments