Skip to content

Commit 795dc02

Browse files
dr75kylesayrs
authored andcommitted
Support SHA256 as hash function in prefix caching (vllm-project#15297)
Signed-off-by: Marko Rosenmueller <[email protected]> Signed-off-by: Kyle Sayers <[email protected]>
1 parent e69734f commit 795dc02

File tree

12 files changed

+214
-71
lines changed

12 files changed

+214
-71
lines changed

docs/source/design/v1/prefix_caching.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@ Block 3: |<------------------ prefix -------------------->| |<--- block tokens -
1515
In the example above, the KV cache in the first block can be uniquely identified with the token “A gentle breeze stirred”. The third block can be uniquely identified with the tokens in the block “laughed in the distance”, along with the prefix tokens “A gentle breeze stirred the leaves as children”. Therefore, we can build the block hash of `hash(tuple[components])`, where components are:
1616

1717
* Parent hash value: The hash value of the parent hash block.
18-
* Block tokens: A tuple of tokens in this block. The reason to include the exact tokens is to reduce potential hash value collision.
18+
* Block tokens: A tuple of tokens in this block. The reason to include the exact tokens is to reduce potential hash value collision.
1919
* Extra hashes: Other values required to make this block unique, such as LoRA IDs and multi-modality input hashes (see the example below).
2020

21-
Note 1: We only cache full blocks.
21+
> **Note 1:** We only cache full blocks.
2222
23-
Note 2: The above hash key structure is not 100% collision free. Theoretically it’s still possible for the different prefix tokens to have the same hash value, but this should be nearly impossible to happen. Of course, contributions are welcome if you have an awesome idea to eliminate collusion entirely.
23+
> **Note 2:** The above hash key structure is not 100% collision free. Theoretically it’s still possible for the different prefix tokens to have the same hash value. To avoid any hash collisions **in a multi-tenant setup, we advise to use SHA256** as hash function instead of the default builtin hash.
24+
SHA256 is supported since vLLM v0.8.3 and must be enabled with a command line argument. It comes with a performance impact of about 100-200ns per token (~6ms for 50k tokens of context).
2425

2526
**A hashing example with multi-modality inputs**
2627
In this example, we illustrate how prefix caching works with multi-modality inputs (e.g., images). Assuming we have a request with the following messages:

tests/test_utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# ruff: noqa
33

44
import asyncio
5+
import hashlib
6+
import pickle
57
import socket
68
from collections.abc import AsyncIterator
79
from unittest.mock import patch
@@ -14,7 +16,8 @@
1416
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
1517
PlaceholderModule, StoreBoolean, bind_kv_cache,
1618
deprecate_kwargs, get_open_port, memory_profiling,
17-
merge_async_iterators, supports_kw, swap_dict_values)
19+
merge_async_iterators, sha256, supports_kw,
20+
swap_dict_values)
1821

1922
from .utils import create_new_process_for_each_test, error_on_warning
2023

@@ -476,3 +479,21 @@ def test_swap_dict_values(obj, key1, key2):
476479
assert obj[key1] == original_obj[key2]
477480
else:
478481
assert key1 not in obj
482+
483+
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
484+
(None, bool, [1, 2, 3])])
485+
@pytest.mark.parametrize("output", [0, 1, 2])
486+
def test_sha256(input: tuple, output: int):
487+
hash = sha256(input)
488+
assert hash is not None
489+
assert isinstance(hash, int)
490+
assert hash != 0
491+
492+
bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
493+
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), byteorder="big")
494+
495+
# hashing again, returns the same value
496+
assert hash == sha256(input)
497+
498+
# hashing different input, returns different value
499+
assert hash != sha256(input + (1, ))

tests/v1/core/test_kv_cache_utils.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@
55

66
from vllm.multimodal.inputs import MultiModalKwargs
77
from vllm.sampling_params import SamplingParams
8-
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
9-
KVCacheBlock, PrefixCachingMetrics,
8+
from vllm.utils import sha256
9+
# disable yapf here as it formats differently than isort such that both fail
10+
# yapf: disable
11+
from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
12+
FreeKVCacheBlockQueue, KVCacheBlock,
13+
PrefixCachingMetrics,
1014
generate_block_hash_extra_keys,
1115
hash_block_tokens,
1216
hash_request_tokens,
@@ -16,6 +20,8 @@
1620
from vllm.v1.metrics.stats import PrefixCacheStats
1721
from vllm.v1.request import Request
1822

23+
# yapf: enable
24+
1925

