Skip to content

Commit 7e9bd08

Browse files
authored
Add batched RoPE kernel (#3095)
1 parent ae0ccb4 commit 7e9bd08

File tree

6 files changed

+417
-37
lines changed

6 files changed

+417
-37
lines changed

benchmarks/kernels/benchmark_rope.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from typing import Optional
2+
3+
import argparse
4+
import torch
5+
import nvtx
6+
from itertools import accumulate
7+
from vllm.model_executor.layers.rotary_embedding import get_rope
8+
9+
10+
def benchmark_rope_kernels_multi_lora(
11+
is_neox_style: bool,
12+
batch_size: int,
13+
seq_len: int,
14+
num_heads: int,
15+
head_size: int,
16+
rotary_dim: Optional[int],
17+
dtype: torch.dtype,
18+
seed: int,
19+
device: str,
20+
max_position: int = 8192,
21+
base: int = 10000,
22+
) -> None:
23+
torch.random.manual_seed(seed)
24+
if torch.cuda.is_available():
25+
torch.cuda.manual_seed(seed)
26+
torch.set_default_device(device)
27+
if rotary_dim is None:
28+
rotary_dim = head_size
29+
# silulating serving 4 LoRAs
30+
scaling_factors = [1, 2, 4, 8]
31+
# batched RoPE can take multiple scaling factors
32+
batched_rope = get_rope(head_size, rotary_dim, max_position, base,
33+
is_neox_style, {
34+
"type": "linear",
35+
"factor": tuple(scaling_factors)
36+
})
37+
# non-batched RoPE takes only one scaling factor, we create multiple
38+
# instances to simulate the same behavior
39+
non_batched_ropes = []
40+
for scaling_factor in scaling_factors:
41+
non_batched_ropes.append(
42+
get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
43+
{
44+
"type": "linear",
45+
"factor": (scaling_factor, )
46+
}))
47+
48+
positions = torch.randint(0, max_position, (batch_size, seq_len))
49+
query = torch.randn(batch_size,
50+
seq_len,
51+
num_heads * head_size,
52+
dtype=dtype)
53+
key = torch.randn_like(query)
54+
55+
# create query offsets for batched RoPE, we concat multiple kv cache
56+
# together and each query needs to find the right kv cache of its type
57+
offset_map = torch.tensor(
58+
list(
59+
accumulate([0] + [
60+
max_position * scaling_factor * 2
61+
for scaling_factor in scaling_factors[:-1]
62+
])))
63+
query_types = torch.randint(0,
64+
len(scaling_factors), (batch_size, seq_len),
65+
device=device)
66+
# map query types to offsets
67+
query_offsets = offset_map[query_types]
68+
# the kernel takes flattened offsets
69+
flatten_offsets = query_offsets.flatten()
70+
71+
# batched queries of the same type together for non-batched RoPE
72+
queries = [query[query_types == i] for i in range(len(scaling_factors))]
73+
keys = [key[query_types == i] for i in range(len(scaling_factors))]
74+
packed_qkr = zip(queries, keys, non_batched_ropes)
75+
# synchronize before start timing
76+
torch.cuda.synchronize()
77+
with nvtx.annotate("non-batched", color="yellow"):
78+
for q, k, r in packed_qkr:
79+
r.forward(positions, q, k)
80+
torch.cuda.synchronize()
81+
with nvtx.annotate("batched", color="green"):
82+
batched_rope.forward(positions, query, key, flatten_offsets)
83+
torch.cuda.synchronize()
84+
85+
86+
if __name__ == '__main__':
87+
parser = argparse.ArgumentParser(
88+
description="Benchmark the rotary embedding kernels.")
89+
parser.add_argument("--is-neox-style", type=bool, default=True)
90+
parser.add_argument("--batch-size", type=int, default=16)
91+
parser.add_argument("--seq-len", type=int, default=512)
92+
parser.add_argument("--num-heads", type=int, default=8)
93+
parser.add_argument("--head-size",
94+
type=int,
95+
choices=[64, 80, 96, 112, 128, 256],
96+
default=128)
97+
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
98+
parser.add_argument("--dtype",
99+
type=str,
100+
choices=["bfloat16", "float"],
101+
default="float")
102+
parser.add_argument("--seed", type=int, default=0)
103+
parser.add_argument("--device",
104+
type=str,
105+
choices=["cuda:0", "cuda:1"],
106+
default="cuda:0")
107+
args = parser.parse_args()
108+
print(args)
109+
110+
benchmark_rope_kernels_multi_lora(
111+
is_neox_style=args.is_neox_style,
112+
batch_size=args.batch_size,
113+
seq_len=args.seq_len,
114+
num_heads=args.num_heads,
115+
head_size=args.head_size,
116+
rotary_dim=args.rotary_dim,
117+
dtype=getattr(torch, args.dtype),
118+
seed=args.seed,
119+
device=args.device,
120+
)

csrc/ops.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ void rotary_embedding(
5353
torch::Tensor& cos_sin_cache,
5454
bool is_neox);
5555

56+
void batched_rotary_embedding(
57+
torch::Tensor& positions,
58+
torch::Tensor& query,
59+
torch::Tensor& key,
60+
int head_size,
61+
torch::Tensor& cos_sin_cache,
62+
bool is_neox,
63+
int rot_dim,
64+
torch::Tensor& cos_sin_cache_offsets);
65+
5666
void silu_and_mul(
5767
torch::Tensor& out,
5868
torch::Tensor& input);

csrc/pos_encoding_kernels.cu

