5
5
import torch .nn .functional as F
6
6
from einops import rearrange , repeat
7
7
8
+ from vllm .model_executor .layers .mamba .mamba2_metadata import (
9
+ _seq_idx_to_chunk_indices_offsets )
8
10
from vllm .model_executor .layers .mamba .ops .ssd_combined import (
9
11
mamba_chunk_scan_combined )
10
12
from vllm .platforms import current_platform
@@ -160,14 +162,14 @@ def end_boundary(n: int):
160
162
161
163
# get the metadata
162
164
cu_seqlens = torch .tensor ((0 , ) + spec , device = device ).cumsum (dim = 0 )
163
- sed_idx = torch .zeros (cu_seqlens [- 1 ],
165
+ seq_idx = torch .zeros (cu_seqlens [- 1 ],
164
166
dtype = torch .int32 ,
165
167
device = cu_seqlens .device )
166
168
for i , (srt , end ) in enumerate (zip (
167
169
cu_seqlens ,
168
170
cu_seqlens [1 :],
169
171
)):
170
- sed_idx [srt :end ] = i
172
+ seq_idx [srt :end ] = i
171
173
172
174
# for cont batch
173
175
if IND_E is None :
@@ -177,7 +179,7 @@ def end_boundary(n: int):
177
179
IND_E = [end_boundary (x + y ) for x , y in zip (IND_S , spec )]
178
180
179
181
yield ([Y_min [s , IND_S [s ]:IND_E [s ]] for s in range (num_examples )],
180
- cu_seqlens , sed_idx .unsqueeze (0 ), (A , dt2 , X2 , B2 , C2 ))
182
+ cu_seqlens , seq_idx .unsqueeze (0 ), (A , dt2 , X2 , B2 , C2 ))
181
183
182
184
183
185
@pytest .mark .parametrize ("itype" ,
@@ -266,12 +268,15 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
266
268
exhausted : dict = {} # map: eg -> boolean indicating example is exhausted
267
269
268
270
states = None
269
- for Y_min , cu_seqlens , sed_idx , (A , dt , X , B ,
271
+ for Y_min , cu_seqlens , seq_idx , (A , dt , X , B ,
270
272
C ) in generate_continous_batched_examples (
271
273
cases , num_examples , seqlen ,
272
274
last_taken , exhausted , n_heads ,
273
275
d_head , itype ):
274
276
277
+ chunk_indices , chunk_offsets = _seq_idx_to_chunk_indices_offsets (
278
+ seq_idx , chunk_size )
279
+
275
280
Y , new_states = mamba_chunk_scan_combined (
276
281
X ,
277
282
dt ,
@@ -281,7 +286,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
281
286
chunk_size ,
282
287
D = None ,
283
288
cu_seqlens = cu_seqlens ,
284
- seq_idx = sed_idx ,
289
+ seq_idx = seq_idx ,
290
+ chunk_indices = chunk_indices ,
291
+ chunk_offsets = chunk_offsets ,
285
292
return_varlen_states = True ,
286
293
initial_states = states ,
287
294
)
0 commit comments