Skip to content

Commit 4c672ff

Browse files
fabianlimcyang49
authored andcommitted
refactored mixer and bamba
Signed-off-by: Yu Chin Fabian Lim <[email protected]> Signed-off-by: Chih-Chieh-Yang <[email protected]>
1 parent 28c13c8 commit 4c672ff

File tree

4 files changed

+73
-48
lines changed

4 files changed

+73
-48
lines changed

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,13 @@ def forward_cuda(
392392
chunk_indices: Optional[torch.Tensor] = None,
393393
chunk_offsets: Optional[torch.Tensor] = None,
394394
):
395+
# For the mamba2 triton kernels to operate in continuous batching,
396+
# the sequence_idx is needed to be passed in. Also, for the kernels
397+
# to operate in chunked prefill, the chunk_indices and chunk_offsets
398+
# can be optionally passed in; it is more efficient to pre-compute
399+
# once since they are common to all layers. If they are not provided
400+
# then they will be derived from sequence_idx inside the kernels
401+
395402
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
396403

397404
seq_len, _ = hidden_states.shape

vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
# ruff: noqa: E501,SIM102
77

8+
import math
9+
810
import torch
911
import triton
1012
import triton.language as tl
@@ -440,6 +442,40 @@ def _chunk_scan_fwd_kernel(
440442
(offs_out_n[None, :] < hdim))
441443

442444

