Skip to content

Commit 9ef98d5

Browse files
ZZBoomqscqesze
andauthored
[Model][MiniMaxText01] Support MiniMaxText01 model inference (#13454)
Signed-off-by: qscqesze <[email protected]> Co-authored-by: qingjun <[email protected]> Co-authored-by: qscqesze <[email protected]>
1 parent 93491ae commit 9ef98d5

File tree

11 files changed

+2440
-130
lines changed

11 files changed

+2440
-130
lines changed

docs/source/models/supported_models.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,11 @@ See [this page](#generative-models) for more information on how to use generativ
503503
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
504504
* ✅︎
505505
* ✅︎
506+
- * `MiniMaxText01ForCausalLM`
507+
* MiniMax-Text
508+
* `MiniMaxAI/MiniMax-Text-01`, etc.
509+
*
510+
* ✅︎
506511
- * `Zamba2ForCausalLM`
507512
* Zamba2
508513
* `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc.

tests/kernels/test_lightning_attn.py

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
import torch
5+
6+
from vllm.model_executor.layers.lightning_attn import (
7+
linear_decode_forward_triton)
8+
from vllm.platforms import current_platform
9+
10+
NUM_HEADS = [4, 8]
11+
HEAD_SIZES = [64]
12+
BATCH_SIZES = [1, 2]
13+
SEQ_LENGTHS = [16]
14+
DTYPES = [torch.float32]
15+
16+
17+
def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
18+
"""Reference implementation of lightning attention core algorithm
19+
20+
The difference from the main implementation is that this processes
21+
each step sequentially, instead of using parallelized triton kernels
22+
"""
23+
B, H, S, D = q.shape
24+
E = v.shape[-1]
25+
dtype = q.dtype
26+
output = torch.zeros((B, H, S, E), dtype=dtype, device=q.device)
27+
28+
# Use clone() to ensure an independent copy
29+
if kv_history is None:
30+
kv_cache = torch.zeros((B, H, D, E), dtype=dtype, device=q.device)
31+
else:
32+
kv_cache = kv_history.clone()
33+
34+
# More efficient implementation
35+
# Convert decay factors to matrix form
36+
if ed.dim() == 1:
37+
decay = torch.exp(-ed).view(1, -1, 1, 1)
38+
else:
39+
decay = torch.exp(-ed)
40+
41+
for b in range(B):
42+
for step in range(S):
43+
# Process all heads at once for this position
44+
q_bs = q[b, :, step] # [H, D]
45+
k_bs = k[b, :, step] # [H, D]
46+
v_bs = v[b, :, step] # [H, E]
47+
48+
# Calculate KV outer products for all heads
49+
for h in range(H):
50+
# Calculate KV outer product
51+
kv_outer = torch.outer(k_bs[h], v_bs[h])
52+
53+
# Update KV cache with decay
54+
# Note: Using the same order as in the Triton kernel
55+
kv_cache[b, h] = decay[0, h, 0, 0] * kv_cache[b, h] + kv_outer
56+
57+
# Calculate attention output
58+
output[b, h, step] = torch.matmul(q_bs[h], kv_cache[b, h])
59+
60+
# Match the shape returned by the actual implementation
61+
# The actual implementation returns a tensor of shape [B, H, 2, D, E]
62+
# where dimension 2 contains both KV and KV history
63+
kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E]
64+
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped],
65+
dim=2) # [B, H, 2, D, E]
66+
67+
return output, final_kv_cache
68+
69+
70+
def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx):
71+
"""Reference implementation: linear attention decode function"""
72+
B, H, _, D = q.shape
73+
output = torch.zeros(B, H * D, dtype=q.dtype, device=q.device)
74+
75+
# Calculate decay factors once (more efficient)
76+
decay = torch.exp(-slope_rate).view(-1, 1, 1) # [H, 1, 1]
77+
78+
# Process each batch
79+
for b in range(B):
80+
slot_id = slot_idx[b].item()
81+
82+
# Skip padding positions
83+
if slot_id == -1:
84+
continue
85+
86+
# Process all heads at once for this batch
87+
q_b = q[b, :, 0] # [H, D]
88+
k_b = k[b, :, 0] # [H, D]
89+
v_b = v[b, :, 0] # [H, D]
90+
91+
# Process each attention head
92+
for h in range(H):
93+
# Get current query, key and value
94+
q_bh = q_b[h]
95+
k_bh = k_b[h]
96+
v_bh = v_b[h]
97+
98+
# Get cache
99+
kv_cache_old = kv_caches[b, h]
100+
101+
# Calculate new key-value outer product
102+
kv_outer = torch.outer(k_bh, v_bh)
103+
104+
# Apply decay and update cache
105+
kv_new = kv_outer + decay[h, 0, 0] * kv_cache_old
106+
107+
# Calculate output
108+
out_h = torch.matmul(q_bh, kv_new)
109+
110+
# Update output and cache
111+
output[b, h * D:(h + 1) * D] = out_h
112+
kv_caches[b, h] = kv_new
113+
114+
return output
115+
116+
117+
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
118+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
119+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
120+
@pytest.mark.parametrize("dtype", DTYPES)
121+
@torch.inference_mode()
122+
def test_linear_decode_forward_triton(
123+
batch_size: int,
124+
num_heads: int,
125+
head_size: int,
126+
dtype: torch.dtype,
127+
):
128+
torch.set_default_device("cuda")
129+
torch.manual_seed(42)
130+
torch.cuda.manual_seed_all(42)
131+
current_platform.seed_everything(42)
132+
base = 0.01
133+
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
134+
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
135+
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
136+
137+
kv_caches = base * torch.randn(batch_size,
138+
num_heads,
139+
head_size,
140+
head_size,
141+
dtype=dtype,
142+
device="cuda")
143+
144+
kv_caches_copy = kv_caches.clone()
145+
146+
slope_rate = torch.zeros(num_heads, device="cuda")
147+
for h in range(num_heads):
148+
slope_rate[h] = 0.1 * (h + 1)
149+
150+
slot_idx = torch.arange(batch_size, device="cuda")
151+
152+
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
153+
slope_rate, slot_idx)
154+
155+
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
156+
slope_rate, slot_idx)
157+
torch.testing.assert_close(triton_output,
158+
reference_output,
159+
rtol=1e-1,
160+
atol=1e-1)
161+
torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1)
162+
163+
assert triton_output.shape == (batch_size, num_heads * head_size)
164+
165+
166+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
167+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
168+
@pytest.mark.parametrize("dtype", DTYPES)
169+
@torch.inference_mode()
170+
def test_linear_decode_forward_triton_with_padding(
171+
num_heads: int,
172+
head_size: int,
173+
dtype: torch.dtype,
174+
):
175+
torch.set_default_device("cuda")
176+
torch.manual_seed(42)
177+
torch.cuda.manual_seed_all(42)
178+
current_platform.seed_everything(42)
179+
180+
batch_size = 4
181+
base = 0.01
182+
q = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
183+
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
184+
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
185+
186+
kv_caches = base * torch.randn(batch_size,
187+
num_heads,
188+
head_size,
189+
head_size,
190+
dtype=dtype,
191+
device="cuda")
192+
193+
kv_caches_copy = kv_caches.clone()
194+
195+
slope_rate = torch.zeros(num_heads, device="cuda")
196+
for h in range(num_heads):
197+
slope_rate[h] = 0.1 * (h + 1)
198+
199+
slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")
200+
201+
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
202+
slope_rate, slot_idx)
203+
204+
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
205+
slope_rate, slot_idx)
206+
207+
padding_mask = (slot_idx
208+
!= -1).unsqueeze(1).expand(-1, num_heads * head_size)
209+
210+
triton_masked = triton_output[padding_mask]
211+
reference_masked = reference_output[padding_mask]
212+
213+
atol, rtol = 1.5e-1, 1.5e-1
214+
215+
valid_indices = slot_idx != -1
216+
217+
for i in range(batch_size):
218+
if valid_indices[i] > 0:
219+
torch.testing.assert_close(kv_caches[i],
220+
kv_caches_copy[i],
221+
rtol=rtol,
222+
atol=atol)
223+
224+
torch.testing.assert_close(triton_masked,
225+
reference_masked,
226+
rtol=rtol,
227+
atol=atol)
228+
229+
assert triton_output.shape == (batch_size, num_heads * head_size)
230+
231+
232+
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
233+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
234+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
235+
@pytest.mark.parametrize("seq_len", SEQ_LENGTHS)
236+
@pytest.mark.parametrize("dtype", DTYPES)
237+
@torch.inference_mode()
238+
def test_lightning_attention_reference(
239+
batch_size: int,
240+
num_heads: int,
241+
head_size: int,
242+
seq_len: int,
243+
dtype: torch.dtype,
244+
):
245+
torch.set_default_device("cuda")
246+
torch.manual_seed(42)
247+
torch.cuda.manual_seed_all(42)
248+
current_platform.seed_everything(42)
249+
250+
base = 0.01
251+
q = base * torch.randn(
252+
batch_size, num_heads, seq_len, head_size, dtype=dtype)
253+
k = base * torch.randn(
254+
batch_size, num_heads, seq_len, head_size, dtype=dtype)
255+
v = base * torch.randn(
256+
batch_size, num_heads, seq_len, head_size, dtype=dtype)
257+
258+
ed = torch.zeros(num_heads, device="cuda")
259+
for h in range(num_heads):
260+
ed[h] = 0.1 * (h + 1)
261+
262+
kv_history = base * torch.randn(batch_size,
263+
num_heads,
264+
head_size,
265+
head_size,
266+
dtype=dtype,
267+
device="cuda")
268+
269+
kv_history_clone = kv_history.clone()
270+
271+
ref_output, ref_kv_cache = reference_lightning_attention(
272+
q, k, v, ed, 256, kv_history)
273+
274+
from vllm.model_executor.layers.lightning_attn import lightning_attention
275+
actual_output, actual_kv_cache = lightning_attention(
276+
q, k, v, ed, 256, kv_history_clone)
277+
278+
atol, rtol = 1.5e-1, 1.5e-1
279+
torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol)
280+
torch.testing.assert_close(ref_kv_cache,
281+
actual_kv_cache,
282+
rtol=rtol,
283+
atol=atol)
284+
285+
assert ref_output.shape == (batch_size, num_heads, seq_len, head_size)
286+
assert ref_kv_cache.shape == actual_kv_cache.shape

tests/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def check_available_online(
176176
trust_remote_code=True),
177177
"MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B",
178178
trust_remote_code=True),
179+
"MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01",
180+
trust_remote_code=True),
179181
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
180182
"MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501
181183
"QuantMixtralForCausalLM": _HfExamplesInfo("mistral-community/Mixtral-8x22B-v0.1-AWQ"), # noqa: E501

vllm/config.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -971,26 +971,34 @@ def get_num_layers_by_block_type(
971971
return sum(not bc.attention.no_op
972972
for bc in block_configs[start:end])
973973
else:
974-
# Hybrid model
974+
# Hybrid model Jamba
975975
layers_block_type_value = getattr(self.hf_config,
976976
"layers_block_type", None)
977-
if layers_block_type_value is None:
978-
raise ValueError("The model is an hybrid without a "
979-
"layers_block_type in the hf_config, "
980-
"cannot determine the num of "
981-
f"{block_type.value} layers")
982-
983-
if hasattr(self.hf_text_config,
984-
"model_type") and (self.hf_text_config.model_type
985-
== "zamba2"):
986-
if attn_block_type:
987-
return sum(t == "hybrid"
988-
for t in layers_block_type_value[start:end])
989-
else:
990-
return self.get_num_layers(parallel_config)
977+
if layers_block_type_value is not None:
978+
if hasattr(self.hf_text_config,
979+
"model_type") and (self.hf_text_config.model_type
980+
== "zamba2"):
981+
if attn_block_type:
982+
return sum(t == "hybrid"
983+
for t in layers_block_type_value[start:end])
984+
else:
985+
return self.get_num_layers(parallel_config)
986+
return sum(t == block_type.value
987+
for t in layers_block_type_value[start:end])
988+
989+
# Hybrid model Minimax
990+
attn_type_list = getattr(self.hf_config, "attn_type_list", None)
991+
if attn_type_list:
992+
return sum(t == 1 for t in attn_type_list[start:end])
993+
994+
if layers_block_type_value is None and attn_type_list is None:
995+
raise ValueError(
996+
"The model is an hybrid without a"
997+
"layers_block_type or an attn_type_list in the hf_config,"
998+
"cannot determine the num of "
999+
f"{block_type.value} layers")
9911000

992-
return sum(t == block_type.value
993-
for t in layers_block_type_value[start:end])
1001+
return sum(t == 1 for t in attn_type_list[start:end])
9941002

9951003
def get_multimodal_config(self) -> "MultiModalConfig":
9961004
"""

vllm/engine/async_llm_engine.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,11 @@ async def step_async(
303303
ctx.seq_group_metadata_list = seq_group_metadata_list
304304
ctx.scheduler_outputs = scheduler_outputs
305305

306-
finished_requests_ids = self.scheduler[
307-
virtual_engine].get_and_reset_finished_requests_ids()
306+
if not scheduler_outputs.is_empty():
307+
# this will cause mamba_cache/minimax_cache failed
308+
# to release finished_requests_ids of the last steps
309+
finished_requests_ids = self.scheduler[
310+
virtual_engine].get_and_reset_finished_requests_ids()
308311

309312
# Maybe switch from async mode to sync mode
310313
if not allow_async_output_proc and len(ctx.output_queue) > 0:

0 commit comments

Comments
 (0)