Skip to content

Commit cbc4012

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
and
Varun Sundar Rabindranath
authored
[V1] LoRA - Enable Serving Usecase (#12883)
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]>
1 parent f0b2da7 commit cbc4012

File tree

7 files changed

+210
-7
lines changed

7 files changed

+210
-7
lines changed

tests/lora/test_add_lora.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import asyncio
3+
import time
4+
from pathlib import Path
5+
from typing import List
6+
7+
import pytest
8+
from huggingface_hub import snapshot_download
9+
10+
from vllm.engine.arg_utils import AsyncEngineArgs
11+
from vllm.inputs import TextPrompt
12+
from vllm.lora.request import LoRARequest
13+
from vllm.sampling_params import SamplingParams
14+
from vllm.utils import merge_async_iterators
15+
16+
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
17+
LORA_MODULE_DOWNLOAD_PATH = None # Populated by download_and_prepare_lora_module() #noqa
18+
LORA_RANK = 8
19+
DEFAULT_MAX_LORAS = 16 * 3
20+
21+
22+
def download_and_prepare_lora_module():
23+
"""
24+
Request submission is expensive when the LoRA adapters have their own
25+
tokenizers. This is because, for each request with a new LoRA adapter ID,
26+
the front-end loads the tokenizer from disk.
27+
28+
In this test, as we are comparing request processing times, we want to
29+
minimize any extra activity. To this effect, we download the LoRA
30+
adapter and remove all the tokenizer files, so the engine will default
31+
to the base model tokenizer.
32+
"""
33+
global LORA_MODULE_DOWNLOAD_PATH
34+
35+
LORA_MODULE_HF_PATH = "yard1/llama-2-7b-sql-lora-test"
36+
LORA_MODULE_DOWNLOAD_PATH = snapshot_download(repo_id=LORA_MODULE_HF_PATH)
37+
38+
tokenizer_files = [
39+
'added_tokens.json', 'tokenizer_config.json', 'tokenizer.json',
40+
'tokenizer.model'
41+
]
42+
for tokenizer_file in tokenizer_files:
43+
del_path = Path(LORA_MODULE_DOWNLOAD_PATH) / tokenizer_file
44+
del_path.unlink()
45+
46+
47+
@pytest.fixture(autouse=True)
48+
def v1(run_with_both_engines_lora):
49+
# Simple autouse wrapper to run both engines for each test
50+
# This can be promoted up to conftest.py to run for every
51+
# test in a package
52+
pass
53+
54+
55+
def get_lora_requests() -> List[LoRARequest]:
56+
lora_requests: List[LoRARequest] = [
57+
LoRARequest(lora_name=f"{i}",
58+
lora_int_id=i,
59+
lora_path=LORA_MODULE_DOWNLOAD_PATH)
60+
for i in range(1, DEFAULT_MAX_LORAS + 1)
61+
]
62+
return lora_requests
63+
64+
65+
async def requests_processing_time(llm,
66+
lora_requests: List[LoRARequest]) -> float:
67+
68+
sampling_params = SamplingParams(n=1,
69+
temperature=0.0,
70+
top_p=1.0,
71+
ignore_eos=True,
72+
max_tokens=1)
73+
74+
generators = []
75+
start = time.perf_counter()
76+
77+
for lora_request in lora_requests:
78+
lora_int_id = lora_request.lora_int_id
79+
generator = llm.generate(
80+
prompt=TextPrompt(prompt=f"hello {lora_int_id}",
81+
multi_modal_data=None), # type: ignore
82+
sampling_params=sampling_params,
83+
lora_request=lora_request,
84+
request_id=f"test{lora_int_id}")
85+
generators.append(generator)
86+
87+
all_gens = merge_async_iterators(*generators)
88+
async for i, res in all_gens:
89+
pass
90+
91+
end = time.perf_counter()
92+
return end - start
93+
94+
95+
@pytest.mark.asyncio
96+
async def test_add_lora():
97+
"""
98+
The add_lora function is used to pre-load some LoRA adapters into the
99+
engine in anticipation of future requests using these adapters. To test
100+
this functionality, we use the async engine to process some requests - We
101+
do it twice, once with add_lora() pre-loading and once without.
102+
103+
We measure the request processing time in both cases and expect the time
104+
to be lesser in the case with add_lora() calls.
105+
"""
106+
107+
download_and_prepare_lora_module()
108+
109+
lora_requests: List[LoRARequest] = get_lora_requests()
110+
111+
max_loras = len(set([lr.lora_int_id for lr in lora_requests]))
112+
# Create engine in eager-mode. Due to high max_loras, the CI can
113+
# OOM during cuda-graph capture.
114+
engine_args = AsyncEngineArgs(
115+
model=MODEL_PATH,
116+
enable_lora=True,
117+
max_loras=max_loras,
118+
max_lora_rank=LORA_RANK,
119+
max_model_len=128,
120+
gpu_memory_utilization=0.8, #avoid OOM
121+
enforce_eager=True)
122+
123+
# The run_with_both_engines_lora fixture sets up the `VLLM_USE_V1`
124+
# environment variable. reload vllm.enging.async_llm_engine as
125+
# vllm.engine.async_llm_engine.AsyncLLMEgnine changes depending on the
126+
# env var.
127+
import importlib
128+
129+
import vllm.engine.async_llm_engine
130+
importlib.reload(vllm.engine.async_llm_engine)
131+
from vllm.entrypoints.openai.api_server import (
132+
build_async_engine_client_from_engine_args)
133+
134+
# split lora_requests into 3 parts
135+
part_size = len(lora_requests) // 3
136+
dummy_run_requests = lora_requests[:part_size]
137+
warmup_run_requests = lora_requests[part_size:part_size * 2]
138+
cold_run_requests = lora_requests[part_size * 2:]
139+
140+
async with build_async_engine_client_from_engine_args(engine_args) as llm:
141+
142+
# Dummy run - So any 1-time functionality like triton kernel compilation
143+
# is complete here.
144+
await requests_processing_time(llm, dummy_run_requests)
145+
146+
# Run with warmup
147+
for lr in warmup_run_requests:
148+
await llm.add_lora(lr)
149+
# Wait for the add_lora function to complete on the server side.
150+
await asyncio.sleep(30)
151+
time_with_add_lora = await requests_processing_time(
152+
llm, warmup_run_requests)
153+
154+
# Run without any warmup
155+
time_cold_start = await requests_processing_time(
156+
llm, cold_run_requests)
157+
158+
print(f"time hot-start {time_with_add_lora} vs "
159+
f"time cold-start {time_cold_start} ")
160+
161+
assert time_with_add_lora < time_cold_start, (
162+
f"time_with_add_lora={time_with_add_lora}, "
163+
f"time_cold_start={time_cold_start}"
164+
"The engine request processing time with LoRA pre-loading "
165+
"must be less than the version that does on-demand LoRA loading.")