445+
def seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
446+
447+
# convert seq_idx to chunk indices and offsets
448+
# - derive the cu_seqlens
449+
_, cu_seqlens = torch.where(seq_idx.diff())
450+
cu_seqlens += 1
451+
452+
# outputs will have length expansion of chunks that do not divide
453+
# chunk_size
454+
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
455+
> 0).sum()
456+
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
457+
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
458+
459+
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
460+
p = 0 # num of insertions
461+
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
462+
463+
# if does not divide chunk_size, then there is one chunk insertion
464+
p += (s % chunk_size > 0)
465+
466+
# get the dimensions
467+
# - the + 1 for _e is to shift the boundary by one chunk
468+
# - this shifting is not needed if chunk_size divides e
469+
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
470+
> 0)
471+
472+
# adjust inidces and offsets
473+
chunk_indices[_s:_e] -= p
474+
chunk_offsets[_s] = s % chunk_size
475+
476+
return chunk_indices, chunk_offsets
477+
478+
443479
def _chunk_scan_fwd(
444480
cb,
445481
x,
@@ -481,8 +517,20 @@ def _chunk_scan_fwd(
481517
if initial_states.shape[0] == 1:
482518
# no in this case no point to use initial states
483519
initial_states = None
484-
485-
if initial_states is None:
520+
elif chunk_indices is None and chunk_offsets is None:
521+
# if chunk_indices and chunk_offsets both unset, then derive
522+
# from seq_idx
523+
chunk_indices, chunk_offsets = seq_idx_to_chunk_indices_offsets(
524+
seq_idx, chunk_size)
525+
else:
526+
assert chunk_indices is not None and chunk_offsets is not None, \
527+
(
528+
"chunk_indices and chunk_offsets should either "
529+
"be left unset, or else both should be set."
530+
)
531+
else:
532+
chunk_indices, chunk_offsets = None, None
533+
else:
486534
chunk_indices, chunk_offsets = None, None
487535

488536
# Allocates output.
@@ -509,7 +557,6 @@ def _chunk_scan_fwd(
509557
if chunk_offsets is None else len(chunk_offsets), nheads)
510558
z_strides = ((z.stride(0), z.stride(1), z.stride(2),
511559
z.stride(3)) if z is not None else (0, 0, 0, 0))
512-
513560
_chunk_scan_fwd_kernel[grid](
514561
cb,
515562
x,

vllm/model_executor/layers/mamba/ops/ssd_combined.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _mamba_chunk_scan_combined_fwd(x,
9898
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
9999
# (middle term of factorization of off-diag blocks; A terms)
100100
# - for handling chunked prefill, this requires i) initial_states
101-
# ii) seq_idx and iii) has_cu_seqlens to be all specified.
101+
# ii) seq_idx and iii) is_cont_batched to be all specified.
102102
# - When a new seq_idx is detected, we will stop passing the prev_state
103103
# and switch accordingly to the init_state corresponding to the new seq_idx.
104104
# - this will ensure that states will be updated with the rightmost flushed seq_idx

vllm/model_executor/models/bamba.py

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
2323
MambaMixer2, extra_groups_for_head_shards)
2424
from vllm.model_executor.layers.quantization import QuantizationConfig
25+
from vllm.model_executor.layers.mamba.ops.ssd_chunk_scan import (
26+
seq_idx_to_chunk_indices_offsets)
2527
from vllm.model_executor.layers.rotary_embedding import get_rope
2628
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
2729
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -256,41 +258,6 @@ def forward(
256258
"mamba": BambaMixerDecoderLayer
257259
}
258260

259-
260-
def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
261-
262-
# convert seq_idx to chunk indices and offsets
263-
# - derive the cu_seqlens
264-
_, cu_seqlens = torch.where(seq_idx.diff())
265-
cu_seqlens += 1
266-
267-
# outputs will have length expansion of chunks that do not divide
268-
# chunk_size
269-
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
270-
> 0).sum()
271-
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
272-
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
273-
274-
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
275-
p = 0 # num of insertions
276-
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
277-
278-
# if does not divide chunk_size, then there is one chunk insertion
279-
p += (s % chunk_size > 0)
280-
281-
# get the dimensions
282-
# - the + 1 for _e is to shift the boundary by one chunk
283-
# - this shifting is not needed if chunk_size divides e
284-
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
285-
> 0)
286-
287-
# adjust inidces and offsets
288-
chunk_indices[_s:_e] -= p
289-
chunk_offsets[_s] = s % chunk_size
290-
291-
return chunk_indices, chunk_offsets
292-
293-
294261
class BambaModel(nn.Module):
295262

296263
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -361,8 +328,17 @@ def forward(
361328
)):
362329
seq_idx[srt:end] = i
363330
seq_idx.unsqueeze_(0)
364-
# Compute mamba2 metadata tensors that are reused across layers
365-
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
331+
332+
# compute metadata for chunked prefill.
333+
# actually this is only needed if there are
334+
# initial states, but this is determinable
335+
# only from attention metadata yet
336+
# unavailable from the current top-level forward.
337+
# Rather than complicating things to extract said
338+
# metadata, we simply just compute redundently and
339+
# will be silently ignored inside the mamba kernels.
340+
# if not needed.
341+
chunk_indices, chunk_offsets = seq_idx_to_chunk_indices_offsets(
366342
seq_idx, self.config.mamba_chunk_size)
367343

368344
if get_pp_group().is_first_rank:
@@ -378,7 +354,6 @@ def forward(
378354

379355
residual = None
380356
num_attn = 0
381-
extra_args = {}
382357
for i in range(len(self.layers)):
383358
layer = self.layers[i]
384359
if isinstance(layer, BambaAttentionDecoderLayer):
@@ -388,19 +363,15 @@ def forward(
388363
if isinstance(layer, BambaMixerDecoderLayer):
389364
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
390365
i - num_attn)
391-
extra_args = {
392-
'chunk_indices': chunk_indices,
393-
'chunk_offsets': chunk_offsets,
394-
}
395366

396-
# print(f"{len(extra_args)=}")
397367
hidden_states, residual = layer(
398368
positions=positions,
399369
hidden_states=hidden_states,
400370
residual=residual,
401371
mamba_cache_params=layer_mamba_cache_params,
402372
sequence_idx=seq_idx,
403-
**extra_args,
373+
chunk_indices=chunk_indices,
374+
chunk_offsets=chunk_offsets,
404375
)
405376

406377
if not get_pp_group().is_last_rank:

0 commit comments

Comments
 (0)