-
-
Notifications
You must be signed in to change notification settings - Fork 7.6k
Implement PagedAttention V2 #1348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 28 commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
7d057f9
PagedAttention V1
WoosukKwon 2cc7bff
Mid
WoosukKwon 8946093
PagedAttention V1
WoosukKwon f5b05fc
Undef DIVIDE_ROUND_UP
WoosukKwon 235f273
Add empty PagedAttention V2
WoosukKwon 472ee66
Minor
WoosukKwon 3827e24
Minor
WoosukKwon 2605c6e
Implement PagedAttention V2
WoosukKwon 877a3f5
Add comment
WoosukKwon 634f961
Fix performance bug
WoosukKwon 7585101
Fix attention test
WoosukKwon 3ea3891
Add heuristic
WoosukKwon ab89848
Minor optimization
WoosukKwon d83ce92
Add benchmark
WoosukKwon 760e7a2
Minor
WoosukKwon e6d8a15
yapf
WoosukKwon 4313691
Minor fix on comments
WoosukKwon c0021c1
Add comment on heuristic
WoosukKwon 8ddb426
Fix test_attention
WoosukKwon ae14bba
Merge branch 'main' into pa-v2
WoosukKwon 08e92c3
yapf
WoosukKwon dac5e24
Minor
WoosukKwon d674616
Minor
WoosukKwon 612236b
Reimplement
WoosukKwon 3d2eff1
Rename
WoosukKwon 57b3071
Minor
WoosukKwon cb3af6d
yapf
WoosukKwon 000abdf
Remove unnecessary fns
WoosukKwon f80f49f
Address comments
WoosukKwon 5b0a536
Minor fix
WoosukKwon f3c8cb0
Support attention with ALiBi
WoosukKwon bfa8569
yapf
WoosukKwon 9451b2d
yapf
WoosukKwon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
import argparse | ||
import random | ||
import time | ||
|
||
import torch | ||
|
||
from vllm import attention_ops | ||
|
||
NUM_BLOCKS = 1024 | ||
PARTITION_SIZE = 512 | ||
|
||
|
||
@torch.inference_mode() | ||
def main( | ||
version: int, | ||
num_seqs: int, | ||
context_len: int, | ||
num_query_heads: int, | ||
num_kv_heads: int, | ||
head_size: int, | ||
use_alibi: bool, | ||
block_size: int, | ||
dtype: torch.dtype, | ||
seed: int, | ||
do_profile: bool, | ||
) -> None: | ||
random.seed(seed) | ||
torch.random.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
|
||
scale = float(1.0 / (head_size**0.5)) | ||
query = torch.empty(num_seqs, | ||
num_query_heads, | ||
head_size, | ||
dtype=dtype, | ||
device="cuda") | ||
query.uniform_(-scale, scale) | ||
|
||
assert num_query_heads % num_kv_heads == 0 | ||
num_queries_per_kv = num_query_heads // num_kv_heads | ||
head_mapping = torch.repeat_interleave( | ||
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), | ||
num_queries_per_kv) | ||
alibi_slopes = None | ||
if use_alibi: | ||
alibi_slopes = torch.randn(num_query_heads, | ||
dtype=torch.float, | ||
device="cuda") | ||
|
||
context_lens = [context_len for _ in range(num_seqs)] | ||
max_context_len = max(context_lens) | ||
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") | ||
|
||
# Create the block tables. | ||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size | ||
block_tables = [] | ||
for _ in range(num_seqs): | ||
block_table = [ | ||
random.randint(0, NUM_BLOCKS - 1) | ||
for _ in range(max_num_blocks_per_seq) | ||
] | ||
block_tables.append(block_table) | ||
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") | ||
|
||
# Create the KV cache. | ||
x = 16 // torch.tensor([], dtype=dtype).element_size() | ||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) | ||
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda") | ||
key_cache.uniform_(-scale, scale) | ||
value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size) | ||
value_cache = torch.empty(size=value_cache_shape, | ||
dtype=dtype, | ||
device="cuda") | ||
value_cache.uniform_(-scale, scale) | ||
|
||
# Prepare for the paged attention kernel. | ||
output = torch.empty_like(query) | ||
if version == 2: | ||
num_partitions = ((max_context_len + PARTITION_SIZE - 1) // | ||
PARTITION_SIZE) | ||
tmp_output = torch.empty( | ||
size=(num_seqs, num_query_heads, num_partitions, head_size), | ||
dtype=output.dtype, | ||
device=output.device, | ||
) | ||
exp_sums = torch.empty( | ||
size=(num_seqs, num_query_heads, num_partitions), | ||
dtype=torch.float32, | ||
device=output.device, | ||
) | ||
max_logits = torch.empty_like(exp_sums) | ||
|
||
def run_benchmark(num_iters: int, profile: bool = False) -> float: | ||
if profile: | ||
torch.cuda.cudart().cudaProfilerStart() | ||
start_time = time.perf_counter() | ||
WoosukKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
for _ in range(num_iters): | ||
if version == 1: | ||
attention_ops.paged_attention_v1( | ||
output, | ||
query, | ||
key_cache, | ||
value_cache, | ||
head_mapping, | ||
scale, | ||
block_tables, | ||
context_lens, | ||
block_size, | ||
max_context_len, | ||
alibi_slopes, | ||
) | ||
else: | ||
attention_ops.paged_attention_v2( | ||
output, | ||
exp_sums, | ||
max_logits, | ||
tmp_output, | ||
query, | ||
key_cache, | ||
value_cache, | ||
head_mapping, | ||
scale, | ||
block_tables, | ||
context_lens, | ||
block_size, | ||
max_context_len, | ||
alibi_slopes, | ||
) | ||
torch.cuda.synchronize() | ||
|
||
end_time = time.perf_counter() | ||
if profile: | ||
torch.cuda.cudart().cudaProfilerStart() | ||
return (end_time - start_time) / num_iters | ||
|
||
# Warmup. | ||
print("Warming up...") | ||
run_benchmark(num_iters=3, profile=False) | ||
torch.cuda.synchronize() | ||
WoosukKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Benchmark. | ||
if do_profile: | ||
latency = run_benchmark(num_iters=1, profile=True) | ||
else: | ||
latency = run_benchmark(num_iters=100, profile=False) | ||
print(f"Kernel running time: {latency * 1000000:.3f} us") | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser( | ||
description="Benchmark the paged attention kernel.") | ||
parser.add_argument("--version", type=int, choices=[1, 2], default=2) | ||
parser.add_argument("--batch-size", type=int, default=8) | ||
parser.add_argument("--context-len", type=int, default=4096) | ||
parser.add_argument("--num-query-heads", type=int, default=64) | ||
parser.add_argument("--num-kv-heads", type=int, default=8) | ||
parser.add_argument("--head-size", | ||
type=int, | ||
choices=[64, 80, 96, 112, 128, 256], | ||
default=128) | ||
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) | ||
parser.add_argument("--use-alibi", action="store_true") | ||
parser.add_argument("--dtype", | ||
type=str, | ||
choices=["half", "bfloat16", "float"], | ||
default="half") | ||
parser.add_argument("--seed", type=int, default=0) | ||
parser.add_argument("--profile", action="store_true") | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
if args.num_query_heads % args.num_kv_heads != 0: | ||
raise ValueError("num_query_heads must be divisible by num_kv_heads") | ||
dtype_to_torch_dtype = { | ||
"half": torch.half, | ||
"bfloat16": torch.bfloat16, | ||
"float": torch.float, | ||
} | ||
main( | ||
version=args.version, | ||
num_seqs=args.batch_size, | ||
context_len=args.context_len, | ||
num_query_heads=args.num_query_heads, | ||
num_kv_heads=args.num_kv_heads, | ||
head_size=args.head_size, | ||
block_size=args.block_size, | ||
use_alibi=args.use_alibi, | ||
dtype=dtype_to_torch_dtype[args.dtype], | ||
seed=args.seed, | ||
do_profile=args.profile, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.