vllm/v1/engine/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,4 @@ class EngineCoreRequestType(enum.Enum):
134134
ABORT = b'\x01'
135135
PROFILE = b'\x02'
136136
RESET_PREFIX_CACHE = b'\x03'
137+
ADD_LORA = b'\x04'

vllm/v1/engine/async_llm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,10 @@ async def stop_profile(self) -> None:
361361
async def reset_prefix_cache(self) -> None:
362362
await self.engine_core.reset_prefix_cache_async()
363363

364+
async def add_lora(self, lora_request: LoRARequest) -> None:
365+
"""Load a new LoRA adapter into the engine for future requests."""
366+
await self.engine_core.add_lora_async(lora_request)
367+
364368
@property
365369
def is_running(self) -> bool:
366370
return True
@@ -376,7 +380,3 @@ def errored(self) -> bool:
376380
@property
377381
def dead_error(self) -> BaseException:
378382
return Exception() # TODO: implement
379-
380-
async def add_lora(self, lora_request: LoRARequest) -> None:
381-
"""Load a new LoRA adapter into the engine for future requests."""
382-
raise NotImplementedError("LoRA not yet supported in V1")

vllm/v1/engine/core.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from vllm.config import VllmConfig
1515
from vllm.logger import init_logger
16+
from vllm.lora.request import LoRARequest
1617
from vllm.transformers_utils.config import (
1718
maybe_register_config_serialize_by_value)
1819
from vllm.utils import get_exception_traceback, zmq_socket_ctx
@@ -146,6 +147,9 @@ def profile(self, is_start: bool = True):
146147
def reset_prefix_cache(self):
147148
self.scheduler.reset_prefix_cache()
148149

150+
def add_lora(self, lora_request: LoRARequest) -> None:
151+
self.model_executor.add_lora(lora_request)
152+
149153

