Skip to content

Commit cd649ab

Browse files
committed
Support more cases of multi-dimensional advanced indexing and updating in Numba
Extends pre-existing rewrite to ravel multiple integer indices, and to place them consecutively. The following cases should now be supported without object mode: * Advanced integer indexing (not mixed with basic or boolean indexing) that do not require broadcasting of indices * Consecutive advanced integer indexing updating (set/inc) (not mixed with basic or boolean indexing) that do not require broadcasting of indices or y.
1 parent 7cd054e commit cd649ab

File tree

4 files changed

+183
-58
lines changed

4 files changed

+183
-58
lines changed

Diff for: pytensor/link/numba/dispatch/subtensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
150150
for adv_idx in adv_idxs
151151
)
152152
# Must be consecutive
153-
and not op.non_contiguous_adv_indexing(node)
153+
and not op.non_consecutive_adv_indexing(node)
154154
# y in set/inc_subtensor cannot be broadcasted
155155
and (
156156
y is None

Diff for: pytensor/tensor/rewriting/subtensor.py

+104-28
Original file line numberDiff line numberDiff line change
@@ -2029,18 +2029,41 @@ def ravel_multidimensional_bool_idx(fgraph, node):
20292029
return [copy_stack_trace(node.outputs[0], new_out)]
20302030

20312031

2032-
@node_rewriter(tracks=[AdvancedSubtensor])
2032+
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
20332033
def ravel_multidimensional_int_idx(fgraph, node):
2034-
"""Convert multidimensional integer indexing into equivalent vector integer index, supported by Numba
2035-
2036-
x[eye(3, dtype=int)] -> x[eye(3).ravel()].reshape((3, 3))
2034+
"""Convert multidimensional integer indexing into equivalent consecutive vector integer index,
2035+
supported by Numba or by our specialized dispatchers
20372036
2037+
x[eye(3)] -> x[eye(3).ravel()].reshape((3, 3))
20382038
20392039
NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices
20402040
2041-
x[eye(3, dtype=int), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
2041+
x[eye(3), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
2042+
2043+
It also handles multiple integer indices, but only if they don't broadcast
2044+
2045+
x[eye(3,), 2:, eye(3)] -> x[eye(3).ravel(), eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
2046+
2047+
Also handles AdvancedIncSubtensor, but only if the advanced indices are consecutive and neither indices nor y broadcast
2048+
2049+
x[eye(3), 2:].set(y) -> x[eye(3).ravel(), 2:].set(y.reshape(-1, y.shape[1:]))
2050+
20422051
"""
2043-
x, *idxs = node.inputs
2052+
op = node.op
2053+
non_consecutive_adv_indexing = op.non_consecutive_adv_indexing(node)
2054+
is_inc_subtensor = isinstance(op, AdvancedIncSubtensor)
2055+
2056+
if is_inc_subtensor:
2057+
x, y, *idxs = node.inputs
2058+
# Inc/SetSubtensor is harder to reason about due to y
2059+
# We get out if it's broadcasting or if the advanced indices are non-consecutive
2060+
if non_consecutive_adv_indexing or (
2061+
y.type.broadcastable != x[tuple(idxs)].type.broadcastable
2062+
):
2063+
return None
2064+
2065+
else:
2066+
x, *idxs = node.inputs
20442067

20452068
if any(
20462069
(
@@ -2049,50 +2072,103 @@ def ravel_multidimensional_int_idx(fgraph, node):
20492072
)
20502073
for idx in idxs
20512074
):
2052-
# Get out if there are any other advanced indexes or np.newaxis
2075+
# Get out if there are any other advanced indices or np.newaxis
20532076
return None
20542077

2055-
int_idxs = [
2078+
int_idxs_and_pos = [
20562079
(i, idx)
20572080
for i, idx in enumerate(idxs)
20582081
if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes)
20592082
]
20602083

2061-
if len(int_idxs) != 1:
2062-
# Get out if there are no or multiple integer idxs
2084+
if not int_idxs_and_pos:
20632085
return None
20642086

2065-
[(int_idx_pos, int_idx)] = int_idxs
2066-
if int_idx.type.ndim < 2:
2067-
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
2087+
int_idxs_pos, int_idxs = zip(
2088+
*int_idxs_and_pos, strict=False
2089+
) # strict=False because by definition it's true
2090+
2091+
first_int_idx_pos = int_idxs_pos[0]
2092+
first_int_idx = int_idxs[0]
2093+
first_int_idx_bcast = first_int_idx.type.broadcastable
2094+
2095+
if any(int_idx.type.broadcastable != first_int_idx_bcast for int_idx in int_idxs):
2096+
# We don't have a view-only broadcasting operation
2097+
# Explicitly broadcasting the indices can incur a memory / copy overhead
20682098
return None
20692099

2070-
raveled_int_idx = int_idx.ravel()
2071-
new_idxs = list(idxs)
2072-
new_idxs[int_idx_pos] = raveled_int_idx
2073-
raveled_subtensor = x[tuple(new_idxs)]
2074-
2075-
# Reshape into correct shape
2076-
# Because we only allow one advanced indexing, the output dimension corresponding to the raveled integer indexing
2077-
# must match the input position. If there were multiple advanced indexes, this could have been forcefully moved to the front
2078-
raveled_shape = raveled_subtensor.shape
2079-
unraveled_shape = (
2080-
*raveled_shape[:int_idx_pos],
2081-
*int_idx.shape,
2082-
*raveled_shape[int_idx_pos + 1 :],
2083-
)
2084-
new_out = raveled_subtensor.reshape(unraveled_shape)
2100+
int_idxs_ndim = len(first_int_idx_bcast)
2101+
if (
2102+
int_idxs_ndim == 0
2103+
): # This should be a basic indexing operation, rewrite elsewhere
2104+
return None
2105+
2106+
int_idxs_need_raveling = int_idxs_ndim > 1
2107+
if not (int_idxs_need_raveling or non_consecutive_adv_indexing):
2108+
# Numba or our dispatch natively supports consecutive vector indices, nothing needs to be done
2109+
return None
2110+
2111+
# Reorder non-consecutive indices
2112+
if non_consecutive_adv_indexing:
2113+
assert not is_inc_subtensor # Sanity check that we got out if this was the case
2114+
# This case works as if all the advanced indices were on the front
2115+
transposition = list(int_idxs_pos) + [
2116+
i for i in range(len(idxs)) if i not in int_idxs_pos
2117+
]
2118+
idxs = tuple(idxs[a] for a in transposition)
2119+
x = x.transpose(transposition)
2120+
first_int_idx_pos = 0
2121+
del int_idxs_pos # Make sure they are not wrongly used
2122+
2123+
# Ravel multidimensional indices
2124+
if int_idxs_need_raveling:
2125+
idxs = list(idxs)
2126+
for idx_pos, int_idx in enumerate(int_idxs, start=first_int_idx_pos):
2127+
idxs[idx_pos] = int_idx.ravel()
2128+
2129+
# Index with reordered and/or raveled indices
2130+
new_subtensor = x[tuple(idxs)]
2131+
2132+
if is_inc_subtensor:
2133+
y_shape = tuple(y.shape)
2134+
y_raveled_shape = (
2135+
*y_shape[:first_int_idx_pos],
2136+
-1,
2137+
*y_shape[first_int_idx_pos + int_idxs_ndim :],
2138+
)
2139+
y_raveled = y.reshape(y_raveled_shape)
2140+
2141+
new_out = inc_subtensor(
2142+
new_subtensor,
2143+
y_raveled,
2144+
set_instead_of_inc=op.set_instead_of_inc,
2145+
ignore_duplicates=op.ignore_duplicates,
2146+
inplace=op.inplace,
2147+
)
2148+
2149+
else:
2150+
# Unravel advanced indexing dimensions
2151+
raveled_shape = tuple(new_subtensor.shape)
2152+
unraveled_shape = (
2153+
*raveled_shape[:first_int_idx_pos],
2154+
*first_int_idx.shape,
2155+
*raveled_shape[first_int_idx_pos + 1 :],
2156+
)
2157+
new_out = new_subtensor.reshape(unraveled_shape)
2158+
20852159
return [copy_stack_trace(node.outputs[0], new_out)]
20862160

20872161

20882162
optdb["specialize"].register(
20892163
ravel_multidimensional_bool_idx.__name__,
20902164
ravel_multidimensional_bool_idx,
20912165
"numba",
2166+
use_db_name_as_tag=False, # Not included if only "specialize" is requested
20922167
)
20932168

20942169
optdb["specialize"].register(
20952170
ravel_multidimensional_int_idx.__name__,
20962171
ravel_multidimensional_int_idx,
20972172
"numba",
2173+
use_db_name_as_tag=False, # Not included if only "specialize" is requested
20982174
)

Diff for: pytensor/tensor/subtensor.py

+28-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import sys
3+
import warnings
34
from collections.abc import Callable, Iterable
45
from itertools import chain, groupby
56
from textwrap import dedent
@@ -580,8 +581,8 @@ def group_indices(indices):
580581
return idx_groups
581582

582583

583-
def _non_contiguous_adv_indexing(indices) -> bool:
584-
"""Check if the advanced indexing is non-contiguous (i.e., split by basic indexing)."""
584+
def _non_consecutive_adv_indexing(indices) -> bool:
585+
"""Check if the advanced indexing is non-consecutive (i.e., split by basic indexing)."""
585586
idx_groups = group_indices(indices)
586587
# This means that there are at least two groups of advanced indexing separated by basic indexing
587588
return len(idx_groups) > 3 or (len(idx_groups) == 3 and not idx_groups[0][0])
@@ -611,7 +612,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
611612
remaining_dims = range(pytensor.tensor.basic.get_vector_length(array_shape))
612613
idx_groups = group_indices(indices)
613614

614-
if _non_contiguous_adv_indexing(indices):
615+
if _non_consecutive_adv_indexing(indices):
615616
# In this case NumPy places the advanced index groups in the front of the array
616617
# https://numpy.org/devdocs/user/basics.indexing.html#combining-advanced-and-basic-indexing
617618
idx_groups = sorted(idx_groups, key=lambda x: x[0])
@@ -2796,10 +2797,17 @@ def grad(self, inputs, grads):
27962797

27972798
@staticmethod
27982799
def non_contiguous_adv_indexing(node: Apply) -> bool:
2800+
warnings.warn(
2801+
"Method was renamed to `non_consecutive_adv_indexing`", FutureWarning
2802+
)
2803+
return AdvancedSubtensor.non_consecutive_adv_indexing(node)
2804+
2805+
@staticmethod
2806+
def non_consecutive_adv_indexing(node: Apply) -> bool:
27992807
"""
2800-
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
2808+
Check if the advanced indexing is non-consecutive (i.e. interrupted by basic indexing).
28012809
2802-
This function checks if the advanced indexing is non-contiguous,
2810+
This function checks if the advanced indexing is non-consecutive,
28032811
in which case the advanced index dimensions are placed on the left of the
28042812
output array, regardless of their opriginal position.
28052813
@@ -2814,10 +2822,10 @@ def non_contiguous_adv_indexing(node: Apply) -> bool:
28142822
Returns
28152823
-------
28162824
bool
2817-
True if the advanced indexing is non-contiguous, False otherwise.
2825+
True if the advanced indexing is non-consecutive, False otherwise.
28182826
"""
28192827
_, *idxs = node.inputs
2820-
return _non_contiguous_adv_indexing(idxs)
2828+
return _non_consecutive_adv_indexing(idxs)
28212829

28222830

28232831
advanced_subtensor = AdvancedSubtensor()
@@ -2835,7 +2843,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
28352843
if isinstance(batch_idx, TensorVariable)
28362844
)
28372845

2838-
if idxs_are_batched or (x_is_batched and op.non_contiguous_adv_indexing(node)):
2846+
if idxs_are_batched or (x_is_batched and op.non_consecutive_adv_indexing(node)):
28392847
# Fallback to Blockwise if idxs are batched or if we have non contiguous advanced indexing
28402848
# which would put the indexed results to the left of the batch dimensions!
28412849
# TODO: Not all cases must be handled by Blockwise, but the logic is complex
@@ -2844,7 +2852,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
28442852
# TODO: Implement these internally, so Blockwise is always a safe fallback
28452853
if any(not isinstance(idx, TensorVariable) for idx in idxs):
28462854
raise NotImplementedError(
2847-
"Vectorized AdvancedSubtensor with batched indexes or non-contiguous advanced indexing "
2855+
"Vectorized AdvancedSubtensor with batched indexes or non-consecutive advanced indexing "
28482856
"and slices or newaxis is currently not supported."
28492857
)
28502858
else:
@@ -2954,10 +2962,17 @@ def grad(self, inpt, output_gradients):
29542962

29552963
@staticmethod
29562964
def non_contiguous_adv_indexing(node: Apply) -> bool:
2965+
warnings.warn(
2966+
"Method was renamed to `non_consecutive_adv_indexing`", FutureWarning
2967+
)
2968+
return AdvancedIncSubtensor.non_consecutive_adv_indexing(node)
2969+
2970+
@staticmethod
2971+
def non_consecutive_adv_indexing(node: Apply) -> bool:
29572972
"""
2958-
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
2973+
Check if the advanced indexing is non-consecutive (i.e. interrupted by basic indexing).
29592974
2960-
This function checks if the advanced indexing is non-contiguous,
2975+
This function checks if the advanced indexing is non-consecutive,
29612976
in which case the advanced index dimensions are placed on the left of the
29622977
output array, regardless of their opriginal position.
29632978
@@ -2972,10 +2987,10 @@ def non_contiguous_adv_indexing(node: Apply) -> bool:
29722987
Returns
29732988
-------
29742989
bool
2975-
True if the advanced indexing is non-contiguous, False otherwise.
2990+
True if the advanced indexing is non-consecutive, False otherwise.
29762991
"""
29772992
_, _, *idxs = node.inputs
2978-
return _non_contiguous_adv_indexing(idxs)
2993+
return _non_consecutive_adv_indexing(idxs)
29792994

29802995

29812996
advanced_inc_subtensor = AdvancedIncSubtensor()

0 commit comments

Comments
 (0)