Skip to content

Commit 2b5008a

Browse files
tlrmchlsmthfabianlim
authored andcommitted
[Model] Support Mamba2 (Codestral Mamba) (vllm-project#9292)
Signed-off-by: Tyler Michael Smith <[email protected]> Co-authored-by: Yu Chin Fabian Lim <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
1 parent 217f1ce commit 2b5008a

File tree

9 files changed

+376
-65
lines changed

9 files changed

+376
-65
lines changed

tests/models/decoder_only/language/test_mamba.py

+30-13
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,22 @@
44
Run `pytest tests/models/test_mamba.py`.
55
"""
66
import pytest
7+
import torch
78
from transformers import AutoModelForCausalLM, AutoTokenizer
89

910
from vllm.engine.arg_utils import EngineArgs
1011
from vllm.sampling_params import SamplingParams
1112

1213
from ...utils import check_outputs_equal
1314

14-
MODELS = ["state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev"]
15+
MODELS = [
16+
"state-spaces/mamba-130m-hf",
17+
"tiiuae/falcon-mamba-tiny-dev",
18+
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
19+
# Mamba2 is buggy for Codestral as it doesn't handle n_groups.
20+
# See https://github.com/huggingface/transformers/pull/35943
21+
# "mistralai/Mamba-Codestral-7B-v0.1",
22+
]
1523

1624

1725
# Use lower-level interfaces to create this greedy generator, as mamba will
@@ -21,6 +29,10 @@ def generate_greedy(model_name, example_prompts, max_tokens):
2129
tokenizer = AutoTokenizer.from_pretrained(model_name)
2230
model = AutoModelForCausalLM.from_pretrained(model_name)
2331

32+
# Set the device (GPU if available, else CPU)
33+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34+
model.to(device)
35+
2436
# Generate texts from the prompts
2537
outputs = []
2638
for prompt in example_prompts:
@@ -29,7 +41,9 @@ def generate_greedy(model_name, example_prompts, max_tokens):
2941
input_ids = inputs["input_ids"].to(model.device)
3042

3143
# Generate text using the model's generate method directly
32-
generated_ids = model.generate(input_ids, max_new_tokens=max_tokens)
44+
generated_ids = model.generate(input_ids,
45+
max_new_tokens=max_tokens,
46+
do_sample=False)
3347
generated_text = tokenizer.decode(generated_ids[0],
3448
skip_special_tokens=True)
3549

@@ -50,7 +64,8 @@ def test_models(
5064
) -> None:
5165
hf_outputs = generate_greedy(model, example_prompts, max_tokens)
5266

53-
with vllm_runner(model, dtype=dtype) as vllm_model:
67+
# Set max_num_seqs to keep Codestral from going OOM at fp32
68+
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
5469
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
5570

5671
# This test is for verifying whether the model's extra_repr
@@ -81,7 +96,7 @@ def test_batching(
8196
) -> None:
8297
# To pass the small model tests, we need full precision.
8398
for_loop_outputs = []
84-
with vllm_runner(model, dtype=dtype) as vllm_model:
99+
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
85100
for prompt in example_prompts:
86101
for_loop_outputs.append(
87102
vllm_model.generate_greedy([prompt], max_tokens)[0])
@@ -165,20 +180,22 @@ def test_parallel_sampling(
165180
max_tokens: int,
166181
) -> None:
167182

168-
with vllm_runner(model, dtype=dtype) as vllm_model:
183+
# Numerical differences produce slightly different output for these
184+
if 'state-spaces' in model:
185+
example_prompts.pop(0)
186+
example_prompts.pop(0)
187+
example_prompts.pop(0)
188+
189+
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
169190
for_loop_outputs = []
170191
for _ in range(10):
171192
for_loop_outputs.append(
172-
# using example_prompts index 1 instead of 0 since with 0 the
173-
# logprobs get really close and the test doesn't pass
174-
vllm_model.generate_greedy([example_prompts[1]], max_tokens)
175-
[0])
193+
vllm_model.generate_greedy(example_prompts, max_tokens)[0])
176194
sampling_params = SamplingParams(n=10,
177195
temperature=0.001,
178196
seed=0,
179197
max_tokens=max_tokens)
180-
n_lt_1_outputs = vllm_model.generate([example_prompts[1]],
181-
sampling_params)
198+
n_lt_1_outputs = vllm_model.generate(example_prompts, sampling_params)
182199
token_ids, texts = n_lt_1_outputs[0]
183200
n_lt_1_outputs = [(token_id, text)
184201
for token_id, text in zip(token_ids, texts)]
@@ -232,7 +249,7 @@ def test_models_preemption_recompute(
232249
# Tests that outputs are identical with and w/o preemtions (recompute)
233250
assert dtype == "float"
234251

235-
with vllm_runner(model, dtype=dtype) as vllm_model:
252+
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
236253
vllm_model.model.llm_engine.scheduler[
237254
0].ENABLE_ARTIFICIAL_PREEMPT = True
238255
preempt_vllm_outputs = vllm_model.generate_greedy(
@@ -283,7 +300,7 @@ def test_state_cleanup(
283300
# This test is for verifying that the Mamba state is cleaned up between
284301
# steps, If its not cleaned, an error would be expected.
285302
try:
286-
with vllm_runner(model, dtype=dtype) as vllm_model:
303+
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
287304
for _ in range(10):
288305
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
289306
except ValueError:

tests/models/registry.py

+2
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ def check_available_online(
145145
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
146146
is_available_online=False),
147147
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
148+
"Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1",
149+
is_available_online=False),
148150
"FalconMambaForCausalLM": _HfExamplesInfo("tiiuae/falcon-mamba-7b-instruct"), # noqa: E501
149151
"MiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-2B-sft-bf16",
150152
trust_remote_code=True),

vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@ def _chunk_scan_fwd_kernel(
293293
dA_cs_m_boundary = tl.load(
294294
dA_cumsum_ptr +
295295
(pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize,
296-
mask=(pid_m * BLOCK_SIZE_M + c_off - 1) > -1,
296+
mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1)
297+
and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)),
297298
other=0.0).to(tl.float32)
298299

299300
if HAS_SEQ_IDX:
@@ -463,7 +464,10 @@ def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
463464
p += (s % chunk_size > 0)
464465

465466
# get the dimensions
466-
_s, _e = s // chunk_size + p, e // chunk_size + p + 1
467+
# - the + 1 for _e is to shift the boundary by one chunk
468+
# - this shifting is not needed if chunk_size divides e
469+
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
470+
> 0)
467471

468472
# adjust inidces and offsets
469473
chunk_indices[_s:_e] -= p

vllm/model_executor/models/bamba.py

+2-19
Original file line numberDiff line numberDiff line change
@@ -440,23 +440,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
440440
self.make_empty_intermediate_tensors = (
441441
self.model.make_empty_intermediate_tensors)
442442

443-
# follow jamba
444-
if self.scheduler_config is not None and \
445-
not self.model_config.enforce_eager:
446-
# for compilation
447-
if self.scheduler_config.max_num_seqs > \
448-
vllm_config.compilation_config.max_capture_size:
449-
self.max_batch_size = \
450-
vllm_config.compilation_config.max_capture_size
451-
else:
452-
self.max_batch_size = vllm_config.pad_for_cudagraph(
453-
self.scheduler_config.max_num_seqs)
454-
elif self.scheduler_config is not None:
455-
# for eager just take the scheduler_config if avail
456-
self.max_batch_size = self.scheduler_config.max_num_seqs
457-
else:
458-
self.max_batch_size = 8192 + 2
459-
460443
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
461444
return self.model.get_input_embeddings(input_ids)
462445

@@ -474,8 +457,8 @@ def forward(self,
474457
self.vllm_config.parallel_config, LayerBlockType.mamba)
475458

476459
self.mamba_cache = MambaCacheManager(
477-
self.lm_head.weight.dtype, num_mamba_layers,
478-
self.max_batch_size, *self._get_mamba_cache_shape())
460+
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
461+
*self._get_mamba_cache_shape())
479462
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
480463
hidden_states = self.model(input_ids, positions, kv_caches,
481464
attn_metadata, mamba_cache_params,

vllm/model_executor/models/jamba.py

+2-13
Original file line numberDiff line numberDiff line change
@@ -426,17 +426,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
426426

427427
self.make_empty_intermediate_tensors = (
428428
self.model.make_empty_intermediate_tensors)
429-
if self.scheduler_config is not None and \
430-
not self.model_config.enforce_eager:
431-
if self.scheduler_config.max_num_seqs > \
432-
vllm_config.compilation_config.max_capture_size:
433-
self.max_batch_size = \
434-
vllm_config.compilation_config.max_capture_size
435-
else:
436-
self.max_batch_size = vllm_config.pad_for_cudagraph(
437-
self.scheduler_config.max_num_seqs)
438-
else:
439-
self.max_batch_size = 8192 + 2
440429

441430
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
442431
return self.model.get_input_embeddings(input_ids)
@@ -453,8 +442,8 @@ def forward(self,
453442
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
454443
self.vllm_config.parallel_config, LayerBlockType.mamba)
455444
self.mamba_cache = MambaCacheManager(
456-
self.lm_head.weight.dtype, num_mamba_layers,
457-
self.max_batch_size, *self._get_mamba_cache_shape())
445+
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
446+
*self._get_mamba_cache_shape())
458447

459448
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
460449

vllm/model_executor/models/mamba.py

+3-15
Original file line numberDiff line numberDiff line change
@@ -166,14 +166,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
166166
config = vllm_config.model_config.hf_config
167167
cache_config = vllm_config.cache_config
168168
lora_config = vllm_config.lora_config
169-
scheduler_config = vllm_config.scheduler_config
169+
self.scheduler_config = vllm_config.scheduler_config
170170
assert not cache_config.enable_prefix_caching, \
171171
"Mamba does not support prefix caching"
172172

173173
super().__init__()
174174
self.config = config
175175
self.vllm_config = vllm_config
176-
self.scheduler_config = scheduler_config
177176
self.model_config = vllm_config.model_config
178177
self.backbone = MambaModel(vllm_config=vllm_config,
179178
prefix=maybe_prefix(prefix, "backbone"))
@@ -202,17 +201,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
202201

203202
self.make_empty_intermediate_tensors = (
204203
self.backbone.make_empty_intermediate_tensors)
205-
if self.scheduler_config is not None and \
206-
not self.model_config.enforce_eager:
207-
if self.scheduler_config.max_num_seqs > \
208-
vllm_config.compilation_config.max_capture_size:
209-
self.max_batch_size = \
210-
vllm_config.compilation_config.max_capture_size
211-
else:
212-
self.max_batch_size = vllm_config.pad_for_cudagraph(
213-
self.scheduler_config.max_num_seqs)
214-
else:
215-
self.max_batch_size = 8192 + 2
216204

217205
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
218206
return self.backbone.get_input_embeddings(input_ids)
@@ -229,8 +217,8 @@ def forward(self,
229217
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
230218
self.vllm_config.parallel_config, LayerBlockType.mamba)
231219
self.mamba_cache = MambaCacheManager(
232-
self.lm_head.weight.dtype, num_mamba_layers,
233-
self.max_batch_size, *self._get_mamba_cache_shape())
220+
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
221+
*self._get_mamba_cache_shape())
234222

235223
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
236224

0 commit comments

Comments
 (0)