Skip to content

Commit 8af890a

Browse files
jeejeeleeYard1
andauthored
Enable more models to inference based on LoRA (#3382)
Co-authored-by: Antoni Baum <[email protected]>
1 parent dfeb2ec commit 8af890a

File tree

10 files changed

+402
-45
lines changed

10 files changed

+402
-45
lines changed

csrc/punica/bgmv/bgmv_config.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,26 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
1616
f(in_T, out_T, W_T, narrow, 512) \
1717
f(in_T, out_T, W_T, narrow, 768) \
1818
f(in_T, out_T, W_T, narrow, 1024) \
19+
f(in_T, out_T, W_T, narrow, 1152) \
1920
f(in_T, out_T, W_T, narrow, 1280) \
21+
f(in_T, out_T, W_T, narrow, 1536) \
2022
f(in_T, out_T, W_T, narrow, 1728) \
2123
f(in_T, out_T, W_T, narrow, 1792) \
2224
f(in_T, out_T, W_T, narrow, 2048) \
25+
f(in_T, out_T, W_T, narrow, 2304) \
2326
f(in_T, out_T, W_T, narrow, 2560) \
2427
f(in_T, out_T, W_T, narrow, 2752) \
2528
f(in_T, out_T, W_T, narrow, 2816) \
2629
f(in_T, out_T, W_T, narrow, 3072) \
2730
f(in_T, out_T, W_T, narrow, 3456) \
2831
f(in_T, out_T, W_T, narrow, 3584) \
2932
f(in_T, out_T, W_T, narrow, 4096) \
33+
f(in_T, out_T, W_T, narrow, 4608) \
3034
f(in_T, out_T, W_T, narrow, 5120) \
3135
f(in_T, out_T, W_T, narrow, 5504) \
3236
f(in_T, out_T, W_T, narrow, 5632) \
3337
f(in_T, out_T, W_T, narrow, 6144) \
38+
f(in_T, out_T, W_T, narrow, 6848) \
3439
f(in_T, out_T, W_T, narrow, 6912) \
3540
f(in_T, out_T, W_T, narrow, 7168) \
3641
f(in_T, out_T, W_T, narrow, 8192) \
@@ -45,6 +50,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
4550
f(in_T, out_T, W_T, narrow, 20480) \
4651
f(in_T, out_T, W_T, narrow, 22016) \
4752
f(in_T, out_T, W_T, narrow, 24576) \
53+
f(in_T, out_T, W_T, narrow, 27392) \
4854
f(in_T, out_T, W_T, narrow, 28672) \
4955
f(in_T, out_T, W_T, narrow, 32000) \
5056
f(in_T, out_T, W_T, narrow, 32256) \

tests/lora/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,16 @@ def gemma_lora_files():
134134
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
135135

136136

137+
@pytest.fixture(scope="session")
138+
def chatglm3_lora_files():
139+
return snapshot_download(repo_id="jeeejeee/chatglm3-text2sql-spider")
140+
141+
142+
@pytest.fixture(scope="session")
143+
def baichuan_lora_files():
144+
return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")
145+
146+
137147
@pytest.fixture
138148
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
139149
cleanup()

tests/lora/test_baichuan.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import pytest
2+
3+
import vllm
4+
from vllm.lora.request import LoRARequest
5+
6+
from .conftest import cleanup
7+
8+
MODEL_PATH = "baichuan-inc/Baichuan-7B"
9+
10+
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
11+
12+
13+
def do_sample(llm, lora_path: str, lora_id: int) -> str:
14+
prompts = [
15+
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
16+
PROMPT_TEMPLATE.format(
17+
query=
18+
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
19+
),
20+
PROMPT_TEMPLATE.format(
21+
query=
22+
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
23+
),
24+
]
25+
print(prompts)
26+
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
27+
outputs = llm.generate(
28+
prompts,
29+
sampling_params,
30+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
31+
if lora_id else None)
32+
# Print the outputs.
33+
generated_texts = []
34+
for output in outputs:
35+
prompt = output.prompt
36+
generated_text = output.outputs[0].text.strip()
37+
generated_texts.append(generated_text)
38+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
39+
return generated_texts
40+
41+
42+
def test_baichuan_lora(baichuan_lora_files):
43+
llm = vllm.LLM(MODEL_PATH,
44+
max_model_len=1024,
45+
enable_lora=True,
46+
max_loras=4,
47+
max_lora_rank=64,
48+
trust_remote_code=True)
49+
50+
expected_lora_output = [
51+
"SELECT count(*) FROM singer",
52+
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE Country = 'France'", # noqa: E501
53+
"SELECT name , country , age FROM singer ORDER BY age ASC",
54+
]
55+
56+
output1 = do_sample(llm, baichuan_lora_files, lora_id=1)
57+
for i in range(len(expected_lora_output)):
58+
assert output1[i] == expected_lora_output[i]
59+
output2 = do_sample(llm, baichuan_lora_files, lora_id=2)
60+
for i in range(len(expected_lora_output)):
61+
assert output2[i] == expected_lora_output[i]
62+
63+
64+
@pytest.mark.skip("Requires multiple GPUs")
65+
def test_llama_tensor_parallel_equality(baichuan_lora_files):
66+
# Cannot use as it will initialize torch.cuda too early...
67+
# if torch.cuda.device_count() < 4:
68+
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
69+
70+
llm_tp1 = vllm.LLM(MODEL_PATH,
71+
enable_lora=True,
72+
max_num_seqs=16,
73+
max_loras=4,
74+
max_lora_rank=64,
75+
tensor_parallel_size=1,
76+
trust_remote_code=True)
77+
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)
78+
79+
del llm_tp1
80+
cleanup()
81+
82+
llm_tp2 = vllm.LLM(MODEL_PATH,
83+
enable_lora=True,
84+
max_num_seqs=16,
85+
max_loras=4,
86+
max_lora_rank=64,
87+
tensor_parallel_size=2,
88+
trust_remote_code=True)
89+
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)
90+
91+
del llm_tp2
92+
cleanup()
93+
94+
assert output_tp1 == output_tp2
95+
96+
llm_tp4 = vllm.LLM(MODEL_PATH,
97+
enable_lora=True,
98+
max_num_seqs=16,
99+
max_loras=4,
100+
max_lora_rank=64,
101+
tensor_parallel_size=4,
102+
trust_remote_code=True)
103+
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)
104+
105+
del llm_tp4
106+
cleanup()
107+
108+
assert output_tp1 == output_tp4

