Skip to content

Commit 0d2f3dd

Browse files
committed
Merge remote-tracking branch 'origin/main' into coalesce-stream
2 parents a4f49d5 + 99d01a5 commit 0d2f3dd

File tree

23 files changed

+524
-218
lines changed

23 files changed

+524
-218
lines changed

CMakeLists.txt

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
2424
# Suppress potential warnings about unused manually-specified variables
2525
set(ignoreMe "${VLLM_PYTHON_PATH}")
2626

27-
# Prevent installation of dependencies (cutlass) by default.
28-
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
29-
3027
#
3128
# Supported python versions. These versions will be searched in order, the
3229
# first match will be selected. These should be kept in sync with setup.py.
@@ -535,7 +532,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
535532
endif()
536533

537534
# vllm-flash-attn currently only supported on CUDA
538-
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda")
535+
if (NOT VLLM_GPU_LANG STREQUAL "CUDA")
539536
return()
540537
endif ()
541538

@@ -558,7 +555,7 @@ endif()
558555
# They should be identical but if they aren't, this is a massive footgun.
559556
#
560557
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
561-
# To only install vllm-flash-attn, use --component vllm_flash_attn_c.
558+
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
562559
# If no component is specified, vllm-flash-attn is still installed.
563560

564561
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
@@ -570,43 +567,41 @@ if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
570567
endif()
571568

572569
if(VLLM_FLASH_ATTN_SRC_DIR)
573-
FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR})
570+
FetchContent_Declare(
571+
vllm-flash-attn SOURCE_DIR
572+
${VLLM_FLASH_ATTN_SRC_DIR}
573+
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
574+
)
574575
else()
575576
FetchContent_Declare(
576577
vllm-flash-attn
577578
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
578-
GIT_TAG 96266b1111111f3d11aabefaf3bacbab6a89d03c
579+
GIT_TAG 90eacc1af2a7c3de62ea249e929ed5faccf38954
579580
GIT_PROGRESS TRUE
580581
# Don't share the vllm-flash-attn build between build types
581582
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
582583
)
583584
endif()
584585

585-
# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
586-
set(VLLM_PARENT_BUILD ON)
587-
588-
# Ensure the vllm/vllm_flash_attn directory exists before installation
589-
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c)
590-
591-
# Make sure vllm-flash-attn install rules are nested under vllm/
592-
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c)
593-
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
594-
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c)
595586

596587
# Fetch the vllm-flash-attn library
597588
FetchContent_MakeAvailable(vllm-flash-attn)
598589
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
599590

600-
# Restore the install prefix
601-
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
602-
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c)
591+
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
592+
# case only one is built, in the case both are built redundant work is done)
593+
install(
594+
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
595+
DESTINATION vllm_flash_attn
596+
COMPONENT _vllm_fa2_C
597+
FILES_MATCHING PATTERN "*.py"
598+
)
603599

604-
# Copy over the vllm-flash-attn python files
605600
install(
606-
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
607-
DESTINATION vllm/vllm_flash_attn
608-
COMPONENT vllm_flash_attn_c
609-
FILES_MATCHING PATTERN "*.py"
601+
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
602+
DESTINATION vllm_flash_attn
603+
COMPONENT _vllm_fa3_C
604+
FILES_MATCHING PATTERN "*.py"
610605
)
611606

612607
# Nothing after vllm-flash-attn, see comment about macros above

Dockerfile.rocm

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace
7272
RUN cd /vllm-workspace \
7373
&& rm -rf vllm \
7474
&& python3 -m pip install -e tests/vllm_test_utils \
75-
&& python3 -m pip install lm-eval[api]==0.4.4
75+
&& python3 -m pip install lm-eval[api]==0.4.4 \
76+
&& python3 -m pip install pytest-shard
7677

7778
# -----------------------
7879
# Final vLLM image

