Skip to content

Commit 81130a0

Browse files
committed
Don't use objectmode with vector boolean inc_subtensor
1 parent 5714253 commit 81130a0

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

Diff for: pytensor/link/numba/dispatch/subtensor.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
130130
if isinstance(idx.type, TensorType)
131131
]
132132

133-
# Special case for consecutive consecutive vector indices
134133
def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
135134
# Check that x is not broadcasted to y based on broadcastable info
136135
if len(x_bcast) < len(to_bcast):
@@ -176,7 +175,14 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]):
176175
or (
177176
isinstance(op, AdvancedIncSubtensor)
178177
and not op.set_instead_of_inc
179-
and not op.ignore_duplicates
178+
and not (
179+
op.ignore_duplicates
180+
# Only vector integer indices can have "duplicates", not scalars or boolean vectors
181+
or all(
182+
adv_idx["ndim"] == 0 or adv_idx["dtype"] == "bool"
183+
for adv_idx in adv_idxs
184+
)
185+
)
180186
)
181187
):
182188
return generate_fallback_impl(op, node, **kwargs)

Diff for: tests/link/numba/test_subtensor.py

+8
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,16 @@ def test_AdvancedIncSubtensor1(x, y, indices):
314314
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
315315
-np.arange(1 * 4 * 5).reshape(1, 4, 5),
316316
(np.array([True, False, False])), # Broadcasted boolean index
317+
False, # It shouldn't matter what we set this to, boolean indices cannot be duplicate
317318
False,
318319
False,
320+
),
321+
(
322+
np.arange(3 * 4 * 5).reshape((3, 4, 5)),
323+
-np.arange(1 * 4 * 5).reshape(1, 4, 5),
324+
(np.array([True, False, False])), # Broadcasted boolean index
325+
True, # It shouldn't matter what we set this to, boolean indices cannot be duplicate
326+
False,
319327
False,
320328
),
321329
(

0 commit comments

Comments
 (0)