4
4
Run `pytest tests/models/test_mamba.py`.
5
5
"""
6
6
import pytest
7
+ import torch
7
8
from transformers import AutoModelForCausalLM , AutoTokenizer
8
9
9
10
from vllm .engine .arg_utils import EngineArgs
10
11
from vllm .sampling_params import SamplingParams
11
12
12
13
from ...utils import check_outputs_equal
13
14
14
- MODELS = ["state-spaces/mamba-130m-hf" , "tiiuae/falcon-mamba-tiny-dev" ]
15
+ MODELS = [
16
+ "state-spaces/mamba-130m-hf" ,
17
+ "tiiuae/falcon-mamba-tiny-dev" ,
18
+ # TODO: Compare to a Mamba2 model. The HF transformers implementation of
19
+ # Mamba2 is buggy for Codestral as it doesn't handle n_groups.
20
+ # See https://github.com/huggingface/transformers/pull/35943
21
+ # "mistralai/Mamba-Codestral-7B-v0.1",
22
+ ]
15
23
16
24
17
25
# Use lower-level interfaces to create this greedy generator, as mamba will
@@ -21,6 +29,10 @@ def generate_greedy(model_name, example_prompts, max_tokens):
21
29
tokenizer = AutoTokenizer .from_pretrained (model_name )
22
30
model = AutoModelForCausalLM .from_pretrained (model_name )
23
31
32
+ # Set the device (GPU if available, else CPU)
33
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
34
+ model .to (device )
35
+
24
36
# Generate texts from the prompts
25
37
outputs = []
26
38
for prompt in example_prompts :
@@ -29,7 +41,9 @@ def generate_greedy(model_name, example_prompts, max_tokens):
29
41
input_ids = inputs ["input_ids" ].to (model .device )
30
42
31
43
# Generate text using the model's generate method directly
32
- generated_ids = model .generate (input_ids , max_new_tokens = max_tokens )
44
+ generated_ids = model .generate (input_ids ,
45
+ max_new_tokens = max_tokens ,
46
+ do_sample = False )
33
47
generated_text = tokenizer .decode (generated_ids [0 ],
34
48
skip_special_tokens = True )
35
49
@@ -50,7 +64,8 @@ def test_models(
50
64
) -> None :
51
65
hf_outputs = generate_greedy (model , example_prompts , max_tokens )
52
66
53
- with vllm_runner (model , dtype = dtype ) as vllm_model :
67
+ # Set max_num_seqs to keep Codestral from going OOM at fp32
68
+ with vllm_runner (model , dtype = dtype , max_num_seqs = 16 ) as vllm_model :
54
69
vllm_outputs = vllm_model .generate_greedy (example_prompts , max_tokens )
55
70
56
71
# This test is for verifying whether the model's extra_repr
@@ -81,7 +96,7 @@ def test_batching(
81
96
) -> None :
82
97
# To pass the small model tests, we need full precision.
83
98
for_loop_outputs = []
84
- with vllm_runner (model , dtype = dtype ) as vllm_model :
99
+ with vllm_runner (model , dtype = dtype , max_num_seqs = 16 ) as vllm_model :
85
100
for prompt in example_prompts :
86
101
for_loop_outputs .append (
87
102
vllm_model .generate_greedy ([prompt ], max_tokens )[0 ])
@@ -165,20 +180,22 @@ def test_parallel_sampling(
165
180
max_tokens : int ,
166
181
) -> None :
167
182
168
- with vllm_runner (model , dtype = dtype ) as vllm_model :
183
+ # Numerical differences produce slightly different output for these
184
+ if 'state-spaces' in model :
185
+ example_prompts .pop (0 )
186
+ example_prompts .pop (0 )
187
+ example_prompts .pop (0 )
188
+
189
+ with vllm_runner (model , dtype = dtype , max_num_seqs = 16 ) as vllm_model :
169
190
for_loop_outputs = []
170
191
for _ in range (10 ):
171
192
for_loop_outputs .append (
172
- # using example_prompts index 1 instead of 0 since with 0 the
173
- # logprobs get really close and the test doesn't pass
174
- vllm_model .generate_greedy ([example_prompts [1 ]], max_tokens )
175
- [0 ])
193
+ vllm_model .generate_greedy (example_prompts , max_tokens )[0 ])
176
194
sampling_params = SamplingParams (n = 10 ,
177
195
temperature = 0.001 ,
178
196
seed = 0 ,
179
197
max_tokens = max_tokens )
180
- n_lt_1_outputs = vllm_model .generate ([example_prompts [1 ]],
181
- sampling_params )
198
+ n_lt_1_outputs = vllm_model .generate (example_prompts , sampling_params )
182
199
token_ids , texts = n_lt_1_outputs [0 ]
183
200
n_lt_1_outputs = [(token_id , text )
184
201
for token_id , text in zip (token_ids , texts )]
@@ -232,7 +249,7 @@ def test_models_preemption_recompute(
232
249
# Tests that outputs are identical with and w/o preemtions (recompute)
233
250
assert dtype == "float"
234
251
235
- with vllm_runner (model , dtype = dtype ) as vllm_model :
252
+ with vllm_runner (model , dtype = dtype , max_num_seqs = 16 ) as vllm_model :
236
253
vllm_model .model .llm_engine .scheduler [
237
254
0 ].ENABLE_ARTIFICIAL_PREEMPT = True
238
255
preempt_vllm_outputs = vllm_model .generate_greedy (
@@ -283,7 +300,7 @@ def test_state_cleanup(
283
300
# This test is for verifying that the Mamba state is cleaned up between
284
301
# steps, If its not cleaned, an error would be expected.
285
302
try :
286
- with vllm_runner (model , dtype = dtype ) as vllm_model :
303
+ with vllm_runner (model , dtype = dtype , max_num_seqs = 16 ) as vllm_model :
287
304
for _ in range (10 ):
288
305
vllm_model .generate_greedy ([example_prompts [0 ]] * 100 , 1 )
289
306
except ValueError :
0 commit comments