Skip to content

Don't run local uint constant indices in C/Python backends #1335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,6 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
"BlasOpt",
"fusion",
"inplace",
"local_uint_constant_indices",
"scan_save_mem_prealloc",
],
),
Expand Down
36 changes: 25 additions & 11 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np

import pytensor
import pytensor.scalar.basic as ps
from pytensor import compile
from pytensor.compile import optdb
from pytensor.graph.basic import Constant, Variable
Expand All @@ -14,8 +13,11 @@
copy_stack_trace,
in2out,
node_rewriter,
out2in,
)
from pytensor.raise_op import Assert
from pytensor.scalar import Add, ScalarConstant, ScalarType
from pytensor.scalar import constant as scalar_constant
from pytensor.tensor.basic import (
Alloc,
Join,
Expand All @@ -31,6 +33,7 @@
register_infer_shape,
switch,
)
from pytensor.tensor.basic import constant as tensor_constant
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
Expand Down Expand Up @@ -588,11 +591,11 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
remove_dim = []
node_inputs_idx = 1
for dim, elem in enumerate(idx):
if isinstance(elem, (ps.ScalarType)):
if isinstance(elem, ScalarType):
# The idx is a ScalarType, ie a Type. This means the actual index
# is contained in node.inputs[1]
dim_index = node.inputs[node_inputs_idx]
if isinstance(dim_index, ps.ScalarConstant):
if isinstance(dim_index, ScalarConstant):
dim_index = dim_index.value
if dim_index in (0, -1) and node.inputs[0].broadcastable[dim]:
remove_dim.append(dim)
Expand Down Expand Up @@ -770,7 +773,7 @@ def local_subtensor_make_vector(fgraph, node):

(idx,) = idxs

if isinstance(idx, ps.ScalarType | TensorType):
if isinstance(idx, ScalarType | TensorType):
old_idx, idx = idx, node.inputs[1]
assert idx.type.is_super(old_idx)
elif isinstance(node.op, AdvancedSubtensor1):
Expand Down Expand Up @@ -895,7 +898,7 @@ def local_set_to_inc_subtensor(fgraph, node):
and node.op.set_instead_of_inc
and node.inputs[1].owner
and isinstance(node.inputs[1].owner.op, Elemwise)
and isinstance(node.inputs[1].owner.op.scalar_op, ps.Add)
and isinstance(node.inputs[1].owner.op.scalar_op, Add)
):
addn = node.inputs[1].owner
subn = None
Expand Down Expand Up @@ -1789,7 +1792,6 @@ def local_join_subtensors(fgraph, node):
return [merged_subtensors]


