@@ -2029,18 +2029,41 @@ def ravel_multidimensional_bool_idx(fgraph, node):
2029
2029
return [copy_stack_trace (node .outputs [0 ], new_out )]
2030
2030
2031
2031
2032
- @node_rewriter (tracks = [AdvancedSubtensor ])
2032
+ @node_rewriter (tracks = [AdvancedSubtensor , AdvancedIncSubtensor ])
2033
2033
def ravel_multidimensional_int_idx (fgraph , node ):
2034
- """Convert multidimensional integer indexing into equivalent vector integer index, supported by Numba
2035
-
2036
- x[eye(3, dtype=int)] -> x[eye(3).ravel()].reshape((3, 3))
2034
+ """Convert multidimensional integer indexing into equivalent consecutive vector integer index,
2035
+ supported by Numba or by our specialized dispatchers
2037
2036
2037
+ x[eye(3)] -> x[eye(3).ravel()].reshape((3, 3))
2038
2038
2039
2039
NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices
2040
2040
2041
- x[eye(3, dtype=int), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
2041
+ x[eye(3), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
2042
+
2043
+ It also handles multiple integer indices, but only if they don't broadcast
2044
+
2045
+ x[eye(3,), 2:, eye(3)] -> x[eye(3).ravel(), eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
2046
+
2047
+ Also handles AdvancedIncSubtensor, but only if the advanced indices are consecutive and neither indices nor y broadcast
2048
+
2049
+ x[eye(3), 2:].set(y) -> x[eye(3).ravel(), 2:].set(y.reshape(-1, y.shape[1:]))
2050
+
2042
2051
"""
2043
- x , * idxs = node .inputs
2052
+ op = node .op
2053
+ non_consecutive_adv_indexing = op .non_consecutive_adv_indexing (node )
2054
+ is_inc_subtensor = isinstance (op , AdvancedIncSubtensor )
2055
+
2056
+ if is_inc_subtensor :
2057
+ x , y , * idxs = node .inputs
2058
+ # Inc/SetSubtensor is harder to reason about due to y
2059
+ # We get out if it's broadcasting or if the advanced indices are non-consecutive
2060
+ if non_consecutive_adv_indexing or (
2061
+ y .type .broadcastable != x [tuple (idxs )].type .broadcastable
2062
+ ):
2063
+ return None
2064
+
2065
+ else :
2066
+ x , * idxs = node .inputs
2044
2067
2045
2068
if any (
2046
2069
(
@@ -2049,50 +2072,103 @@ def ravel_multidimensional_int_idx(fgraph, node):
2049
2072
)
2050
2073
for idx in idxs
2051
2074
):
2052
- # Get out if there are any other advanced indexes or np.newaxis
2075
+ # Get out if there are any other advanced indices or np.newaxis
2053
2076
return None
2054
2077
2055
- int_idxs = [
2078
+ int_idxs_and_pos = [
2056
2079
(i , idx )
2057
2080
for i , idx in enumerate (idxs )
2058
2081
if (isinstance (idx .type , TensorType ) and idx .dtype in integer_dtypes )
2059
2082
]
2060
2083
2061
- if len (int_idxs ) != 1 :
2062
- # Get out if there are no or multiple integer idxs
2084
+ if not int_idxs_and_pos :
2063
2085
return None
2064
2086
2065
- [(int_idx_pos , int_idx )] = int_idxs
2066
- if int_idx .type .ndim < 2 :
2067
- # No need to do anything if it's a vector or scalar, as it's already supported by Numba
2087
+ int_idxs_pos , int_idxs = zip (
2088
+ * int_idxs_and_pos , strict = False
2089
+ ) # strict=False because by definition it's true
2090
+
2091
+ first_int_idx_pos = int_idxs_pos [0 ]
2092
+ first_int_idx = int_idxs [0 ]
2093
+ first_int_idx_bcast = first_int_idx .type .broadcastable
2094
+
2095
+ if any (int_idx .type .broadcastable != first_int_idx_bcast for int_idx in int_idxs ):
2096
+ # We don't have a view-only broadcasting operation
2097
+ # Explicitly broadcasting the indices can incur a memory / copy overhead
2068
2098
return None
2069
2099
2070
- raveled_int_idx = int_idx .ravel ()
2071
- new_idxs = list (idxs )
2072
- new_idxs [int_idx_pos ] = raveled_int_idx
2073
- raveled_subtensor = x [tuple (new_idxs )]
2074
-
2075
- # Reshape into correct shape
2076
- # Because we only allow one advanced indexing, the output dimension corresponding to the raveled integer indexing
2077
- # must match the input position. If there were multiple advanced indexes, this could have been forcefully moved to the front
2078
- raveled_shape = raveled_subtensor .shape
2079
- unraveled_shape = (
2080
- * raveled_shape [:int_idx_pos ],
2081
- * int_idx .shape ,
2082
- * raveled_shape [int_idx_pos + 1 :],
2083
- )
2084
- new_out = raveled_subtensor .reshape (unraveled_shape )
2100
+ int_idxs_ndim = len (first_int_idx_bcast )
2101
+ if (
2102
+ int_idxs_ndim == 0
2103
+ ): # This should be a basic indexing operation, rewrite elsewhere
2104
+ return None
2105
+
2106
+ int_idxs_need_raveling = int_idxs_ndim > 1
2107
+ if not (int_idxs_need_raveling or non_consecutive_adv_indexing ):
2108
+ # Numba or our dispatch natively supports consecutive vector indices, nothing needs to be done
2109
+ return None
2110
+
2111
+ # Reorder non-consecutive indices
2112
+ if non_consecutive_adv_indexing :
2113
+ assert not is_inc_subtensor # Sanity check that we got out if this was the case
2114
+ # This case works as if all the advanced indices were on the front
2115
+ transposition = list (int_idxs_pos ) + [
2116
+ i for i in range (len (idxs )) if i not in int_idxs_pos
2117
+ ]
2118
+ idxs = tuple (idxs [a ] for a in transposition )
2119
+ x = x .transpose (transposition )
2120
+ first_int_idx_pos = 0
2121
+ del int_idxs_pos # Make sure they are not wrongly used
2122
+
2123
+ # Ravel multidimensional indices
2124
+ if int_idxs_need_raveling :
2125
+ idxs = list (idxs )
2126
+ for idx_pos , int_idx in enumerate (int_idxs , start = first_int_idx_pos ):
2127
+ idxs [idx_pos ] = int_idx .ravel ()
2128
+
2129
+ # Index with reordered and/or raveled indices
2130
+ new_subtensor = x [tuple (idxs )]
2131
+
2132
+ if is_inc_subtensor :
2133
+ y_shape = tuple (y .shape )
2134
+ y_raveled_shape = (
2135
+ * y_shape [:first_int_idx_pos ],
2136
+ - 1 ,
2137
+ * y_shape [first_int_idx_pos + int_idxs_ndim :],
2138
+ )
2139
+ y_raveled = y .reshape (y_raveled_shape )
2140
+
2141
+ new_out = inc_subtensor (
2142
+ new_subtensor ,
2143
+ y_raveled ,
2144
+ set_instead_of_inc = op .set_instead_of_inc ,
2145
+ ignore_duplicates = op .ignore_duplicates ,
2146
+ inplace = op .inplace ,
2147
+ )
2148
+
2149
+ else :
2150
+ # Unravel advanced indexing dimensions
2151
+ raveled_shape = tuple (new_subtensor .shape )
2152
+ unraveled_shape = (
2153
+ * raveled_shape [:first_int_idx_pos ],
2154
+ * first_int_idx .shape ,
2155
+ * raveled_shape [first_int_idx_pos + 1 :],
2156
+ )
2157
+ new_out = new_subtensor .reshape (unraveled_shape )
2158
+
2085
2159
return [copy_stack_trace (node .outputs [0 ], new_out )]
2086
2160
2087
2161
2088
2162
optdb ["specialize" ].register (
2089
2163
ravel_multidimensional_bool_idx .__name__ ,
2090
2164
ravel_multidimensional_bool_idx ,
2091
2165
"numba" ,
2166
+ use_db_name_as_tag = False , # Not included if only "specialize" is requested
2092
2167
)
2093
2168
2094
2169
optdb ["specialize" ].register (
2095
2170
ravel_multidimensional_int_idx .__name__ ,
2096
2171
ravel_multidimensional_int_idx ,
2097
2172
"numba" ,
2173
+ use_db_name_as_tag = False , # Not included if only "specialize" is requested
2098
2174
)
0 commit comments