2026
def make_request(request_id,
2127
prompt_token_ids,
@@ -40,6 +46,12 @@ def make_request(request_id,
4046
)
4147

4248

49+
def test_none_hash():
50+
assert NONE_HASH is not None
51+
assert isinstance(NONE_HASH, int)
52+
assert NONE_HASH != 0
53+
54+
4355
def test_kv_cache_block():
4456
# Test KVCacheBlock initialization
4557
block = KVCacheBlock(block_id=0)
@@ -190,21 +202,23 @@ def test_generate_block_hash_extra_keys_no_mm_inputs():
190202
assert next_mm_idx == 0
191203

192204

193-
def test_hash_block_tokens():
205+
@pytest.mark.parametrize("hash_fn", [sha256, hash])
206+
def test_hash_block_tokens(hash_fn):
194207
parent_block_hash = 123
195208
curr_block_token_ids = (1, 2, 3)
196209
extra_keys = ("key1", "key2")
197210

198-
block_hash = hash_block_tokens(parent_block_hash, curr_block_token_ids,
199-
extra_keys)
211+
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
212+
curr_block_token_ids, extra_keys)
200213
assert isinstance(block_hash, BlockHashType)
201-
assert block_hash.hash_value == hash(
214+
assert block_hash.hash_value == hash_fn(
202215
(parent_block_hash, curr_block_token_ids, extra_keys))
203216
assert block_hash.token_ids == curr_block_token_ids
204217
assert block_hash.extra_keys == extra_keys
205218

206219

207-
def test_hash_request_tokens():
220+
@pytest.mark.parametrize("hash_fn", [sha256, hash])
221+
def test_hash_request_tokens(hash_fn):
208222
request = make_request(
209223
request_id=0,
210224
prompt_token_ids=[_ for _ in range(6)],
@@ -219,7 +233,7 @@ def test_hash_request_tokens():
219233
)
220234

221235
block_size = 3
222-
block_hashes = hash_request_tokens(block_size, request)
236+
block_hashes = hash_request_tokens(hash_fn, block_size, request)
223237

224238
assert len(block_hashes) == 2
225239
assert isinstance(block_hashes[0], BlockHashType)
@@ -234,7 +248,8 @@ def test_hash_request_tokens():
234248
assert block_hashes[1].extra_keys == ("hash2", )
235249

236250

237-
def test_hash_tokens_different_mm_input():
251+
@pytest.mark.parametrize("hash_fn", [sha256, hash])
252+
def test_hash_tokens_different_mm_input(hash_fn):
238253
request1 = make_request(
239254
request_id=0,
240255
prompt_token_ids=[_ for _ in range(6)],
@@ -260,13 +275,14 @@ def test_hash_tokens_different_mm_input():
260275
mm_hashes=["hash3", "hash2"],
261276
)
262277
block_size = 3
263-
block_hashes1 = hash_request_tokens(block_size, request1)
264-
block_hashes2 = hash_request_tokens(block_size, request2)
278+
block_hashes1 = hash_request_tokens(hash_fn, block_size, request1)
279+
block_hashes2 = hash_request_tokens(hash_fn, block_size, request2)
265280
assert block_hashes1[0] != block_hashes2[0]
266281
assert block_hashes1[1] != block_hashes2[1]
267282

268283

269-
def test_hash_request_tokens_no_mm_inputs():
284+
@pytest.mark.parametrize("hash_fn", [sha256, hash])
285+
def test_hash_request_tokens_no_mm_inputs(hash_fn):
270286
request = make_request(
271287
request_id=0,
272288
prompt_token_ids=[_ for _ in range(6)],
@@ -275,7 +291,7 @@ def test_hash_request_tokens_no_mm_inputs():
275291
)
276292

277293
block_size = 3
278-
block_hashes = hash_request_tokens(block_size, request)
294+
block_hashes = hash_request_tokens(hash_fn, block_size, request)
279295

280296
assert len(block_hashes) == 2
281297
assert block_hashes[0].token_ids == (0, 1, 2)

tests/v1/core/test_prefix_caching.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
99
from vllm.sampling_params import SamplingParams
10-
from vllm.utils import cdiv
10+
from vllm.utils import cdiv, sha256
1111
from vllm.v1.core.block_pool import BlockPool
1212
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
1313
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
@@ -39,16 +39,21 @@ def make_request(request_id,
3939
)
4040

4141

42-
def test_prefill():
42+
@pytest.mark.parametrize("hash_algo", ["sha256", "hash"])
43+
def test_prefill(hash_algo):
4344
manager = KVCacheManager(
4445
block_size=16,
4546
num_gpu_blocks=10,
4647
max_model_len=8192,
4748
sliding_window=None,
4849
enable_caching=True,
50+
caching_hash_algo=hash_algo,
4951
num_preallocate_tokens=16,
5052
)
5153

54+
# choose the hash function according to the parameter
55+
hash_fn = sha256 if hash_algo == "sha256" else hash
56+
5257
# Complete 3 blocks (48 tokens)
5358
common_token_ids = [i for i in range(3) for _ in range(16)]
5459

@@ -68,7 +73,8 @@ def test_prefill():
6873
parent_block_hash = None
6974
for block_id in (0, 1, 2):
7075
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
71-
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
76+
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
77+
block_tokens)
7278
assert manager.block_pool.blocks[block_id].block_hash == block_hash
7379
assert manager.block_pool.blocks[block_id].ref_cnt == 1
7480
parent_block_hash = block_hash.hash_value
@@ -163,6 +169,8 @@ def test_prefill_plp():
163169
enable_caching=True,
164170
num_preallocate_tokens=16,
165171
)
172+
# the default hash function is hash
173+
hash_fn = hash
166174

