Skip to content

Commit 467a96a

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
and
Varun Sundar Rabindranath
authored
[V1] LoRA Support (#10957)
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]>
1 parent 8108ac8 commit 467a96a

16 files changed

+453
-56
lines changed

tests/lora/conftest.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,3 +306,20 @@ def get_model_patched(**kwargs):
306306
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
307307
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
308308
model_runner.model)
309+
310+
311+
@pytest.fixture(params=[True, False])
312+
def run_with_both_engines_lora(request, monkeypatch):
313+
# Automatically runs tests twice, once with V1 and once without
314+
use_v1 = request.param
315+
# Tests decorated with `@skip_v1` are only run without v1
316+
skip_v1 = request.node.get_closest_marker("skip_v1")
317+
318+
if use_v1:
319+
if skip_v1:
320+
pytest.skip("Skipping test on vllm V1")
321+
monkeypatch.setenv('VLLM_USE_V1', '1')
322+
else:
323+
monkeypatch.setenv('VLLM_USE_V1', '0')
324+
325+
yield

tests/lora/test_baichuan.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
4242
return generated_texts
4343

4444

45+
@pytest.fixture(autouse=True)
46+
def v1(run_with_both_engines_lora):
47+
# Simple autouse wrapper to run both engines for each test
48+
# This can be promoted up to conftest.py to run for every
49+
# test in a package
50+
pass
51+
52+
4553
def test_baichuan_lora(baichuan_lora_files):
4654
llm = vllm.LLM(MODEL_PATH,
4755
max_model_len=1024,

tests/lora/test_chatglm3_tp.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import List
44

5+
import pytest
6+
57
import vllm
68
from tests.utils import fork_new_process_for_each_test
79
from vllm.lora.request import LoRARequest
@@ -47,6 +49,15 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
4749
return generated_texts
4850

4951

52+
@pytest.fixture(autouse=True)
53+
def v1(run_with_both_engines_lora):
54+
# Simple autouse wrapper to run both engines for each test
55+
# This can be promoted up to conftest.py to run for every
56+
# test in a package
57+
pass
58+
59+
60+
@pytest.mark.skip_v1
5061
@fork_new_process_for_each_test
5162
def test_chatglm3_lora(chatglm3_lora_files):
5263
llm = vllm.LLM(MODEL_PATH,
@@ -66,6 +77,7 @@ def test_chatglm3_lora(chatglm3_lora_files):
6677
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
6778

6879

80+
@pytest.mark.skip_v1
6981
@multi_gpu_test(num_gpus=4)
7082
@fork_new_process_for_each_test
7183
def test_chatglm3_lora_tp4(chatglm3_lora_files):
@@ -87,6 +99,7 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files):
8799
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
88100

89101

102+
@pytest.mark.skip_v1
90103
@multi_gpu_test(num_gpus=4)
91104
@fork_new_process_for_each_test
92105
def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):

tests/lora/test_gemma.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
3333
return generated_texts
3434

3535

36+
@pytest.fixture(autouse=True)
37+
def v1(run_with_both_engines_lora):
38+
# Simple autouse wrapper to run both engines for each test
39+
# This can be promoted up to conftest.py to run for every
40+
# test in a package
41+
pass
42+
43+
3644
@pytest.mark.xfail(current_platform.is_rocm(),
3745
reason="There can be output mismatch on ROCm")
3846
def test_gemma_lora(gemma_lora_files):

tests/lora/test_llama_tp.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import List
44

5+
import pytest
56
import ray
67

78
import vllm
@@ -73,6 +74,14 @@ def generate_and_test(llm, sql_lora_files):
7374
print("removing lora")
7475

7576

77+
@pytest.fixture(autouse=True)
78+
def v1(run_with_both_engines_lora):
79+
# Simple autouse wrapper to run both engines for each test
80+
# This can be promoted up to conftest.py to run for every
81+
# test in a package
82+
pass
83+
84+
7685
@fork_new_process_for_each_test
7786
def test_llama_lora(sql_lora_files):
7887

@@ -85,6 +94,9 @@ def test_llama_lora(sql_lora_files):
8594
generate_and_test(llm, sql_lora_files)
8695

8796

