Skip to content

Support more cases of numba advanced indexing #1254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
for adv_idx in adv_idxs
)
# Must be consecutive
and not op.non_contiguous_adv_indexing(node)
and not op.non_consecutive_adv_indexing(node)
# y in set/inc_subtensor cannot be broadcasted
and (
y is None
Expand Down
132 changes: 104 additions & 28 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2029,18 +2029,41 @@ def ravel_multidimensional_bool_idx(fgraph, node):
return [copy_stack_trace(node.outputs[0], new_out)]


@node_rewriter(tracks=[AdvancedSubtensor])
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
def ravel_multidimensional_int_idx(fgraph, node):
"""Convert multidimensional integer indexing into equivalent vector integer index, supported by Numba

x[eye(3, dtype=int)] -> x[eye(3).ravel()].reshape((3, 3))
"""Convert multidimensional integer indexing into equivalent consecutive vector integer index,
supported by Numba or by our specialized dispatchers

x[eye(3)] -> x[eye(3).ravel()].reshape((3, 3))

NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices

x[eye(3, dtype=int), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
x[eye(3), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes

It also handles multiple integer indices, but only if they don't broadcast

x[eye(3,), 2:, eye(3)] -> x[eye(3).ravel(), eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes

Also handles AdvancedIncSubtensor, but only if the advanced indices are consecutive and neither indices nor y broadcast

x[eye(3), 2:].set(y) -> x[eye(3).ravel(), 2:].set(y.reshape(-1, y.shape[1:]))

"""
x, *idxs = node.inputs
op = node.op
non_consecutive_adv_indexing = op.non_consecutive_adv_indexing(node)
is_inc_subtensor = isinstance(op, AdvancedIncSubtensor)

if is_inc_subtensor:
x, y, *idxs = node.inputs
# Inc/SetSubtensor is harder to reason about due to y
# We get out if it's broadcasting or if the advanced indices are non-consecutive
if non_consecutive_adv_indexing or (
y.type.broadcastable != x[tuple(idxs)].type.broadcastable
):
return None

else:
x, *idxs = node.inputs

if any(
(
Expand All @@ -2049,50 +2072,103 @@ def ravel_multidimensional_int_idx(fgraph, node):
)
for idx in idxs
):
# Get out if there are any other advanced indexes or np.newaxis
# Get out if there are any other advanced indices or np.newaxis
return None

int_idxs = [
int_idxs_and_pos = [
(i, idx)
for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes)
]

if len(int_idxs) != 1:
# Get out if there are no or multiple integer idxs
if not int_idxs_and_pos:
return None

[(int_idx_pos, int_idx)] = int_idxs
if int_idx.type.ndim < 2:
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
int_idxs_pos, int_idxs = zip(
*int_idxs_and_pos, strict=False
) # strict=False because by definition it's true

first_int_idx_pos = int_idxs_pos[0]
first_int_idx = int_idxs[0]
first_int_idx_bcast = first_int_idx.type.broadcastable

if any(int_idx.type.broadcastable != first_int_idx_bcast for int_idx in int_idxs):
# We don't have a view-only broadcasting operation
# Explicitly broadcasting the indices can incur a memory / copy overhead
return None

raveled_int_idx = int_idx.ravel()
new_idxs = list(idxs)
new_idxs[int_idx_pos] = raveled_int_idx
raveled_subtensor = x[tuple(new_idxs)]

# Reshape into correct shape
# Because we only allow one advanced indexing, the output dimension corresponding to the raveled integer indexing
# must match the input position. If there were multiple advanced indexes, this could have been forcefully moved to the front
raveled_shape = raveled_subtensor.shape
unraveled_shape = (
*raveled_shape[:int_idx_pos],
*int_idx.shape,
*raveled_shape[int_idx_pos + 1 :],
)
new_out = raveled_subtensor.reshape(unraveled_shape)
int_idxs_ndim = len(first_int_idx_bcast)
if (
int_idxs_ndim == 0
): # This should be a basic indexing operation, rewrite elsewhere
return None

int_idxs_need_raveling = int_idxs_ndim > 1
if not (int_idxs_need_raveling or non_consecutive_adv_indexing):
# Numba or our dispatch natively supports consecutive vector indices, nothing needs to be done
return None

# Reorder non-consecutive indices
if non_consecutive_adv_indexing:
assert not is_inc_subtensor # Sanity check that we got out if this was the case
# This case works as if all the advanced indices were on the front
transposition = list(int_idxs_pos) + [
i for i in range(len(idxs)) if i not in int_idxs_pos
]
idxs = tuple(idxs[a] for a in transposition)
x = x.transpose(transposition)
first_int_idx_pos = 0
del int_idxs_pos # Make sure they are not wrongly used

# Ravel multidimensional indices
if int_idxs_need_raveling:
idxs = list(idxs)
for idx_pos, int_idx in enumerate(int_idxs, start=first_int_idx_pos):
idxs[idx_pos] = int_idx.ravel()

# Index with reordered and/or raveled indices
new_subtensor = x[tuple(idxs)]

if is_inc_subtensor:
y_shape = tuple(y.shape)
y_raveled_shape = (
*y_shape[:first_int_idx_pos],
-1,
*y_shape[first_int_idx_pos + int_idxs_ndim :],
)
y_raveled = y.reshape(y_raveled_shape)

new_out = inc_subtensor(
new_subtensor,
y_raveled,
set_instead_of_inc=op.set_instead_of_inc,
ignore_duplicates=op.ignore_duplicates,
inplace=op.inplace,
)

else:
# Unravel advanced indexing dimensions
raveled_shape = tuple(new_subtensor.shape)
unraveled_shape = (
*raveled_shape[:first_int_idx_pos],
*first_int_idx.shape,
*raveled_shape[first_int_idx_pos + 1 :],
)
new_out = new_subtensor.reshape(unraveled_shape)

return [copy_stack_trace(node.outputs[0], new_out)]


optdb["specialize"].register(
ravel_multidimensional_bool_idx.__name__,
ravel_multidimensional_bool_idx,
"numba",
use_db_name_as_tag=False, # Not included if only "specialize" is requested
)

optdb["specialize"].register(
ravel_multidimensional_int_idx.__name__,
ravel_multidimensional_int_idx,
"numba",
use_db_name_as_tag=False, # Not included if only "specialize" is requested
)
64 changes: 47 additions & 17 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import sys
import warnings
from collections.abc import Callable, Iterable
from itertools import chain, groupby
from textwrap import dedent
Expand Down Expand Up @@ -59,6 +60,7 @@
zscalar,
)
from pytensor.tensor.type_other import (
MakeSlice,
NoneConst,
NoneTypeT,
SliceConstant,
Expand Down Expand Up @@ -527,11 +529,20 @@ def basic_shape(shape, indices):
if isinstance(idx, slice):
res_shape += (slice_len(idx, n),)
elif isinstance(getattr(idx, "type", None), SliceType):
if idx.owner:
idx_inputs = idx.owner.inputs
if idx.owner is None:
if not isinstance(idx, Constant):
# This is an input slice, we can't reason symbolically on it.
# We don't even know if we will get None entries or integers
res_shape += (None,)
continue
else:
sl: slice = idx.data
slice_inputs = (sl.start, sl.stop, sl.step)
elif isinstance(idx.owner.op, MakeSlice):
slice_inputs = idx.owner.inputs
else:
idx_inputs = (None,)
res_shape += (slice_len(slice(*idx_inputs), n),)
raise ValueError(f"Unexpected Slice producing Op {idx.owner.op}")
res_shape += (slice_len(slice(*slice_inputs), n),)
elif idx is None:
res_shape += (ps.ScalarConstant(ps.int64, 1),)
elif isinstance(getattr(idx, "type", None), NoneTypeT):
Expand Down Expand Up @@ -570,8 +581,8 @@ def group_indices(indices):
return idx_groups


def _non_contiguous_adv_indexing(indices) -> bool:
"""Check if the advanced indexing is non-contiguous (i.e., split by basic indexing)."""
def _non_consecutive_adv_indexing(indices) -> bool:
"""Check if the advanced indexing is non-consecutive (i.e., split by basic indexing)."""
idx_groups = group_indices(indices)
# This means that there are at least two groups of advanced indexing separated by basic indexing
return len(idx_groups) > 3 or (len(idx_groups) == 3 and not idx_groups[0][0])
Expand Down Expand Up @@ -601,7 +612,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
remaining_dims = range(pytensor.tensor.basic.get_vector_length(array_shape))
idx_groups = group_indices(indices)

if _non_contiguous_adv_indexing(indices):
if _non_consecutive_adv_indexing(indices):
# In this case NumPy places the advanced index groups in the front of the array
# https://numpy.org/devdocs/user/basics.indexing.html#combining-advanced-and-basic-indexing
idx_groups = sorted(idx_groups, key=lambda x: x[0])
Expand Down Expand Up @@ -2728,6 +2739,11 @@ def is_bool_index(idx):
res_shape = list(
indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True)
)
for i, res_dim_length in enumerate(res_shape):
if res_dim_length is None:
# This can happen when we have a Slice provided by the user (not a constant nor the result of MakeSlice)
# We must compute the Op to find its shape
res_shape[i] = Shape_i(i)(node.out)

adv_indices = [idx for idx in indices if not is_basic_idx(idx)]
bool_indices = [idx for idx in adv_indices if is_bool_index(idx)]
Expand Down Expand Up @@ -2781,10 +2797,17 @@ def grad(self, inputs, grads):

@staticmethod
def non_contiguous_adv_indexing(node: Apply) -> bool:
warnings.warn(
"Method was renamed to `non_consecutive_adv_indexing`", FutureWarning
)
return AdvancedSubtensor.non_consecutive_adv_indexing(node)

@staticmethod
def non_consecutive_adv_indexing(node: Apply) -> bool:
"""
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
Check if the advanced indexing is non-consecutive (i.e. interrupted by basic indexing).

This function checks if the advanced indexing is non-contiguous,
This function checks if the advanced indexing is non-consecutive,
in which case the advanced index dimensions are placed on the left of the
output array, regardless of their opriginal position.

Expand All @@ -2799,10 +2822,10 @@ def non_contiguous_adv_indexing(node: Apply) -> bool:
Returns
-------
bool
True if the advanced indexing is non-contiguous, False otherwise.
True if the advanced indexing is non-consecutive, False otherwise.
"""
_, *idxs = node.inputs
return _non_contiguous_adv_indexing(idxs)
return _non_consecutive_adv_indexing(idxs)


advanced_subtensor = AdvancedSubtensor()
Expand All @@ -2820,7 +2843,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
if isinstance(batch_idx, TensorVariable)
)

if idxs_are_batched or (x_is_batched and op.non_contiguous_adv_indexing(node)):
if idxs_are_batched or (x_is_batched and op.non_consecutive_adv_indexing(node)):
# Fallback to Blockwise if idxs are batched or if we have non contiguous advanced indexing
# which would put the indexed results to the left of the batch dimensions!
# TODO: Not all cases must be handled by Blockwise, but the logic is complex
Expand All @@ -2829,7 +2852,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
# TODO: Implement these internally, so Blockwise is always a safe fallback
if any(not isinstance(idx, TensorVariable) for idx in idxs):
raise NotImplementedError(
"Vectorized AdvancedSubtensor with batched indexes or non-contiguous advanced indexing "
"Vectorized AdvancedSubtensor with batched indexes or non-consecutive advanced indexing "
"and slices or newaxis is currently not supported."
)
else:
Expand Down Expand Up @@ -2939,10 +2962,17 @@ def grad(self, inpt, output_gradients):

@staticmethod
def non_contiguous_adv_indexing(node: Apply) -> bool:
warnings.warn(
"Method was renamed to `non_consecutive_adv_indexing`", FutureWarning
)
return AdvancedIncSubtensor.non_consecutive_adv_indexing(node)

@staticmethod
def non_consecutive_adv_indexing(node: Apply) -> bool:
"""
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
Check if the advanced indexing is non-consecutive (i.e. interrupted by basic indexing).

This function checks if the advanced indexing is non-contiguous,
This function checks if the advanced indexing is non-consecutive,
in which case the advanced index dimensions are placed on the left of the
output array, regardless of their opriginal position.

Expand All @@ -2957,10 +2987,10 @@ def non_contiguous_adv_indexing(node: Apply) -> bool:
Returns
-------
bool
True if the advanced indexing is non-contiguous, False otherwise.
True if the advanced indexing is non-consecutive, False otherwise.
"""
_, _, *idxs = node.inputs
return _non_contiguous_adv_indexing(idxs)
return _non_consecutive_adv_indexing(idxs)


advanced_inc_subtensor = AdvancedIncSubtensor()
Expand Down
1 change: 0 additions & 1 deletion scripts/mypy-failing.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ pytensor/link/numba/dispatch/scan.py
pytensor/printing.py
pytensor/raise_op.py
pytensor/sparse/basic.py
pytensor/sparse/type.py
pytensor/tensor/basic.py
pytensor/tensor/blas_c.py
pytensor/tensor/blas_headers.py
Expand Down
8 changes: 7 additions & 1 deletion scripts/run_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,13 @@ def check_no_unexpected_results(mypy_lines: Iterable[str]):
print(*missing, sep="\n")
sys.exit(1)
cp = subprocess.run(
["mypy", "--show-error-codes", "pytensor"],
[
"mypy",
"--show-error-codes",
"--disable-error-code",
"annotation-unchecked",
"pytensor",
],
capture_output=True,
)
output = cp.stdout.decode()
Expand Down
10 changes: 9 additions & 1 deletion tests/link/jax/test_pad.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
from packaging import version

import pytensor.tensor as pt
from pytensor import config
Expand All @@ -16,7 +17,14 @@
"mode, kwargs",
[
("constant", {"constant_values": 0}),
("constant", {"constant_values": (1, 2)}),
pytest.param(
"constant",
{"constant_values": (1, 2)},
marks=pytest.mark.skipif(
version.parse(jax.__version__) > version.parse("0.4.35"),
reason="Bug in JAX: https://github.com/jax-ml/jax/issues/26888",
),
),
("edge", {}),
("linear_ramp", {"end_values": 0}),
("linear_ramp", {"end_values": (1, 2)}),
Expand Down
Loading