Skip to content

Commit 8bddb73

Browse files
Akshat-TripathimosalovjeejeeleeIsotr0py
authored
[Hardware][CPU] Multi-LoRA implementation for the CPU backend (#11100)
Signed-off-by: Akshat Tripathi <[email protected]> Signed-off-by: Oleg Mosalov <[email protected]> Signed-off-by: Jee Jee Li <[email protected]> Co-authored-by: Oleg Mosalov <[email protected]> Co-authored-by: Jee Jee Li <[email protected]> Co-authored-by: Isotr0py <[email protected]>
1 parent f967e51 commit 8bddb73

25 files changed

+855
-193
lines changed

.buildkite/run-cpu-test.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ function cpu_tests() {
7575
--num-prompts 20 \
7676
--endpoint /v1/completions \
7777
--tokenizer facebook/opt-125m"
78+
79+
# Run multi-lora tests
80+
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c "
81+
set -e
82+
pytest -s -v \
83+
tests/lora/test_qwen2vl.py"
7884
}
7985

8086
# All of CPU tests are expected to be finished less than 25 mins.

docs/source/features/compatibility_matrix.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar
359359
- ✅
360360
- ✅
361361
- ✅
362-
- [✗](gh-pr:4830)
362+
-
363363
- ✅
364364
* - <abbr title="Prompt Adapter">prmpt adptr</abbr>
365365
- ✅

tests/lora/conftest.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from vllm.model_executor.layers.sampler import Sampler
2222
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
2323
from vllm.model_executor.model_loader import get_model
24+
from vllm.platforms import current_platform
2425

2526

2627
class ContextIDInfo(TypedDict):
@@ -65,13 +66,16 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
6566
@pytest.fixture
6667
def dist_init():
6768
temp_file = tempfile.mkstemp()[1]
68-
init_distributed_environment(
69-
world_size=1,
70-
rank=0,
71-
distributed_init_method=f"file://{temp_file}",
72-
local_rank=0,
73-
backend="nccl",
74-
)
69+
70+
backend = "nccl"
71+
if current_platform.is_cpu():
72+
backend = "gloo"
73+
74+
init_distributed_environment(world_size=1,
75+
rank=0,
76+
distributed_init_method=f"file://{temp_file}",
77+
local_rank=0,
78+
backend=backend)
7579
initialize_model_parallel(1, 1)
7680
yield
7781
cleanup_dist_env_and_memory(shutdown_ray=True)
@@ -81,13 +85,15 @@ def dist_init():
8185
def dist_init_torch_only():
8286
if torch.distributed.is_initialized():
8387
return
88+
backend = "nccl"
89+
if current_platform.is_cpu():
90+
backend = "gloo"
91+
8492
temp_file = tempfile.mkstemp()[1]
85-
torch.distributed.init_process_group(
86-
backend="nccl",
87-
world_size=1,
88-
rank=0,
89-
init_method=f"file://{temp_file}",
90-
)
93+
torch.distributed.init_process_group(world_size=1,
94+
rank=0,
95+
init_method=f"file://{temp_file}",
96+
backend=backend)
9197

9298

9399
@pytest.fixture

tests/lora/test_layers.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,14 @@
4848
torch.float32: (5e-3, 5e-3),
4949
torch.bfloat16: (3e-2, 2e-2),
5050
}
51-
# TODO: Modify this based on platform
52-
DEVICES = [
51+
52+
pytestmark = pytest.mark.skipif(
53+
not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
54+
reason="Backend not supported")
55+
56+
DEVICES = ([
5357
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
54-
]
58+
] if current_platform.is_cuda_alike() else ["cpu"])
5559

5660
#For GPU, we will launch different triton kernels between the prefill and decode
5761
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
@@ -198,6 +202,10 @@ def check_punica_wrapper(punica_wrapper) -> bool:
198202
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
199203

200204
return type(punica_wrapper) is PunicaWrapperGPU
205+
elif current_platform.is_cpu():
206+
from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU
207+
208+
return type(punica_wrapper) is PunicaWrapperCPU
201209
else:
202210
return False
203211

@@ -211,7 +219,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
211219
# For multi-GPU testing of Triton kernel, we must explicitly set the CUDA
212220
# device, see: https://github.com/triton-lang/triton/issues/2925
213221
# Same below.
214-
torch.cuda.set_device(device)
222+
if current_platform.is_cuda_alike():
223+
torch.cuda.set_device(device)
215224

216225
torch.set_default_device(device)
217226
max_loras = 8
@@ -313,7 +322,9 @@ def create_random_embedding_layer():
313322
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
314323
vocab_size, stage) -> None:
315324

316-
torch.cuda.set_device(device)
325+
if current_platform.is_cuda_alike():
326+
torch.cuda.set_device(device)
327+
317328
torch.set_default_device(device)
318329
max_loras = 8
319330
punica_wrapper = get_punica_wrapper(8192, 256, device)
@@ -450,7 +461,9 @@ def create_random_embedding_layer():
450461
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
451462
stage) -> None:
452463

453-
torch.cuda.set_device(device)
464+
if current_platform.is_cuda_alike():
465+
torch.cuda.set_device(device)
466+
454467
torch.set_default_device(device)
455468
max_loras = 8
456469
punica_wrapper = get_punica_wrapper(8192, 256, device)
@@ -582,7 +595,9 @@ def _pretest():
582595
def test_linear_replicated(dist_init, num_loras, device, stage,
583596
bias_enabled) -> None:
584597

585-
torch.cuda.set_device(device)
598+
if current_platform.is_cuda_alike():
599+
torch.cuda.set_device(device)
600+
586601
torch.set_default_device(device)
587602
punica_wrapper = get_punica_wrapper(8192, 256, device)
588603
assert check_punica_wrapper(punica_wrapper)
@@ -695,7 +710,9 @@ def create_random_linear_replicated_layer():
695710
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
696711
device, stage, bias_enabled) -> None:
697712