Lines changed: 111 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
namespace vllm {
99

1010
template<typename scalar_t, bool IS_NEOX>
11-
inline __device__ void apply_rotary_embedding(
11+
inline __device__ void apply_token_rotary_embedding(
1212
scalar_t* __restrict__ arr,
1313
const scalar_t* __restrict__ cos_ptr,
1414
const scalar_t* __restrict__ sin_ptr,
@@ -38,22 +38,18 @@ inline __device__ void apply_rotary_embedding(
3838
}
3939

4040
template<typename scalar_t, bool IS_NEOX>
41-
__global__ void rotary_embedding_kernel(
42-
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
41+
inline __device__ void apply_rotary_embedding(
4342
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
4443
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
45-
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
46-
const int rot_dim,
47-
const int64_t query_stride,
48-
const int64_t key_stride,
44+
const scalar_t* cache_ptr,
45+
const int head_size,
4946
const int num_heads,
5047
const int num_kv_heads,
51-
const int head_size) {
52-
// Each thread block is responsible for one token.
53-
const int token_idx = blockIdx.x;
54-
int64_t pos = positions[token_idx];
55-
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
56-
48+
const int rot_dim,
49+
const int token_idx,
50+
const int64_t query_stride,
51+
const int64_t key_stride)
52+
{
5753
const int embed_dim = rot_dim / 2;
5854
const scalar_t* cos_ptr = cache_ptr;
5955
const scalar_t* sin_ptr = cache_ptr + embed_dim;
@@ -63,7 +59,7 @@ __global__ void rotary_embedding_kernel(
6359
const int head_idx = i / embed_dim;
6460
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
6561
const int rot_offset = i % embed_dim;
66-
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
62+
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
6763
sin_ptr, rot_offset, embed_dim);
6864
}
6965

@@ -72,11 +68,53 @@ __global__ void rotary_embedding_kernel(
7268
const int head_idx = i / embed_dim;
7369
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
7470
const int rot_offset = i % embed_dim;
75-
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
71+
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
7672
sin_ptr, rot_offset, embed_dim);
7773
}
7874
}
7975

76+
template<typename scalar_t, bool IS_NEOX>
77+
__global__ void rotary_embedding_kernel(
78+
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
79+
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
80+
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
81+
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
82+
const int rot_dim,
83+
const int64_t query_stride,
84+
const int64_t key_stride,
85+
const int num_heads,
86+
const int num_kv_heads,
87+
const int head_size) {
88+
// Each thread block is responsible for one token.
89+
const int token_idx = blockIdx.x;
90+
int64_t pos = positions[token_idx];
91+
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
92+
93+
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
94+
}
95+
96+
template<typename scalar_t, bool IS_NEOX>
97+
__global__ void batched_rotary_embedding_kernel(
98+
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
99+
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
100+
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
101+
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
102+
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens]
103+
const int rot_dim,
104+
const int64_t query_stride,
105+
const int64_t key_stride,
106+
const int num_heads,
107+
const int num_kv_heads,
108+
const int head_size) {
109+
// Each thread block is responsible for one token.
110+
const int token_idx = blockIdx.x;
111+
int64_t pos = positions[token_idx];
112+
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
113+
const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
114+
115+
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
116+
}
117+
80118
} // namespace vllm
81119

82120
void rotary_embedding(
@@ -128,3 +166,61 @@ void rotary_embedding(
128166
}
129167
});
130168
}
169+
170+
/*
171+
Batched version of rotary embedding, pack multiple LoRAs together
172+
and process in batched manner.
173+
*/
174+
void batched_rotary_embedding(
175+
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
176+
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
177+
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
178+
int head_size,
179+
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
180+
bool is_neox,
181+
int rot_dim,
182+
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
183+
) {
184+
int64_t num_tokens = cos_sin_cache_offsets.size(0);
185+
int num_heads = query.size(-1) / head_size;
186+
int num_kv_heads = key.size(-1) / head_size;
187+
int64_t query_stride = query.stride(-2);
188+
int64_t key_stride = key.stride(-2);
189+
190+
dim3 grid(num_tokens);
191+
dim3 block(std::min(num_heads * rot_dim / 2, 512));
192+
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
193+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
194+
VLLM_DISPATCH_FLOATING_TYPES(
195+
query.scalar_type(),
196+
"rotary_embedding",
197+
[&] {
198+
if (is_neox) {
199+
vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
200+
positions.data_ptr<int64_t>(),
201+
query.data_ptr<scalar_t>(),
202+
key.data_ptr<scalar_t>(),
203+
cos_sin_cache.data_ptr<scalar_t>(),
204+
cos_sin_cache_offsets.data_ptr<int64_t>(),
205+
rot_dim,
206+
query_stride,
207+
key_stride,
208+
num_heads,
209+
num_kv_heads,
210+
head_size);
211+
} else {
212+
vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
213+
positions.data_ptr<int64_t>(),
214+
query.data_ptr<scalar_t>(),
215+
key.data_ptr<scalar_t>(),
216+
cos_sin_cache.data_ptr<scalar_t>(),
217+
cos_sin_cache_offsets.data_ptr<int64_t>(),
218+
rot_dim,
219+
query_stride,
220+
key_stride,
221+
num_heads,
222+
num_kv_heads,
223+
head_size);
224+
}
225+
});
226+
}

csrc/pybind.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
5656
&rotary_embedding,
5757
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
5858

59+
ops.def(
60+
"batched_rotary_embedding",
61+
&batched_rotary_embedding,
62+
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)");
63+
5964
// Quantization ops
6065
#ifndef USE_ROCM
6166
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");

0 commit comments

Comments
 (0)