diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index 94cb51434d..81348b57be 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -150,7 +150,7 @@ def broadcasted_to(x_bcast: tuple[bool, ...], to_bcast: tuple[bool, ...]): for adv_idx in adv_idxs ) # Must be consecutive - and not op.non_contiguous_adv_indexing(node) + and not op.non_consecutive_adv_indexing(node) # y in set/inc_subtensor cannot be broadcasted and ( y is None diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 020a2e04e0..c7a4574a91 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -2029,18 +2029,41 @@ def ravel_multidimensional_bool_idx(fgraph, node): return [copy_stack_trace(node.outputs[0], new_out)] -@node_rewriter(tracks=[AdvancedSubtensor]) +@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor]) def ravel_multidimensional_int_idx(fgraph, node): - """Convert multidimensional integer indexing into equivalent vector integer index, supported by Numba - - x[eye(3, dtype=int)] -> x[eye(3).ravel()].reshape((3, 3)) + """Convert multidimensional integer indexing into equivalent consecutive vector integer index, + supported by Numba or by our specialized dispatchers + x[eye(3)] -> x[eye(3).ravel()].reshape((3, 3)) NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices - x[eye(3, dtype=int), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes + x[eye(3), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes + + It also handles multiple integer indices, but only if they don't broadcast + + x[eye(3,), 2:, eye(3)] -> x[eye(3).ravel(), eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes + + Also handles AdvancedIncSubtensor, but only if the advanced indices are consecutive and neither indices nor y broadcast + + x[eye(3), 2:].set(y) -> x[eye(3).ravel(), 2:].set(y.reshape(-1, y.shape[1:])) + """ - x, *idxs = node.inputs + op = node.op + non_consecutive_adv_indexing = op.non_consecutive_adv_indexing(node) + is_inc_subtensor = isinstance(op, AdvancedIncSubtensor) + + if is_inc_subtensor: + x, y, *idxs = node.inputs + # Inc/SetSubtensor is harder to reason about due to y + # We get out if it's broadcasting or if the advanced indices are non-consecutive + if non_consecutive_adv_indexing or ( + y.type.broadcastable != x[tuple(idxs)].type.broadcastable + ): + return None + + else: + x, *idxs = node.inputs if any( ( @@ -2049,39 +2072,90 @@ def ravel_multidimensional_int_idx(fgraph, node): ) for idx in idxs ): - # Get out if there are any other advanced indexes or np.newaxis + # Get out if there are any other advanced indices or np.newaxis return None - int_idxs = [ + int_idxs_and_pos = [ (i, idx) for i, idx in enumerate(idxs) if (isinstance(idx.type, TensorType) and idx.dtype in integer_dtypes) ] - if len(int_idxs) != 1: - # Get out if there are no or multiple integer idxs + if not int_idxs_and_pos: return None - [(int_idx_pos, int_idx)] = int_idxs - if int_idx.type.ndim < 2: - # No need to do anything if it's a vector or scalar, as it's already supported by Numba + int_idxs_pos, int_idxs = zip( + *int_idxs_and_pos, strict=False + ) # strict=False because by definition it's true + + first_int_idx_pos = int_idxs_pos[0] + first_int_idx = int_idxs[0] + first_int_idx_bcast = first_int_idx.type.broadcastable + + if any(int_idx.type.broadcastable != first_int_idx_bcast for int_idx in int_idxs): + # We don't have a view-only broadcasting operation + # Explicitly broadcasting the indices can incur a memory / copy overhead return None - raveled_int_idx = int_idx.ravel() - new_idxs = list(idxs) - new_idxs[int_idx_pos] = raveled_int_idx - raveled_subtensor = x[tuple(new_idxs)] - - # Reshape into correct shape - # Because we only allow one advanced indexing, the output dimension corresponding to the raveled integer indexing - # must match the input position. If there were multiple advanced indexes, this could have been forcefully moved to the front - raveled_shape = raveled_subtensor.shape - unraveled_shape = ( - *raveled_shape[:int_idx_pos], - *int_idx.shape, - *raveled_shape[int_idx_pos + 1 :], - ) - new_out = raveled_subtensor.reshape(unraveled_shape) + int_idxs_ndim = len(first_int_idx_bcast) + if ( + int_idxs_ndim == 0 + ): # This should be a basic indexing operation, rewrite elsewhere + return None + + int_idxs_need_raveling = int_idxs_ndim > 1 + if not (int_idxs_need_raveling or non_consecutive_adv_indexing): + # Numba or our dispatch natively supports consecutive vector indices, nothing needs to be done + return None + + # Reorder non-consecutive indices + if non_consecutive_adv_indexing: + assert not is_inc_subtensor # Sanity check that we got out if this was the case + # This case works as if all the advanced indices were on the front + transposition = list(int_idxs_pos) + [ + i for i in range(len(idxs)) if i not in int_idxs_pos + ] + idxs = tuple(idxs[a] for a in transposition) + x = x.transpose(transposition) + first_int_idx_pos = 0 + del int_idxs_pos # Make sure they are not wrongly used + + # Ravel multidimensional indices + if int_idxs_need_raveling: + idxs = list(idxs) + for idx_pos, int_idx in enumerate(int_idxs, start=first_int_idx_pos): + idxs[idx_pos] = int_idx.ravel() + + # Index with reordered and/or raveled indices + new_subtensor = x[tuple(idxs)] + + if is_inc_subtensor: + y_shape = tuple(y.shape) + y_raveled_shape = ( + *y_shape[:first_int_idx_pos], + -1, + *y_shape[first_int_idx_pos + int_idxs_ndim :], + ) + y_raveled = y.reshape(y_raveled_shape) + + new_out = inc_subtensor( + new_subtensor, + y_raveled, + set_instead_of_inc=op.set_instead_of_inc, + ignore_duplicates=op.ignore_duplicates, + inplace=op.inplace, + ) + + else: + # Unravel advanced indexing dimensions + raveled_shape = tuple(new_subtensor.shape) + unraveled_shape = ( + *raveled_shape[:first_int_idx_pos], + *first_int_idx.shape, + *raveled_shape[first_int_idx_pos + 1 :], + ) + new_out = new_subtensor.reshape(unraveled_shape) + return [copy_stack_trace(node.outputs[0], new_out)] @@ -2089,10 +2163,12 @@ def ravel_multidimensional_int_idx(fgraph, node): ravel_multidimensional_bool_idx.__name__, ravel_multidimensional_bool_idx, "numba", + use_db_name_as_tag=False, # Not included if only "specialize" is requested ) optdb["specialize"].register( ravel_multidimensional_int_idx.__name__, ravel_multidimensional_int_idx, "numba", + use_db_name_as_tag=False, # Not included if only "specialize" is requested ) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 3a2304eb7b..3de4f41068 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -1,5 +1,6 @@ import logging import sys +import warnings from collections.abc import Callable, Iterable from itertools import chain, groupby from textwrap import dedent @@ -59,6 +60,7 @@ zscalar, ) from pytensor.tensor.type_other import ( + MakeSlice, NoneConst, NoneTypeT, SliceConstant, @@ -527,11 +529,20 @@ def basic_shape(shape, indices): if isinstance(idx, slice): res_shape += (slice_len(idx, n),) elif isinstance(getattr(idx, "type", None), SliceType): - if idx.owner: - idx_inputs = idx.owner.inputs + if idx.owner is None: + if not isinstance(idx, Constant): + # This is an input slice, we can't reason symbolically on it. + # We don't even know if we will get None entries or integers + res_shape += (None,) + continue + else: + sl: slice = idx.data + slice_inputs = (sl.start, sl.stop, sl.step) + elif isinstance(idx.owner.op, MakeSlice): + slice_inputs = idx.owner.inputs else: - idx_inputs = (None,) - res_shape += (slice_len(slice(*idx_inputs), n),) + raise ValueError(f"Unexpected Slice producing Op {idx.owner.op}") + res_shape += (slice_len(slice(*slice_inputs), n),) elif idx is None: res_shape += (ps.ScalarConstant(ps.int64, 1),) elif isinstance(getattr(idx, "type", None), NoneTypeT): @@ -570,8 +581,8 @@ def group_indices(indices): return idx_groups -def _non_contiguous_adv_indexing(indices) -> bool: - """Check if the advanced indexing is non-contiguous (i.e., split by basic indexing).""" +def _non_consecutive_adv_indexing(indices) -> bool: + """Check if the advanced indexing is non-consecutive (i.e., split by basic indexing).""" idx_groups = group_indices(indices) # This means that there are at least two groups of advanced indexing separated by basic indexing return len(idx_groups) > 3 or (len(idx_groups) == 3 and not idx_groups[0][0]) @@ -601,7 +612,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False): remaining_dims = range(pytensor.tensor.basic.get_vector_length(array_shape)) idx_groups = group_indices(indices) - if _non_contiguous_adv_indexing(indices): + if _non_consecutive_adv_indexing(indices): # In this case NumPy places the advanced index groups in the front of the array # https://numpy.org/devdocs/user/basics.indexing.html#combining-advanced-and-basic-indexing idx_groups = sorted(idx_groups, key=lambda x: x[0]) @@ -2728,6 +2739,11 @@ def is_bool_index(idx): res_shape = list( indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True) ) + for i, res_dim_length in enumerate(res_shape): + if res_dim_length is None: + # This can happen when we have a Slice provided by the user (not a constant nor the result of MakeSlice) + # We must compute the Op to find its shape + res_shape[i] = Shape_i(i)(node.out) adv_indices = [idx for idx in indices if not is_basic_idx(idx)] bool_indices = [idx for idx in adv_indices if is_bool_index(idx)] @@ -2781,10 +2797,17 @@ def grad(self, inputs, grads): @staticmethod def non_contiguous_adv_indexing(node: Apply) -> bool: + warnings.warn( + "Method was renamed to `non_consecutive_adv_indexing`", FutureWarning + ) + return AdvancedSubtensor.non_consecutive_adv_indexing(node) + + @staticmethod + def non_consecutive_adv_indexing(node: Apply) -> bool: """ - Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing). + Check if the advanced indexing is non-consecutive (i.e. interrupted by basic indexing). - This function checks if the advanced indexing is non-contiguous, + This function checks if the advanced indexing is non-consecutive, in which case the advanced index dimensions are placed on the left of the output array, regardless of their opriginal position. @@ -2799,10 +2822,10 @@ def non_contiguous_adv_indexing(node: Apply) -> bool: Returns ------- bool - True if the advanced indexing is non-contiguous, False otherwise. + True if the advanced indexing is non-consecutive, False otherwise. """ _, *idxs = node.inputs - return _non_contiguous_adv_indexing(idxs) + return _non_consecutive_adv_indexing(idxs) advanced_subtensor = AdvancedSubtensor() @@ -2820,7 +2843,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs): if isinstance(batch_idx, TensorVariable) ) - if idxs_are_batched or (x_is_batched and op.non_contiguous_adv_indexing(node)): + if idxs_are_batched or (x_is_batched and op.non_consecutive_adv_indexing(node)): # Fallback to Blockwise if idxs are batched or if we have non contiguous advanced indexing # which would put the indexed results to the left of the batch dimensions! # TODO: Not all cases must be handled by Blockwise, but the logic is complex @@ -2829,7 +2852,7 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs): # TODO: Implement these internally, so Blockwise is always a safe fallback if any(not isinstance(idx, TensorVariable) for idx in idxs): raise NotImplementedError( - "Vectorized AdvancedSubtensor with batched indexes or non-contiguous advanced indexing " + "Vectorized AdvancedSubtensor with batched indexes or non-consecutive advanced indexing " "and slices or newaxis is currently not supported." ) else: @@ -2939,10 +2962,17 @@ def grad(self, inpt, output_gradients): @staticmethod def non_contiguous_adv_indexing(node: Apply) -> bool: + warnings.warn( + "Method was renamed to `non_consecutive_adv_indexing`", FutureWarning + ) + return AdvancedIncSubtensor.non_consecutive_adv_indexing(node) + + @staticmethod + def non_consecutive_adv_indexing(node: Apply) -> bool: """ - Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing). + Check if the advanced indexing is non-consecutive (i.e. interrupted by basic indexing). - This function checks if the advanced indexing is non-contiguous, + This function checks if the advanced indexing is non-consecutive, in which case the advanced index dimensions are placed on the left of the output array, regardless of their opriginal position. @@ -2957,10 +2987,10 @@ def non_contiguous_adv_indexing(node: Apply) -> bool: Returns ------- bool - True if the advanced indexing is non-contiguous, False otherwise. + True if the advanced indexing is non-consecutive, False otherwise. """ _, _, *idxs = node.inputs - return _non_contiguous_adv_indexing(idxs) + return _non_consecutive_adv_indexing(idxs) advanced_inc_subtensor = AdvancedIncSubtensor() diff --git a/scripts/mypy-failing.txt b/scripts/mypy-failing.txt index a7cb4a1826..99dd26a26e 100644 --- a/scripts/mypy-failing.txt +++ b/scripts/mypy-failing.txt @@ -11,7 +11,6 @@ pytensor/link/numba/dispatch/scan.py pytensor/printing.py pytensor/raise_op.py pytensor/sparse/basic.py -pytensor/sparse/type.py pytensor/tensor/basic.py pytensor/tensor/blas_c.py pytensor/tensor/blas_headers.py diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index c2e87560cd..34cc810647 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -142,7 +142,13 @@ def check_no_unexpected_results(mypy_lines: Iterable[str]): print(*missing, sep="\n") sys.exit(1) cp = subprocess.run( - ["mypy", "--show-error-codes", "pytensor"], + [ + "mypy", + "--show-error-codes", + "--disable-error-code", + "annotation-unchecked", + "pytensor", + ], capture_output=True, ) output = cp.stdout.decode() diff --git a/tests/link/jax/test_pad.py b/tests/link/jax/test_pad.py index 8ecb460ace..13d71be9ad 100644 --- a/tests/link/jax/test_pad.py +++ b/tests/link/jax/test_pad.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from packaging import version import pytensor.tensor as pt from pytensor import config @@ -16,7 +17,14 @@ "mode, kwargs", [ ("constant", {"constant_values": 0}), - ("constant", {"constant_values": (1, 2)}), + pytest.param( + "constant", + {"constant_values": (1, 2)}, + marks=pytest.mark.skipif( + version.parse(jax.__version__) > version.parse("0.4.35"), + reason="Bug in JAX: https://github.com/jax-ml/jax/issues/26888", + ), + ), ("edge", {}), ("linear_ramp", {"end_values": 0}), ("linear_ramp", {"end_values": (1, 2)}), diff --git a/tests/link/numba/test_subtensor.py b/tests/link/numba/test_subtensor.py index 8b95de34b7..675afdc996 100644 --- a/tests/link/numba/test_subtensor.py +++ b/tests/link/numba/test_subtensor.py @@ -81,11 +81,6 @@ def test_AdvancedSubtensor1_out_of_bounds(): (np.array([True, False, False])), False, ), - ( - pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - ([1, 2], [2, 3]), - False, - ), # Single multidimensional indexing (supported after specialization rewrites) ( as_tensor(np.arange(3 * 3).reshape((3, 3))), @@ -117,6 +112,12 @@ def test_AdvancedSubtensor1_out_of_bounds(): (slice(2, None), np.eye(3).astype(bool)), False, ), + # Multiple vector indexing (supported by our dispatcher) + ( + pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + ([1, 2], [2, 3]), + False, + ), ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (slice(None), [1, 2], [3, 4]), @@ -127,18 +128,35 @@ def test_AdvancedSubtensor1_out_of_bounds(): ([1, 2], [3, 4], [5, 6]), False, ), - # Non-contiguous vector indexing, only supported in obj mode + # Non-consecutive vector indexing, supported by our dispatcher after rewriting ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], slice(None), [3, 4]), - True, + False, + ), + # Multiple multidimensional integer indexing (supported by our dispatcher) + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + ([[1, 2], [2, 1]], [[0, 0], [0, 0]]), + False, + ), + ( + as_tensor(np.arange(2 * 3 * 4 * 5).reshape((2, 3, 4, 5))), + (slice(None), [[1, 2], [2, 1]], slice(None), [[0, 0], [0, 0]]), + False, ), - # >1d vector indexing, only supported in obj mode + # Multiple multidimensional indexing with broadcasting, only supported in obj mode ( as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([[1, 2], [2, 1]], [0, 0]), True, ), + # multiple multidimensional integer indexing mixed with basic indexing, only supported in obj mode + ( + as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + ([[1, 2], [2, 1]], slice(1, None), [[0, 0], [0, 0]]), + True, + ), ], ) @pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed @@ -297,7 +315,7 @@ def test_AdvancedIncSubtensor1(x, y, indices): ( np.arange(3 * 4 * 5).reshape((3, 4, 5)), -np.arange(4 * 5).reshape(4, 5), - (0, [1, 2, 2, 3]), # Broadcasted vector index + (0, [1, 2, 2, 3]), # Broadcasted vector index with repeated values True, False, True, @@ -305,7 +323,7 @@ def test_AdvancedIncSubtensor1(x, y, indices): ( np.arange(3 * 4 * 5).reshape((3, 4, 5)), np.array([-99]), # Broadcasted value - (0, [1, 2, 2, 3]), # Broadcasted vector index + (0, [1, 2, 2, 3]), # Broadcasted vector index with repeated values True, False, True, @@ -380,7 +398,7 @@ def test_AdvancedIncSubtensor1(x, y, indices): ( np.arange(3 * 4 * 5).reshape((3, 4, 5)), rng.poisson(size=(2, 4)), - ([1, 2], slice(None), [3, 4]), # Non-contiguous vector indices + ([1, 2], slice(None), [3, 4]), # Non-consecutive vector indices False, True, True, @@ -400,15 +418,23 @@ def test_AdvancedIncSubtensor1(x, y, indices): ( np.arange(5), rng.poisson(size=(2, 2)), - ([[1, 2], [2, 3]]), # matrix indices + ([[1, 2], [2, 3]]), # matrix index + False, + False, + False, + ), + ( + np.arange(3 * 5).reshape((3, 5)), + rng.poisson(size=(2, 2, 2)), + (slice(1, 3), [[1, 2], [2, 3]]), # matrix index, mixed with basic index + False, + False, False, - False, # Gets converted to AdvancedIncSubtensor1 - True, # This is actually supported with the default `ignore_duplicates=False` ), ( np.arange(3 * 5).reshape((3, 5)), - rng.poisson(size=(1, 2, 2)), - (slice(1, 3), [[1, 2], [2, 3]]), # matrix indices, mixed with basic index + rng.poisson(size=(1, 2, 2)), # Same as before, but Y broadcasts + (slice(1, 3), [[1, 2], [2, 3]]), False, True, True, @@ -421,6 +447,14 @@ def test_AdvancedIncSubtensor1(x, y, indices): False, False, ), + ( + np.arange(3 * 4 * 5).reshape((3, 4, 5)), + rng.poisson(size=(3, 2, 2)), + (slice(None), [[1, 2], [2, 1]], [[2, 3], [0, 0]]), # 2 matrix indices + False, + False, + False, + ), ], ) @pytest.mark.parametrize("inplace", (False, True)) diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 3886a08f48..ebe07f4947 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -15,6 +15,7 @@ from pytensor.compile.mode import Mode from pytensor.configdefaults import config from pytensor.gradient import grad +from pytensor.graph import Constant from pytensor.graph.op import get_test_value from pytensor.graph.rewriting.utils import is_same_graph from pytensor.printing import pprint @@ -37,6 +38,7 @@ advanced_inc_subtensor1, advanced_set_subtensor, advanced_set_subtensor1, + advanced_subtensor, advanced_subtensor1, as_index_literal, basic_shape, @@ -2145,7 +2147,17 @@ def test_adv_sub_slice(self): slc = slicetype() f = pytensor.function([slc], var[slc], mode=self.mode) s = slice(1, 3) - f(s) + assert f(s).shape == (2, 3) + + f_shape0 = pytensor.function([slc], var[slc].shape[0], mode=self.mode) + assert f_shape0(s) == 2 + + f_shape1 = pytensor.function([slc], var[slc].shape[1], mode=self.mode) + assert not any( + isinstance(node.op, AdvancedSubtensor) + for node in f_shape1.maker.fgraph.toposort() + ) + assert f_shape1(s) == 3 def test_adv_grouped(self): # Reported in https://github.com/Theano/Theano/issues/6152 @@ -2611,6 +2623,14 @@ def test_AdvancedSubtensor_bool_mixed(self): AdvancedSubtensor, ) + def test_advanced_subtensor_constant_slice(self): + x = dmatrix("x") + constant_slice = pytensor.as_symbolic(slice(1, None, None)) + assert isinstance(constant_slice, Constant) + adv_indices = ptb.constant(np.zeros((2, 3)), dtype="int") + y = advanced_subtensor(x, constant_slice, adv_indices) + assert tuple(y.shape.eval({x: np.zeros((10, 10))})) == (9, 2, 3) + @config.change_flags(compute_test_value="raise") def test_basic_shape():