Skip to content

Commit ec99fca

Browse files
committed
Support more cases of multi-dimensional advanced indexing and updating in Numba
Extends pre-existing rewrite to ravel multiple integer indices, and to place them consecutively. The following cases should now be supported without object mode: * Advanced integer indexing (not mixed with basic or boolean indexing) that do not require broadcasting of indices * Consecutive advanced integer indexing updating (set/inc) (not mixed with basic or boolean indexing) that do not require broadcasting of indices or y.
1 parent 5975304 commit ec99fca

File tree

4 files changed

+176
-50
lines changed

4 files changed

+176
-50
lines changed

Diff for: pytensor/link/numba/dispatch/subtensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
150150
for adv_idx in adv_idxs
151151
)
152152
# Must be consecutive
153-
and not op.non_contiguous_adv_indexing(node)
153+
and not op.non_consecutive_adv_indexing(node)
154154
# y in set/inc_subtensor cannot be broadcasted
155155
and (
156156
y is None

Diff for: pytensor/tensor/rewriting/subtensor.py

+105-28
Original file line numberDiff line numberDiff line change
@@ -2029,18 +2029,41 @@ def ravel_multidimensional_bool_idx(fgraph, node):
20292029
return [copy_stack_trace(node.outputs[0], new_out)]
20302030

20312031

2032-
@node_rewriter(tracks=[AdvancedSubtensor])
2032+
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
20332033
def ravel_multidimensional_int_idx(fgraph, node):
2034-
"""Convert multidimensional integer indexing into equivalent vector integer index, supported by Numba
2035-
2036-
x[eye(3, dtype=int)] -> x[eye(3).ravel()].reshape((3, 3))
2034+
"""Convert multidimensional integer indexing into equivalent consecutive vector integer index,
2035+
supported by Numba or by our specialized dispatchers
20372036
2037+
x[eye(3)] -> x[eye(3).ravel()].reshape((3, 3))
20382038
20392039
NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices
20402040
2041-
x[eye(3, dtype=int), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
2041+
x[eye(3), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
2042+
2043+
It also handles multiple integer indices, but only if they don't broadcast
2044+
2045+
x[eye(3,), 2:, eye(3)] -> x[eye(3).ravel(), eye(3).ravel() 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
2046+
2047+
Also handles AdvancedIncSubtensor, but only if the advanced indices are consecutive and neither indices nor y broadcast
2048+
2049+
x[eye(3), 2:].set(y) -> x[eye(3).ravel(), 2:].set(y.reshape(-1, y.shape[1:]))
2050+
20422051
"""
2043-
x, *idxs = node.inputs
2052+
op = node.op
2053+
non_consecutive_adv_indexing = op.non_consecutive_adv_indexing(node)
2054+
is_inc_subtensor = isinstance(op, AdvancedIncSubtensor)
2055+
2056+
if is_inc_subtensor:
2057+
x, y, *idxs = node.inputs
2058+
# Inc/SetSubtensor is harder to reason about due to y
2059+
# We get out if it's broadcasting or if the advanced indices are non-consecutive
2060+
if non_consecutive_adv_indexing or (
2061+
y.type.broadcastable != x[tuple(idxs)].type.broadcastable
2062+
):
2063+
return None
2064+
2065+
else:
2066+
x, *idxs = node.inputs
20442067

20452068
if any(
20462069
(
@@ -2049,50 +2072,104 @@ def ravel_multidimensional_int_idx(fgraph, node):
20492072
)
20502073
for idx in idxs
20512074
):
2052-
# Get out if there are any other advanced indexes or np.newaxis
2075+
# Get out if there are any other advanced indices or np.newaxis
20532076
return None
20542077

2055-
int_idxs = [
2078+
int_idxs_and_pos = [
20562079
(i, idx)
20572080
for i, idx in enumerate(idxs)
20582081
if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes)
20592082
]
20602083

2061-
if len(int_idxs) != 1:
2062-
# Get out if there are no or multiple integer idxs
2084+
if not int_idxs_and_pos:
20632085
return None
20642086

2065-
[(int_idx_pos, int_idx)] = int_idxs
2066-
if int_idx.type.ndim < 2:
2067-
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
2087+
int_idxs_pos, int_idxs = zip(
2088+
*int_idxs_and_pos, strict=False
2089+
) # strict=False because by definition it's true
2090+
2091+
first_int_idx_pos = int_idxs_pos[0]
2092+
first_int_idx = int_idxs[0]
2093+
first_int_idx_bcast = first_int_idx.type.broadcastable
2094+
2095+
if any(int_idx.type.broadcastable != first_int_idx_bcast for int_idx in int_idxs):
2096+
# We don't have a view-only broadcasting operation
2097+
# Explicitly broadcasting the indices can incur a memory / copy overhead
20682098
return None
20692099

2070-
raveled_int_idx = int_idx.ravel()
2071-
new_idxs = list(idxs)
2072-
new_idxs[int_idx_pos] = raveled_int_idx
2073-
raveled_subtensor = x[tuple(new_idxs)]
2074-
2075-
# Reshape into correct shape
2076-
# Because we only allow one advanced indexing, the output dimension corresponding to the raveled integer indexing
2077-
# must match the input position. If there were multiple advanced indexes, this could have been forcefully moved to the front
2078-
raveled_shape = raveled_subtensor.shape
2079-
unraveled_shape = (
2080-
*raveled_shape[:int_idx_pos],
2081-
*int_idx.shape,
2082-
*raveled_shape[int_idx_pos + 1 :],
2083-
)
2084-
new_out = raveled_subtensor.reshape(unraveled_shape)
2100+
int_idxs_ndim = len(first_int_idx_bcast)
2101+
if (
2102+
int_idxs_ndim == 0
2103+
): # This should be a basic indexing operation, rewrite elsewhere
2104+
return None
2105+
2106+
int_idxs_need_raveling = int_idxs_ndim > 1
2107+
if not (int_idxs_need_raveling or non_consecutive_adv_indexing):
2108+
# Numba or our dispatch natively supports consecutive vector indices, nothing needs to be done
2109+
return None
2110+
2111+
# Reorder non-consecutive indices
2112+
if non_consecutive_adv_indexing:
2113+
assert not is_inc_subtensor # Sanity check that we got out if this was the case
2114+
# This case works as if all the advanced indices were on the front
2115+
transposition = list(int_idxs_pos) + [
2116+
i for i in range(len(idxs)) if i not in int_idxs_pos
2117+
]
2118+
idxs = tuple(idxs[a] for a in transposition)
2119+
x = x.transpose(transposition)
2120+
first_int_idx_pos = 0
2121+
del int_idxs_pos # Make sure they are not wrongly used
2122+
2123+
# Ravel multidimensional indices
2124+
if int_idxs_need_raveling:
2125+
idxs = list(idxs)
2126+
for idx_pos, int_idx in enumerate(int_idxs, start=first_int_idx_pos):
2127+
idxs[idx_pos] = int_idx.ravel()
2128+
2129+
# Index with reordered and/or raveled indices
2130+
new_subtensor = x[tuple(idxs)]
2131+
2132+
if is_inc_subtensor:
2133+
int_idx_ndim = len(first_int_idx_bcast)
2134+
y_shape = tuple(y.shape)
2135+
y_raveled_shape = (
2136+
*y_shape[:first_int_idx_pos],
2137+
-1,
2138+
*y_shape[first_int_idx_pos + int_idx_ndim :],
2139+
)
2140+
y_raveled = y.reshape(y_raveled_shape)
2141+
2142+
new_out = inc_subtensor(
2143+
new_subtensor,
2144+
y_raveled,
2145+
set_instead_of_inc=op.set_instead_of_inc,
2146+
ignore_duplicates=op.ignore_duplicates,
2147+
inplace=op.inplace,
2148+
)
2149+
2150+
else:
2151+
# Unravel advanced indexing dimensions
2152+
raveled_shape = tuple(new_subtensor.shape)
2153+
unraveled_shape = (
2154+
*raveled_shape[:first_int_idx_pos],
2155+
*first_int_idx.shape,
2156+
*raveled_shape[first_int_idx_pos + 1 :],
2157+
)
2158+
new_out = new_subtensor.reshape(unraveled_shape)
2159+
20852160
return [copy_stack_trace(node.outputs[0], new_out)]
20862161

20872162

20882163
optdb["specialize"].register(
20892164
ravel_multidimensional_bool_idx.__name__,
20902165
ravel_multidimensional_bool_idx,
20912166
"numba",
2167+
use_db_name_as_tag=False, # Not included if only "specialize" is requested
20922168
)
20932169

20942170
optdb["specialize"].register(
20952171
ravel_multidimensional_int_idx.__name__,
20962172
ravel_multidimensional_int_idx,
20972173
"numba",
2174+
use_db_name_as_tag=False, # Not included if only "specialize" is requested
20982175
)

Diff for: pytensor/tensor/subtensor.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import sys
3+
import warnings
34
from collections.abc import Callable, Iterable
45
from itertools import chain, groupby
56
from textwrap import dedent
@@ -580,7 +581,7 @@ def group_indices(indices):
580581
return idx_groups
581582

582583

583-
def _non_contiguous_adv_indexing(indices) -> bool:
584+
def _non_consecutive_adv_indexing(indices) -> bool:
584585
"""Check if the advanced indexing is non-contiguous (i.e., split by basic indexing)."""
585586
idx_groups = group_indices(indices)
586587
# This means that there are at least two groups of advanced indexing separated by basic indexing
@@ -611,7 +612,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
611612
remaining_dims = range(pytensor.tensor.basic.get_vector_length(array_shape))
612613
idx_groups = group_indices(indices)
613614

614-
if _non_contiguous_adv_indexing(indices):
615+
if _non_consecutive_adv_indexing(indices):
615616
# In this case NumPy places the advanced index groups in the front of the array
616617
# https://numpy.org/devdocs/user/basics.indexing.html#combining-advanced-and-basic-indexing
617618
idx_groups = sorted(idx_groups, key=lambda x: x[0])
@@ -2796,6 +2797,13 @@ def grad(self, inputs, grads):
27962797

27972798
@staticmethod
27982799
def non_contiguous_adv_indexing(node: Apply) -> bool:
2800+
warnings.warn(
2801+
"Method was renamed to `non_consecutive_adv_indexing`", FutureWarning
2802+
)
2803+
return AdvancedSubtensor.non_consecutive_adv_indexing(node)
2804+
2805+
@staticmethod
2806+
def non_consecutive_adv_indexing(node: Apply) -> bool:
27992807
"""
28002808
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
28012809
@@ -2817,7 +2825,7 @@ def non_contiguous_adv_indexing(node: Apply) -> bool:
28172825
True if the advanced indexing is non-contiguous, False otherwise.
28182826
"""
28192827
_, *idxs = node.inputs
2820-
return _non_contiguous_adv_indexing(idxs)
2828+
return _non_consecutive_adv_indexing(idxs)
28212829

28222830

28232831
advanced_subtensor = AdvancedSubtensor()
@@ -2835,7 +2843,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
28352843
if isinstance(batch_idx, TensorVariable)
28362844
)
28372845

2838-
if idxs_are_batched or (x_is_batched and op.non_contiguous_adv_indexing(node)):
2846+
if idxs_are_batched or (x_is_batched and op.non_consecutive_adv_indexing(node)):
28392847
# Fallback to Blockwise if idxs are batched or if we have non contiguous advanced indexing
28402848
# which would put the indexed results to the left of the batch dimensions!
28412849
# TODO: Not all cases must be handled by Blockwise, but the logic is complex
@@ -2954,6 +2962,13 @@ def grad(self, inpt, output_gradients):
29542962

29552963
@staticmethod
29562964
def non_contiguous_adv_indexing(node: Apply) -> bool:
2965+
warnings.warn(
2966+
"Method was renamed to `non_consecutive_adv_indexing`", FutureWarning
2967+
)
2968+
return AdvancedIncSubtensor.non_consecutive_adv_indexing(node)
2969+
2970+
@staticmethod
2971+
def non_consecutive_adv_indexing(node: Apply) -> bool:
29572972
"""
29582973
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
29592974
@@ -2975,7 +2990,7 @@ def non_contiguous_adv_indexing(node: Apply) -> bool:
29752990
True if the advanced indexing is non-contiguous, False otherwise.
29762991
"""
29772992
_, _, *idxs = node.inputs
2978-
return _non_contiguous_adv_indexing(idxs)
2993+
return _non_consecutive_adv_indexing(idxs)
29792994

29802995

29812996
advanced_inc_subtensor = AdvancedIncSubtensor()

Diff for: tests/link/numba/test_subtensor.py

+50-16
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,6 @@ def test_AdvancedSubtensor1_out_of_bounds():
8181
(np.array([True, False, False])),
8282
False,
8383
),
84-
(
85-
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
86-
([1, 2], [2, 3]),
87-
False,
88-
),
8984
# Single multidimensional indexing (supported after specialization rewrites)
9085
(
9186
as_tensor(np.arange(3 * 3).reshape((3, 3))),
@@ -117,6 +112,12 @@ def test_AdvancedSubtensor1_out_of_bounds():
117112
(slice(2, None), np.eye(3).astype(bool)),
118113
False,
119114
),
115+
# Multiple vector indexing (supported by our dispatcher)
116+
(
117+
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
118+
([1, 2], [2, 3]),
119+
False,
120+
),
120121
(
121122
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
122123
(slice(None), [1, 2], [3, 4]),
@@ -127,18 +128,35 @@ def test_AdvancedSubtensor1_out_of_bounds():
127128
([1, 2], [3, 4], [5, 6]),
128129
False,
129130
),
130-
# Non-contiguous vector indexing, only supported in obj mode
131+
# Non-consecutive vector indexing, supported by our dispatcher after rewriting
131132
(
132133
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
133134
([1, 2], slice(None), [3, 4]),
134-
True,
135+
False,
136+
),
137+
# Multiple multidimensional integer indexing (supported by our dispatcher)
138+
(
139+
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
140+
([[1, 2], [2, 1]], [[0, 0], [0, 0]]),
141+
False,
142+
),
143+
(
144+
as_tensor(np.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5))),
145+
(slice(None), [[1, 2], [2, 1]], slice(None), [[0, 0], [0, 0]]),
146+
False,
135147
),
136-
# >1d vector indexing, only supported in obj mode
148+
# Multiple multidimensional indexing with broadcasting, only supported in obj mode
137149
(
138150
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
139151
([[1, 2], [2, 1]], [0, 0]),
140152
True,
141153
),
154+
# multiple multidimensional integer indexing mixed with basic indexing, only supported in obj mode
155+
(
156+
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
157+
([[1, 2], [2, 1]], slice(1, None), [[0, 0], [0, 0]]),
158+
True,
159+
),
142160
],
143161
)
144162
@pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed
@@ -297,15 +315,15 @@ def test_AdvancedIncSubtensor1(x, y, indices):
297315
(
298316
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
299317
-np.arange(4 * 5).reshape(4, 5),
300-
(0, [1, 2, 2, 3]), # Broadcasted vector index
318+
(0, [1, 2, 2, 3]), # Broadcasted vector index with repeated values
301319
True,
302320
False,
303321
True,
304322
),
305323
(
306324
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
307325
np.array([-99]), # Broadcasted value
308-
(0, [1, 2, 2, 3]), # Broadcasted vector index
326+
(0, [1, 2, 2, 3]), # Broadcasted vector index with repeated values
309327
True,
310328
False,
311329
True,
@@ -380,7 +398,7 @@ def test_AdvancedIncSubtensor1(x, y, indices):
380398
(
381399
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
382400
rng.poisson(size=(2, 4)),
383-
([1, 2], slice(None), [3, 4]), # Non-contiguous vector indices
401+
([1, 2], slice(None), [3, 4]), # Non-consecutive vector indices
384402
False,
385403
True,
386404
True,
@@ -400,15 +418,23 @@ def test_AdvancedIncSubtensor1(x, y, indices):
400418
(
401419
np.arange(5),
402420
rng.poisson(size=(2, 2)),
403-
([[1, 2], [2, 3]]), # matrix indices
421+
([[1, 2], [2, 3]]), # matrix index
422+
False,
423+
False,
424+
False,
425+
),
426+
(
427+
np.arange(3 * 5).reshape((3, 5)),
428+
rng.poisson(size=(2, 2, 2)),
429+
(slice(1, 3), [[1, 2], [2, 3]]), # matrix index, mixed with basic index
430+
False,
431+
False,
404432
False,
405-
False, # Gets converted to AdvancedIncSubtensor1
406-
True, # This is actually supported with the default `ignore_duplicates=False`
407433
),
408434
(
409435
np.arange(3 * 5).reshape((3, 5)),
410-
rng.poisson(size=(1, 2, 2)),
411-
(slice(1, 3), [[1, 2], [2, 3]]), # matrix indices, mixed with basic index
436+
rng.poisson(size=(1, 2, 2)), # Same as before, but Y broadcasts
437+
(slice(1, 3), [[1, 2], [2, 3]]),
412438
False,
413439
True,
414440
True,
@@ -421,6 +447,14 @@ def test_AdvancedIncSubtensor1(x, y, indices):
421447
False,
422448
False,
423449
),
450+
(
451+
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
452+
rng.poisson(size=(3, 2, 2)),
453+
(slice(None), [[1, 2], [2, 1]], [[2, 3], [0, 0]]), # 2 matrix indices
454+
False,
455+
False,
456+
False,
457+
),
424458
],
425459
)
426460
@pytest.mark.parametrize("inplace", (False, True))

0 commit comments

Comments
 (0)