tests/lora/test_chatglm3.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import vllm
2+
from vllm.lora.request import LoRARequest
3+
4+
MODEL_PATH = "THUDM/chatglm3-6b"
5+
6+
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
7+
8+
9+
def do_sample(llm, lora_path: str, lora_id: int) -> str:
10+
prompts = [
11+
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
12+
PROMPT_TEMPLATE.format(
13+
query=
14+
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
15+
),
16+
PROMPT_TEMPLATE.format(
17+
query=
18+
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
19+
),
20+
]
21+
print(prompts)
22+
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
23+
outputs = llm.generate(
24+
prompts,
25+
sampling_params,
26+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
27+
if lora_id else None)
28+
# Print the outputs.
29+
generated_texts = []
30+
for output in outputs:
31+
prompt = output.prompt
32+
generated_text = output.outputs[0].text.strip()
33+
generated_texts.append(generated_text)
34+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
35+
return generated_texts
36+
37+
38+
def test_chatglm3_lora(chatglm3_lora_files):
39+
llm = vllm.LLM(MODEL_PATH,
40+
max_model_len=1024,
41+
enable_lora=True,
42+
max_loras=4,
43+
max_lora_rank=64,
44+
trust_remote_code=True)
45+
46+
expected_lora_output = [
47+
"SELECT count(*) FROM singer",
48+
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501
49+
"SELECT name , country , age FROM singer ORDER BY age",
50+
]
51+
52+
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
53+
for i in range(len(expected_lora_output)):
54+
assert output1[i] == expected_lora_output[i]
55+
output2 = do_sample(llm, chatglm3_lora_files, lora_id=2)
56+
for i in range(len(expected_lora_output)):
57+
assert output2[i] == expected_lora_output[i]

tests/lora/test_layers.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@
88
import torch.nn.functional as F
99

1010
from vllm.config import LoRAConfig
11+
# yapf conflicts with isort for this block
12+
# yapf: disable
1113
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
1214
LogitsProcessorWithLoRA, LoRAMapping,
1315
MergedColumnParallelLinearWithLoRA,
16+
MergedQKVParallelLinearWithLora,
1417
QKVParallelLinearWithLora,
1518
RowParallelLinearWithLoRA,
1619
VocabParallelEmbeddingWithLoRA)
20+
# yapf: enable
1721
from vllm.lora.models import (LoRALayerWeights, PackedLoRALayerWeights,
1822
convert_mapping)
1923
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -93,8 +97,7 @@ def populate_loras(
9397
lora_dict: Dict[int, LoRALayerWeights] = dict()
9498

9599
# Dictionary that maps the lora ID to the
96-
# corresponding subloras. Only useful when
97-
# repeats > 1.
100+
# corresponding subloras.
98101
sublora_dict: Dict[int, List[LoRALayerWeights]] = dict()
99102

100103
for slot_idx, lora_id in enumerate(id_to_index):
@@ -607,7 +610,7 @@ def create_random_linear_parallel_layer():
607610

608611
@torch.inference_mode()
609612
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
610-
@pytest.mark.parametrize("repeats", [2, 3])
613+
@pytest.mark.parametrize("repeats", [1, 2, 3])
611614
@pytest.mark.parametrize("device", CUDA_DEVICES)
612615
def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
613616

@@ -623,6 +626,10 @@ def create_column_parallel_packed_layer():
623626
bias=False)
624627
linear.weight.data = torch.rand_like(linear.weight.data)
625628
lora_linear = MergedColumnParallelLinearWithLoRA(linear)
629+
elif repeats == 3:
630+
linear = QKVParallelLinear(4096, 64, 32, bias=False)
631+
linear.weight.data = torch.rand_like(linear.weight.data)
632+
lora_linear = MergedQKVParallelLinearWithLora(linear)
626633
else:
627634
linear = QKVParallelLinear(4096, 64, 32, bias=False)
628635
linear.weight.data = torch.rand_like(linear.weight.data)

tests/lora/test_punica.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,10 @@ def _lora_ref_impl(
4343

4444

4545
H1 = H2 = [
46-
128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120,
47-
5504, 5632, 6144, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336,
48-
22016, 24576, 32000, 32256, 32512, 32768, 33024
46+
128, 256, 512, 1024, 1152, 1280, 1536, 2048, 2304, 2560, 2752, 3072, 3456,
47+
3584, 4096, 4608, 5120, 5504, 5632, 6144, 6848, 6912, 7168, 8192, 9216,
48+
10240, 11008, 13824, 14336, 22016, 24576, 27392, 32000, 32256, 32512,
49+
32768, 33024
4950
]
5051
SEED = [0xabcdabcd987]
5152

0 commit comments

Comments
 (0)