docs/source/contributing/vulnerability_management.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,20 @@ You may use the `#security` channel in the [VLLM Slack](https://slack.vllm.ai)
4141
to discuss security-related topics. However, please do not disclose any
4242
vulnerabilities in this channel. If you need to report a vulnerability, please
4343
use the GitHub security advisory system or contact a VMT member privately.
44+
45+
## Vulnerability Disclosure
46+
47+
The process for disclosing vulnerabilities is the following:
48+
49+
- The VMT will work with the project maintainers to develop a fix for the
50+
vulnerability.
51+
- The VMT will coordinate with the reporter and project maintainers to prepare a
52+
security advisory that adequately describes the vulnerability and its impact.
53+
- The VMT will coordinate with the project maintainers to publish a fix and
54+
release an update that includes that fix.
55+
- The VMT will publish the security advisory on GitHub. Release notes will be
56+
updated to include a reference to the security advisory.
57+
58+
The VMT and project maintainers will work to minimize the amount of time in
59+
between disclosing any public information about the vulnerability and making a
60+
release and advisory available.

docs/source/features/quantization/fp8_e4m3_kvcache.md

Lines changed: 0 additions & 44 deletions
This file was deleted.

docs/source/features/quantization/fp8_e5m2_kvcache.md

Lines changed: 0 additions & 31 deletions
This file was deleted.

docs/source/features/quantization/index.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,5 @@ bnb
1414
gguf
1515
int8
1616
fp8
17-
fp8_e5m2_kvcache
18-
fp8_e4m3_kvcache
17+
quantized_kvcache
1918
```
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
(quantized-kvcache)=
2+
3+
# Quantized KV Cache
4+
5+
## FP8 KV Cache
6+
7+
Quantizing the KV cache to FP8 reduces its memory footprint. This increases the number of tokens that can be stored in the cache, improving throughput.
8+
9+
### FP8 Formats
10+
11+
[OCP (Open Compute Project)](https://www.opencompute.org) specifies two common 8-bit floating point data formats:
12+
13+
- E5M2 (5 exponent bits and 2 mantissa bits)
14+
- E4M3FN (4 exponent bits and 3 mantissa bits, often shortened as E4M3)
15+
16+
The E4M3 format offers higher precision compared to E5M2. However, due to its small dynamic range (±240.0), E4M3 typically requires a higher-precision (FP32) scaling factor alongside each quantized tensor.
17+
18+
### Current Limitations
19+
20+
For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling factors of a finer granularity (e.g. per-channel).
21+
22+
### Performance Impact
23+
24+
The current FP8 KV cache implementation primarily benefits throughput by allowing approximately double the amount of space for KV cache allocation. This enables either:
25+
26+
- Processing longer context lengths for individual requests, or
27+
- Handling more concurrent request batches
28+
29+
However, there are currently no latency improvements as the implementation does not yet include fused dequantization and attention operations. Future releases will support quantized attention with hardware acceleration, which should provide additional performance benefits. While the most recent silicon offerings (e.g. AMD MI300, NVIDIA Hopper or later) support native hardware conversion between FP8 and other formats (fp32, fp16, bf16), this benefit is not yet fully realized.
30+
31+
Studies have shown that FP8 E4M3 quantization typically only minimally degrades inference accuracy, making it a practical choice for throughput optimization.
32+
33+
## Usage Example
34+
35+
Here is an example of how to enable FP8 quantization:
36+
37+
```python
38+
from vllm import LLM, SamplingParams
39+
40+
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
41+
llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", kv_cache_dtype="fp8")
42+
prompt = "London is the capital of"
43+
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
44+
print(out)
45+
46+
# output w/ scaling factors: England, the United Kingdom, and one of the world's leading financial,
47+
# output w/o scaling factors: England, located in the southeastern part of the country. It is known
48+
```
49+
50+
The `kv_cache_dtype` argument specifies the data type for KV cache storage:
51+
- `"auto"`: Uses the model's default "unquantized" data type
52+
- `"fp8"` or `"fp8_e4m3"`: Supported on CUDA 11.8+ and ROCm (AMD GPU)
53+
- `"fp8_e5m2"`: Supported on CUDA 11.8+
54+
55+
## Calibrated Scales for Better Accuracy
56+
57+
For optimal model quality when using FP8 KV Cache, we recommend using calibrated scales tuned to representative inference data. [LLM Compressor](https://github.com/vllm-project/llm-compressor/) is the recommended tool for this process.
58+
59+
### Installation
60+
61+
First, install the required dependencies:
62+
63+
```console
64+
pip install llmcompressor
65+
```
66+
67+
### Example Usage
68+
69+
Here's a complete example using `meta-llama/Llama-3.1-8B-Instruct` (most models can use this same pattern):
70+
71+
```python
72+
from datasets import load_dataset
73+
from transformers import AutoModelForCausalLM, AutoTokenizer
74+
from llmcompressor.transformers import oneshot
75+
76+
# Select model and load it
77+
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
78+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", torch_dtype="auto")
79+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
80+
81+
# Select calibration dataset
82+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
83+
DATASET_SPLIT = "train_sft"
84+
85+
# Configure calibration parameters
86+
NUM_CALIBRATION_SAMPLES = 512 # 512 samples is a good starting point
87+
MAX_SEQUENCE_LENGTH = 2048
88+
89+
# Load and preprocess dataset
90+
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
91+
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
92+
93+
def process_and_tokenize(example):
94+
text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
95+
return tokenizer(
96+
text,
97+
padding=False,
98+
max_length=MAX_SEQUENCE_LENGTH,
99+
truncation=True,
100+
add_special_tokens=False,
101+
)
102+
103+
ds = ds.map(process_and_tokenize, remove_columns=ds.column_names)
104+
105+
# Configure quantization settings
106+
recipe = """
107+
quant_stage:
108+
quant_modifiers:
109+
QuantizationModifier:
110+
kv_cache_scheme:
111+
num_bits: 8
112+
type: float
113+
strategy: tensor
114+
dynamic: false
115+
symmetric: true
116+
"""
117+
118+
# Apply quantization
119+
oneshot(
120+
model=model,
121+
dataset=ds,
122+
recipe=recipe,
123+
max_seq_length=MAX_SEQUENCE_LENGTH,
124+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
125+
)
126+
127+
# Save quantized model
128+
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-KV"
129+
model.save_pretrained(SAVE_DIR, save_compressed=True)
130+
tokenizer.save_pretrained(SAVE_DIR)
131+
```
132+
133+
The above script will create a folder in your current directory containing your quantized model (e.g., `Llama-3.1-8B-Instruct-FP8-KV`) with calibrated scales.
134+
135+
When running the model you must specify `kv_cache_dtype="fp8"` in order to enable the kv cache quantization and use the scales.
136+
137+
```python
138+
from vllm import LLM, SamplingParams
139+
140+
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
141+
llm = LLM(model="Llama-3.1-8B-Instruct-FP8-KV", kv_cache_dtype="fp8")
142+
prompt = "London is the capital of"
143+
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
144+
print(out)
145+
```

0 commit comments

Comments
 (0)