|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +import asyncio |
| 3 | +import time |
| 4 | +from pathlib import Path |
| 5 | +from typing import List |
| 6 | + |
| 7 | +import pytest |
| 8 | +from huggingface_hub import snapshot_download |
| 9 | + |
| 10 | +from vllm.engine.arg_utils import AsyncEngineArgs |
| 11 | +from vllm.inputs import TextPrompt |
| 12 | +from vllm.lora.request import LoRARequest |
| 13 | +from vllm.sampling_params import SamplingParams |
| 14 | +from vllm.utils import merge_async_iterators |
| 15 | + |
| 16 | +MODEL_PATH = "meta-llama/Llama-2-7b-hf" |
| 17 | +LORA_MODULE_DOWNLOAD_PATH = None # Populated by download_and_prepare_lora_module() #noqa |
| 18 | +LORA_RANK = 8 |
| 19 | +DEFAULT_MAX_LORAS = 16 * 3 |
| 20 | + |
| 21 | + |
| 22 | +def download_and_prepare_lora_module(): |
| 23 | + """ |
| 24 | + Request submission is expensive when the LoRA adapters have their own |
| 25 | + tokenizers. This is because, for each request with a new LoRA adapter ID, |
| 26 | + the front-end loads the tokenizer from disk. |
| 27 | +
|
| 28 | + In this test, as we are comparing request processing times, we want to |
| 29 | + minimize any extra activity. To this effect, we download the LoRA |
| 30 | + adapter and remove all the tokenizer files, so the engine will default |
| 31 | + to the base model tokenizer. |
| 32 | + """ |
| 33 | + global LORA_MODULE_DOWNLOAD_PATH |
| 34 | + |
| 35 | + LORA_MODULE_HF_PATH = "yard1/llama-2-7b-sql-lora-test" |
| 36 | + LORA_MODULE_DOWNLOAD_PATH = snapshot_download(repo_id=LORA_MODULE_HF_PATH) |
| 37 | + |
| 38 | + tokenizer_files = [ |
| 39 | + 'added_tokens.json', 'tokenizer_config.json', 'tokenizer.json', |
| 40 | + 'tokenizer.model' |
| 41 | + ] |
| 42 | + for tokenizer_file in tokenizer_files: |
| 43 | + del_path = Path(LORA_MODULE_DOWNLOAD_PATH) / tokenizer_file |
| 44 | + del_path.unlink() |
| 45 | + |
| 46 | + |
| 47 | +@pytest.fixture(autouse=True) |
| 48 | +def v1(run_with_both_engines_lora): |
| 49 | + # Simple autouse wrapper to run both engines for each test |
| 50 | + # This can be promoted up to conftest.py to run for every |
| 51 | + # test in a package |
| 52 | + pass |
| 53 | + |
| 54 | + |
| 55 | +def get_lora_requests() -> List[LoRARequest]: |
| 56 | + lora_requests: List[LoRARequest] = [ |
| 57 | + LoRARequest(lora_name=f"{i}", |
| 58 | + lora_int_id=i, |
| 59 | + lora_path=LORA_MODULE_DOWNLOAD_PATH) |
| 60 | + for i in range(1, DEFAULT_MAX_LORAS + 1) |
| 61 | + ] |
| 62 | + return lora_requests |
| 63 | + |
| 64 | + |
| 65 | +async def requests_processing_time(llm, |
| 66 | + lora_requests: List[LoRARequest]) -> float: |
| 67 | + |
| 68 | + sampling_params = SamplingParams(n=1, |
| 69 | + temperature=0.0, |
| 70 | + top_p=1.0, |
| 71 | + ignore_eos=True, |
| 72 | + max_tokens=1) |
| 73 | + |
| 74 | + generators = [] |
| 75 | + start = time.perf_counter() |
| 76 | + |
| 77 | + for lora_request in lora_requests: |
| 78 | + lora_int_id = lora_request.lora_int_id |
| 79 | + generator = llm.generate( |
| 80 | + prompt=TextPrompt(prompt=f"hello {lora_int_id}", |
| 81 | + multi_modal_data=None), # type: ignore |
| 82 | + sampling_params=sampling_params, |
| 83 | + lora_request=lora_request, |
| 84 | + request_id=f"test{lora_int_id}") |
| 85 | + generators.append(generator) |
| 86 | + |
| 87 | + all_gens = merge_async_iterators(*generators) |
| 88 | + async for i, res in all_gens: |
| 89 | + pass |
| 90 | + |
| 91 | + end = time.perf_counter() |
| 92 | + return end - start |
| 93 | + |
| 94 | + |
| 95 | +@pytest.mark.asyncio |
| 96 | +async def test_add_lora(): |
| 97 | + """ |
| 98 | + The add_lora function is used to pre-load some LoRA adapters into the |
| 99 | + engine in anticipation of future requests using these adapters. To test |
| 100 | + this functionality, we use the async engine to process some requests - We |
| 101 | + do it twice, once with add_lora() pre-loading and once without. |
| 102 | +
|
| 103 | + We measure the request processing time in both cases and expect the time |
| 104 | + to be lesser in the case with add_lora() calls. |
| 105 | + """ |
| 106 | + |
| 107 | + download_and_prepare_lora_module() |
| 108 | + |
| 109 | + lora_requests: List[LoRARequest] = get_lora_requests() |
| 110 | + |
| 111 | + max_loras = len(set([lr.lora_int_id for lr in lora_requests])) |
| 112 | + # Create engine in eager-mode. Due to high max_loras, the CI can |
| 113 | + # OOM during cuda-graph capture. |
| 114 | + engine_args = AsyncEngineArgs( |
| 115 | + model=MODEL_PATH, |
| 116 | + enable_lora=True, |
| 117 | + max_loras=max_loras, |
| 118 | + max_lora_rank=LORA_RANK, |
| 119 | + max_model_len=128, |
| 120 | + gpu_memory_utilization=0.8, #avoid OOM |
| 121 | + enforce_eager=True) |
| 122 | + |
| 123 | + # The run_with_both_engines_lora fixture sets up the `VLLM_USE_V1` |
| 124 | + # environment variable. reload vllm.enging.async_llm_engine as |
| 125 | + # vllm.engine.async_llm_engine.AsyncLLMEgnine changes depending on the |
| 126 | + # env var. |
| 127 | + import importlib |
| 128 | + |
| 129 | + import vllm.engine.async_llm_engine |
| 130 | + importlib.reload(vllm.engine.async_llm_engine) |
| 131 | + from vllm.entrypoints.openai.api_server import ( |
| 132 | + build_async_engine_client_from_engine_args) |
| 133 | + |
| 134 | + # split lora_requests into 3 parts |
| 135 | + part_size = len(lora_requests) // 3 |
| 136 | + dummy_run_requests = lora_requests[:part_size] |
| 137 | + warmup_run_requests = lora_requests[part_size:part_size * 2] |
| 138 | + cold_run_requests = lora_requests[part_size * 2:] |
| 139 | + |
| 140 | + async with build_async_engine_client_from_engine_args(engine_args) as llm: |
| 141 | + |
| 142 | + # Dummy run - So any 1-time functionality like triton kernel compilation |
| 143 | + # is complete here. |
| 144 | + await requests_processing_time(llm, dummy_run_requests) |
| 145 | + |
| 146 | + # Run with warmup |
| 147 | + for lr in warmup_run_requests: |
| 148 | + await llm.add_lora(lr) |
| 149 | + # Wait for the add_lora function to complete on the server side. |
| 150 | + await asyncio.sleep(30) |
| 151 | + time_with_add_lora = await requests_processing_time( |
| 152 | + llm, warmup_run_requests) |
| 153 | + |
| 154 | + # Run without any warmup |
| 155 | + time_cold_start = await requests_processing_time( |
| 156 | + llm, cold_run_requests) |
| 157 | + |
| 158 | + print(f"time hot-start {time_with_add_lora} vs " |
| 159 | + f"time cold-start {time_cold_start} ") |
| 160 | + |
| 161 | + assert time_with_add_lora < time_cold_start, ( |
| 162 | + f"time_with_add_lora={time_with_add_lora}, " |
| 163 | + f"time_cold_start={time_cold_start}" |
| 164 | + "The engine request processing time with LoRA pre-loading " |
| 165 | + "must be less than the version that does on-demand LoRA loading.") |
0 commit comments