Skip to content

Commit f94a44c

Browse files
committed
Support multidimensional boolean set/inc_subtensor in Numba via rewrite
1 parent 9dad122 commit f94a44c

File tree

3 files changed

+58
-16
lines changed

3 files changed

+58
-16
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
249249
This is only done when there's a single vector index.
250250
"""
251251

252-
if not isinstance(node.op, AdvancedIncSubtensor) or node.op.ignore_duplicates:
252+
if node.op.ignore_duplicates:
253253
# `AdvancedIncSubtensor1` does not ignore duplicate index values
254254
return
255255

@@ -1967,19 +1967,26 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node):
19671967
return new_out
19681968

19691969

1970-
@node_rewriter(tracks=[AdvancedSubtensor])
1970+
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
19711971
def ravel_multidimensional_bool_idx(fgraph, node):
19721972
"""Convert multidimensional boolean indexing into equivalent vector boolean index, supported by Numba
19731973
19741974
x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()]
1975+
x[eye(3, dtype=bool)].set(1) -> x.ravel()[eye(3).ravel()].set(1).reshape(x.shape)
19751976
"""
1976-
x, *idxs = node.inputs
1977+
if isinstance(node.op, AdvancedSubtensor):
1978+
x, *idxs = node.inputs
1979+
else:
1980+
x, y, *idxs = node.inputs
19771981

19781982
if any(
1979-
isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int")
1983+
(
1984+
(isinstance(idx.type, TensorType) and idx.type.dtype.startswith("int"))
1985+
or isinstance(idx.type, NoneTypeT)
1986+
)
19801987
for idx in idxs
19811988
):
1982-
# Get out if there are any other advanced indexes
1989+
# Get out if there are any other advanced indexes or np.newaxis
19831990
return None
19841991

19851992
bool_idxs = [
@@ -2007,7 +2014,16 @@ def ravel_multidimensional_bool_idx(fgraph, node):
20072014
new_idxs = list(idxs)
20082015
new_idxs[bool_idx_pos] = raveled_bool_idx
20092016

2010-
return [raveled_x[tuple(new_idxs)]]
2017+
if isinstance(node.op, AdvancedSubtensor):
2018+
new_out = node.op(raveled_x, *new_idxs)
2019+
else:
2020+
# The dimensions of y that correspond to the boolean indices
2021+
# must already be raveled in the original graph, so we don't need to do anything to it
2022+
new_out = node.op(raveled_x, y, *new_idxs)
2023+
# But we must reshape the output to math the original shape
2024+
new_out = new_out.reshape(x_shape)
2025+
2026+
return [copy_stack_trace(node.outputs[0], new_out)]
20112027

20122028

20132029
@node_rewriter(tracks=[AdvancedSubtensor])
@@ -2024,10 +2040,13 @@ def ravel_multidimensional_int_idx(fgraph, node):
20242040
x, *idxs = node.inputs
20252041

20262042
if any(
2027-
isinstance(idx.type, TensorType) and idx.type.dtype.startswith("bool")
2043+
(
2044+
(isinstance(idx.type, TensorType) and idx.type.dtype == "bool")
2045+
or isinstance(idx.type, NoneTypeT)
2046+
)
20282047
for idx in idxs
20292048
):
2030-
# Get out if there are any other advanced indexes
2049+
# Get out if there are any other advanced indexes or np.newaxis
20312050
return None
20322051

20332052
int_idxs = [
@@ -2059,7 +2078,8 @@ def ravel_multidimensional_int_idx(fgraph, node):
20592078
*int_idx.shape,
20602079
*raveled_shape[int_idx_pos + 1 :],
20612080
)
2062-
return [raveled_subtensor.reshape(unraveled_shape)]
2081+
new_out = raveled_subtensor.reshape(unraveled_shape)
2082+
return [copy_stack_trace(node.outputs[0], new_out)]
20632083

20642084

20652085
optdb["specialize"].register(

pytensor/tensor/subtensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,10 +1456,10 @@ def inc_subtensor(
14561456
views; if they overlap, the result of this `Op` will generally be
14571457
incorrect. This value has no effect if ``inplace=False``.
14581458
ignore_duplicates
1459-
This determines whether or not ``x[indices] += y`` is used or
1459+
This determines whether ``x[indices] += y`` is used or
14601460
``np.add.at(x, indices, y)``. When the special duplicates handling of
14611461
``np.add.at`` isn't required, setting this option to ``True``
1462-
(i.e. using ``x[indices] += y``) can resulting in faster compiled
1462+
(i.e. using ``x[indices] += y``) can result in faster compiled
14631463
graphs.
14641464
14651465
Examples

tests/link/numba/test_subtensor.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,19 @@ def test_AdvancedIncSubtensor1(x, y, indices):
334334
-np.arange(3),
335335
(np.eye(3).astype(bool)), # Boolean index
336336
False,
337-
True,
338-
True,
337+
False,
338+
False,
339+
),
340+
(
341+
np.arange(3 * 3 * 5).reshape((3, 3, 5)),
342+
rng.poisson(size=(3, 2)),
343+
(
344+
np.eye(3).astype(bool),
345+
slice(-2, None),
346+
), # Boolean index, mixed with basic index
347+
False,
348+
False,
349+
False,
339350
),
340351
(
341352
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
@@ -394,10 +405,18 @@ def test_AdvancedIncSubtensor1(x, y, indices):
394405
rng.poisson(size=(2, 2)),
395406
([[1, 2], [2, 3]]), # matrix indices
396407
False,
408+
False, # Gets converted to AdvancedIncSubtensor1
409+
True, # This is actually supported with the default `ignore_duplicates=False`
410+
),
411+
(
412+
np.arange(3 * 5).reshape((3, 5)),
413+
rng.poisson(size=(1, 2, 2)),
414+
(slice(1, 3), [[1, 2], [2, 3]]), # matrix indices, mixed with basic index
415+
False,
397416
True,
398417
True,
399418
),
400-
pytest.param(
419+
(
401420
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
402421
rng.poisson(size=(2, 5)),
403422
([1, 1], [2, 2]), # Repeated indices
@@ -418,6 +437,9 @@ def test_AdvancedIncSubtensor(
418437
inc_requires_objmode,
419438
inplace,
420439
):
440+
# Need rewrite to support certain forms of advanced indexing without object mode
441+
mode = numba_mode.including("specialize")
442+
421443
x_pt = pt.as_tensor(x).type("x")
422444
y_pt = pt.as_tensor(y).type("y")
423445

@@ -432,7 +454,7 @@ def test_AdvancedIncSubtensor(
432454
if set_requires_objmode
433455
else contextlib.nullcontext()
434456
):
435-
fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y])
457+
fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y], numba_mode=mode)
436458

437459
if inplace:
438460
# Test updates inplace
@@ -452,7 +474,7 @@ def test_AdvancedIncSubtensor(
452474
if inc_requires_objmode
453475
else contextlib.nullcontext()
454476
):
455-
fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y])
477+
fn, _ = compare_numba_and_py(([x_pt, y_pt], [out_pt]), [x, y], numba_mode=mode)
456478
if inplace:
457479
# Test updates inplace
458480
x_orig = x.copy()

0 commit comments

Comments
 (0)