Skip to content

Commit 03b62a3

Browse files
committed
Allow inplacing of SITSOT and last MITSOT in numba Scan, when they are discarded immediately
1 parent 1fa22d8 commit 03b62a3

File tree

3 files changed

+109
-2
lines changed

3 files changed

+109
-2
lines changed

pytensor/link/numba/dispatch/scan.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def range_arr(x):
5555

5656

5757
@numba_funcify.register(Scan)
58-
def numba_funcify_Scan(op, node, **kwargs):
58+
def numba_funcify_Scan(op: Scan, node, **kwargs):
5959
# Apply inner rewrites
6060
# TODO: Not sure this is the right place to do this, should we have a rewrite that
6161
# explicitly triggers the optimization of the inner graphs of Scan?
@@ -67,9 +67,32 @@ def numba_funcify_Scan(op, node, **kwargs):
6767
.optimizer
6868
)
6969
fgraph = op.fgraph
70+
# When the buffer can only hold one SITSOT or as as many MITSOT as there are taps,
71+
# We must always discard the oldest tap, so it's safe to destroy it in the inner function.
72+
# TODO: Allow inplace for MITMOT
73+
destroyable_sitsot = [
74+
inner_sitsot
75+
for outer_sitsot, inner_sitsot in zip(
76+
op.outer_sitsot(node.inputs), op.inner_sitsot(fgraph.inputs), strict=True
77+
)
78+
if outer_sitsot.type.shape[0] == 1
79+
]
80+
destroyable_mitsot = [
81+
oldest_inner_mitmot
82+
for outer_mitsot, oldest_inner_mitmot, taps in zip(
83+
op.outer_mitsot(node.inputs),
84+
op.oldest_inner_mitsot(fgraph.inputs),
85+
op.info.mit_sot_in_slices,
86+
strict=True,
87+
)
88+
if outer_mitsot.type.shape[0] == abs(min(taps))
89+
]
90+
destroyable = {*destroyable_sitsot, *destroyable_mitsot}
7091
add_supervisor_to_fgraph(
7192
fgraph=fgraph,
72-
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
93+
input_specs=[
94+
In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs
95+
],
7396
accept_inplace=True,
7497
)
7598
rewriter(fgraph)

pytensor/scan/op.py

+10
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,16 @@ def inner_mitsot(self, list_inputs):
321321
self.info.n_seqs + n_mitmot_taps : self.info.n_seqs + ntaps_upto_sit_sot
322322
]
323323

324+
def oldest_inner_mitsot(self, list_inputs):
325+
inner_mitsot_inputs = self.inner_mitsot(list_inputs)
326+
oldest_inner_mitsot_inputs = []
327+
offset = 0
328+
for taps in self.info.mit_sot_in_slices:
329+
oldest_tap = np.argmin(taps)
330+
oldest_inner_mitsot_inputs += [inner_mitsot_inputs[offset + oldest_tap]]
331+
offset += len(taps)
332+
return oldest_inner_mitsot_inputs
333+
324334
def outer_mitsot(self, list_inputs):
325335
offset = 1 + self.info.n_seqs + self.info.n_mit_mot
326336
return list_inputs[offset : offset + self.info.n_mit_sot]

tests/link/numba/test_scan.py

+74
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,80 @@ def step(seq1, seq2, mitsot1, mitsot2, sitsot1):
451451
benchmark(numba_fn, *test.values())
452452

453453

454+
@pytest.mark.parametrize("n_steps_constant", (True, False))
455+
def test_inplace_taps(n_steps_constant):
456+
"""Test that numba will inplace in the inner_function of the oldest sit-sot, mit-sot taps."""
457+
n_steps = 10 if n_steps_constant else scalar("n_steps", dtype=int)
458+
a = scalar("a")
459+
x0 = scalar("x0")
460+
y0 = vector("y0", shape=(2,))
461+
z0 = vector("z0", shape=(3,))
462+
463+
def step(ztm3, ztm1, xtm1, ytm1, ytm2, a):
464+
z = ztm1 + 1 + ztm3 + a
465+
x = xtm1 + 1
466+
y = ytm1 + 1 + ytm2 + a
467+
return z, x, z + x + y, y
468+
469+
[zs, xs, ws, ys], _ = scan(
470+
fn=step,
471+
outputs_info=[
472+
dict(initial=z0, taps=[-3, -1]),
473+
dict(initial=x0, taps=[-1]),
474+
None,
475+
dict(initial=y0, taps=[-1, -2]),
476+
],
477+
non_sequences=[a],
478+
n_steps=n_steps,
479+
)
480+
numba_fn, _ = compare_numba_and_py(
481+
[n_steps] * (not n_steps_constant) + [a, x0, y0, z0],
482+
[zs[-1], xs[-1], ws[-1], ys[-1]],
483+
[10] * (not n_steps_constant) + [np.pi, np.e, [1, np.euler_gamma], [0, 1, 2]],
484+
numba_mode="NUMBA",
485+
eval_obj_mode=False,
486+
)
487+
[scan_op] = [
488+
node.op
489+
for node in numba_fn.maker.fgraph.toposort()
490+
if isinstance(node.op, Scan)
491+
]
492+
493+
# Scan reorders inputs internally, so we need to check its ordering
494+
inner_inps = scan_op.fgraph.inputs
495+
mit_sot_inps = scan_op.inner_mitsot(inner_inps)
496+
oldest_mit_sot_inps = [
497+
# Implicitly assume that the first mit-sot input is the one with 3 taps
498+
# This is not a required behavior and the test can change if we need to change Scan.
499+
mit_sot_inps[:2][scan_op.info.mit_sot_in_slices[0].index(-3)],
500+
mit_sot_inps[2:][scan_op.info.mit_sot_in_slices[1].index(-2)],
501+
]
502+
[sit_sot_inp] = scan_op.inner_sitsot(inner_inps)
503+
504+
inner_outs = scan_op.fgraph.outputs
505+
mit_sot_outs = scan_op.inner_mitsot_outs(inner_outs)
506+
[sit_sot_out] = scan_op.inner_sitsot_outs(inner_outs)
507+
[nit_sot_out] = scan_op.inner_nitsot_outs(inner_outs)
508+
509+
if n_steps_constant:
510+
assert mit_sot_outs[0].owner.op.destroy_map == {
511+
0: [mit_sot_outs[0].owner.inputs.index(oldest_mit_sot_inps[0])]
512+
}
513+
assert mit_sot_outs[1].owner.op.destroy_map == {
514+
0: [mit_sot_outs[1].owner.inputs.index(oldest_mit_sot_inps[1])]
515+
}
516+
assert sit_sot_out.owner.op.destroy_map == {
517+
0: [sit_sot_out.owner.inputs.index(sit_sot_inp)]
518+
}
519+
else:
520+
# This is not a feature, but a current limitation
521+
# https://github.com/pymc-devs/pytensor/issues/1283
522+
assert mit_sot_outs[0].owner.op.destroy_map == {}
523+
assert mit_sot_outs[1].owner.op.destroy_map == {}
524+
assert sit_sot_out.owner.op.destroy_map == {}
525+
assert nit_sot_out.owner.op.destroy_map == {}
526+
527+
454528
@pytest.mark.parametrize(
455529
"buffer_size", ("unit", "aligned", "misaligned", "whole", "whole+init")
456530
)

0 commit comments

Comments
 (0)