Skip to content

Commit 034e0ee

Browse files
authored
Implement PagedAttention V2 (vllm-project#1348)
1 parent 5e0d608 commit 034e0ee

File tree

6 files changed

+764
-139
lines changed

6 files changed

+764
-139
lines changed
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import argparse
2+
import random
3+
import time
4+
5+
import torch
6+
7+
from vllm import attention_ops
8+
9+
NUM_BLOCKS = 1024
10+
PARTITION_SIZE = 512
11+
12+
13+
@torch.inference_mode()
14+
def main(
15+
version: str,
16+
num_seqs: int,
17+
context_len: int,
18+
num_query_heads: int,
19+
num_kv_heads: int,
20+
head_size: int,
21+
use_alibi: bool,
22+
block_size: int,
23+
dtype: torch.dtype,
24+
seed: int,
25+
do_profile: bool,
26+
) -> None:
27+
random.seed(seed)
28+
torch.random.manual_seed(seed)
29+
torch.cuda.manual_seed(seed)
30+
31+
scale = float(1.0 / (head_size**0.5))
32+
query = torch.empty(num_seqs,
33+
num_query_heads,
34+
head_size,
35+
dtype=dtype,
36+
device="cuda")
37+
query.uniform_(-scale, scale)
38+
39+
assert num_query_heads % num_kv_heads == 0
40+
num_queries_per_kv = num_query_heads // num_kv_heads
41+
head_mapping = torch.repeat_interleave(
42+
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
43+
num_queries_per_kv)
44+
alibi_slopes = None
45+
if use_alibi:
46+
alibi_slopes = torch.randn(num_query_heads,
47+
dtype=torch.float,
48+
device="cuda")
49+
50+
context_lens = [context_len for _ in range(num_seqs)]
51+
max_context_len = max(context_lens)
52+
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
53+
54+
# Create the block tables.
55+
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
56+
block_tables = []
57+
for _ in range(num_seqs):
58+
block_table = [
59+
random.randint(0, NUM_BLOCKS - 1)
60+
for _ in range(max_num_blocks_per_seq)
61+
]
62+
block_tables.append(block_table)
63+
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
64+
65+
# Create the KV cache.
66+
x = 16 // torch.tensor([], dtype=dtype).element_size()
67+
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
68+
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda")
69+
key_cache.uniform_(-scale, scale)
70+
value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
71+
value_cache = torch.empty(size=value_cache_shape,
72+
dtype=dtype,
73+
device="cuda")
74+
value_cache.uniform_(-scale, scale)
75+
76+
# Prepare for the paged attention kernel.
77+
output = torch.empty_like(query)
78+
if version == "v2":
79+
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
80+
PARTITION_SIZE)
81+
tmp_output = torch.empty(
82+
size=(num_seqs, num_query_heads, num_partitions, head_size),
83+
dtype=output.dtype,
84+
device=output.device,
85+
)
86+
exp_sums = torch.empty(
87+
size=(num_seqs, num_query_heads, num_partitions),
88+
dtype=torch.float32,
89+
device=output.device,
90+
)
91+
max_logits = torch.empty_like(exp_sums)
92+
93+
def run_benchmark(num_iters: int, profile: bool = False) -> float:
94+
torch.cuda.synchronize()
95+
if profile:
96+
torch.cuda.cudart().cudaProfilerStart()
97+
start_time = time.perf_counter()
98+
99+
for _ in range(num_iters):
100+
if version == "v1":
101+
attention_ops.paged_attention_v1(
102+
output,
103+
query,
104+
key_cache,
105+
value_cache,
106+
head_mapping,
107+
scale,
108+
block_tables,
109+
context_lens,
110+
block_size,
111+
max_context_len,
112+
alibi_slopes,
113+
)
114+
elif version == "v2":
115+
attention_ops.paged_attention_v2(
116+
output,
117+
exp_sums,
118+
max_logits,
119+
tmp_output,
120+
query,
121+
key_cache,
122+
value_cache,
123+
head_mapping,
124+
scale,
125+
block_tables,
126+
context_lens,
127+
block_size,
128+
max_context_len,
129+
alibi_slopes,
130+
)
131+
else:
132+
raise ValueError(f"Invalid version: {version}")
133+
torch.cuda.synchronize()
134+
135+
end_time = time.perf_counter()
136+
if profile:
137+
torch.cuda.cudart().cudaProfilerStart()
138+
return (end_time - start_time) / num_iters
139+
140+
# Warmup.
141+
print("Warming up...")
142+
run_benchmark(num_iters=3, profile=False)
143+
144+
# Benchmark.
145+
if do_profile:
146+
latency = run_benchmark(num_iters=1, profile=True)
147+
else:
148+
latency = run_benchmark(num_iters=100, profile=False)
149+
print(f"Kernel running time: {latency * 1000000:.3f} us")
150+
151+
152+
if __name__ == '__main__':
153+
parser = argparse.ArgumentParser(
154+
description="Benchmark the paged attention kernel.")
155+
parser.add_argument("--version",
156+
type=str,
157+
choices=["v1", "v2"],
158+
default="v2")
159+
parser.add_argument("--batch-size", type=int, default=8)
160+
parser.add_argument("--context-len", type=int, default=4096)
161+
parser.add_argument("--num-query-heads", type=int, default=64)
162+
parser.add_argument("--num-kv-heads", type=int, default=8)
163+
parser.add_argument("--head-size",
164+
type=int,
165+
choices=[64, 80, 96, 112, 128, 256],
166+
default=128)
167+
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
168+
parser.add_argument("--use-alibi", action="store_true")
169+
parser.add_argument("--dtype",
170+
type=str,
171+
choices=["half", "bfloat16", "float"],
172+
default="half")
173+
parser.add_argument("--seed", type=int, default=0)
174+
parser.add_argument("--profile", action="store_true")
175+
args = parser.parse_args()
176+
print(args)
177+
178+
if args.num_query_heads % args.num_kv_heads != 0:
179+
raise ValueError("num_query_heads must be divisible by num_kv_heads")
180+
dtype_to_torch_dtype = {
181+
"half": torch.half,
182+
"bfloat16": torch.bfloat16,
183+
"float": torch.float,
184+
}
185+
main(
186+
version=args.version,
187+
num_seqs=args.batch_size,
188+
context_len=args.context_len,
189+
num_query_heads=args.num_query_heads,
190+
num_kv_heads=args.num_kv_heads,
191+
head_size=args.head_size,
192+
block_size=args.block_size,
193+
use_alibi=args.use_alibi,
194+
dtype=dtype_to_torch_dtype[args.dtype],
195+
seed=args.seed,
196+
do_profile=args.profile,
197+
)

