Skip to content

Commit ab57a79

Browse files
committed
Cleanup Scan symbolic buffer size graph
Graph was being broken by Scalar/Tensor conversions that prevented fusion
1 parent e1cbe76 commit ab57a79

File tree

3 files changed

+27
-0
lines changed

3 files changed

+27
-0
lines changed

Diff for: pytensor/scan/rewriting.py

+3
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
get_idx_list,
7070
get_slice_elements,
7171
set_subtensor,
72+
undo_scalarization,
7273
)
7374
from pytensor.tensor.variable import TensorConstant, TensorVariable
7475

@@ -1343,6 +1344,8 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
13431344
except KeyError:
13441345
length = out.shape[0]
13451346
cf_slice = get_canonical_form_slice(this_slice[0], length)
1347+
cf_slice = (undo_scalarization(cf_slice[0]), cf_slice[1])
1348+
13461349
slices[i] += [(cf_slice, this_slice)] # type: ignore
13471350

13481351
if isinstance(this_slice[0], slice) and this_slice[0].stop is None:

Diff for: pytensor/tensor/rewriting/subtensor.py

+2
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
get_slice_elements,
8585
inc_subtensor,
8686
indices_from_subtensor,
87+
undo_scalarization,
8788
)
8889
from pytensor.tensor.type import TensorType, integer_dtypes
8990
from pytensor.tensor.type_other import NoneTypeT, SliceConstant, SliceType
@@ -1136,6 +1137,7 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
11361137
# We are in the more complex case when we do not actually know
11371138
# if the first slice was in reverse or not.
11381139
# in case it was not in reverse:
1140+
sl2 = undo_scalarization(sl2)
11391141
p_val = sl1.start + sl2 * sl1.step
11401142
# case it was in reverse we need to realize that we do not want
11411143
# the k-th element from sl.start but the k-th element from

Diff for: pytensor/tensor/subtensor.py

+22
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
nonzero,
3636
scalar_from_tensor,
3737
)
38+
from pytensor.tensor.basic import (
39+
constant as tensor_constant,
40+
)
3841
from pytensor.tensor.blockwise import vectorize_node_fallback
3942
from pytensor.tensor.elemwise import DimShuffle
4043
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
@@ -252,6 +255,23 @@ def get_idx_list(inputs, idx_list):
252255
return indices_from_subtensor(inputs[1:], idx_list)
253256

254257

258+
def undo_scalarization(x):
259+
"""Undo scalarization of a variable.
260+
261+
PyTensor Basic index operations use ScalarVariables for the indices/slice arguments.
262+
When reason symbolically about the result of multiple indexing operations, we usually
263+
want to work on TensorVariables, since rewrites work on those and not ScalarVariables.
264+
265+
This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants.
266+
"""
267+
if isinstance(x, ScalarVariable):
268+
if isinstance(x, ScalarConstant):
269+
return tensor_constant(x.data, dtype=x.dtype)
270+
elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor):
271+
return x.owner.inputs[0]
272+
return x
273+
274+
255275
@overload
256276
def get_canonical_form_slice(
257277
theslice: slice,
@@ -298,6 +318,7 @@ def get_canonical_form_slice(
298318

299319
# Other non-slice types are the scalar indexing case
300320
if not isinstance(theslice, slice):
321+
theslice = undo_scalarization(theslice)
301322
if isinstance(theslice, int | np.integer | ScalarVariable) or (
302323
isinstance(theslice, TensorVariable) and theslice.ndim == 0
303324
):
@@ -381,6 +402,7 @@ def analyze(x):
381402
elif is_stop_length:
382403
# start:length:1
383404
if is_start_constant and start >= 0:
405+
length = undo_scalarization(length)
384406
return slice(switch(lt(start, length), start, length), length, 1), 1
385407
start_plus_len = start + length
386408
start = switch(

0 commit comments

Comments
 (0)