Skip to content

Commit 53c69a2

Browse files
cyang49fabianlim
authored andcommitted
[Model] Reduce redundant computations in mamba2 blocks for Bamba-9B (vllm-project#15423)
Signed-off-by: Chih-Chieh-Yang <[email protected]> Co-authored-by: Yu Chin Fabian Lim <[email protected]>
1 parent 70c6016 commit 53c69a2

File tree

8 files changed

+186
-132
lines changed

8 files changed

+186
-132
lines changed
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import math
3+
from dataclasses import dataclass
4+
5+
import torch
6+
7+
from vllm.attention.backends.abstract import AttentionMetadata
8+
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
9+
from vllm.attention.backends.placeholder_attn import (
10+
PlaceholderAttentionMetadata)
11+
from vllm.attention.backends.xformers import XFormersMetadata
12+
13+
14+
@dataclass
15+
class Mamba2Metadata:
16+
has_prefill: bool
17+
18+
has_initial_states: torch.Tensor
19+
prep_initial_states: bool
20+
21+
chunk_size: int
22+
seq_idx: torch.Tensor
23+
chunk_indices: torch.Tensor
24+
chunk_offsets: torch.Tensor
25+
26+
27+
def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
28+
29+
# convert seq_idx to chunk indices and offsets
30+
# - derive the cu_seqlens
31+
_, cu_seqlens = torch.where(seq_idx.diff())
32+
cu_seqlens += 1
33+
34+
# outputs will have length expansion of chunks that do not divide
35+
# chunk_size
36+
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
37+
> 0).sum()
38+
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
39+
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
40+
41+
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
42+
p = 0 # num of insertions
43+
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
44+
45+
# if does not divide chunk_size, then there is one chunk insertion
46+
p += (s % chunk_size > 0)
47+
48+
# get the dimensions
49+
# - the + 1 for _e is to shift the boundary by one chunk
50+
# - this shifting is not needed if chunk_size divides e
51+
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
52+
> 0)
53+
54+
# adjust inidces and offsets
55+
chunk_indices[_s:_e] -= p
56+
chunk_offsets[_s] = s % chunk_size
57+
58+
return chunk_indices, chunk_offsets
59+
60+
61+
def prepare_mamba2_metadata(
62+
chunk_size: int,
63+
input_ids: torch.Tensor,
64+
attn_metadata: AttentionMetadata,
65+
) -> Mamba2Metadata:
66+
67+
# Need flags to indicate if there are initial states
68+
# currently we really only support the FlashAttention backend
69+
has_initial_states = None
70+
prep_initial_states = False
71+
if (isinstance(attn_metadata, (FlashAttentionMetadata, XFormersMetadata,
72+
PlaceholderAttentionMetadata))
73+
and attn_metadata.context_lens_tensor is not None):
74+
has_initial_states = attn_metadata.context_lens_tensor > 0
75+
# precompute flag to avoid device syncs later in mamba2 forwards
76+
prep_initial_states = torch.any(has_initial_states).item()
77+
78+
has_prefill = attn_metadata.num_prefills > 0
79+
80+
seq_idx = None
81+
chunk_indices, chunk_offsets = None, None
82+
if has_prefill:
83+
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
84+
for i, (srt, end) in enumerate(
85+
zip(
86+
attn_metadata.query_start_loc,
87+
attn_metadata.query_start_loc[1:],
88+
)):
89+
seq_idx[srt:end] = i
90+
seq_idx.unsqueeze_(0)
91+
92+
# compute metadata for chunked prefill.
93+
# actually this is only needed if there are initial states,
94+
# but this is determinable only from attention metadata yet
95+
# unavailable from the top-level model forward. Rather than
96+
# complicating things to extract said metadata, we simply just
97+
# compute them once at the top level model forward and reuse
98+
# them in mamba layers. If not needed, they will be ignored
99+
# inside mamba kernels.
100+
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
101+
seq_idx, chunk_size)
102+
103+
return Mamba2Metadata(has_prefill=has_prefill,
104+
has_initial_states=has_initial_states,
105+
prep_initial_states=prep_initial_states,
106+
chunk_size=chunk_size,
107+
seq_idx=seq_idx,
108+
chunk_indices=chunk_indices,
109+
chunk_offsets=chunk_offsets)

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,6 @@
66
from torch import nn
77

88
from vllm.attention.backends.abstract import AttentionMetadata
9-
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
10-
from vllm.attention.backends.placeholder_attn import (
11-
PlaceholderAttentionMetadata)
12-
from vllm.attention.backends.xformers import XFormersMetadata
139
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
1410
get_tensor_model_parallel_world_size,
1511
tensor_model_parallel_all_gather,
@@ -18,6 +14,7 @@
1814
from vllm.model_executor.custom_op import CustomOp
1915
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
2016
RowParallelLinear)
17+
from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata
2118
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
2219
causal_conv1d_fn, causal_conv1d_update)
2320
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
@@ -221,7 +218,6 @@ def __init__(self,
221218
head_dim: int = 64,
222219
rms_norm_eps: float = 1e-5,
223220
activation="silu",
224-
chunk_size: int = 256,
225221
quant_config: Optional[QuantizationConfig] = None):
226222
super().__init__()
227223

