diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index 6dc4d4c294..8e79efda00 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -130,7 +130,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs): if isinstance(idx.type, TensorType) ] - # Special case for consecutive consecutive vector indices def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]): # Check that x is not broadcasted to y based on broadcastable info if len(x_bcast) < len(to_bcast): @@ -176,7 +175,14 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]): or ( isinstance(op, AdvancedIncSubtensor) and not op.set_instead_of_inc - and not op.ignore_duplicates + and not ( + op.ignore_duplicates + # Only vector integer indices can have "duplicates", not scalars or boolean vectors + or all( + adv_idx["ndim"] == 0 or adv_idx["dtype"] == "bool" + for adv_idx in adv_idxs + ) + ) ) ): return generate_fallback_impl(op, node, **kwargs) diff --git a/tests/link/numba/test_subtensor.py b/tests/link/numba/test_subtensor.py index d28c94f5b5..8b95de34b7 100644 --- a/tests/link/numba/test_subtensor.py +++ b/tests/link/numba/test_subtensor.py @@ -314,8 +314,16 @@ def test_AdvancedIncSubtensor1(x, y, indices): np.arange(3 * 4 * 5).reshape((3, 4, 5)), -np.arange(1 * 4 * 5).reshape(1, 4, 5), (np.array([True, False, False])), # Broadcasted boolean index + False, # It shouldn't matter what we set this to, boolean indices cannot be duplicate False, False, + ), + ( + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + -np.arange(1 * 4 * 5).reshape(1, 4, 5), + (np.array([True, False, False])), # Broadcasted boolean index + True, # It shouldn't matter what we set this to, boolean indices cannot be duplicate + False, False, ), (