csrc/attention.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include <torch/extension.h>
22
#include <c10/util/Optional.h>
33

4-
void single_query_cached_kv_attention(
4+
void paged_attention_v1(
55
torch::Tensor& out,
66
torch::Tensor& query,
77
torch::Tensor& key_cache,
@@ -14,9 +14,29 @@ void single_query_cached_kv_attention(
1414
int max_context_len,
1515
const c10::optional<torch::Tensor>& alibi_slopes);
1616

17+
void paged_attention_v2(
18+
torch::Tensor& out,
19+
torch::Tensor& exp_sums,
20+
torch::Tensor& max_logits,
21+
torch::Tensor& tmp_out,
22+
torch::Tensor& query,
23+
torch::Tensor& key_cache,
24+
torch::Tensor& value_cache,
25+
torch::Tensor& head_mapping,
26+
float scale,
27+
torch::Tensor& block_tables,
28+
torch::Tensor& context_lens,
29+
int block_size,
30+
int max_context_len,
31+
const c10::optional<torch::Tensor>& alibi_slopes);
32+
1733
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1834
m.def(
19-
"single_query_cached_kv_attention",
20-
&single_query_cached_kv_attention,
21-
"Compute the attention between an input query and the cached key/value tensors");
35+
"paged_attention_v1",
36+
&paged_attention_v1,
37+
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
38+
m.def(
39+
"paged_attention_v2",
40+
&paged_attention_v2,
41+
"PagedAttention V2.");
2242
}

0 commit comments

Comments
 (0)