Skip to content

Commit e1cbe76

Browse files
committed
Allow inplacing of SITSOT and last MITSOT in numba Scan, when they are discarded immediately
1 parent d2f0948 commit e1cbe76

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

Diff for: pytensor/link/numba/dispatch/scan.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def range_arr(x):
5555

5656

5757
@numba_funcify.register(Scan)
58-
def numba_funcify_Scan(op, node, **kwargs):
58+
def numba_funcify_Scan(op: Scan, node, **kwargs):
5959
# Apply inner rewrites
6060
# TODO: Not sure this is the right place to do this, should we have a rewrite that
6161
# explicitly triggers the optimization of the inner graphs of Scan?
@@ -67,9 +67,32 @@ def numba_funcify_Scan(op, node, **kwargs):
6767
.optimizer
6868
)
6969
fgraph = op.fgraph
70+
# When the buffer can only hold one SITSOT or as as many MITSOT as there are taps,
71+
# We must always discard the oldest tap, so it's safe to destroy it in the inner function.
72+
# TODO: Allow inplace for MITMOT
73+
destroyable_sitsot = [
74+
inner_sitsot
75+
for outer_sitsot, inner_sitsot in zip(
76+
op.outer_sitsot(node.inputs), op.inner_sitsot(fgraph.inputs), strict=True
77+
)
78+
if outer_sitsot.type.shape[0] == 1
79+
]
80+
destroyable_mitsot = [
81+
oldest_inner_mitmot
82+
for outer_mitsot, oldest_inner_mitmot, taps in zip(
83+
op.outer_mitsot(node.inputs),
84+
op.oldest_inner_mitsot(fgraph.inputs),
85+
op.info.mit_sot_in_slices,
86+
strict=True,
87+
)
88+
if outer_mitsot.type.shape[0] == abs(min(taps))
89+
]
90+
destroyable = {*destroyable_sitsot, *destroyable_mitsot}
7091
add_supervisor_to_fgraph(
7192
fgraph=fgraph,
72-
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
93+
input_specs=[
94+
In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs
95+
],
7396
accept_inplace=True,
7497
)
7598
rewriter(fgraph)

Diff for: pytensor/scan/op.py

+10
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,16 @@ def inner_mitsot(self, list_inputs):
321321
self.info.n_seqs + n_mitmot_taps : self.info.n_seqs + ntaps_upto_sit_sot
322322
]
323323

324+
def oldest_inner_mitsot(self, list_inputs):
325+
inner_mitsot_inputs = self.inner_mitsot(list_inputs)
326+
oldest_inner_mitsot_inputs = []
327+
offset = 0
328+
for taps in self.info.mit_sot_in_slices:
329+
oldest_tap = np.argmin(taps)
330+
oldest_inner_mitsot_inputs += [inner_mitsot_inputs[offset + oldest_tap]]
331+
offset += len(taps)
332+
return oldest_inner_mitsot_inputs
333+
324334
def outer_mitsot(self, list_inputs):
325335
offset = 1 + self.info.n_seqs + self.info.n_mit_mot
326336
return list_inputs[offset : offset + self.info.n_mit_sot]

0 commit comments

Comments
 (0)