@register_specialize
@node_rewriter(
[
Subtensor,
Expand Down Expand Up @@ -1850,12 +1852,10 @@ def local_uint_constant_indices(fgraph, node):
if dtype == index_val.dtype:
continue

if index_val.ndim > 0:
new_index = pytensor.tensor.as_tensor_variable(
index_val.astype(dtype), dtype=dtype
)
if isinstance(index.type, TensorType):
new_index = tensor_constant(index_val.astype(dtype), dtype=dtype)
else:
new_index = ps.constant(index_val.astype(dtype), dtype=dtype)
new_index = scalar_constant(index_val.astype(dtype), dtype=dtype)

new_indices[i] = new_index
has_new_index = True
Expand All @@ -1877,6 +1877,20 @@ def local_uint_constant_indices(fgraph, node):
return [new_out]


compile.optdb.register(
local_uint_constant_indices.__name__,
out2in(local_uint_constant_indices),
# We don't include in the Python / C because those always cast indices to int64 internally.
"numba",
"jax",
# After specialization and uncanonicalization
# Other rewrites don't worry about the dtype of the indices
# And can cause unnecessary passes of this optimization
# Such as x.shape[np.int(0)] -> x.shape[np.uint(0)]
position=4,
)


@register_canonicalize("shape_unsafe")
@register_stabilize("shape_unsafe")
@register_specialize("shape_unsafe")
Expand Down
89 changes: 7 additions & 82 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import warnings
from collections.abc import Callable, Iterable, Sequence
from itertools import chain, groupby
from textwrap import dedent
from typing import cast, overload

import numpy as np
Expand All @@ -19,7 +18,7 @@
from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.npy_2_compat import npy_2_compat_header, numpy_version, using_numpy_2
from pytensor.npy_2_compat import numpy_version, using_numpy_2
from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant, ScalarVariable
from pytensor.tensor import (
Expand Down Expand Up @@ -2130,24 +2129,6 @@ def perform(self, node, inp, out_):
else:
o = None

# If i.dtype is more precise than numpy.intp (int32 on 32-bit machines,
# int64 on 64-bit machines), numpy may raise the following error:
# TypeError: array cannot be safely cast to required type.
# We need to check if values in i can fit in numpy.intp, because
# if they don't, that should be an error (no array can have that
# many elements on a 32-bit arch).
if i.dtype != np.intp:
i_ = np.asarray(i, dtype=np.intp)
if not np.can_cast(i.dtype, np.intp):
# Check if there was actually an incorrect conversion
if np.any(i != i_):
raise IndexError(
"index contains values that are bigger "
"than the maximum array size on this system.",
i,
)
i = i_

out[0] = x.take(i, axis=0, out=o)

def connection_pattern(self, node):
Expand Down Expand Up @@ -2187,16 +2168,6 @@ def infer_shape(self, fgraph, node, ishapes):
x, ilist = ishapes
return [ilist + x[1:]]

def c_support_code(self, **kwargs):
# In some versions of numpy, NPY_MIN_INTP is defined as MIN_LONG,
# which is not defined. It should be NPY_MIN_LONG instead in that case.
return npy_2_compat_header() + dedent(
"""\
#ifndef MIN_LONG
#define MIN_LONG NPY_MIN_LONG
#endif"""
)

def c_code(self, node, name, input_names, output_names, sub):
if self.__class__ is not AdvancedSubtensor1:
raise MethodNotDefined(
Expand All @@ -2207,69 +2178,24 @@ def c_code(self, node, name, input_names, output_names, sub):
output_name = output_names[0]
fail = sub["fail"]
return f"""
PyArrayObject *indices;
int i_type = PyArray_TYPE({i_name});
if (i_type != NPY_INTP) {{
// Cast {i_name} to NPY_INTP (expected by PyArray_TakeFrom),
// if all values fit.
if (!PyArray_CanCastSafely(i_type, NPY_INTP) &&
PyArray_SIZE({i_name}) > 0) {{
npy_int64 min_val, max_val;
PyObject* py_min_val = PyArray_Min({i_name}, NPY_RAVEL_AXIS,
NULL);
if (py_min_val == NULL) {{
{fail};
}}
min_val = PyLong_AsLongLong(py_min_val);
Py_DECREF(py_min_val);
if (min_val == -1 && PyErr_Occurred()) {{
{fail};
}}
PyObject* py_max_val = PyArray_Max({i_name}, NPY_RAVEL_AXIS,
NULL);
if (py_max_val == NULL) {{
{fail};
}}
max_val = PyLong_AsLongLong(py_max_val);
Py_DECREF(py_max_val);
if (max_val == -1 && PyErr_Occurred()) {{
{fail};
}}
if (min_val < NPY_MIN_INTP || max_val > NPY_MAX_INTP) {{
PyErr_SetString(PyExc_IndexError,
"Index contains values "
"that are bigger than the maximum array "
"size on this system.");
{fail};
}}
}}
indices = (PyArrayObject*) PyArray_Cast({i_name}, NPY_INTP);
if (indices == NULL) {{
{fail};
}}
}}
else {{
indices = {i_name};
Py_INCREF(indices);
}}
if ({output_name} != NULL) {{
npy_intp nd, i, *shape;
nd = PyArray_NDIM({a_name}) + PyArray_NDIM(indices) - 1;
nd = PyArray_NDIM({a_name}) + PyArray_NDIM({i_name}) - 1;
if (PyArray_NDIM({output_name}) != nd) {{
Py_CLEAR({output_name});
}}
else {{
shape = PyArray_DIMS({output_name});
for (i = 0; i < PyArray_NDIM(indices); i++) {{
if (shape[i] != PyArray_DIMS(indices)[i]) {{
for (i = 0; i < PyArray_NDIM({i_name}); i++) {{
if (shape[i] != PyArray_DIMS({i_name})[i]) {{
Py_CLEAR({output_name});
break;
}}
}}
if ({output_name} != NULL) {{
for (; i < nd; i++) {{
if (shape[i] != PyArray_DIMS({a_name})[
i-PyArray_NDIM(indices)+1]) {{
i-PyArray_NDIM({i_name})+1]) {{
Py_CLEAR({output_name});
break;
}}
Expand All @@ -2278,13 +2204,12 @@ def c_code(self, node, name, input_names, output_names, sub):
}}
}}
{output_name} = (PyArrayObject*)PyArray_TakeFrom(
{a_name}, (PyObject*)indices, 0, {output_name}, NPY_RAISE);
Py_DECREF(indices);
{a_name}, (PyObject*){i_name}, 0, {output_name}, NPY_RAISE);
if ({output_name} == NULL) {fail};
"""

def c_code_cache_version(self):
return (0, 1, 2, 3)
return (4,)


advanced_subtensor1 = AdvancedSubtensor1()
Expand Down