698-
torch.cuda.set_device(device)
713+
if current_platform.is_cuda_alike():
714+
torch.cuda.set_device(device)
715+
699716
torch.set_default_device(device)
700717
punica_wrapper = get_punica_wrapper(8192, 256, device)
701718
assert check_punica_wrapper(punica_wrapper)
@@ -818,7 +835,9 @@ def create_random_linear_parallel_layer():
818835
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
819836
device, stage, bias_enabled) -> None:
820837

821-
torch.cuda.set_device(device)
838+
if current_platform.is_cuda_alike():
839+
torch.cuda.set_device(device)
840+
822841
torch.set_default_device(device)
823842
punica_wrapper = get_punica_wrapper(8192, 256, device)
824843
assert check_punica_wrapper(punica_wrapper)
@@ -971,6 +990,8 @@ class FakeConfig:
971990
@pytest.mark.parametrize("rotary_dim", [None, 32])
972991
@pytest.mark.parametrize("head_size", [32, 108])
973992
@pytest.mark.parametrize("seq_len", [11, 1024])
993+
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
994+
reason="Only CUDA backends are supported")
974995
def test_rotary_embedding_long_context(dist_init, num_loras, device,
975996
scaling_factors, max_position,
976997
is_neox_style, rotary_dim, head_size,

tests/lora/test_lora_manager.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
2121
WorkerLoRAManager)
2222
from vllm.model_executor.layers.linear import RowParallelLinear
23+
from vllm.platforms import current_platform
2324

2425
EMBEDDING_MODULES = {
2526
"embed_tokens": "input_embeddings",
@@ -28,9 +29,9 @@
2829

2930
EMBEDDING_PADDING_MODULES = ["lm_head"]
3031

31-
CUDA_DEVICES = [
32+
DEVICES = ([
3233
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
33-
]
34+
] if current_platform.is_cuda_alike() else ["cpu"])
3435

3536

3637
def test_peft_helper(sql_lora_files):
@@ -83,7 +84,7 @@ def test_peft_helper(sql_lora_files):
8384
PEFTHelper.from_dict(config)
8485

8586

86-
@pytest.mark.parametrize("device", CUDA_DEVICES)
87+
@pytest.mark.parametrize("device", DEVICES)
8788
def test_from_lora_tensors(sql_lora_files, device):
8889
tensors = load_file(
8990
os.path.join(sql_lora_files, "adapter_model.safetensors"))
@@ -171,7 +172,7 @@ def test_replace_submodules(dist_init, dummy_model):
171172
manager = LoRAModelManager(
172173
model, 1, 1, 1,
173174
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
174-
torch.device("cuda"))
175+
torch.device(DEVICES[0]))
175176
model = manager.model
176177

177178
assert isinstance(model.get_submodule("dense1"),
@@ -183,7 +184,7 @@ def test_replace_submodules(dist_init, dummy_model):
183184
RowParallelLinearWithLoRA)
184185

185186

186-
@pytest.mark.parametrize("device", CUDA_DEVICES)
187+
@pytest.mark.parametrize("device", DEVICES)
187188
def test_lora_model_manager(dist_init, dummy_model, device):
188189
model = dummy_model
189190
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
@@ -244,7 +245,7 @@ def test_lora_model_manager(dist_init, dummy_model, device):
244245
assert manager.punica_wrapper.device == device
245246

246247

247-
@pytest.mark.parametrize("device", CUDA_DEVICES)
248+
@pytest.mark.parametrize("device", DEVICES)
248249
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
249250
model = dummy_model
250251
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
@@ -336,7 +337,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
336337
assert manager.device == device
337338

338339

339-
@pytest.mark.parametrize("device", CUDA_DEVICES)
340+
@pytest.mark.parametrize("device", DEVICES)
340341
def test_lru_lora_model_manager(dist_init, dummy_model, device):
341342
# This tests just the LRU cache functionality, everything else is
342343
# tested in test_lora_model_manager
@@ -466,7 +467,7 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
466467
assert manager.device == device
467468

468469

469-
@pytest.mark.parametrize("device", CUDA_DEVICES)
470+
@pytest.mark.parametrize("device", DEVICES)
470471
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
471472
sql_lora_files, device):
472473
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
@@ -545,7 +546,7 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
545546
device)
546547

547548

548-
@pytest.mark.parametrize("device", CUDA_DEVICES)
549+
@pytest.mark.parametrize("device", DEVICES)
549550
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
550551
sql_lora_files, device):
551552
# Should remove every LoRA not specified in the request.
@@ -621,7 +622,7 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
621622
device)
622623

623624

624-
@pytest.mark.parametrize("device", CUDA_DEVICES)
625+
@pytest.mark.parametrize("device", DEVICES)
625626
def test_packed_loras(dist_init, dummy_model_gate_up, device):
626627
model = dummy_model_gate_up
627628
model.supported_lora_modules = ["gate_up_proj"]

tests/lora/test_mixtral.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import vllm
77
from vllm.lora.request import LoRARequest
8+
from vllm.platforms import current_platform
89

910
MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1"
1011

@@ -31,7 +32,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int,
3132
@pytest.mark.parametrize("tp_size", [4])
3233
def test_mixtral_lora(mixtral_lora_files, tp_size):
3334
"""Original test, the LoRA model has the common target modules, not all"""
34-
if torch.cuda.device_count() < tp_size:
35+
if torch.cuda.device_count(
36+
) < tp_size and tp_size > 1 and current_platform.is_cuda_alike():
3537
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
3638

3739
prompts = [

0 commit comments

Comments
 (0)