|
4 | 4 | from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
|
5 | 5 |
|
6 | 6 | from pytensor import Variable
|
| 7 | +from pytensor.compile import optdb |
7 | 8 | from pytensor.graph import Constant, FunctionGraph, node_rewriter
|
8 | 9 | from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
|
9 | 10 | from pytensor.scalar import basic as ps
|
|
38 | 39 | )
|
39 | 40 | from pytensor.tensor.special import Softmax, softmax
|
40 | 41 | from pytensor.tensor.subtensor import (
|
| 42 | + AdvancedSubtensor, |
41 | 43 | AdvancedSubtensor1,
|
42 | 44 | Subtensor,
|
| 45 | + _non_contiguous_adv_indexing, |
43 | 46 | as_index_literal,
|
44 | 47 | get_canonical_form_slice,
|
45 | 48 | get_constant_idx,
|
46 | 49 | get_idx_list,
|
47 | 50 | indices_from_subtensor,
|
48 | 51 | )
|
49 | 52 | from pytensor.tensor.type import TensorType
|
50 |
| -from pytensor.tensor.type_other import SliceType |
| 53 | +from pytensor.tensor.type_other import NoneTypeT, SliceType |
51 | 54 |
|
52 | 55 |
|
53 | 56 | def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]:
|
@@ -812,3 +815,79 @@ def local_subtensor_shape_constant(fgraph, node):
|
812 | 815 | return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)]
|
813 | 816 | elif shape_parts:
|
814 | 817 | return [as_tensor(1, dtype=np.int64)]
|
| 818 | + |
| 819 | + |
| 820 | +@node_rewriter([Subtensor]) |
| 821 | +def local_subtensor_of_adv_subtensor(fgraph, node): |
| 822 | + """Lift a simple Subtensor through an AdvancedSubtensor, when basic index dimensions are to the left of any advanced ones. |
| 823 | +
|
| 824 | + x[:, :, vec_idx][i, j] -> x[i, j][vec_idx] |
| 825 | + x[:, vec_idx][i, j, k] -> x[i][vec_idx][j, k] |
| 826 | +
|
| 827 | + Restricted to a single advanced indexing dimension. |
| 828 | +
|
| 829 | + An alternative approach could have fused the basic and advanced indices, |
| 830 | + so it is not clear this rewrite should be canonical or a specialization. |
| 831 | + Users must include it manually if it fits their use case. |
| 832 | + """ |
| 833 | + adv_subtensor, *idxs = node.inputs |
| 834 | + |
| 835 | + if not ( |
| 836 | + adv_subtensor.owner and isinstance(adv_subtensor.owner.op, AdvancedSubtensor) |
| 837 | + ): |
| 838 | + return None |
| 839 | + |
| 840 | + if len(fgraph.clients[adv_subtensor]) > 1: |
| 841 | + # AdvancedSubtensor involves a full_copy, so we don't want to do it twice |
| 842 | + return None |
| 843 | + |
| 844 | + x, *adv_idxs = adv_subtensor.owner.inputs |
| 845 | + |
| 846 | + # Advanced indexing is a minefield, avoid all cases except for consecutive integer indices |
| 847 | + if any( |
| 848 | + ( |
| 849 | + isinstance(adv_idx.type, NoneTypeT) |
| 850 | + or (isinstance(adv_idx.type, TensorType) and adv_idx.type.dtype == "bool") |
| 851 | + or (isinstance(adv_idx.type, SliceType) and not is_full_slice(adv_idx)) |
| 852 | + ) |
| 853 | + for adv_idx in adv_idxs |
| 854 | + ) or _non_contiguous_adv_indexing(adv_idxs): |
| 855 | + return None |
| 856 | + |
| 857 | + for first_adv_idx_dim, adv_idx in enumerate(adv_idxs): |
| 858 | + # We already made sure there were only None slices besides integer indexes |
| 859 | + if isinstance(adv_idx.type, TensorType): |
| 860 | + break |
| 861 | + else: # no-break |
| 862 | + # Not sure if this should ever happen, but better safe than sorry |
| 863 | + return None |
| 864 | + |
| 865 | + basic_idxs = indices_from_subtensor(idxs, node.op.idx_list) |
| 866 | + basic_idxs_lifted = basic_idxs[:first_adv_idx_dim] |
| 867 | + basic_idxs_kept = ((slice(None),) * len(basic_idxs_lifted)) + basic_idxs[ |
| 868 | + first_adv_idx_dim: |
| 869 | + ] |
| 870 | + |
| 871 | + if all(basic_idx == slice(None) for basic_idx in basic_idxs_lifted): |
| 872 | + # All basic indices happen to the right of the advanced indices |
| 873 | + return None |
| 874 | + |
| 875 | + [basic_subtensor] = node.outputs |
| 876 | + dropped_dims = _dims_dropped_by_basic_index(basic_idxs_lifted) |
| 877 | + |
| 878 | + x_indexed = x[basic_idxs_lifted] |
| 879 | + copy_stack_trace([basic_subtensor, adv_subtensor], x_indexed) |
| 880 | + |
| 881 | + x_after_index_lift = expand_dims(x_indexed, dropped_dims) |
| 882 | + x_after_adv_idx = adv_subtensor.owner.op(x_after_index_lift, *adv_idxs) |
| 883 | + copy_stack_trace([basic_subtensor, adv_subtensor], x_after_adv_idx) |
| 884 | + |
| 885 | + new_out = squeeze(x_after_adv_idx[basic_idxs_kept], dropped_dims) |
| 886 | + return [new_out] |
| 887 | + |
| 888 | + |
| 889 | +# Rewrite will only be included if tagged by name |
| 890 | +r = local_subtensor_of_adv_subtensor |
| 891 | +optdb["canonicalize"].register(r.__name__, r, use_db_name_as_tag=False) |
| 892 | +optdb["specialize"].register(r.__name__, r, use_db_name_as_tag=False) |
| 893 | +del r |
0 commit comments