97+
# Skipping for v1 as v1 doesn't have a good way to expose the num_gpu_blocks
98+
# used by the engine yet.
99+
@pytest.mark.skip_v1
88100
@fork_new_process_for_each_test
89101
def test_llama_lora_warmup(sql_lora_files):
90102
"""Test that the LLM initialization works with a warmup LORA path and

tests/lora/test_lora_bias_e2e.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
3030
return generated_texts
3131

3232

33+
@pytest.fixture(autouse=True)
34+
def v1(run_with_both_engines_lora):
35+
# Simple autouse wrapper to run both engines for each test
36+
# This can be promoted up to conftest.py to run for every
37+
# test in a package
38+
pass
39+
40+
41+
# Skipping for V1 for now as we are hitting,
42+
# "Head size 80 is not supported by FlashAttention." error.
43+
@pytest.mark.skip_v1
3344
@pytest.mark.parametrize("lora_bias", [True])
3445
@pytest.mark.parametrize("fully_sharded", [True, False])
3546
def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool):

tests/lora/test_phi.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import List
44

5+
import pytest
6+
57
import vllm
68
from vllm.lora.request import LoRARequest
79

@@ -48,6 +50,17 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
4850
return generated_texts
4951

5052

53+
@pytest.fixture(autouse=True)
54+
def v1(run_with_both_engines_lora):
55+
# Simple autouse wrapper to run both engines for each test
56+
# This can be promoted up to conftest.py to run for every
57+
# test in a package
58+
pass
59+
60+
61+
# Skipping for V1 for now as we are hitting,
62+
# "Head size 80 is not supported by FlashAttention." error.
63+
@pytest.mark.skip_v1
5164
def test_phi2_lora(phi2_lora_files):
5265
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
5366
# Otherwise, the lora-test will fail due to CUDA OOM.

tests/lora/test_quant_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ def format_prompt_tuples(prompt):
7070
return generated_texts
7171

7272

73+
@pytest.fixture(autouse=True)
74+
def v1(run_with_both_engines_lora):
75+
# Simple autouse wrapper to run both engines for each test
76+
# This can be promoted up to conftest.py to run for every
77+
# test in a package
78+
pass
79+
80+
7381
@pytest.mark.parametrize("model", MODELS)
7482
@pytest.mark.parametrize("tp_size", [1])
7583
def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,

tests/v1/core/test_kv_cache_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def test_generate_block_hash_extra_keys():
163163

164164
# Test with no overlap
165165
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 6, 10, 0)
166-
assert extra_keys == ()
166+
assert extra_keys is None
167167
assert next_mm_idx == 1
168168

169169
# Test with multiple extra keys

vllm/lora/layers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
get_tensor_model_parallel_world_size,
1717
split_tensor_along_last_dim,
1818
tensor_model_parallel_all_gather,
19-
tensor_model_parallel_all_reduce,
20-
tensor_model_parallel_gather)
19+
tensor_model_parallel_all_reduce)
2120
from vllm.distributed.utils import divide
2221
# yapf: disable
2322
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -1043,7 +1042,10 @@ def _get_logits(
10431042
logits = lm_head.linear_method.apply(lm_head, hidden_states)
10441043
if embedding_bias is not None:
10451044
logits += embedding_bias
1046-
logits = tensor_model_parallel_gather(logits)
1045+
1046+
# Gather logits for TP
1047+
logits = self.base_layer._gather_logits(logits)
1048+
10471049
if logits is None:
10481050
return None
10491051

vllm/model_executor/layers/logits_processor.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def __init__(self,
5151
# Soft cap the logits. Used in Gemma 2.
5252
self.soft_cap = soft_cap
5353
# Whether to use gather or all-gather to gather the logits.
54-
5554
parallel_config = get_current_vllm_config().parallel_config
5655
self.use_all_gather = current_platform.is_tpu() \
5756
or envs.VLLM_USE_V1 \
@@ -88,6 +87,20 @@ def forward(
8887

8988
return logits
9089

90+
def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
91+
"""gather/all-gather the logits tensor across model parallel group."""
92+
if self.use_all_gather:
93+
# Gather is not supported for some devices such as TPUs.
94+
# Use all-gather instead.
95+
# NOTE(woosuk): Here, the outputs of every device should not be None
96+
# because XLA requires strict SPMD among all devices. Every device
97+
# should execute the same operations after gathering the logits.
98+
logits = tensor_model_parallel_all_gather(logits)
99+
else:
100+
# None may be returned for rank > 0
101+
logits = tensor_model_parallel_gather(logits)
102+
return logits
103+
91104
def _get_logits(
92105
self,
93106
hidden_states: torch.Tensor,
@@ -99,16 +112,9 @@ def _get_logits(
99112
hidden_states,
100113
bias=embedding_bias)
101114

102-
if self.use_all_gather:
103-
# Gather is not supported for some devices such as TPUs.
104-
# Use all-gather instead.
105-
# NOTE(woosuk): Here, the outputs of every device should not be None
106-
# because XLA requires strict SPMD among all devices. Every device
107-
# should execute the same operations after gathering the logits.
108-
logits = tensor_model_parallel_all_gather(logits)
109-
else:
110-
# None may be returned for rank > 0
111-
logits = tensor_model_parallel_gather(logits)
115+
# Gather logits for TP
116+
logits = self._gather_logits(logits)
117+
112118
# Remove paddings in vocab (if any).
113119
if logits is not None:
114120
logits = logits[..., :self.org_vocab_size]

0 commit comments

Comments
 (0)