Skip to content

Commit af8ccb1

Browse files
committed
Cleanup Scan symbolic buffer size graph
Graph was being broken by Scalar/Tensor conversions that prevented fusion
1 parent cddefef commit af8ccb1

File tree

2 files changed

+68
-28
lines changed

2 files changed

+68
-28
lines changed

Diff for: pytensor/tensor/subtensor.py

+43-15
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
alloc,
3434
get_scalar_constant_value,
3535
nonzero,
36-
scalar_from_tensor,
36+
)
37+
from pytensor.tensor.basic import (
38+
constant as tensor_constant,
3739
)
3840
from pytensor.tensor.blockwise import vectorize_node_fallback
3941
from pytensor.tensor.elemwise import DimShuffle
@@ -256,20 +258,20 @@ def get_idx_list(inputs, idx_list):
256258
def get_canonical_form_slice(
257259
theslice: slice,
258260
length: int | np.integer | ScalarVariable | TensorVariable,
259-
) -> tuple[slice, int | ScalarConstant]: ...
261+
) -> tuple[slice, int | TensorVariable]: ...
260262

261263

262264
@overload
263265
def get_canonical_form_slice(
264266
theslice: int | np.integer | ScalarVariable | TensorVariable,
265267
length: int | np.integer | ScalarVariable | TensorVariable,
266-
) -> tuple[ScalarVariable, int]: ...
268+
) -> tuple[TensorVariable, int]: ...
267269

268270

269271
def get_canonical_form_slice(
270272
theslice: slice | int | np.integer | ScalarVariable | TensorVariable,
271273
length: int | np.integer | ScalarVariable | TensorVariable,
272-
) -> tuple[slice | ScalarVariable, int | ScalarConstant]:
274+
) -> tuple[slice | TensorVariable, int | TensorVariable]:
273275
"""Convert indices or slices to canonical form.
274276
275277
Scalar integer indices or python Slices with Scalar/None attributes
@@ -296,30 +298,56 @@ def get_canonical_form_slice(
296298
"""
297299
from pytensor.tensor import ge, lt, sign, switch
298300

299-
# Other non-slice types are the scalar indexing case
300-
if not isinstance(theslice, slice):
301-
if isinstance(theslice, int | np.integer | ScalarVariable) or (
302-
isinstance(theslice, TensorVariable) and theslice.ndim == 0
303-
):
304-
cano = switch(lt(theslice, 0), (theslice + length), theslice)
305-
return scalar_from_tensor(cano), 1
306-
raise ValueError(f"Slice {theslice} is not a supported slice type.")
301+
def undo_scalarization(x):
302+
"""Undo scalarization of a variable.
307303
308-
# At this point we have a slice object. Possibly with symbolic inputs.
304+
PyTensor Basic index operations use ScalarVariables for the indices/slice arguments.
305+
But reasoning symbolically about the result of multiple indexing operations, we usually
306+
want to work on TensorVariables, since rewrites work on those and not ScalarVariables.
307+
308+
This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants.
309+
"""
310+
if isinstance(x, ScalarVariable):
311+
if isinstance(x, ScalarConstant):
312+
return tensor_constant(x.data, dtype=x.dtype)
313+
elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor):
314+
return x.owner.inputs[0]
315+
else:
316+
return as_tensor_variable(x)
317+
return x
309318

310319
def analyze(x):
311320
try:
312321
x_constant = as_index_literal(x)
313322
is_constant = True
314323
except NotScalarConstantError:
315-
x_constant = x
324+
x_constant = undo_scalarization(x)
316325
is_constant = False
317326
return x_constant, is_constant
318327

