Skip to content

Commit d9b1085

Browse files
authored
Don't run local uint constant indices in C/Python backends (#1335)
* Let numpy methods handle integer size problems in AdvancedSubtensor1 * Don't run `local_uint_constant_indices` in C/python backend Indices are always cast to int64 by the underlying methods. Also don't run in specialize, to reduce number of passes. Other rewrites may introduce temporar indexing operations (such as x.shape[i]) which always default to int64, and it's useless to optimize immediately.
1 parent afb7695 commit d9b1085

File tree

3 files changed

+32
-94
lines changed

3 files changed

+32
-94
lines changed

Diff for: pytensor/compile/mode.py

-1
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,6 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
489489
"BlasOpt",
490490
"fusion",
491491
"inplace",
492-
"local_uint_constant_indices",
493492
"scan_save_mem_prealloc",
494493
],
495494
),

Diff for: pytensor/tensor/rewriting/subtensor.py

+25-11
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import numpy as np
66

77
import pytensor
8-
import pytensor.scalar.basic as ps
98
from pytensor import compile
109
from pytensor.compile import optdb
1110
from pytensor.graph.basic import Constant, Variable
@@ -14,8 +13,11 @@
1413
copy_stack_trace,
1514
in2out,
1615
node_rewriter,
16+
out2in,
1717
)
1818
from pytensor.raise_op import Assert
19+
from pytensor.scalar import Add, ScalarConstant, ScalarType
20+
from pytensor.scalar import constant as scalar_constant
1921
from pytensor.tensor.basic import (
2022
Alloc,
2123
Join,
@@ -31,6 +33,7 @@
3133
register_infer_shape,
3234
switch,
3335
)
36+
from pytensor.tensor.basic import constant as tensor_constant
3437
from pytensor.tensor.blockwise import Blockwise
3538
from pytensor.tensor.elemwise import Elemwise
3639
from pytensor.tensor.exceptions import NotScalarConstantError
@@ -588,11 +591,11 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
588591
remove_dim = []
589592
node_inputs_idx = 1
590593
for dim, elem in enumerate(idx):
591-
if isinstance(elem, (ps.ScalarType)):
594+
if isinstance(elem, ScalarType):
592595
# The idx is a ScalarType, ie a Type. This means the actual index
593596
# is contained in node.inputs[1]
594597
dim_index = node.inputs[node_inputs_idx]
595-
if isinstance(dim_index, ps.ScalarConstant):
598+
if isinstance(dim_index, ScalarConstant):
596599
dim_index = dim_index.value
597600
if dim_index in (0, -1) and node.inputs[0].broadcastable[dim]:
598601
remove_dim.append(dim)
@@ -770,7 +773,7 @@ def local_subtensor_make_vector(fgraph, node):
770773

771774
(idx,) = idxs
772775

773-
if isinstance(idx, ps.ScalarType | TensorType):
776+
if isinstance(idx, ScalarType | TensorType):
774777
old_idx, idx = idx, node.inputs[1]
775778
assert idx.type.is_super(old_idx)
776779
elif isinstance(node.op, AdvancedSubtensor1):
@@ -895,7 +898,7 @@ def local_set_to_inc_subtensor(fgraph, node):
895898
and node.op.set_instead_of_inc
896899
and node.inputs[1].owner
897900
and isinstance(node.inputs[1].owner.op, Elemwise)
898-
and isinstance(node.inputs[1].owner.op.scalar_op, ps.Add)
901+
and isinstance(node.inputs[1].owner.op.scalar_op, Add)
899902
):
900903
addn = node.inputs[1].owner
901904
subn = None
@@ -1789,7 +1792,6 @@ def local_join_subtensors(fgraph, node):
17891792
return [merged_subtensors]
17901793

17911794

