Skip to content

Commit aff4045

Browse files
Add Bamba Model (#10909)
Signed-off-by: Yu Chin Fabian Lim <[email protected]> Signed-off-by: Tyler Michael Smith <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]>
1 parent 467a96a commit aff4045

File tree

17 files changed

+3706
-112
lines changed

17 files changed

+3706
-112
lines changed

tests/kernels/test_mamba_mixer2.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import unittest
4+
from typing import Tuple
5+
6+
import pytest
7+
import torch
8+
9+
from tests.utils import multi_gpu_test
10+
from vllm.distributed.parallel_state import (init_distributed_environment,
11+
initialize_model_parallel)
12+
from vllm.model_executor.layers.mamba.mamba_mixer2 import Mixer2RMSNormGated
13+
from vllm.platforms import current_platform
14+
from vllm.utils import update_environment_variables
15+
16+
17+
@multi_gpu_test(num_gpus=2)
18+
@pytest.mark.parametrize("batch_size", [8])
19+
@pytest.mark.parametrize("seq_len", [128])
20+
@pytest.mark.parametrize(
21+
"hidden_size_n_groups",
22+
[
23+
(64, 1),
24+
(64, 2),
25+
(64, 4), # hidden_size be divisible by num_gpus
26+
(100, 5), # and n_groups must divide hidden_size
27+
])
28+
@pytest.mark.parametrize("dtype", [torch.float16])
29+
def test_mixer2_gated_norm_multi_gpu(
30+
batch_size: int,
31+
seq_len: int,
32+
hidden_size_n_groups: Tuple[int, int],
33+
dtype: torch.dtype,
34+
device: str = 'cuda',
35+
):
36+
hidden_size, n_groups = hidden_size_n_groups
37+
num_processes = 2
38+
39+
def run_torch_spawn(fn, nprocs):
40+
# need to use torch.mp.spawn otherwise will have problems with
41+
# torch.distributed and cuda
42+
torch.multiprocessing.spawn(fn,
43+
args=(
44+
num_processes,
45+
batch_size,
46+
seq_len,
47+
hidden_size,
48+
n_groups,
49+
dtype,
50+
device,
51+
),
52+
nprocs=nprocs)
53+
54+
run_torch_spawn(mixer2_gated_norm_tensor_parallel, 2)
55+
56+
57+
def mixer2_gated_norm_tensor_parallel(
58+
local_rank: int,
59+
world_size: int,
60+
batch_size: int,
61+
seq_len: int,
62+
hidden_size: int,
63+
n_groups: int,
64+
dtype: torch.dtype,
65+
device: str,
66+
):
67+
current_platform.seed_everything(0)
68+
69+
device = torch.device(f"cuda:{local_rank}")
70+
torch.cuda.set_device(device)
71+
torch.set_default_device(device)
72+
torch.set_default_dtype(dtype)
73+
74+
update_environment_variables({
75+
'RANK': str(local_rank),
76+
'LOCAL_RANK': str(local_rank),
77+
'WORLD_SIZE': str(world_size),
78+
'MASTER_ADDR': 'localhost',
79+
'MASTER_PORT': '12345',
80+
})
81+
82+
# initialize distributed
83+
init_distributed_environment()
84+
initialize_model_parallel(tensor_model_parallel_size=world_size)
85+
86+
# create random weights an inputs
87+
weight = torch.rand((hidden_size, ), dtype=dtype, device=device)
88+
hidden_states = torch.randn(batch_size, seq_len, hidden_size)
89+
gate_states = torch.randn(batch_size, seq_len, hidden_size)
90+
91+
# create gated-norm with TP
92+
mixer = Mixer2RMSNormGated(
93+
full_hidden_size=hidden_size,
94+
full_n_groups=n_groups,
95+
)
96+
mixer.weight.weight_loader(mixer.weight, weight) # load
97+
98+
# create gated-norm without TP to compute reference
99+
# - utilize mock patching to disable TP when
100+
with (unittest.mock.patch(
101+
"vllm.model_executor.layers.mamba.mamba_mixer2."
102+
"get_tensor_model_parallel_world_size",
103+
return_value=1),
104+
unittest.mock.patch(
105+
"vllm.model_executor.layers.mamba.mamba_mixer2."
106+
"get_tensor_model_parallel_rank",
107+
return_value=0)):
108+
mixer_single_gpu = Mixer2RMSNormGated(
109+
full_hidden_size=hidden_size,
110+
full_n_groups=n_groups,
111+
)
112+
# assign weight to single-gpu mixer
113+
mixer_single_gpu.weight.data = weight
114+
115+
# generate and compare
116+
N = hidden_size // world_size
117+
output = mixer(
118+
hidden_states[..., local_rank * N:(local_rank + 1) * N],
119+
gate_states[..., local_rank * N:(local_rank + 1) * N],
120+
)
121+
ref_output = mixer_single_gpu(hidden_states, gate_states)
122+
torch.allclose(output,
123+
ref_output[..., local_rank * N:(local_rank + 1) * N],
124+
atol=1e-3,
125+
rtol=1e-3)

0 commit comments

Comments
 (0)