167175
# Complete 3 blocks (48 tokens)
168176
common_token_ids = [i for i in range(3) for _ in range(16)]
@@ -185,7 +193,8 @@ def test_prefill_plp():
185193
parent_block_hash = None
186194
for block_id in (0, 1, 2):
187195
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
188-
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
196+
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
197+
block_tokens)
189198
assert manager.block_pool.blocks[block_id].block_hash == block_hash
190199
assert manager.block_pool.blocks[block_id].ref_cnt == 1
191200
parent_block_hash = block_hash.hash_value
@@ -522,7 +531,8 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
522531
assert len(blocks) == 1 + num_preallocated_blocks
523532

524533

525-
def test_cache_blocks():
534+
@pytest.mark.parametrize("hash_fn", [sha256, hash])
535+
def test_cache_blocks(hash_fn):
526536
"""
527537
This is a unit test that tests the correctness of the _cache_full_blocks
528538
function of KVCacheManager.
@@ -550,6 +560,7 @@ def test_cache_blocks():
550560
num_cached_blocks=0,
551561
num_full_blocks=2,
552562
block_size=block_size,
563+
hash_fn=hash_fn,
553564
)
554565

555566
assert len(block_pool.cached_block_hash_to_block) == 2
@@ -564,6 +575,7 @@ def test_cache_blocks():
564575
num_cached_blocks=2,
565576
num_full_blocks=3,
566577
block_size=block_size,
578+
hash_fn=hash_fn,
567579
)
568580
assert len(block_pool.cached_block_hash_to_block) == 3
569581
assert blocks[0].block_hash is not None

tests/v1/engine/test_engine_args.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
from argparse import ArgumentError
4+
35
import pytest
46

57
from vllm import envs
@@ -32,6 +34,24 @@ def test_prefix_caching_from_cli():
3234
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
3335
assert vllm_config.cache_config.enable_prefix_caching
3436

37+
# default hash algorithm is "builtin"
38+
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
39+
40+
# set hash algorithm to sha256
41+
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"])
42+
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
43+
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
44+
45+
# set hash algorithm to builtin
46+
args = parser.parse_args(["--prefix-caching-hash-algo", "builtin"])
47+
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
48+
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
49+
50+
# an invalid hash algorithm raises an error
51+
parser.exit_on_error = False
52+
with pytest.raises(ArgumentError):
53+
args = parser.parse_args(["--prefix-caching-hash-algo", "invalid"])
54+
3555

3656
def test_defaults_with_usage_context():
3757
engine_args = EngineArgs(model="facebook/opt-125m")

vllm/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,6 +1124,7 @@ def __init__(
11241124
num_gpu_blocks_override: Optional[int] = None,
11251125
sliding_window: Optional[int] = None,
11261126
enable_prefix_caching: bool = False,
1127+
prefix_caching_hash_algo: str = "builtin",
11271128
cpu_offload_gb: float = 0,
11281129
calculate_kv_scales: Optional[bool] = None,
11291130
) -> None:
@@ -1135,6 +1136,7 @@ def __init__(
11351136
self.is_attention_free = is_attention_free
11361137
self.sliding_window = sliding_window
11371138
self.enable_prefix_caching = enable_prefix_caching
1139+
self.prefix_caching_hash_algo = prefix_caching_hash_algo
11381140
self.cpu_offload_gb = cpu_offload_gb
11391141
self.calculate_kv_scales = calculate_kv_scales
11401142
self._verify_args()
@@ -1185,6 +1187,13 @@ def _verify_prefix_caching(self) -> None:
11851187
"Prefix caching is not supported with sliding window. "
11861188
"Run with --disable-sliding-window to use prefix caching.")
11871189

1190+
if self.enable_prefix_caching and self.prefix_caching_hash_algo not in (
1191+
"builtin", "sha256"):
1192+
raise ValueError(
1193+
"Unknown prefix caching hash algorithm: "
1194+
f"{self.prefix_caching_hash_algo}. Must be either "
1195+
"'builtin' or 'sha256'.")
1196+
11881197
def verify_with_parallel_config(
11891198
self,
11901199
parallel_config: "ParallelConfig",

vllm/engine/arg_utils.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class EngineArgs:
118118
max_parallel_loading_workers: Optional[int] = None
119119
block_size: Optional[int] = None
120120
enable_prefix_caching: Optional[bool] = None
121+
prefix_caching_hash_algo: str = "builtin"
121122
disable_sliding_window: bool = False
122123
disable_cascade_attn: bool = False
123124
use_v2_block_manager: bool = True
@@ -475,6 +476,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
475476
help="Enables automatic prefix caching. "
476477
"Use ``--no-enable-prefix-caching`` to disable explicitly.",
477478
)
479+
parser.add_argument(
480+
"--prefix-caching-hash-algo",
481+
type=str,
482+
choices=["builtin", "sha256"],
483+
default=EngineArgs.prefix_caching_hash_algo,
484+
help="Set the hash algorithm for prefix caching. "
485+
"Options are 'builtin' (Python's built-in hash) or 'sha256' "
486+
"(collision resistant but with certain overheads). Defaults "
487+
"to 'builtin'.",
488+
)
478489
parser.add_argument('--disable-sliding-window',
479490
action='store_true',
480491
help='Disables sliding window, '
@@ -1329,6 +1340,7 @@ def create_engine_config(
13291340
num_gpu_blocks_override=self.num_gpu_blocks_override,
13301341
sliding_window=model_config.get_sliding_window(),
13311342
enable_prefix_caching=self.enable_prefix_caching,
1343+
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
13321344
cpu_offload_gb=self.cpu_offload_gb,
13331345
calculate_kv_scales=self.calculate_kv_scales,
13341346
)
@@ -1737,12 +1749,22 @@ def _set_default_args_v0(self, model_config: ModelConfig) -> None:
17371749
msg = "Chunked prefill is not supported for pooling models"
17381750
raise ValueError(msg)
17391751

1740-
# Disable prefix caching for multimodal models for VLLM_V0.
1741-
if (model_config.is_multimodal_model and self.enable_prefix_caching):
1742-
logger.warning(
1743-
"--enable-prefix-caching is not supported for multimodal "
1744-
"models in V0 and has been disabled.")
1745-
self.enable_prefix_caching = False
1752+
# if using prefix caching, we must set a hash algo
1753+
if self.enable_prefix_caching:
1754+
# Disable prefix caching for multimodal models for VLLM_V0.
1755+
if model_config.is_multimodal_model:
1756+
logger.warning(
1757+
"--enable-prefix-caching is not supported for multimodal "
1758+
"models in V0 and has been disabled.")
1759+
self.enable_prefix_caching = False
1760+
1761+
# VLLM_V0 only supports builtin hash algo for prefix caching.
1762+
if self.prefix_caching_hash_algo is None:
1763+
self.prefix_caching_hash_algo = "builtin"
1764+
elif self.prefix_caching_hash_algo == "sha256":
1765+
raise ValueError(
1766+
"sha256 is not supported for prefix caching in V0 engine. "
1767+
"Please use 'builtin'.")
17461768

17471769
# Set max_num_seqs to 256 for VLLM_V0.
17481770
if self.max_num_seqs is None:
@@ -1758,6 +1780,10 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None:
17581780
if self.enable_prefix_caching is None:
17591781
self.enable_prefix_caching = True
17601782

1783+
# if using prefix caching, we must set a hash algo
1784+
if self.enable_prefix_caching and self.prefix_caching_hash_algo is None:
1785+
self.prefix_caching_hash_algo = "builtin"
1786+
17611787
# V1 should use the new scheduler by default.
17621788
# Swap it only if this arg is set to the original V0 default
17631789
if self.scheduler_cls == EngineArgs.scheduler_cls:

0 commit comments

Comments
 (0)