1792-
@register_specialize
17931795
@node_rewriter(
17941796
[
17951797
Subtensor,
@@ -1850,12 +1852,10 @@ def local_uint_constant_indices(fgraph, node):
18501852
if dtype == index_val.dtype:
18511853
continue
18521854

1853-
if index_val.ndim > 0:
1854-
new_index = pytensor.tensor.as_tensor_variable(
1855-
index_val.astype(dtype), dtype=dtype
1856-
)
1855+
if isinstance(index.type, TensorType):
1856+
new_index = tensor_constant(index_val.astype(dtype), dtype=dtype)
18571857
else:
1858-
new_index = ps.constant(index_val.astype(dtype), dtype=dtype)
1858+
new_index = scalar_constant(index_val.astype(dtype), dtype=dtype)
18591859

18601860
new_indices[i] = new_index
18611861
has_new_index = True
@@ -1877,6 +1877,20 @@ def local_uint_constant_indices(fgraph, node):
18771877
return [new_out]
18781878

18791879

1880+
compile.optdb.register(
1881+
local_uint_constant_indices.__name__,
1882+
out2in(local_uint_constant_indices),
1883+
# We don't include in the Python / C because those always cast indices to int64 internally.
1884+
"numba",
1885+
"jax",
1886+
# After specialization and uncanonicalization
1887+
# Other rewrites don't worry about the dtype of the indices
1888+
# And can cause unnecessary passes of this optimization
1889+
# Such as x.shape[np.int(0)] -> x.shape[np.uint(0)]
1890+
position=4,
1891+
)
1892+
1893+
18801894
@register_canonicalize("shape_unsafe")
18811895
@register_stabilize("shape_unsafe")
18821896
@register_specialize("shape_unsafe")

Diff for: pytensor/tensor/subtensor.py

+7-82
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import warnings
44
from collections.abc import Callable, Iterable, Sequence
55
from itertools import chain, groupby
6-
from textwrap import dedent
76
from typing import cast, overload
87

98
import numpy as np
@@ -19,7 +18,7 @@
1918
from pytensor.graph.utils import MethodNotDefined
2019
from pytensor.link.c.op import COp
2120
from pytensor.link.c.params_type import ParamsType
22-
from pytensor.npy_2_compat import npy_2_compat_header, numpy_version, using_numpy_2
21+
from pytensor.npy_2_compat import numpy_version, using_numpy_2
2322
from pytensor.printing import Printer, pprint, set_precedence
2423
from pytensor.scalar.basic import ScalarConstant, ScalarVariable
2524
from pytensor.tensor import (
@@ -2130,24 +2129,6 @@ def perform(self, node, inp, out_):
21302129
else:
21312130
o = None
21322131

2133-
# If i.dtype is more precise than numpy.intp (int32 on 32-bit machines,
2134-
# int64 on 64-bit machines), numpy may raise the following error:
2135-
# TypeError: array cannot be safely cast to required type.
2136-
# We need to check if values in i can fit in numpy.intp, because
2137-
# if they don't, that should be an error (no array can have that
2138-
# many elements on a 32-bit arch).
2139-
if i.dtype != np.intp:
2140-
i_ = np.asarray(i, dtype=np.intp)
2141-
if not np.can_cast(i.dtype, np.intp):
2142-
# Check if there was actually an incorrect conversion
2143-
if np.any(i != i_):
2144-
raise IndexError(
2145-
"index contains values that are bigger "
2146-
"than the maximum array size on this system.",
2147-
i,
2148-
)
2149-
i = i_
2150-
21512132
out[0] = x.take(i, axis=0, out=o)
21522133

21532134
def connection_pattern(self, node):
@@ -2187,16 +2168,6 @@ def infer_shape(self, fgraph, node, ishapes):
21872168
x, ilist = ishapes
21882169
return [ilist + x[1:]]
21892170

2190-
def c_support_code(self, **kwargs):
2191-
# In some versions of numpy, NPY_MIN_INTP is defined as MIN_LONG,
2192-
# which is not defined. It should be NPY_MIN_LONG instead in that case.
2193-
return npy_2_compat_header() + dedent(
2194-
"""\
2195-
#ifndef MIN_LONG
2196-
#define MIN_LONG NPY_MIN_LONG
2197-
#endif"""
2198-
)
2199-
22002171
def c_code(self, node, name, input_names, output_names, sub):
22012172
if self.__class__ is not AdvancedSubtensor1:
22022173
raise MethodNotDefined(
@@ -2207,69 +2178,24 @@ def c_code(self, node, name, input_names, output_names, sub):
22072178
output_name = output_names[0]
22082179
fail = sub["fail"]
22092180
return f"""
2210-
PyArrayObject *indices;
2211-
int i_type = PyArray_TYPE({i_name});
2212-
if (i_type != NPY_INTP) {{
2213-
// Cast {i_name} to NPY_INTP (expected by PyArray_TakeFrom),
2214-
// if all values fit.
2215-
if (!PyArray_CanCastSafely(i_type, NPY_INTP) &&
2216-
PyArray_SIZE({i_name}) > 0) {{
2217-
npy_int64 min_val, max_val;
2218-
PyObject* py_min_val = PyArray_Min({i_name}, NPY_RAVEL_AXIS,
2219-
NULL);
2220-
if (py_min_val == NULL) {{
2221-
{fail};
2222-
}}
2223-
min_val = PyLong_AsLongLong(py_min_val);
2224-
Py_DECREF(py_min_val);
2225-
if (min_val == -1 && PyErr_Occurred()) {{
2226-
{fail};
2227-
}}
2228-
PyObject* py_max_val = PyArray_Max({i_name}, NPY_RAVEL_AXIS,
2229-
NULL);
2230-
if (py_max_val == NULL) {{
2231-
{fail};
2232-
}}
2233-
max_val = PyLong_AsLongLong(py_max_val);
2234-
Py_DECREF(py_max_val);
2235-
if (max_val == -1 && PyErr_Occurred()) {{
2236-
{fail};
2237-
}}
2238-
if (min_val < NPY_MIN_INTP || max_val > NPY_MAX_INTP) {{
2239-
PyErr_SetString(PyExc_IndexError,
2240-
"Index contains values "
2241-
"that are bigger than the maximum array "
2242-
"size on this system.");
2243-
{fail};
2244-
}}
2245-
}}
2246-
indices = (PyArrayObject*) PyArray_Cast({i_name}, NPY_INTP);
2247-
if (indices == NULL) {{
2248-
{fail};
2249-
}}
2250-
}}
2251-
else {{
2252-
indices = {i_name};
2253-
Py_INCREF(indices);
2254-
}}
22552181
if ({output_name} != NULL) {{
22562182
npy_intp nd, i, *shape;
2257-
nd = PyArray_NDIM({a_name}) + PyArray_NDIM(indices) - 1;
2183+
nd = PyArray_NDIM({a_name}) + PyArray_NDIM({i_name}) - 1;
22582184
if (PyArray_NDIM({output_name}) != nd) {{
22592185
Py_CLEAR({output_name});
22602186
}}
22612187
else {{
22622188
shape = PyArray_DIMS({output_name});
2263-
for (i = 0; i < PyArray_NDIM(indices); i++) {{
2264-
if (shape[i] != PyArray_DIMS(indices)[i]) {{
2189+
for (i = 0; i < PyArray_NDIM({i_name}); i++) {{
2190+
if (shape[i] != PyArray_DIMS({i_name})[i]) {{
22652191
Py_CLEAR({output_name});
22662192
break;
22672193
}}
22682194
}}
22692195
if ({output_name} != NULL) {{
22702196
for (; i < nd; i++) {{
22712197
if (shape[i] != PyArray_DIMS({a_name})[
2272-
i-PyArray_NDIM(indices)+1]) {{
2198+
i-PyArray_NDIM({i_name})+1]) {{
22732199
Py_CLEAR({output_name});
22742200
break;
22752201
}}
@@ -2278,13 +2204,12 @@ def c_code(self, node, name, input_names, output_names, sub):
22782204
}}
22792205
}}
22802206
{output_name} = (PyArrayObject*)PyArray_TakeFrom(
2281-
{a_name}, (PyObject*)indices, 0, {output_name}, NPY_RAISE);
2282-
Py_DECREF(indices);
2207+
{a_name}, (PyObject*){i_name}, 0, {output_name}, NPY_RAISE);
22832208
if ({output_name} == NULL) {fail};
22842209
"""
22852210

22862211
def c_code_cache_version(self):
2287-
return (0, 1, 2, 3)
2212+
return (4,)
22882213

22892214

22902215
advanced_subtensor1 = AdvancedSubtensor1()

0 commit comments

Comments
 (0)