Skip to content

Commit 3ae469f

Browse files
committed
Don't run local_uint_constant_indices in C/python backend
Indices are always cast to int64 by the underlying methods. Also don't run in specialize, to reduce number of passes. Other rewrites may introduce temporar indexing operations (such as x.shape[i]) which always default to int64, and it's useless to optimize immediately.
1 parent 6e84408 commit 3ae469f

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

Diff for: pytensor/compile/mode.py

-1
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,6 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
489489
"BlasOpt",
490490
"fusion",
491491
"inplace",
492-
"local_uint_constant_indices",
493492
"scan_save_mem_prealloc",
494493
],
495494
),

Diff for: pytensor/tensor/rewriting/subtensor.py

+25-11
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import numpy as np
66

77
import pytensor
8-
import pytensor.scalar.basic as ps
98
from pytensor import compile
109
from pytensor.compile import optdb
1110
from pytensor.graph.basic import Constant, Variable
@@ -14,8 +13,11 @@
1413
copy_stack_trace,
1514
in2out,
1615
node_rewriter,
16+
out2in,
1717
)
1818
from pytensor.raise_op import Assert
19+
from pytensor.scalar import Add, ScalarConstant, ScalarType
20+
from pytensor.scalar import constant as scalar_constant
1921
from pytensor.tensor.basic import (
2022
Alloc,
2123
Join,
@@ -31,6 +33,7 @@
3133
register_infer_shape,
3234
switch,
3335
)
36+
from pytensor.tensor.basic import constant as tensor_constant
3437
from pytensor.tensor.blockwise import Blockwise
3538
from pytensor.tensor.elemwise import Elemwise
3639
from pytensor.tensor.exceptions import NotScalarConstantError
@@ -588,11 +591,11 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
588591
remove_dim = []
589592
node_inputs_idx = 1
590593
for dim, elem in enumerate(idx):
591-
if isinstance(elem, (ps.ScalarType)):
594+
if isinstance(elem, ScalarType):
592595
# The idx is a ScalarType, ie a Type. This means the actual index
593596
# is contained in node.inputs[1]
594597
dim_index = node.inputs[node_inputs_idx]
595-
if isinstance(dim_index, ps.ScalarConstant):
598+
if isinstance(dim_index, ScalarConstant):
596599
dim_index = dim_index.value
597600
if dim_index in (0, -1) and node.inputs[0].broadcastable[dim]:
598601
remove_dim.append(dim)
@@ -770,7 +773,7 @@ def local_subtensor_make_vector(fgraph, node):
770773

771774
(idx,) = idxs
772775

773-
if isinstance(idx, ps.ScalarType | TensorType):
776+
if isinstance(idx, ScalarType | TensorType):
774777
old_idx, idx = idx, node.inputs[1]
775778
assert idx.type.is_super(old_idx)
776779
elif isinstance(node.op, AdvancedSubtensor1):
@@ -895,7 +898,7 @@ def local_set_to_inc_subtensor(fgraph, node):
895898
and node.op.set_instead_of_inc
896899
and node.inputs[1].owner
897900
and isinstance(node.inputs[1].owner.op, Elemwise)
898-
and isinstance(node.inputs[1].owner.op.scalar_op, ps.Add)
901+
and isinstance(node.inputs[1].owner.op.scalar_op, Add)
899902
):
900903
addn = node.inputs[1].owner
901904
subn = None
@@ -1789,7 +1792,6 @@ def local_join_subtensors(fgraph, node):
17891792
return [merged_subtensors]
17901793

17911794

1792-
@register_specialize
17931795
@node_rewriter(
17941796
[
17951797
Subtensor,
@@ -1850,12 +1852,10 @@ def local_uint_constant_indices(fgraph, node):
18501852
if dtype == index_val.dtype:
18511853
continue
18521854

1853-
if index_val.ndim > 0:
1854-
new_index = pytensor.tensor.as_tensor_variable(
1855-
index_val.astype(dtype), dtype=dtype
1856-
)
1855+
if isinstance(index.type, TensorType):
1856+
new_index = tensor_constant(index_val.astype(dtype), dtype=dtype)
18571857
else:
1858-
new_index = ps.constant(index_val.astype(dtype), dtype=dtype)
1858+
new_index = scalar_constant(index_val.astype(dtype), dtype=dtype)
18591859

18601860
new_indices[i] = new_index
18611861
has_new_index = True
@@ -1877,6 +1877,20 @@ def local_uint_constant_indices(fgraph, node):
18771877
return [new_out]
18781878

18791879

1880+
compile.optdb.register(
1881+
local_uint_constant_indices.__name__,
1882+
out2in(local_uint_constant_indices),
1883+
# Python / C backends always cast indices to int64 internally.
1884+
"numba",
1885+
"jax",
1886+
# After specialization and uncanonicalization
1887+
# Other rewrites don't worry about the dtype of the indices
1888+
# And can cause unnecessary passes of this optimization
1889+
# Such as x.shape[np.int(0)] -> x.shape[np.uint(0)]
1890+
position=4,
1891+
)
1892+
1893+
18801894
@register_canonicalize("shape_unsafe")
18811895
@register_stabilize("shape_unsafe")
18821896
@register_specialize("shape_unsafe")

0 commit comments

Comments
 (0)