Skip to content

[Model] Support Mamba2 (Codestral Mamba) #9292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c956a30
Mamba2 changes from #10909
tlrmchlsmth Jan 16, 2025
17923ad
Get Mamba2 working!
tlrmchlsmth Jan 16, 2025
4183d45
Add integration test -- something is wrong!!
tlrmchlsmth Jan 17, 2025
5377644
format
tlrmchlsmth Jan 17, 2025
39f55d1
fixes
tlrmchlsmth Jan 17, 2025
e2e5aac
Fix for conv state shape and update placeholder_attn
tlrmchlsmth Jan 19, 2025
bc1b8af
back out placeholder_attn changes
tlrmchlsmth Jan 19, 2025
cd89283
WIP debugging, restore local mamba and placeholder_attn changes
tlrmchlsmth Jan 20, 2025
9a838a3
Integration tests are now green
tlrmchlsmth Jan 20, 2025
be8318e
remove bamba-specific files
tlrmchlsmth Jan 20, 2025
f34d434
Merge branch 'main' into tms/mamba2
tlrmchlsmth Jan 27, 2025
a65e2cb
Handle grouping in Mixer2RMSNormGated
tlrmchlsmth Jan 30, 2025
0d4bb0f
debug cruft
tlrmchlsmth Jan 30, 2025
74f6088
Remove codestral integration test
tlrmchlsmth Jan 30, 2025
b28cdfc
Merge branch 'main' into tms/mamba2
tlrmchlsmth Feb 7, 2025
bcc93e6
spdx
tlrmchlsmth Feb 7, 2025
2e5f7be
Fix invalid memory accesses in chunk_scan_fwd
tlrmchlsmth Feb 7, 2025
8ab4a90
Merge branch 'main' into tms/mamba2
tlrmchlsmth Feb 13, 2025
0f2cc22
Clean up max batch size logic
tlrmchlsmth Feb 13, 2025
a062d0d
Merge branch 'main' into tms/mamba2
tlrmchlsmth Feb 14, 2025
9cfd012
Add mamba to test-pipeline to see if it passes
tlrmchlsmth Feb 14, 2025
201696b
fix registery, determine max_batch_size in MambaCache
tlrmchlsmth Feb 14, 2025
fa27469
update registry
tlrmchlsmth Feb 15, 2025
d6ef554
Merge branch 'main' into tms/mamba2
tlrmchlsmth Feb 15, 2025
c5d2658
rm debug cruft
tlrmchlsmth Feb 15, 2025
3457501
Merge branch 'main' into tms/mamba2
tlrmchlsmth Feb 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions tests/models/decoder_only/language/test_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@
Run `pytest tests/models/test_mamba.py`.
"""
import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams

from ...utils import check_outputs_equal

MODELS = ["state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev"]
MODELS = [
"state-spaces/mamba-130m-hf",
"tiiuae/falcon-mamba-tiny-dev",
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
# Mamba2 is buggy for Codestral as it doesn't handle n_groups.
# See https://github.com/huggingface/transformers/pull/35943
# "mistralai/Mamba-Codestral-7B-v0.1",
]


# Use lower-level interfaces to create this greedy generator, as mamba will
Expand All @@ -21,6 +29,10 @@ def generate_greedy(model_name, example_prompts, max_tokens):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Set the device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Generate texts from the prompts
outputs = []
for prompt in example_prompts:
Expand All @@ -29,7 +41,9 @@ def generate_greedy(model_name, example_prompts, max_tokens):
input_ids = inputs["input_ids"].to(model.device)

# Generate text using the model's generate method directly
generated_ids = model.generate(input_ids, max_new_tokens=max_tokens)
generated_ids = model.generate(input_ids,
max_new_tokens=max_tokens,
do_sample=False)
generated_text = tokenizer.decode(generated_ids[0],
skip_special_tokens=True)

Expand All @@ -50,7 +64,8 @@ def test_models(
) -> None:
hf_outputs = generate_greedy(model, example_prompts, max_tokens)

with vllm_runner(model, dtype=dtype) as vllm_model:
# Set max_num_seqs to keep Codestral from going OOM at fp32
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

# This test is for verifying whether the model's extra_repr
Expand Down Expand Up @@ -81,7 +96,7 @@ def test_batching(
) -> None:
# To pass the small model tests, we need full precision.
for_loop_outputs = []
with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
for prompt in example_prompts:
for_loop_outputs.append(
vllm_model.generate_greedy([prompt], max_tokens)[0])
Expand Down Expand Up @@ -165,7 +180,7 @@ def test_parallel_sampling(
max_tokens: int,
) -> None:

with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
for_loop_outputs = []
for _ in range(10):
for_loop_outputs.append(
Expand Down Expand Up @@ -232,7 +247,7 @@ def test_models_preemption_recompute(
# Tests that outputs are identical with and w/o preemtions (recompute)
assert dtype == "float"

with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
vllm_model.model.llm_engine.scheduler[
0].ENABLE_ARTIFICIAL_PREEMPT = True
preempt_vllm_outputs = vllm_model.generate_greedy(
Expand Down Expand Up @@ -283,7 +298,7 @@ def test_state_cleanup(
# This test is for verifying that the Mamba state is cleaned up between
# steps, If its not cleaned, an error would be expected.
try:
with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model, dtype=dtype, max_num_seqs=16) as vllm_model:
for _ in range(10):
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
except ValueError:
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ def _chunk_scan_fwd_kernel(
dA_cs_m_boundary = tl.load(
dA_cumsum_ptr +
(pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize,
mask=(pid_m * BLOCK_SIZE_M + c_off - 1) > -1,
mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1)
and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)),
other=0.0).to(tl.float32)

if HAS_SEQ_IDX:
Expand Down Expand Up @@ -463,7 +464,10 @@ def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
p += (s % chunk_size > 0)

# get the dimensions
_s, _e = s // chunk_size + p, e // chunk_size + p + 1
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
> 0)

# adjust inidces and offsets
chunk_indices[_s:_e] -= p
Expand Down
21 changes: 5 additions & 16 deletions vllm/model_executor/models/bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,22 +440,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

# follow jamba
if self.scheduler_config is not None and \
not self.model_config.enforce_eager:
# for compilation
if self.scheduler_config.max_num_seqs > \
vllm_config.compilation_config.max_capture_size:
self.max_batch_size = \
vllm_config.compilation_config.max_capture_size
else:
self.max_batch_size = vllm_config.pad_for_cudagraph(
self.scheduler_config.max_num_seqs)
elif self.scheduler_config is not None:
# for eager just take the scheduler_config if avail
self.max_batch_size = self.scheduler_config.max_num_seqs
else:
self.max_batch_size = 8192 + 2
# Determine max batch size to set size of MambaCache
self.max_batch_size = self.scheduler_config.max_num_seqs
if not self.model_config.enforce_eager:
self.max_batch_size = vllm_config.pad_for_cudagraph(
self.max_batch_size)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
Expand Down
17 changes: 6 additions & 11 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,17 +426,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
if self.scheduler_config is not None and \
not self.model_config.enforce_eager:
if self.scheduler_config.max_num_seqs > \
vllm_config.compilation_config.max_capture_size:
self.max_batch_size = \
vllm_config.compilation_config.max_capture_size
else:
self.max_batch_size = vllm_config.pad_for_cudagraph(
self.scheduler_config.max_num_seqs)
else:
self.max_batch_size = 8192 + 2

# Determine max batch size to set size of MambaCache
self.max_batch_size = self.scheduler_config.max_num_seqs
if not self.model_config.enforce_eager:
self.max_batch_size = vllm_config.pad_for_cudagraph(
self.max_batch_size)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
Expand Down
20 changes: 7 additions & 13 deletions vllm/model_executor/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
self.scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \
"Mamba does not support prefix caching"

super().__init__()
self.config = config
self.vllm_config = vllm_config
self.scheduler_config = scheduler_config
self.model_config = vllm_config.model_config
self.backbone = MambaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "backbone"))
Expand Down Expand Up @@ -202,17 +201,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.make_empty_intermediate_tensors = (
self.backbone.make_empty_intermediate_tensors)
if self.scheduler_config is not None and \
not self.model_config.enforce_eager:
if self.scheduler_config.max_num_seqs > \
vllm_config.compilation_config.max_capture_size:
self.max_batch_size = \
vllm_config.compilation_config.max_capture_size
else:
self.max_batch_size = vllm_config.pad_for_cudagraph(
self.scheduler_config.max_num_seqs)
else:
self.max_batch_size = 8192 + 2

# Determine max batch size to set size of MambaCache
self.max_batch_size = self.scheduler_config.max_num_seqs
if not self.model_config.enforce_eager:
self.max_batch_size = vllm_config.pad_for_cudagraph(
self.max_batch_size)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.backbone.get_input_embeddings(input_ids)
Expand Down
Loading