Skip to content

[Docs] Update FP8 KV Cache documentation #12238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 0 additions & 44 deletions docs/source/features/quantization/fp8_e4m3_kvcache.md

This file was deleted.

31 changes: 0 additions & 31 deletions docs/source/features/quantization/fp8_e5m2_kvcache.md

This file was deleted.

3 changes: 1 addition & 2 deletions docs/source/features/quantization/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,5 @@ bnb
gguf
int8
fp8
fp8_e5m2_kvcache
fp8_e4m3_kvcache
quantized_kvcache
```
145 changes: 145 additions & 0 deletions docs/source/features/quantization/quantized_kvcache.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
(quantized-kvcache)=

# Quantized KV Cache

## FP8 KV Cache

Quantizing the KV cache to FP8 reduces its memory footprint. This increases the number of tokens that can be stored in the cache, improving throughput.

### FP8 Formats

[OCP (Open Compute Project)](https://www.opencompute.org) specifies two common 8-bit floating point data formats:

- E5M2 (5 exponent bits and 2 mantissa bits)
- E4M3FN (4 exponent bits and 3 mantissa bits, often shortened as E4M3)

The E4M3 format offers higher precision compared to E5M2. However, due to its small dynamic range (±240.0), E4M3 typically requires a higher-precision (FP32) scaling factor alongside each quantized tensor.

### Current Limitations

For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling factors of a finer granularity (e.g. per-channel).

### Performance Impact

The current FP8 KV cache implementation primarily benefits throughput by allowing approximately double the amount of space for KV cache allocation. This enables either:

- Processing longer context lengths for individual requests, or
- Handling more concurrent request batches

However, there are currently no latency improvements as the implementation does not yet include fused dequantization and attention operations. Future releases will support quantized attention with hardware acceleration, which should provide additional performance benefits. While the most recent silicon offerings (e.g. AMD MI300, NVIDIA Hopper or later) support native hardware conversion between FP8 and other formats (fp32, fp16, bf16), this benefit is not yet fully realized.

Studies have shown that FP8 E4M3 quantization typically only minimally degrades inference accuracy, making it a practical choice for throughput optimization.

## Usage Example

Here is an example of how to enable FP8 quantization:

```python
from vllm import LLM, SamplingParams

sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", kv_cache_dtype="fp8")
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)

# output w/ scaling factors: England, the United Kingdom, and one of the world's leading financial,
# output w/o scaling factors: England, located in the southeastern part of the country. It is known
```

The `kv_cache_dtype` argument specifies the data type for KV cache storage:
- `"auto"`: Uses the model's default "unquantized" data type
- `"fp8"` or `"fp8_e4m3"`: Supported on CUDA 11.8+ and ROCm (AMD GPU)
- `"fp8_e5m2"`: Supported on CUDA 11.8+

## Calibrated Scales for Better Accuracy

For optimal model quality when using FP8 KV Cache, we recommend using calibrated scales tuned to representative inference data. [LLM Compressor](https://github.com/vllm-project/llm-compressor/) is the recommended tool for this process.

### Installation

First, install the required dependencies:

```console
pip install llmcompressor
```

### Example Usage

Here's a complete example using `meta-llama/Llama-3.1-8B-Instruct` (most models can use this same pattern):

```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor.transformers import oneshot

# Select model and load it
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Select calibration dataset
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Configure calibration parameters
NUM_CALIBRATION_SAMPLES = 512 # 512 samples is a good starting point
MAX_SEQUENCE_LENGTH = 2048

# Load and preprocess dataset
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))

def process_and_tokenize(example):
text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
return tokenizer(
text,
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)

ds = ds.map(process_and_tokenize, remove_columns=ds.column_names)

# Configure quantization settings
recipe = """
quant_stage:
quant_modifiers:
QuantizationModifier:
kv_cache_scheme:
num_bits: 8
type: float
strategy: tensor
dynamic: false
symmetric: true
"""

# Apply quantization
oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Save quantized model
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-KV"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
```

The above script will create a folder in your current directory containing your quantized model (e.g., `Llama-3.1-8B-Instruct-FP8-KV`) with calibrated scales.

When running the model you must specify `kv_cache_dtype="fp8"` in order to enable the kv cache quantization and use the scales.

```python
from vllm import LLM, SamplingParams

sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
llm = LLM(model="Llama-3.1-8B-Instruct-FP8-KV", kv_cache_dtype="fp8")
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)
```
Loading