328+
length, is_length_constant = analyze(length)
329+
330+
# Other non-slice types are the scalar indexing case
331+
if not isinstance(theslice, slice):
332+
if not (
333+
isinstance(theslice, int | np.integer | ScalarVariable)
334+
or (isinstance(theslice, TensorVariable) and theslice.ndim == 0)
335+
):
336+
raise ValueError(f"Slice {theslice} is not a supported slice type.")
337+
338+
idx, is_index_constant = analyze(theslice)
339+
if is_index_constant:
340+
if idx >= 0:
341+
return idx, 1
342+
else:
343+
return idx + length, 1
344+
else:
345+
return switch(lt(idx, 0), idx + length, idx), 1
346+
347+
# At this point we have a slice object. Possibly with symbolic inputs.
319348
start, is_start_constant = analyze(theslice.start)
320349
stop, is_stop_constant = analyze(theslice.stop)
321350
step, is_step_constant = analyze(theslice.step)
322-
length, is_length_constant = analyze(length)
323351

324352
if (
325353
is_start_constant

Diff for: tests/tensor/test_subtensor.py

+25-13
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616
from pytensor.configdefaults import config
1717
from pytensor.gradient import grad
1818
from pytensor.graph import Constant
19+
from pytensor.graph.basic import equal_computations
1920
from pytensor.graph.op import get_test_value
2021
from pytensor.graph.rewriting.utils import is_same_graph
2122
from pytensor.printing import pprint
2223
from pytensor.scalar.basic import as_scalar, int16
2324
from pytensor.tensor import as_tensor, get_vector_length, vectorize
2425
from pytensor.tensor.blockwise import Blockwise
2526
from pytensor.tensor.elemwise import DimShuffle
26-
from pytensor.tensor.math import exp, isinf
27+
from pytensor.tensor.math import exp, isinf, lt, switch
2728
from pytensor.tensor.math import sum as pt_sum
2829
from pytensor.tensor.shape import specify_shape
2930
from pytensor.tensor.subtensor import (
@@ -136,30 +137,41 @@ def test_unsupported_inputs(self, idx):
136137
def test_scalar_constant(self):
137138
a = as_scalar(0)
138139
length = lscalar()
139-
res = get_canonical_form_slice(a, length)
140-
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
141-
assert res[1] == 1
140+
res, direction = get_canonical_form_slice(a, length)
141+
assert res == 0
142+
assert direction == 1
143+
144+
b = as_scalar(-1)
145+
res, direction = get_canonical_form_slice(b, length)
146+
assert equal_computations([res], [as_tensor(-1) + length])
147+
assert direction == 1
142148

143149
def test_tensor_constant(self):
144150
a = as_tensor(0)
145151
length = lscalar()
146-
res = get_canonical_form_slice(a, length)
147-
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
148-
assert res[1] == 1
152+
res, direction = get_canonical_form_slice(a, length)
153+
assert equal_computations([res], [a])
154+
assert direction == 1
155+
156+
b = as_tensor(-1)
157+
res, direction = get_canonical_form_slice(b, length)
158+
assert equal_computations([res], [b + length])
159+
assert direction == 1
149160

150161
def test_symbolic_scalar(self):
151162
a = int16()
152163
length = lscalar()
153-
res = get_canonical_form_slice(a, length)
154-
assert res[0].owner.op, ptb.switch
155-
assert res[1] == 1
164+
res, direction = get_canonical_form_slice(a, length)
165+
a_t = as_tensor(a)
166+
assert equal_computations([res], [switch(lt(a_t, 0), a_t + length, a_t)])
167+
assert direction == 1
156168

157169
def test_symbolic_tensor(self):
158170
a = lscalar()
159171
length = lscalar()
160-
res = get_canonical_form_slice(a, length)
161-
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
162-
assert res[1] == 1
172+
res, direction = get_canonical_form_slice(a, length)
173+
assert equal_computations([res], [switch(lt(a, 0), a + length, a)])
174+
assert direction == 1
163175

164176
@pytest.mark.parametrize("int_fn", [int, np.int64, as_tensor, as_scalar])
165177
def test_all_integer(self, int_fn):

0 commit comments

Comments
 (0)