150154
class EngineCoreProc(EngineCore):
151155
"""ZMQ-wrapper for running EngineCore in background process."""
@@ -262,12 +266,15 @@ def _handle_client_request(self, request_type: EngineCoreRequestType,
262266
self.reset_prefix_cache()
263267
elif request_type == EngineCoreRequestType.PROFILE:
264268
self.model_executor.profile(request)
269+
elif request_type == EngineCoreRequestType.ADD_LORA:
270+
self.model_executor.add_lora(request)
265271

266272
def process_input_socket(self, input_path: str):
267273
"""Input socket IO thread."""
268274

269275
# Msgpack serialization decoding.
270276
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
277+
add_lora_decoder = MsgpackDecoder(LoRARequest)
271278
generic_decoder = MsgpackDecoder()
272279

273280
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
@@ -277,9 +284,14 @@ def process_input_socket(self, input_path: str):
277284
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
278285

279286
# Deserialize the request data.
280-
decoder = add_request_decoder if (
281-
request_type
282-
== EngineCoreRequestType.ADD) else generic_decoder
287+
decoder = None
288+
if request_type == EngineCoreRequestType.ADD:
289+
decoder = add_request_decoder
290+
elif request_type == EngineCoreRequestType.ADD_LORA:
291+
decoder = add_lora_decoder
292+
else:
293+
decoder = generic_decoder
294+
283295
request = decoder.decode(data_frame.buffer)
284296

285297
# Push to input queue for core busy loop.

vllm/v1/engine/core_client.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from vllm.config import VllmConfig
1414
from vllm.logger import init_logger
15+
from vllm.lora.request import LoRARequest
1516
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
1617
make_zmq_socket)
1718
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
@@ -77,6 +78,9 @@ def reset_prefix_cache(self) -> None:
7778
def abort_requests(self, request_ids: List[str]) -> None:
7879
raise NotImplementedError
7980

81+
def add_lora(self, lora_request: LoRARequest) -> None:
82+
raise NotImplementedError
83+
8084
async def get_output_async(self) -> EngineCoreOutputs:
8185
raise NotImplementedError
8286

@@ -92,6 +96,9 @@ async def reset_prefix_cache_async(self) -> None:
9296
async def abort_requests_async(self, request_ids: List[str]) -> None:
9397
raise NotImplementedError
9498

99+
async def add_lora_async(self, lora_request: LoRARequest) -> None:
100+
raise NotImplementedError
101+
95102

96103
class InprocClient(EngineCoreClient):
97104
"""
@@ -125,6 +132,9 @@ def profile(self, is_start: bool = True) -> None:
125132
def reset_prefix_cache(self) -> None:
126133
self.engine_core.reset_prefix_cache()
127134

135+
def add_lora(self, lora_request: LoRARequest) -> None:
136+
self.engine_core.add_lora(lora_request)
137+
128138

129139
class MPClient(EngineCoreClient):
130140
"""
@@ -242,6 +252,9 @@ def profile(self, is_start: bool = True) -> None:
242252
def reset_prefix_cache(self) -> None:
243253
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
244254

255+
def add_lora(self, lora_request: LoRARequest) -> None:
256+
self._send_input(EngineCoreRequestType.ADD_LORA, lora_request)
257+
245258

246259
class AsyncMPClient(MPClient):
247260
"""Asyncio-compatible client for multi-proc EngineCore."""
@@ -295,3 +308,6 @@ async def profile_async(self, is_start: bool = True) -> None:
295308

296309
async def reset_prefix_cache_async(self) -> None:
297310
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
311+
312+
async def add_lora_async(self, lora_request: LoRARequest) -> None:
313+
await self._send_input(EngineCoreRequestType.ADD_LORA, lora_request)

vllm/v1/worker/gpu_worker.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
init_distributed_environment,
1616
set_custom_all_reduce)
1717
from vllm.logger import init_logger
18+
from vllm.lora.request import LoRARequest
1819
from vllm.model_executor import set_random_seed
1920
from vllm.platforms import current_platform
2021
from vllm.utils import GiB_bytes
@@ -234,6 +235,9 @@ def profile(self, is_start: bool = True):
234235
else:
235236
self.profiler.stop()
236237

238+
def add_lora(self, lora_request: LoRARequest) -> bool:
239+
return self.model_runner.add_lora(lora_request)
240+
237241
def check_health(self) -> None:
238242
# worker will always be healthy as long as it's running.
239243
return

vllm/v1/worker/lora_model_runner_mixin.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,8 @@ def maybe_profile_with_lora(self, lora_config: LoRAConfig,
127127

128128
# __exit__ code
129129
self.lora_manager.remove_all_adapters()
130+
131+
def add_lora(self, lora_request: LoRARequest) -> bool:
132+
if not self.lora_manager:
133+
raise RuntimeError("LoRA is not enabled.")
134+
return self.lora_manager.add_adapter(lora_request)

0 commit comments

Comments
 (0)