@@ -257,7 +253,6 @@ def __init__(self,
257253
self.ssm_state_size = ssm_state_size
258254
self.activation = activation
259255

260-
self.chunk_size = chunk_size
261256
self.intermediate_size = intermediate_size
262257
self.head_dim = head_dim
263258
self.num_heads = num_heads
@@ -388,25 +383,17 @@ def forward_cuda(
388383
self,
389384
hidden_states: torch.Tensor,
390385
mamba_cache_params: MambaCacheParams,
391-
sequence_idx: Optional[torch.Tensor] = None,
386+
mamba2_metadata: Mamba2Metadata,
392387
):
388+
# mamba2_metadata contains metadata necessary for the mamba2 triton
389+
# kernels to operate in continuous batching and in chunked prefill
390+
# modes; they are computed at top-level model forward since they
391+
# are the same and reused for all mamba layers in the same iteration
393392
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
394393

395394
seq_len, _ = hidden_states.shape
396395
groups_time_state_size = self.n_groups * self.ssm_state_size
397396

398-
# detect if there are prefills
399-
has_prefill = attn_metadata.num_prefills > 0
400-
401-
# - also need flags to indicate if there are initial states
402-
# - currently we really only support the FlashAttention backend
403-
has_initial_states = None
404-
if (isinstance(attn_metadata,
405-
(FlashAttentionMetadata, XFormersMetadata,
406-
PlaceholderAttentionMetadata))
407-
and attn_metadata.context_lens_tensor is not None):
408-
has_initial_states = attn_metadata.context_lens_tensor > 0
409-
410397
# 1. Gated MLP's linear projection
411398
projected_states, _ = self.in_proj(hidden_states)
412399
gate, hidden_states_B_C, dt = torch.split(
@@ -423,7 +410,7 @@ def forward_cuda(
423410
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
424411
self.conv1d.weight.size(2))
425412

426-
if has_prefill:
413+
if mamba2_metadata.has_prefill:
427414
# |---------- N-1 iteration --------|
428415
# |---------------- N iteration ---------------------|
429416
# |- tokenA -|......................|-- newTokens ---|
@@ -439,7 +426,7 @@ def forward_cuda(
439426
self.conv1d.bias,
440427
activation=self.activation,
441428
conv_states=mamba_cache_params.conv_state,
442-
has_initial_state=has_initial_states,
429+
has_initial_state=mamba2_metadata.has_initial_states,
443430
cache_indices=mamba_cache_params.state_indices_tensor,
444431
query_start_loc=attn_metadata.query_start_loc).transpose(
445432
0, 1)[:seq_len]
@@ -467,16 +454,15 @@ def forward_cuda(
467454
)
468455

469456
# 3. State Space Model sequence transformation
470-
if has_prefill:
471-
457+
if mamba2_metadata.has_prefill:
472458
initial_states = None
473-
if has_initial_states is not None and torch.any(
474-
has_initial_states):
475-
zero_init_indices = mamba_cache_params.state_indices_tensor[
476-
~has_initial_states]
477-
mamba_cache_params.ssm_state[zero_init_indices] = 0
478-
initial_states = mamba_cache_params.ssm_state[
479-
mamba_cache_params.state_indices_tensor]
459+
if (mamba2_metadata.has_initial_states is not None
460+
and mamba2_metadata.prep_initial_states):
461+
# making a copy of the states
462+
initial_states = torch.where(
463+
mamba2_metadata.has_initial_states[:, None, None, None],
464+
mamba_cache_params.ssm_state[
465+
mamba_cache_params.state_indices_tensor], 0)
480466

481467
scan_output, varlen_state = mamba_chunk_scan_combined(
482468
hidden_states.view(1, seq_len, self.num_heads // self.tp_size,
@@ -485,11 +471,13 @@ def forward_cuda(
485471
self.A,
486472
B.view(1, seq_len, self.n_groups // self.tp_size, -1),
487473
C.view(1, seq_len, self.n_groups // self.tp_size, -1),
488-
chunk_size=self.chunk_size,
474+
chunk_size=mamba2_metadata.chunk_size,
489475
D=self.D,
490476
z=None,
491477
dt_bias=self.dt_bias,
492-
seq_idx=sequence_idx,
478+
seq_idx=mamba2_metadata.seq_idx,
479+
chunk_indices=mamba2_metadata.chunk_indices,
480+
chunk_offsets=mamba2_metadata.chunk_offsets,
493481
cu_seqlens=attn_metadata.query_start_loc,
494482
initial_states=initial_states,
495483
return_varlen_states=True,

vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py

Lines changed: 10 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
# ruff: noqa: E501,SIM102
77

8-
import math
9-
108
import torch
119
import triton
1210
import triton.language as tl
@@ -442,40 +440,6 @@ def _chunk_scan_fwd_kernel(
442440
(offs_out_n[None, :] < hdim))
443441

444442

445-
def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
446-
447-
# convert seq_idx to chunk indices and offsets
448-
# - derive the cu_seqlens
449-
_, cu_seqlens = torch.where(seq_idx.diff())
450-
cu_seqlens += 1
451-
452-
# outputs will have length expansion of chunks that do not divide
453-
# chunk_size
454-
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
455-
> 0).sum()
456-
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
457-
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
458-
459-
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
460-
p = 0 # num of insertions
461-
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
462-
463-
# if does not divide chunk_size, then there is one chunk insertion
464-
p += (s % chunk_size > 0)
465-
466-
# get the dimensions
467-
# - the + 1 for _e is to shift the boundary by one chunk
468-
# - this shifting is not needed if chunk_size divides e
469-
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
470-
> 0)
471-
472-
# adjust inidces and offsets
473-
chunk_indices[_s:_e] -= p
474-
chunk_offsets[_s] = s % chunk_size
475-
476-
return chunk_indices, chunk_offsets
477-
478-
479443
def _chunk_scan_fwd(
480444
cb,
481445
x,
@@ -486,6 +450,8 @@ def _chunk_scan_fwd(
486450
D=None,
487451
z=None,
488452
seq_idx=None,
453+
chunk_indices=None,
454+
chunk_offsets=None,
489455
initial_states=None,
490456
):
491457
batch, seqlen, nheads, headdim = x.shape
@@ -502,23 +468,26 @@ def _chunk_scan_fwd(
502468
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
503469
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
504470

505-
chunk_indices, chunk_offsets = None, None
506471
if seq_idx is not None:
507472
assert seq_idx.shape == (batch, seqlen)
508473

509474
if initial_states is not None:
510475
# with initial states, we need to take care of how
511476
# seq_idx crosses the boundaries
512477
assert batch == 1, "chunk scan only supports initial states with batch 1"
513-
assert initial_states.shape == (seq_idx[0].max() + 1, nheads,
514-
headdim, dstate)
515478

516479
if initial_states.shape[0] == 1:
517480
# no in this case no point to use initial states
518481
initial_states = None
519482
else:
520-
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
521-
seq_idx, chunk_size)
483+
assert chunk_indices is not None and chunk_offsets is not None, \
484+
(
485+
"chunk_indices and chunk_offsets should have been set"
486+
)
487+
else:
488+
chunk_indices, chunk_offsets = None, None
489+
else:
490+
chunk_indices, chunk_offsets = None, None
522491

523492
# Allocates output.
524493
out = torch.empty(batch,

vllm/model_executor/layers/mamba/ops/ssd_combined.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def _mamba_chunk_scan_combined_fwd(x,
3030
dt_bias=None,
3131
initial_states=None,
3232
seq_idx=None,
33+
chunk_indices=None,
34+
chunk_offsets=None,
3335
cu_seqlens=None,
3436
dt_softplus=False,
3537
dt_limit=(0.0, float("inf"))):
@@ -96,7 +98,7 @@ def _mamba_chunk_scan_combined_fwd(x,
9698
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
9799
# (middle term of factorization of off-diag blocks; A terms)
98100
# - for handling chunked prefill, this requires i) initial_states
99-
# ii) seq_idx and iii) has_cu_seqlens to be all specified.
101+
# ii) seq_idx and iii) is_cont_batched to be all specified.
100102
# - When a new seq_idx is detected, we will stop passing the prev_state
101103
# and switch accordingly to the init_state corresponding to the new seq_idx.
102104
# - this will ensure that states will be updated with the rightmost flushed seq_idx
@@ -141,6 +143,8 @@ def _mamba_chunk_scan_combined_fwd(x,
141143
D=D,
142144
z=z,
143145
seq_idx=seq_idx,
146+
chunk_indices=chunk_indices,
147+
chunk_offsets=chunk_offsets,
144148
initial_states=initial_states,
145149
)
146150
if cu_seqlens is None:
@@ -170,6 +174,8 @@ def mamba_chunk_scan_combined(x,
170174
dt_bias=None,
171175
initial_states=None,
172176
seq_idx=None,
177+
chunk_indices=None,
178+
chunk_offsets=None,
173179
cu_seqlens=None,
174180
dt_softplus=False,
175181
dt_limit=(0.0, float("inf")),
@@ -210,6 +216,8 @@ def mamba_chunk_scan_combined(x,
210216
dt_bias=dt_bias,
211217
initial_states=initial_states,
212218
seq_idx=seq_idx,
219+
chunk_indices=chunk_indices,
220+
chunk_offsets=chunk_offsets,
213221
cu_seqlens=cu_seqlens,
214222
dt_softplus=dt_softplus,
215223
dt_limit=dt_limit)

vllm/model_executor/layers/mamba/ops/ssd_state_passing.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,6 @@ def _state_passing_fwd(
150150
# are used for continuous batching. In which case we
151151
# require seq_idx to be provided
152152
assert seq_idx is not None, ""
153-
assert initial_states.shape == (seq_idx.max().item() + 1, nheads,
154-
dim)
155153
else:
156154
# - this is the regular batching case, where initial
157155
# states are used are for each example of the batch.

0 commit comments

Comments
 (0)