You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
author: Zihao Ye (UW), Ruihang Lai (CMU), Roy Lu (UW), Chien-Yu Lin (UW), Size Zheng (UW & PKU), Lequn Chen (UW), Tianqi Chen (CMU & OctoAI), Luis Ceze (UW & OctoAI)
6
+
author: Zihao Ye (UW), Ruihang Lai (CMU), Bo-Ru Lu (UW), Chien-Yu Lin (UW), Size Zheng (UW & PKU), Lequn Chen (UW), Tianqi Chen (CMU & OctoML), Luis Ceze (UW & OctoML)
7
+
redirect_from: "/2024/01/08/cascade-inference"
7
8
---
8
9
9
10
Many LLM inference tasks involves multiple independent text generation from a shared prefix (prompt), e.g. [Self-Consistency](https://arxiv.org/abs/2203.11171), [Tree of Thoughts](https://arxiv.org/abs/2305.10601) and [Skeleton-of-thought](https://arxiv.org/abs/2307.15337). Serving LLMs with common prefix could be memory and time-consuming, especially when common prefix is long and the number of requests is large: a possible use case is long document QA (Figure 1), multiple users interacts with ChatBot with the same document as prompt. While [vLLM](https://arxiv.org/abs/2309.06180) alleviate the memory issue by only storing one copy of the common prefix. However, it still suffers from the low-efficiency because the default PageAttention implementation do not optimize KV-Cache access to the shared prompt.
10
11
11
12
In this blog post, we introduce Cascade Inference, which simply decouples attention of shared prefix and unique suffixes, and enables storing shared KV-Cache in GPU shared memory (SMEM for short) for fast access in multiple requests. We show that Cascade Inference can greatly accelerate shared-prefix batch decoding operator, with up to 31x speedup compared to the baseline vLLM PageAttention implementation and 26x speedup compared to FlashInfer batch decoding operator without cascading on a H100 SXM 80GB. The kernels have been supported in [FlashInfer](https://github.com/flashinfer-ai/flashinfer/) as [PyTorch](https://docs.flashinfer.ai/api/python/cascade.html#cascade-attention) and C++ APIs.
<figcaption> Figure 1. An example of serving Document QA for multiple users, all of the requests share the same book as prompt. </figcaption>
17
-
</figure>
16
+
<br>
17
+
Figure 1. An example of serving Document QA for multiple users, all of the requests share the same book as prompt.
18
18
</p>
19
19
20
20
## Background
@@ -33,7 +33,7 @@ The single-query attention kernel (used in decode), on the other hand, assumes t
33
33
34
34
Neither multi-query attention nor single-query attention kernel is a good fit for shared-prefix batch decoding. However, multi-query attention is perfect for attention between queries and shared prefix, while single-query attention can deal with the attention between queries and unique suffixes. Can we combine the advantages of both approaches?
35
35
36
-
### Recursive Softmax/Attention
36
+
### Recursive Attention
37
37
38
38
The answer is "yes" if we can find a way to "merge" the attention of the same queries with shared prefix and unique suffixes. Fortunately, FlashAttention has shown it's possible to combine local
39
39
softmax/attention results by not only storing the local attention result, but also the normalization scales and renormalizing local attention results on the fly. Here we formulate the idea in concise notations:
let's also generalize the value vector $\mathbf{v}$ from index to index sets (Note that the generalization of both $s$ and $v$ are self-consistent because when $I$ equals $\{i\}$, we have $s(I) = s_i$ and $\mathbf{v}(I) = \mathbf{v}_i$):
The **attention state** between a query with KV of an index set $I$ can be defined as a tuple $\begin{bmatrix}\mathbf{v}(I) \\\ s(I)\end{bmatrix}$,
54
-
then we can define the **merge** operator $\oplus$ to combine two states as [^3]:
53
+
the $\textrm{softmax}$ function are restricted to the index set $I$. Note that $\mathbf{v}(\{1,2,\cdots, n\})$ is the self-attention output of the entire sequence. The **attention state** between a query with KV of an index set $I$ can be defined as a tuple $\begin{bmatrix}\mathbf{v}(I) \\\ s(I)\end{bmatrix}$,
54
+
then we can define a binary **merge** operator $\oplus$ to combine two states as (in practice we will minus $s$ with maximum value to guarantee numerical stability and here we omit them for simplicity):
Then $\mathbf{v}(\\{1,2,\cdots, n\\})$ is the self-attention result between query and the entire KV. Note that $\oplus$ is communicative and associative, which means we can get the exact attention result by merging the attention states of index subsets as long as their disjoint union is the $\\{1,2,\cdots, n\\}$, regardless of merge order.
62
+
The above n-ary merge operator is consistent with the binary merge operator, and we can prove the operator is *communicative* and *associative*. There are different ways to get the attention state of the entire sequence by merging the attention states of index subsets, and the final outcome is mathematically equivalent:
63
63
64
-
The KV sequence partitioning trick in FlashInfer and Flash-Decoding uses the same idea to merge partial attention states from different thread blocks.
Figure 3. Different order to merge attention states are mathematically equivalent.
68
+
</p>
69
+
70
+
Recursive Attention allow us to decompose attention computation into multiple stages, different stages
71
+
can be dispatched to different compute units/devices. The KV sequence partitioning trick in FlashInfer and Flash-Decoding uses the same idea to merge partial attention states from different thread blocks.
65
72
66
73
### Cascade Inference: The Algorithm
67
74
@@ -74,10 +81,9 @@ we propose the following Divide-and-Conquer algorithm:
74
81
The overall workflow is explained on the left side of Figure 2, different color of rectangles are processed in different thread blocks in GPU. Note that for multi-query attention kernels, we access KV-Cache through SMEM or registers and for decode kernels we can only access KV-Cache through L2 Cache or Global Memory. Cascade Inference allow us to maximize memory reuse for common prefix, thus making the attention computation much more memory efficient.
<figcaption> Figure 2. Workflow of Cascade Inference, throughput values adapted from blog: <ahref="https://khairy2011.medium.com/tpu-vs-gpu-vs-cerebras-vs-graphcore-a-fair-comparison-between-ml-hardware-3f5a19d89e38">TPU vs GPU vs Cerebras vs Graphcore: A Fair Comparison between ML Hardware</a></figcaption>
80
-
</figure>
85
+
<br>
86
+
Figure 2. Workflow of Cascade Inference, throughput values adapted from blog: <ahref="https://khairy2011.medium.com/tpu-vs-gpu-vs-cerebras-vs-graphcore-a-fair-comparison-between-ml-hardware-3f5a19d89e38">TPU vs GPU vs Cerebras vs Graphcore: A Fair Comparison between ML Hardware</a>
81
87
</p>
82
88
83
89
We call the divide-and-conquer approach for shared-prefix attention the "Cascade Inference".
@@ -87,25 +93,15 @@ We call the divide-and-conquer approach for shared-prefix attention the "Cascade
87
93
We evaluate Cascade Inference on H100 SXM 80GB and A100 PCIE 80GB GPUs. The input shape are adapted from LLaMA2-7B (32 heads, 128 dimension per head). We varies three parameters: number of requests (batch size), shared prefix length and unique suffix length per request. The baseline implementations is PageAttention kernel implemented in vLLM 0.2.6, we also show the performance of FlashInfer batch decoding operator without cascading. The page size (or block size, equivalently) is fixed to 16 for all implementations.
Figure 4. Speedup over vLLM PageAttention on A100 PCIe 80GB
109
105
</p>
110
106
111
107
Figure 3 and 4 show the normalized performance on FlashInfer kernels in cascading and non-cascading setting
@@ -122,4 +118,3 @@ Recently, [SGLang](https://arxiv.org/abs/2312.07104) (a domain-specific language
122
118
123
119
[^1]: thread block: the programming abstraction that represents a group of cooperative threads, one SM can execute multiple thread blocks and one thread block cannot span multiple SMs.
124
120
[^2]: [Hopper architecture](https://resources.nvidia.com/en-us-tensor-core) introduces a new abstraction called Thread Block Clusters which enables a thread block to access shared memory of other thread blocks within the same SM. Hopper also supports direct SM-to-SM communication without accessing global memory (a.k.a Distributed Shared Memory), which can greatly accelerate cross SM communication. However, these features are not available in pre-Hopper architectures such as A100 GPUs.
125
-
[^3]: The tricks such as minus $s$ with max value to avoid numerically issues are omitted for simplicity
0 commit comments