-
-
Notifications
You must be signed in to change notification settings - Fork 7.6k
[RFC]: Hybrid Memory Allocator #11382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you! |
Still working on this RFC actively. Leave a comment here to keep it open. |
Great RFC! Is there a timeline for this? I'm going to be looking into some KV Cache compression techniques in the next month, which would require part 2 (flexibility in page sizes). |
Not sure about the timeline for LCM page. For the first milestone, we plan to tune block_size to align the page size of different types, which should be finished in the next few weeks. What KV cache compression techniques are you planing to cover? Does tune block_size work? |
Looking at AQUA KV, where KV caches are compressed to 2 bits. We could likely just reuse existing pages upon compression. An additional challenge will be additional workspace (for temporary decompression) which could run very large for worst-case batching of multiple long-context requests. |
Uh oh!
There was an error while loading. Please reload this page.
Motivation.
In addition to standard self-attention only models, we now having more and more hybrid models with more than one type of layer, for example:
The KV cache size of different tokens are no longer the same in the above models. However, vLLM can only allocate the same KV cache size for all tokens, as shown in the below figure (mllama with BLOCK_SIZE=1). The memory waste can be 79.6% in mllama, 25% in Gemma-2, and 56.25% in Ministral.

And for mamba layers, vLLM has a special MambaCacheManager in
model_executor/models/mamba_cache.py
, which is not compatible with prefix caching.We want a new memory manager to:
We can support them mainly in two milestones:
Milestone 1: per-layer memory allocation, each layer has the same kv cache size per token
In this milestone, we assume that all layers have the same kv_hidden_size but different number of tokens due to different layer type (e.g., encoder, sliding window). Then, each layer can have the same page size, which greatly simplifies the design of memory allocator. This assumption can cover almost all models in current vLLM (except Jamba).
We set page size as [BLOCK_SIZE, kv_hidden_size] (current vLLM is [num_layer, BLOCK_SIZE, kv_hidden_size]). For each request, we call the memory allocator num_layer times to get a block table for each layer. Then, each layer will have a different number of slots based on layer type, and different slot mapping.

The software architecture will be as following, with the memory manager for each layer be one of [SelfAttentionManager, SlidingWindowManager, MambaManager, CrossAttentionManager …]:

To make the per-layer memory allocation faster, we can group the layers to satisfy the following properties, so that layers inside each group share the same block table.
For example, mllama 11B has 8 cross-attention layer & 32 self-attention layer. Will be grouped to 5 groups: (cross attention layer$\times$ 8) (self-attn $\times$ 8) (self-attn $\times$ 8) (self-attn $\times$ 8) (self-attn $\times$ 8), and each group has one memory manager.
We can still use LRU to manage the cached blocks, by putting free blocks of all layers in the same queue and evicting the LRU block among all layers.
Moreover, we will have the freedom to customize the eviction strategy of different layer types:
for request [sys1 sys2 prompt1 prompt2 prompt3], we can evict sys1 prompt1 prompt2; only cache sys2 (the end of system prompt) and prompt3 (the end of request, for multi-turn conversation))
The customized eviction strategies can be implemented within the memory manager of each type by assigning the LRU time tag carefully (or put into the FreeKVCacheBlockQueue in a careful order)
Milestone 2: allow allocation of pages with different size by LCM pages
In this step, we want to build a more general memory allocator to remove the same kv_hidden_size assumption.
Some use cases are listed here:
The above parts will introduce different page sizes. So we need a new allocator for it.
The basic idea is to introduce a two-level page table:
For example, mllama with 2 cross attention layers (and KV cache for image tokens) and 3 self attention layers (and KV cache for text tokens), with kv_hidden_size 128, we can have two page sizes (2*128=256, 3*128=384), and the following memory layout:

Primary result
I’ve implemented a prototype on v0 and got the following results on H100.
The speedup comes from both less fragmentation and better prefix caching.
Proposed Change.
The allocator will be implemented on v1, with the following steps
Feedback Period.
one week
CC List.
@comaniac @WoosukKwon @zhuohan123 @simon
Any Other Things.
No response
Before submitting a new issue...
The text was updated successfully, but these errors were encountered: