Skip to content

Commit 6e66565

Browse files
committed
upd
1 parent f653fe9 commit 6e66565

10 files changed

+32
-30
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
*.pdf
12
*.py
23
*.pyc
34
__pycache__

Diff for: Gemfile

+2
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,5 @@ gem "wdm", "~> 0.1.1", :platforms => [:mingw, :x64_mingw, :mswin]
3131
# Lock `http_parser.rb` gem to `v0.6.x` on JRuby builds since newer versions of the gem
3232
# do not have a Java counterpart.
3333
gem "http_parser.rb", "~> 0.6.0", :platforms => [:jruby]
34+
35+
gem 'jekyll-redirect-from'

Diff for: Gemfile.lock

+3
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ GEM
3333
webrick (~> 1.7)
3434
jekyll-feed (0.17.0)
3535
jekyll (>= 3.7, < 5.0)
36+
jekyll-redirect-from (0.16.0)
37+
jekyll (>= 3.3, < 5.0)
3638
jekyll-sass-converter (3.0.0)
3739
sass-embedded (~> 1.54)
3840
jekyll-seo-tag (2.8.0)
@@ -75,6 +77,7 @@ DEPENDENCIES
7577
http_parser.rb (~> 0.6.0)
7678
jekyll (~> 4.3.2)
7779
jekyll-feed (~> 0.12)
80+
jekyll-redirect-from
7881
minima (~> 2.5)
7982
tzinfo (>= 1, < 3)
8083
tzinfo-data

Diff for: _config.yml

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ github_username: flashinfer-ai
3030
theme: minima
3131
plugins:
3232
- jekyll-feed
33+
- jekyll-redirect-from
3334

3435
# Exclude from processing.
3536
# The following items will not be processed, by default.

Diff for: _posts/2024-01-08-cascade-inference.md

+25-30
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
---
22
layout: post
33
title: "Cascade Inference: Memory Bandwidth Efficient Shared Prefix Batch Decoding"
4-
date: 2024-01-08
4+
date: 2024-02-02
55
comments: true
6-
author: Zihao Ye (UW), Ruihang Lai (CMU), Roy Lu (UW), Chien-Yu Lin (UW), Size Zheng (UW & PKU), Lequn Chen (UW), Tianqi Chen (CMU & OctoAI), Luis Ceze (UW & OctoAI)
6+
author: Zihao Ye (UW), Ruihang Lai (CMU), Bo-Ru Lu (UW), Chien-Yu Lin (UW), Size Zheng (UW & PKU), Lequn Chen (UW), Tianqi Chen (CMU & OctoML), Luis Ceze (UW & OctoML)
7+
redirect_from: "/2024/01/08/cascade-inference"
78
---
89

