Skip to content

Commit feaa06b

Browse files
tlrmchlsmthMu Huai
authored and
Mu Huai
committed
[Bugfix] Fix tests/kernels/test_mamba_ssm_ssd.py (vllm-project#16623)
Signed-off-by: Tyler Michael Smith <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent f263f23 commit feaa06b

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

tests/kernels/test_mamba_ssm_ssd.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import torch.nn.functional as F
66
from einops import rearrange, repeat
77

8+
from vllm.model_executor.layers.mamba.mamba2_metadata import (
9+
_seq_idx_to_chunk_indices_offsets)
810
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
911
mamba_chunk_scan_combined)
1012
from vllm.platforms import current_platform
@@ -160,14 +162,14 @@ def end_boundary(n: int):
160162

161163
# get the metadata
162164
cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0)
163-
sed_idx = torch.zeros(cu_seqlens[-1],
165+
seq_idx = torch.zeros(cu_seqlens[-1],
164166
dtype=torch.int32,
165167
device=cu_seqlens.device)
166168
for i, (srt, end) in enumerate(zip(
167169
cu_seqlens,
168170
cu_seqlens[1:],
169171
)):
170-
sed_idx[srt:end] = i
172+
seq_idx[srt:end] = i
171173

172174
# for cont batch
173175
if IND_E is None:
@@ -177,7 +179,7 @@ def end_boundary(n: int):
177179
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
178180

179181
yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)],
180-
cu_seqlens, sed_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
182+
cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
181183

182184

183185
@pytest.mark.parametrize("itype",
@@ -266,12 +268,15 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
266268
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
267269

268270
states = None
269-
for Y_min, cu_seqlens, sed_idx, (A, dt, X, B,
271+
for Y_min, cu_seqlens, seq_idx, (A, dt, X, B,
270272
C) in generate_continous_batched_examples(
271273
cases, num_examples, seqlen,
272274
last_taken, exhausted, n_heads,
273275
d_head, itype):
274276

277+
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
278+
seq_idx, chunk_size)
279+
275280
Y, new_states = mamba_chunk_scan_combined(
276281
X,
277282
dt,
@@ -281,7 +286,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
281286
chunk_size,
282287
D=None,
283288
cu_seqlens=cu_seqlens,
284-
seq_idx=sed_idx,
289+
seq_idx=seq_idx,
290+
chunk_indices=chunk_indices,
291+
chunk_offsets=chunk_offsets,
285292
return_varlen_states=True,
286293
initial_states=states,
287294
)

0 commit comments

Comments
 (0)