Skip to content

Commit 5510852

Browse files
sshlyapnsumitd2
authored andcommitted
[OpenVINO] Enable GPU support for OpenVINO vLLM backend (vllm-project#8192)
Signed-off-by: Sumit Dubey <[email protected]>
1 parent c2ca2cb commit 5510852

File tree

8 files changed

+446
-107
lines changed

8 files changed

+446
-107
lines changed

docs/source/getting_started/openvino-installation.rst

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Installation with OpenVINO
44
==========================
55

6-
vLLM powered by OpenVINO supports all LLM models from :doc:`vLLM supported models list <../models/supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support. OpenVINO vLLM backend supports the following advanced vLLM features:
6+
vLLM powered by OpenVINO supports all LLM models from :doc:`vLLM supported models list <../models/supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support, as well as on both integrated and discrete Intel® GPUs (`the list of supported GPUs <https://docs.openvino.ai/2024/about-openvino/release-notes-openvino/system-requirements.html#gpu>`_). OpenVINO vLLM backend supports the following advanced vLLM features:
77

88
- Prefix caching (``--enable-prefix-caching``)
99
- Chunked prefill (``--enable-chunked-prefill``)
@@ -53,34 +53,57 @@ Install from source
5353
$ pip install --upgrade pip
5454
$ pip install -r requirements-build.txt --extra-index-url https://download.pytorch.org/whl/cpu
5555
56-
- Finally, install vLLM with OpenVINO backend:
56+
- Finally, install vLLM with OpenVINO backend:
5757

5858
.. code-block:: console
5959
6060
$ PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE=openvino python -m pip install -v .
6161
62+
- [Optional] To use vLLM OpenVINO backend with a GPU device, ensure your system is properly set up. Follow the instructions provided here: `https://docs.openvino.ai/2024/get-started/configurations/configurations-intel-gpu.html <https://docs.openvino.ai/2024/get-started/configurations/configurations-intel-gpu.html>`_.
63+
6264
.. _openvino_backend_performance_tips:
6365

6466
Performance tips
6567
----------------
6668

67-
vLLM OpenVINO backend uses the following environment variables to control behavior:
69+
vLLM OpenVINO backend environment variables
70+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
71+
72+
- ``VLLM_OPENVINO_DEVICE`` to specify which device utilize for the inference. If there are multiple GPUs in the system, additional indexes can be used to choose the proper one (e.g, ``VLLM_OPENVINO_DEVICE=GPU.1``). If the value is not specified, CPU device is used by default.
73+
74+
- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off. You can also export model with different compression techniques using `optimum-cli` and pass exported folder as `<model_id>`
75+
76+
CPU performance tips
77+
~~~~~~~~~~~~~~~~~~~~
78+
79+
CPU uses the following environment variables to control behavior:
6880

6981
- ``VLLM_OPENVINO_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_OPENVINO_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
7082

7183
- ``VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8`` to control KV cache precision. By default, FP16 / BF16 is used depending on platform.
7284

73-
- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off. You can also export model with different compression techniques using `optimum-cli` and pass exported folder as `<model_id>`
74-
7585
To enable better TPOT / TTFT latency, you can use vLLM's chunked prefill feature (``--enable-chunked-prefill``). Based on the experiments, the recommended batch size is ``256`` (``--max-num-batched-tokens``)
7686

77-
OpenVINO best known configuration is:
87+
OpenVINO best known configuration for CPU is:
7888

7989
.. code-block:: console
8090
8191
$ VLLM_OPENVINO_KVCACHE_SPACE=100 VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8 VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \
8292
python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --enable-chunked-prefill --max-num-batched-tokens 256
8393
94+
GPU performance tips
95+
~~~~~~~~~~~~~~~~~~~~
96+
GPU device implements the logic for automatic detection of available GPU memory and, by default, tries to reserve as much memory as possible for the KV cache (taking into account ``gpu_memory_utilization`` option). However, this behavior can be overridden by explicitly specifying the desired amount of memory for the KV cache using ``VLLM_OPENVINO_KVCACHE_SPACE`` environment variable (e.g, ``VLLM_OPENVINO_KVCACHE_SPACE=8`` means 8 GB space for KV cache).
97+
98+
Currently, the best performance using GPU can be achieved with the default vLLM execution parameters for models with quantized weights (8 and 4-bit integer data types are supported) and `preemption-mode=swap`.
99+
100+
OpenVINO best known configuration for GPU is:
101+
102+
.. code-block:: console
103+
104+
$ VLLM_OPENVINO_DEVICE=GPU VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \
105+
python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json
106+
84107
.. _openvino_backend_limitations:
85108

86109
Limitations

requirements-openvino.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33

44
# OpenVINO dependencies
55
torch >= 2.1.2
6-
openvino ~= 2024.3.0
7-
optimum-intel[openvino] >= 1.18.2
6+
openvino ~= 2024.4.0
7+
openvino-tokenizers[transformers] ~= 2024.4.0
8+
optimum-intel[openvino] >= 1.19.0

vllm/attention/backends/openvino.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,31 @@
99
from vllm.attention.backends.utils import CommonAttentionState
1010

1111

12+
def copy_cache_block(src_tensor: ov.Tensor, dst_tensor: ov.Tensor,
13+
src_offset: int, dst_offset: int) -> None:
14+
15+
def create_roi_tensor(
16+
tensor: ov.Tensor,
17+
block_number: int,
18+
) -> ov.Tensor:
19+
roi_begin = ov.runtime.Coordinate([0, 0, 0, 0])
20+
roi_end = ov.runtime.Coordinate(tensor.get_shape())
21+
22+
roi_begin[0] = block_number
23+
roi_end[0] = block_number + 1
24+
25+
if isinstance(tensor, ov.Tensor):
26+
return ov.Tensor(tensor, roi_begin, roi_end)
27+
else:
28+
return ov.RemoteTensor(tensor, roi_begin, roi_end)
29+
30+
src_roi_tensor = \
31+
create_roi_tensor(src_tensor, src_offset)
32+
dst_roi_tensor = \
33+
create_roi_tensor(dst_tensor, dst_offset)
34+
src_roi_tensor.copy_to(dst_roi_tensor)
35+
36+
1237
class OpenVINOAttentionBackend(AttentionBackend):
1338

1439
@staticmethod
@@ -44,13 +69,12 @@ def get_kv_cache_shape(
4469

4570
@staticmethod
4671
def swap_blocks(
47-
src_kv_cache: ov.Tensor,
48-
dst_kv_cache: ov.Tensor,
49-
src_to_dst: torch.Tensor,
72+
src_tensor: ov.Tensor,
73+
dst_tensor: ov.Tensor,
74+
src_to_dists: List[Tuple[int, int]],
5075
) -> None:
51-
# OpenVINO currently supports only CPU, which does not require
52-
# swap of KV cache blocks
53-
raise NotImplementedError
76+
for src, dst in src_to_dists:
77+
copy_cache_block(src_tensor, dst_tensor, src, dst)
5478

5579
@staticmethod
5680
def copy_blocks(
@@ -59,8 +83,8 @@ def copy_blocks(
5983
) -> None:
6084
for src, dst in src_to_dists:
6185
for key_cache, value_cache in kv_caches:
62-
key_cache.data[dst, :] = key_cache.data[src, :]
63-
value_cache.data[dst, :] = value_cache.data[src, :]
86+
copy_cache_block(key_cache, key_cache, src, dst)
87+
copy_cache_block(value_cache, value_cache, src, dst)
6488

6589

6690
@dataclass

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
VLLM_PP_LAYER_PARTITION: Optional[str] = None
3636
VLLM_CPU_KVCACHE_SPACE: int = 0
3737
VLLM_CPU_OMP_THREADS_BIND: str = ""
38+
VLLM_OPENVINO_DEVICE: str = "CPU"
3839
VLLM_OPENVINO_KVCACHE_SPACE: int = 0
3940
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
4041
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
@@ -302,6 +303,11 @@ def get_default_config_root():
302303
"VLLM_CPU_OMP_THREADS_BIND":
303304
lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"),
304305

306+
# OpenVINO device selection
307+
# default is CPU
308+
"VLLM_OPENVINO_DEVICE":
309+
lambda: os.getenv("VLLM_OPENVINO_DEVICE", "CPU").upper(),
310+
305311
# OpenVINO key-value cache space
306312
# default is 4GB
307313
"VLLM_OPENVINO_KVCACHE_SPACE":

vllm/executor/openvino_executor.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,28 @@
1717
logger = init_logger(__name__)
1818

1919

20+
def is_openvino_cpu() -> bool:
21+
return "CPU" in envs.VLLM_OPENVINO_DEVICE
22+
23+
24+
def is_openvino_gpu() -> bool:
25+
return "GPU" in envs.VLLM_OPENVINO_DEVICE
26+
27+
2028
class OpenVINOExecutor(ExecutorBase):
2129

2230
uses_ray: bool = False
2331

2432
def _init_executor(self) -> None:
2533
assert self.device_config.device_type == "openvino"
2634
assert self.lora_config is None, "OpenVINO backend doesn't support LoRA"
35+
assert is_openvino_cpu() or is_openvino_gpu(), \
36+
"OpenVINO backend supports only CPU and GPU devices"
37+
38+
self.ov_core = ov.Core()
2739
self.model_config = _verify_and_get_model_config(self.model_config)
28-
self.cache_config = _verify_and_get_cache_config(self.cache_config)
40+
self.cache_config = _verify_and_get_cache_config(
41+
self.ov_core, self.cache_config)
2942

3043
# Instantiate the worker and load the model to CPU.
3144
self._init_worker()
@@ -40,6 +53,7 @@ def _init_worker(self):
4053
distributed_init_method = get_distributed_init_method(
4154
get_ip(), get_open_port())
4255
self.driver_worker = OpenVINOWorker(
56+
ov_core=self.ov_core,
4357
model_config=self.model_config,
4458
parallel_config=self.parallel_config,
4559
scheduler_config=self.scheduler_config,
@@ -68,10 +82,13 @@ def initialize_cache(self, num_gpu_blocks: int,
6882
# NOTE: We log here to avoid multiple logs when number of workers is
6983
# greater than one. We could log in the engine, but not all executors
7084
# have GPUs.
71-
# NOTE: `cpu block` for OpenVINO backend is located on CPU memory but is
72-
# referred as `gpu block`. Because we want to reuse the existing block
73-
# management procedure.
74-
logger.info("# CPU blocks: %d", num_gpu_blocks)
85+
# NOTE: In case of a CPU device, `cpu block` for OpenVINO backend
86+
# is located on CPU memory but is referred as `gpu block`.
87+
# Because we want to reuse the existing block management procedure.
88+
device_blocks = num_gpu_blocks
89+
swap_blocks = num_cpu_blocks
90+
logger.info("OpenVINO %s: # device blocks: %d; # swap blocks: %d",
91+
envs.VLLM_OPENVINO_DEVICE, device_blocks, swap_blocks)
7592
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
7693

7794
def execute_model(
@@ -143,29 +160,45 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
143160
return config
144161

145162

146-
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
163+
def _verify_and_get_cache_config(ov_core: ov.Core,
164+
config: CacheConfig) -> CacheConfig:
147165
if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8":
148-
logger.info("KV cache type is overried to u8 via "
149-
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
150-
config.cache_dtype = ov.Type.u8
166+
if not is_openvino_cpu():
167+
logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is"
168+
"ignored for GPU, f16 data type will be used.")
169+
config.cache_dtype = ov.Type.f16
170+
else:
171+
logger.info("KV cache type is overridden to u8 via "
172+
"VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.")
173+
config.cache_dtype = ov.Type.u8
151174
else:
152-
core = ov.Core()
153-
inference_precision = core.get_property("CPU",
154-
hints.inference_precision)
155-
if inference_precision == ov.Type.bf16:
156-
config.cache_dtype = ov.Type.bf16
175+
if is_openvino_cpu():
176+
ov_device = envs.VLLM_OPENVINO_DEVICE
177+
inference_precision = ov_core.get_property(
178+
ov_device, hints.inference_precision)
179+
if inference_precision == ov.Type.bf16:
180+
config.cache_dtype = ov.Type.bf16
181+
else:
182+
config.cache_dtype = ov.Type.f16
157183
else:
158184
config.cache_dtype = ov.Type.f16
159185

160-
if config.block_size != 32:
161-
logger.info(
162-
f"OpenVINO optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501
163-
)
164-
config.block_size = 32
186+
if is_openvino_cpu():
187+
if config.block_size != 32:
188+
logger.info(
189+
f"OpenVINO CPU optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501
190+
)
191+
config.block_size = 32
192+
else:
193+
if config.block_size != 16:
194+
logger.info(
195+
f"OpenVINO GPU optimal block size is 16, overriding currently set {config.block_size}" # noqa: G004, E501
196+
)
197+
config.block_size = 16
165198

166199
kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE
167200
if kv_cache_space >= 0:
168-
if kv_cache_space == 0:
201+
if kv_cache_space == 0 and is_openvino_cpu():
169202
config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
170203
logger.warning(
171204
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "

vllm/model_executor/model_loader/openvino.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import vllm.envs as envs
1313
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
1414
from vllm.config import DeviceConfig, ModelConfig
15+
from vllm.executor.openvino_executor import is_openvino_cpu
1516
from vllm.logger import init_logger
1617
from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
1718
_prune_hidden_states)
@@ -51,25 +52,15 @@ def _modify_cache_parameters(model: ov.Model, kv_cache_dtype: ov.Type,
5152
shape = parameter.get_partial_shape()
5253
# use real block size if available, just a placeholder
5354
# to provide the expected rank
54-
x_size = 1
5555
num_blocks = ov.Dimension()
5656
block_size = ov.Dimension()
5757
head_size = ov.Dimension()
58-
# TODO: Negotiate required layout with plugins (CPU is ~OK, GPU is TBD),
59-
# pass more parameters to this function to set more static dimensions
6058
if input_name.startswith("key_cache."):
6159
cpu_shape = [num_blocks, shape[1], block_size, head_size]
62-
gpu_shape = [
63-
num_blocks,
64-
shape[1],
65-
shape[2].get_length() //
66-
x_size if shape[2].is_static else ov.Dimension(),
67-
block_size,
68-
x_size,
69-
]
60+
gpu_shape = [num_blocks, shape[1], shape[2], block_size]
7061
elif input_name.startswith("value_cache."):
7162
cpu_shape = [num_blocks, shape[1], block_size, head_size]
72-
gpu_shape = [num_blocks, shape[1], shape[2], block_size]
63+
gpu_shape = [num_blocks, shape[1], block_size, shape[2]]
7364
else:
7465
continue
7566
parameter.set_partial_shape(
@@ -108,6 +99,7 @@ class OpenVINOCasualLM(nn.Module):
10899

109100
def __init__(
110101
self,
102+
ov_core: ov.Core,
111103
model_config: ModelConfig,
112104
device_config: DeviceConfig,
113105
kv_cache_dtype: ov.Type,
@@ -141,12 +133,12 @@ def __init__(
141133
trust_remote_code=model_config.trust_remote_code,
142134
)
143135

136+
ov_device = envs.VLLM_OPENVINO_DEVICE
144137
paged_attention_transformation(pt_model.model)
145138
_modify_cache_parameters(pt_model.model, kv_cache_dtype,
146-
device_config.device.type == "cpu")
139+
is_openvino_cpu())
147140

148-
core = ov.Core()
149-
ov_compiled = core.compile_model(pt_model.model, "CPU")
141+
ov_compiled = ov_core.compile_model(pt_model.model, ov_device)
150142
self.ov_request = ov_compiled.create_infer_request()
151143

152144
def forward(
@@ -199,11 +191,13 @@ def get_model(
199191
**kwargs,
200192
) -> torch.nn.Module:
201193
lora_config = kwargs.get("lora_config", None)
194+
ov_core = kwargs.get("ov_core")
202195
if lora_config:
203196
raise ValueError(
204197
"OpenVINO modeling does not support LoRA, "
205198
"but LoRA is enabled. Support for this model may "
206199
"be added in the future. If this is important to you, "
207200
"please open an issue on github.")
208201

209-
return OpenVINOCasualLM(model_config, device_config, kv_cache_dtype)
202+
return OpenVINOCasualLM(ov_core, model_config, device_config,
203+
kv_cache_dtype)

vllm/worker/openvino_model_runner.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class OpenVINOModelRunner:
4242

4343
def __init__(
4444
self,
45+
ov_core: ov.Core,
4546
model_config: ModelConfig,
4647
parallel_config: ParallelConfig,
4748
scheduler_config: SchedulerConfig,
@@ -55,6 +56,7 @@ def __init__(
5556
*args,
5657
**kwargs,
5758
):
59+
self.ov_core = ov_core
5860
self.model_config = model_config
5961
self.parallel_config = parallel_config
6062
self.scheduler_config = scheduler_config
@@ -89,11 +91,10 @@ def __init__(
8991
self.model: nn.Module # Set after init_Model
9092

9193
def load_model(self) -> None:
92-
self.model = get_model(
93-
model_config=self.model_config,
94-
device_config=self.device_config,
95-
kv_cache_dtype=self.kv_cache_dtype,
96-
)
94+
self.model = get_model(model_config=self.model_config,
95+
device_config=self.device_config,
96+
kv_cache_dtype=self.kv_cache_dtype,
97+
ov_core=self.ov_core)
9798

9899
def _prepare_model_input(
99100
self,

0 commit comments

Comments
 (0)