|
8 | 8 | )
|
9 | 9 |
|
10 | 10 | from pytensor import Variable
|
| 11 | +from pytensor.compile import optdb |
11 | 12 | from pytensor.graph import Constant, FunctionGraph, node_rewriter
|
12 | 13 | from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
|
13 | 14 | from pytensor.scalar import basic as ps
|
|
42 | 43 | )
|
43 | 44 | from pytensor.tensor.special import Softmax, softmax
|
44 | 45 | from pytensor.tensor.subtensor import (
|
| 46 | + AdvancedSubtensor, |
45 | 47 | AdvancedSubtensor1,
|
46 | 48 | Subtensor,
|
| 49 | + _non_contiguous_adv_indexing, |
47 | 50 | as_index_literal,
|
48 | 51 | get_canonical_form_slice,
|
49 | 52 | get_constant_idx,
|
50 | 53 | get_idx_list,
|
51 | 54 | indices_from_subtensor,
|
52 | 55 | )
|
53 | 56 | from pytensor.tensor.type import TensorType
|
54 |
| -from pytensor.tensor.type_other import SliceType |
| 57 | +from pytensor.tensor.type_other import NoneTypeT, SliceType |
55 | 58 | from pytensor.tensor.variable import TensorVariable
|
56 | 59 |
|
57 | 60 |
|
@@ -818,3 +821,79 @@ def local_subtensor_shape_constant(fgraph, node):
|
818 | 821 | return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)]
|
819 | 822 | elif shape_parts:
|
820 | 823 | return [as_tensor(1, dtype=np.int64)]
|
| 824 | + |
| 825 | + |
| 826 | +@node_rewriter([Subtensor]) |
| 827 | +def local_subtensor_of_adv_subtensor(fgraph, node): |
| 828 | + """Lift a simple Subtensor through an AdvancedSubtensor, when basic index dimensions are to the left of any advanced ones. |
| 829 | +
|
| 830 | + x[:, :, vec_idx][i, j] -> x[i, j][vec_idx] |
| 831 | + x[:, vec_idx][i, j, k] -> x[i][vec_idx][j, k] |
| 832 | +
|
| 833 | + Restricted to a single advanced indexing dimension. |
| 834 | +
|
| 835 | + An alternative approach could have fused the basic and advanced indices, |
| 836 | + so it is not clear this rewrite should be canonical or a specialization. |
| 837 | + Users must include it manually if it fits their use case. |
| 838 | + """ |
| 839 | + adv_subtensor, *idxs = node.inputs |
| 840 | + |
| 841 | + if not ( |
| 842 | + adv_subtensor.owner and isinstance(adv_subtensor.owner.op, AdvancedSubtensor) |
| 843 | + ): |
| 844 | + return None |
| 845 | + |
| 846 | + if len(fgraph.clients[adv_subtensor]) > 1: |
| 847 | + # AdvancedSubtensor involves a full_copy, so we don't want to do it twice |
| 848 | + return None |
| 849 | + |
| 850 | + x, *adv_idxs = adv_subtensor.owner.inputs |
| 851 | + |
| 852 | + # Advanced indexing is a minefield, avoid all cases except for consecutive integer indices |
| 853 | + if any( |
| 854 | + ( |
| 855 | + isinstance(adv_idx.type, NoneTypeT) |
| 856 | + or (isinstance(adv_idx.type, TensorType) and adv_idx.type.dtype == "bool") |
| 857 | + or (isinstance(adv_idx.type, SliceType) and not is_full_slice(adv_idx)) |
| 858 | + ) |
| 859 | + for adv_idx in adv_idxs |
| 860 | + ) or _non_contiguous_adv_indexing(adv_idxs): |
| 861 | + return None |
| 862 | + |
| 863 | + for first_adv_idx_dim, adv_idx in enumerate(adv_idxs): |
| 864 | + # We already made sure there were only None slices besides integer indexes |
| 865 | + if isinstance(adv_idx.type, TensorType): |
| 866 | + break |
| 867 | + else: # no-break |
| 868 | + # Not sure if this should ever happen, but better safe than sorry |
| 869 | + return None |
| 870 | + |
| 871 | + basic_idxs = indices_from_subtensor(idxs, node.op.idx_list) |
| 872 | + basic_idxs_lifted = basic_idxs[:first_adv_idx_dim] |
| 873 | + basic_idxs_kept = ((slice(None),) * len(basic_idxs_lifted)) + basic_idxs[ |
| 874 | + first_adv_idx_dim: |
| 875 | + ] |
| 876 | + |
| 877 | + if all(basic_idx == slice(None) for basic_idx in basic_idxs_lifted): |
| 878 | + # All basic indices happen to the right of the advanced indices |
| 879 | + return None |
| 880 | + |
| 881 | + [basic_subtensor] = node.outputs |
| 882 | + dropped_dims = _dims_dropped_by_basic_index(basic_idxs_lifted) |
| 883 | + |
| 884 | + x_indexed = x[basic_idxs_lifted] |
| 885 | + copy_stack_trace([basic_subtensor, adv_subtensor], x_indexed) |
| 886 | + |
| 887 | + x_after_index_lift = expand_dims(x_indexed, dropped_dims) |
| 888 | + x_after_adv_idx = adv_subtensor.owner.op(x_after_index_lift, *adv_idxs) |
| 889 | + copy_stack_trace([basic_subtensor, adv_subtensor], x_after_adv_idx) |
| 890 | + |
| 891 | + new_out = squeeze(x_after_adv_idx[basic_idxs_kept], dropped_dims) |
| 892 | + return [new_out] |
| 893 | + |
| 894 | + |
| 895 | +# Rewrite will only be included if tagged by name |
| 896 | +r = local_subtensor_of_adv_subtensor |
| 897 | +optdb["canonicalize"].register(r.__name__, r, use_db_name_as_tag=False) |
| 898 | +optdb["specialize"].register(r.__name__, r, use_db_name_as_tag=False) |
| 899 | +del r |
0 commit comments