910
Many LLM inference tasks involves multiple independent text generation from a shared prefix (prompt), e.g. [Self-Consistency](https://arxiv.org/abs/2203.11171), [Tree of Thoughts](https://arxiv.org/abs/2305.10601) and [Skeleton-of-thought](https://arxiv.org/abs/2307.15337). Serving LLMs with common prefix could be memory and time-consuming, especially when common prefix is long and the number of requests is large: a possible use case is long document QA (Figure 1), multiple users interacts with ChatBot with the same document as prompt. While [vLLM](https://arxiv.org/abs/2309.06180) alleviate the memory issue by only storing one copy of the common prefix. However, it still suffers from the low-efficiency because the default PageAttention implementation do not optimize KV-Cache access to the shared prompt.
1011

1112
In this blog post, we introduce Cascade Inference, which simply decouples attention of shared prefix and unique suffixes, and enables storing shared KV-Cache in GPU shared memory (SMEM for short) for fast access in multiple requests. We show that Cascade Inference can greatly accelerate shared-prefix batch decoding operator, with up to 31x speedup compared to the baseline vLLM PageAttention implementation and 26x speedup compared to FlashInfer batch decoding operator without cascading on a H100 SXM 80GB. The kernels have been supported in [FlashInfer](https://github.com/flashinfer-ai/flashinfer/) as [PyTorch](https://docs.flashinfer.ai/api/python/cascade.html#cascade-attention) and C++ APIs.
1213

1314
<p align="center">
14-
<figure>
1515
<img src="/assets/imgs/document-qa-serving.png" alt="Document QA Serving" width="800"/>
16-
<figcaption> Figure 1. An example of serving Document QA for multiple users, all of the requests share the same book as prompt. </figcaption>
17-
</figure>
16+
<br>
17+
Figure 1. An example of serving Document QA for multiple users, all of the requests share the same book as prompt.
1818
</p>
1919

2020
## Background
@@ -33,7 +33,7 @@ The single-query attention kernel (used in decode), on the other hand, assumes t
3333

3434
Neither multi-query attention nor single-query attention kernel is a good fit for shared-prefix batch decoding. However, multi-query attention is perfect for attention between queries and shared prefix, while single-query attention can deal with the attention between queries and unique suffixes. Can we combine the advantages of both approaches?
3535

36-
### Recursive Softmax/Attention
36+
### Recursive Attention
3737

3838
The answer is "yes" if we can find a way to "merge" the attention of the same queries with shared prefix and unique suffixes. Fortunately, FlashAttention has shown it's possible to combine local
3939
softmax/attention results by not only storing the local attention result, but also the normalization scales and renormalizing local attention results on the fly. Here we formulate the idea in concise notations:
@@ -48,20 +48,27 @@ $$ s(I) = \log\left(\sum_{i\in I} \exp(s_i) \right),$$
4848

4949
let's also generalize the value vector $\mathbf{v}$ from index to index sets (Note that the generalization of both $s$ and $v$ are self-consistent because when $I$ equals $\{i\}$, we have $s(I) = s_i$ and $\mathbf{v}(I) = \mathbf{v}_i$):
5050

51-
$$ \mathbf{v}(I)=\frac{\sum_{i\in I}\exp\left(s_i\right)\mathbf{v}_i}{\exp(s(I))}, $$
51+
$$ \mathbf{v}(I) = \sum_{i\in I}\textrm{softmax}(s_i) \mathbf{v}_i = \frac{\sum_{i\in I}\exp\left(s_i\right)\mathbf{v}_i}{\exp(s(I))}, $$
5252

53-
The **attention state** between a query with KV of an index set $I$ can be defined as a tuple $\begin{bmatrix}\mathbf{v}(I) \\\ s(I)\end{bmatrix}$,
54-
then we can define the **merge** operator $\oplus$ to combine two states as [^3]:
53+
the $\textrm{softmax}$ function are restricted to the index set $I$. Note that $\mathbf{v}(\{1,2,\cdots, n\})$ is the self-attention output of the entire sequence. The **attention state** between a query with KV of an index set $I$ can be defined as a tuple $\begin{bmatrix}\mathbf{v}(I) \\\ s(I)\end{bmatrix}$,
54+
then we can define a binary **merge** operator $\oplus$ to combine two states as (in practice we will minus $s$ with maximum value to guarantee numerical stability and here we omit them for simplicity):
5555

5656
$$\begin{bmatrix}\mathbf{v}(I\cup J)\\s(I\cup J)\end{bmatrix}=\begin{bmatrix}\mathbf{v}(I)\\s(I)\end{bmatrix}\oplus\begin{bmatrix}\mathbf{v}(J)\\s(J)\end{bmatrix}=\begin{bmatrix} \frac{\mathbf{v}(I)\exp(s(I)) + \mathbf{v}(J)\exp(s(J))}{\exp(s(I)) + \exp(s(J))} \\ \log(\exp(s(I)) + \exp(s(J))) \end{bmatrix},$$
5757

58-
and we can define **attention state** on the entire sequence (suppose sequence length is $n$):
58+
the **merge** operator can be generalized to any number of attention state inputs:
5959

60-
$$\begin{bmatrix}\mathbf{v}(\{1,2,\dots, n\})\\s(\{1,2,\dots, n\})\end{bmatrix} = \bigoplus_{i=1}^{n} \begin{bmatrix}\mathbf{v}_i\\s_i\end{bmatrix}$$
60+
$$\begin{bmatrix}\mathbf{v}(\bigcup_{i=1}^{n}I_i) \\ s(\bigcup_{i=1}^{n}I_i) \end{bmatrix} = \bigoplus_{i=1}^{n}\begin{bmatrix}\mathbf{v}(I_i) \\ s(I_i)\end{bmatrix} = \begin{bmatrix} \sum_{i=1}^{n} \textrm{softmax}(s(I_i))\mathbf{v}(I_i) \\ \log(\sum_{i=1}^{n} \exp (s(I_i))) \end{bmatrix} $$
6161

62-
Then $\mathbf{v}(\\{1,2,\cdots, n\\})$ is the self-attention result between query and the entire KV. Note that $\oplus$ is communicative and associative, which means we can get the exact attention result by merging the attention states of index subsets as long as their disjoint union is the $\\{1,2,\cdots, n\\}$, regardless of merge order.
62+
The above n-ary merge operator is consistent with the binary merge operator, and we can prove the operator is *communicative* and *associative*. There are different ways to get the attention state of the entire sequence by merging the attention states of index subsets, and the final outcome is mathematically equivalent:
6363

64-
The KV sequence partitioning trick in FlashInfer and Flash-Decoding uses the same idea to merge partial attention states from different thread blocks.
64+
<p align="center">
65+
<img src="/assets/imgs/recursive-attention.png" alt="recursive-attention" width="800"/>
66+
<br>
67+
Figure 3. Different order to merge attention states are mathematically equivalent.
68+
</p>
69+
70+
Recursive Attention allow us to decompose attention computation into multiple stages, different stages
71+
can be dispatched to different compute units/devices. The KV sequence partitioning trick in FlashInfer and Flash-Decoding uses the same idea to merge partial attention states from different thread blocks.
6572

6673
### Cascade Inference: The Algorithm
6774

@@ -74,10 +81,9 @@ we propose the following Divide-and-Conquer algorithm:
7481
The overall workflow is explained on the left side of Figure 2, different color of rectangles are processed in different thread blocks in GPU. Note that for multi-query attention kernels, we access KV-Cache through SMEM or registers and for decode kernels we can only access KV-Cache through L2 Cache or Global Memory. Cascade Inference allow us to maximize memory reuse for common prefix, thus making the attention computation much more memory efficient.
7582

7683
<p align="center">
77-
<figure>
7884
<img src="/assets/imgs/cascade-inference.png" alt="Cascade Inference" width="800"/>
79-
<figcaption> Figure 2. Workflow of Cascade Inference, throughput values adapted from blog: <a href="https://khairy2011.medium.com/tpu-vs-gpu-vs-cerebras-vs-graphcore-a-fair-comparison-between-ml-hardware-3f5a19d89e38">TPU vs GPU vs Cerebras vs Graphcore: A Fair Comparison between ML Hardware</a></figcaption>
80-
</figure>
85+
<br>
86+
Figure 2. Workflow of Cascade Inference, throughput values adapted from blog: <a href="https://khairy2011.medium.com/tpu-vs-gpu-vs-cerebras-vs-graphcore-a-fair-comparison-between-ml-hardware-3f5a19d89e38">TPU vs GPU vs Cerebras vs Graphcore: A Fair Comparison between ML Hardware</a>
8187
</p>
8288

8389
We call the divide-and-conquer approach for shared-prefix attention the "Cascade Inference".
@@ -87,25 +93,15 @@ We call the divide-and-conquer approach for shared-prefix attention the "Cascade
8793
We evaluate Cascade Inference on H100 SXM 80GB and A100 PCIE 80GB GPUs. The input shape are adapted from LLaMA2-7B (32 heads, 128 dimension per head). We varies three parameters: number of requests (batch size), shared prefix length and unique suffix length per request. The baseline implementations is PageAttention kernel implemented in vLLM 0.2.6, we also show the performance of FlashInfer batch decoding operator without cascading. The page size (or block size, equivalently) is fixed to 16 for all implementations.
8894

8995
<p align="center">
90-
<figure>
9196
<img src="/assets/imgs/cascade-inference-performance-h100.png" alt="speedup-h100" width="800"/>
92-
<figcaption>
93-
<center>
97+
<br>
9498
Figure 3. Speedup over vLLM PageAttention on H100 SXM 80GB
95-
</center>
96-
</figcaption>
97-
</figure>
9899
</p>
99100

100101
<p align="center">
101-
<figure>
102102
<img src="/assets/imgs/cascade-inference-performance-a100.png" alt="speedup-a100" width="800"/>
103-
<figcaption>
104-
<center>
105-
Speedup over vLLM PageAttention on A100 PCIe 80GB
106-
</center>
107-
</figcaption>
108-
</figure>
103+
<br>
104+
Figure 4. Speedup over vLLM PageAttention on A100 PCIe 80GB
109105
</p>
110106

111107
Figure 3 and 4 show the normalized performance on FlashInfer kernels in cascading and non-cascading setting
@@ -122,4 +118,3 @@ Recently, [SGLang](https://arxiv.org/abs/2312.07104) (a domain-specific language
122118

123119
[^1]: thread block: the programming abstraction that represents a group of cooperative threads, one SM can execute multiple thread blocks and one thread block cannot span multiple SMs.
124120
[^2]: [Hopper architecture](https://resources.nvidia.com/en-us-tensor-core) introduces a new abstraction called Thread Block Clusters which enables a thread block to access shared memory of other thread blocks within the same SM. Hopper also supports direct SM-to-SM communication without accessing global memory (a.k.a Distributed Shared Memory), which can greatly accelerate cross SM communication. However, these features are not available in pre-Hopper architectures such as A100 GPUs.
125-
[^3]: The tricks such as minus $s$ with max value to avoid numerically issues are omitted for simplicity

Diff for: assets/imgs/devices-roofline.png

-84.6 KB
Loading

Diff for: assets/imgs/flashinfer-roofline-devices.pdf

32.9 KB
Binary file not shown.

Diff for: assets/imgs/recursive-attention.pdf

62.9 KB
Binary file not shown.

Diff for: assets/imgs/recursive-attention.png

207 KB
Loading

0 commit comments

Comments
 (0)