14
14
register_stabilize ,
15
15
)
16
16
from pytensor .tensor .shape import Reshape
17
- from pytensor .tensor .subtensor import AdvancedIncSubtensor , AdvancedSubtensor , Subtensor
17
+ from pytensor .tensor .subtensor import (
18
+ AdvancedIncSubtensor ,
19
+ AdvancedSubtensor ,
20
+ Subtensor ,
21
+ indices_from_subtensor ,
22
+ )
18
23
19
24
20
25
@node_rewriter ([Blockwise ])
@@ -216,9 +221,9 @@ def local_blockwise_reshape(fgraph, node):
216
221
217
222
Reshape is tricky to vectorize eagerly, because a graph like
218
223
`x.reshape([x.shape[0] * x.shape[1], -1])` has many operations
219
- that must be vectorized before we arrize at the reshape operation.
224
+ that must be vectorized before we arrive at the reshape operation.
220
225
221
- For the square Reshape case, we must wait for all the intemediate
226
+ For the square Reshape case, we must wait for all the intermediate
222
227
operations to be lifted as Allocs
223
228
"""
224
229
if not isinstance (node .op .core_op , Reshape ):
@@ -234,6 +239,29 @@ def local_blockwise_reshape(fgraph, node):
234
239
return [new_out ]
235
240
236
241
242
+ @register_stabilize
243
+ @register_specialize
244
+ @node_rewriter ([Blockwise ])
245
+ def local_blockwise_of_subtensor (fgraph , node ):
246
+ """Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
247
+
248
+ Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
249
+ """
250
+ if not isinstance (node .op .core_op , Subtensor ):
251
+ return
252
+
253
+ x , * idxs = node .inputs
254
+ if not all (all (idx .type .broadcastable ) for idx in idxs ):
255
+ return
256
+
257
+ core_idxs = indices_from_subtensor (
258
+ [idx .squeeze () for idx in idxs ], node .op .core_op .idx_list
259
+ )
260
+ # Add empty slices for the batch dims
261
+ none_slices = (slice (None ),) * node .op .batch_ndim (node )
262
+ return [x [(* none_slices , * core_idxs )]]
263
+
264
+
237
265
@node_rewriter (tracks = [Blockwise ], inplace = True )
238
266
def blockwise_inplace (fgraph , node ):
239
267
blockwise_op = node .op
0 commit comments