33
33
alloc ,
34
34
get_scalar_constant_value ,
35
35
nonzero ,
36
- scalar_from_tensor ,
36
+ )
37
+ from pytensor .tensor .basic import (
38
+ constant as tensor_constant ,
37
39
)
38
40
from pytensor .tensor .blockwise import vectorize_node_fallback
39
41
from pytensor .tensor .elemwise import DimShuffle
@@ -256,20 +258,20 @@ def get_idx_list(inputs, idx_list):
256
258
def get_canonical_form_slice (
257
259
theslice : slice ,
258
260
length : int | np .integer | ScalarVariable | TensorVariable ,
259
- ) -> tuple [slice , int | ScalarConstant ]: ...
261
+ ) -> tuple [slice , int | TensorVariable ]: ...
260
262
261
263
262
264
@overload
263
265
def get_canonical_form_slice (
264
266
theslice : int | np .integer | ScalarVariable | TensorVariable ,
265
267
length : int | np .integer | ScalarVariable | TensorVariable ,
266
- ) -> tuple [ScalarVariable , int ]: ...
268
+ ) -> tuple [TensorVariable , int ]: ...
267
269
268
270
269
271
def get_canonical_form_slice (
270
272
theslice : slice | int | np .integer | ScalarVariable | TensorVariable ,
271
273
length : int | np .integer | ScalarVariable | TensorVariable ,
272
- ) -> tuple [slice | ScalarVariable , int | ScalarConstant ]:
274
+ ) -> tuple [slice | TensorVariable , int | TensorVariable ]:
273
275
"""Convert indices or slices to canonical form.
274
276
275
277
Scalar integer indices or python Slices with Scalar/None attributes
@@ -296,30 +298,56 @@ def get_canonical_form_slice(
296
298
"""
297
299
from pytensor .tensor import ge , lt , sign , switch
298
300
299
- # Other non-slice types are the scalar indexing case
300
- if not isinstance (theslice , slice ):
301
- if isinstance (theslice , int | np .integer | ScalarVariable ) or (
302
- isinstance (theslice , TensorVariable ) and theslice .ndim == 0
303
- ):
304
- cano = switch (lt (theslice , 0 ), (theslice + length ), theslice )
305
- return scalar_from_tensor (cano ), 1
306
- raise ValueError (f"Slice { theslice } is not a supported slice type." )
301
+ def undo_scalarization (x ):
302
+ """Undo scalarization of a variable.
307
303
308
- # At this point we have a slice object. Possibly with symbolic inputs.
304
+ PyTensor Basic index operations use ScalarVariables for the indices/slice arguments.
305
+ But reasoning symbolically about the result of multiple indexing operations, we usually
306
+ want to work on TensorVariables, since rewrites work on those and not ScalarVariables.
307
+
308
+ This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants.
309
+ """
310
+ if isinstance (x , ScalarVariable ):
311
+ if isinstance (x , ScalarConstant ):
312
+ return tensor_constant (x .data , dtype = x .dtype )
313
+ elif x .owner is not None and isinstance (x .owner .op , ScalarFromTensor ):
314
+ return x .owner .inputs [0 ]
315
+ else :
316
+ return as_tensor_variable (x )
317
+ return x
309
318
310
319
def analyze (x ):
311
320
try :
312
321
x_constant = as_index_literal (x )
313
322
is_constant = True
314
323
except NotScalarConstantError :
315
- x_constant = x
324
+ x_constant = undo_scalarization ( x )
316
325
is_constant = False
317
326
return x_constant , is_constant
318
327
328
+ length , is_length_constant = analyze (length )
329
+
330
+ # Other non-slice types are the scalar indexing case
331
+ if not isinstance (theslice , slice ):
332
+ if not (
333
+ isinstance (theslice , int | np .integer | ScalarVariable )
334
+ or (isinstance (theslice , TensorVariable ) and theslice .ndim == 0 )
335
+ ):
336
+ raise ValueError (f"Slice { theslice } is not a supported slice type." )
337
+
338
+ idx , is_index_constant = analyze (theslice )
339
+ if is_index_constant :
340
+ if idx >= 0 :
341
+ return idx , 1
342
+ else :
343
+ return idx + length , 1
344
+ else :
345
+ return switch (lt (idx , 0 ), idx + length , idx ), 1
346
+
347
+ # At this point we have a slice object. Possibly with symbolic inputs.
319
348
start , is_start_constant = analyze (theslice .start )
320
349
stop , is_stop_constant = analyze (theslice .stop )
321
350
step , is_step_constant = analyze (theslice .step )
322
- length , is_length_constant = analyze (length )
323
351
324
352
if (
325
353
is_start_constant
0 commit comments