8
8
namespace vllm {
9
9
10
10
template <typename scalar_t , bool IS_NEOX>
11
- inline __device__ void apply_rotary_embedding (
11
+ inline __device__ void apply_token_rotary_embedding (
12
12
scalar_t * __restrict__ arr,
13
13
const scalar_t * __restrict__ cos_ptr,
14
14
const scalar_t * __restrict__ sin_ptr,
@@ -38,22 +38,18 @@ inline __device__ void apply_rotary_embedding(
38
38
}
39
39
40
40
template <typename scalar_t , bool IS_NEOX>
41
- __global__ void rotary_embedding_kernel (
42
- const int64_t * __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
41
+ inline __device__ void apply_rotary_embedding (
43
42
scalar_t * __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
44
43
scalar_t * __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
45
- const scalar_t * __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
46
- const int rot_dim,
47
- const int64_t query_stride,
48
- const int64_t key_stride,
44
+ const scalar_t * cache_ptr,
45
+ const int head_size,
49
46
const int num_heads,
50
47
const int num_kv_heads,
51
- const int head_size) {
52
- // Each thread block is responsible for one token.
53
- const int token_idx = blockIdx .x ;
54
- int64_t pos = positions[token_idx];
55
- const scalar_t * cache_ptr = cos_sin_cache + pos * rot_dim;
56
-
48
+ const int rot_dim,
49
+ const int token_idx,
50
+ const int64_t query_stride,
51
+ const int64_t key_stride)
52
+ {
57
53
const int embed_dim = rot_dim / 2 ;
58
54
const scalar_t * cos_ptr = cache_ptr;
59
55
const scalar_t * sin_ptr = cache_ptr + embed_dim;
@@ -63,7 +59,7 @@ __global__ void rotary_embedding_kernel(
63
59
const int head_idx = i / embed_dim;
64
60
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
65
61
const int rot_offset = i % embed_dim;
66
- apply_rotary_embedding <scalar_t , IS_NEOX>(query + token_head, cos_ptr,
62
+ apply_token_rotary_embedding <scalar_t , IS_NEOX>(query + token_head, cos_ptr,
67
63
sin_ptr, rot_offset, embed_dim);
68
64
}
69
65
@@ -72,11 +68,53 @@ __global__ void rotary_embedding_kernel(
72
68
const int head_idx = i / embed_dim;
73
69
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
74
70
const int rot_offset = i % embed_dim;
75
- apply_rotary_embedding <scalar_t , IS_NEOX>(key + token_head, cos_ptr,
71
+ apply_token_rotary_embedding <scalar_t , IS_NEOX>(key + token_head, cos_ptr,
76
72
sin_ptr, rot_offset, embed_dim);
77
73
}
78
74
}
79
75
76
+ template <typename scalar_t , bool IS_NEOX>
77
+ __global__ void rotary_embedding_kernel (
78
+ const int64_t * __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
79
+ scalar_t * __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
80
+ scalar_t * __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
81
+ const scalar_t * __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
82
+ const int rot_dim,
83
+ const int64_t query_stride,
84
+ const int64_t key_stride,
85
+ const int num_heads,
86
+ const int num_kv_heads,
87
+ const int head_size) {
88
+ // Each thread block is responsible for one token.
89
+ const int token_idx = blockIdx .x ;
90
+ int64_t pos = positions[token_idx];
91
+ const scalar_t * cache_ptr = cos_sin_cache + pos * rot_dim;
92
+
93
+ apply_rotary_embedding<scalar_t , IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
94
+ }
95
+
96
+ template <typename scalar_t , bool IS_NEOX>
97
+ __global__ void batched_rotary_embedding_kernel (
98
+ const int64_t * __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
99
+ scalar_t * __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
100
+ scalar_t * __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
101
+ const scalar_t * __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
102
+ const int64_t * __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens]
103
+ const int rot_dim,
104
+ const int64_t query_stride,
105
+ const int64_t key_stride,
106
+ const int num_heads,
107
+ const int num_kv_heads,
108
+ const int head_size) {
109
+ // Each thread block is responsible for one token.
110
+ const int token_idx = blockIdx .x ;
111
+ int64_t pos = positions[token_idx];
112
+ int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
113
+ const scalar_t * cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
114
+
115
+ apply_rotary_embedding<scalar_t , IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
116
+ }
117
+
80
118
} // namespace vllm
81
119
82
120
void rotary_embedding (
@@ -128,3 +166,61 @@ void rotary_embedding(
128
166
}
129
167
});
130
168
}
169
+
170
+ /*
171
+ Batched version of rotary embedding, pack multiple LoRAs together
172
+ and process in batched manner.
173
+ */
174
+ void batched_rotary_embedding (
175
+ torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
176
+ torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
177
+ torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
178
+ int head_size,
179
+ torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
180
+ bool is_neox,
181
+ int rot_dim,
182
+ torch::Tensor& cos_sin_cache_offsets // [num_tokens]
183
+ ) {
184
+ int64_t num_tokens = cos_sin_cache_offsets.size (0 );
185
+ int num_heads = query.size (-1 ) / head_size;
186
+ int num_kv_heads = key.size (-1 ) / head_size;
187
+ int64_t query_stride = query.stride (-2 );
188
+ int64_t key_stride = key.stride (-2 );
189
+
190
+ dim3 grid (num_tokens);
191
+ dim3 block (std::min (num_heads * rot_dim / 2 , 512 ));
192
+ const at::cuda::OptionalCUDAGuard device_guard (device_of (query));
193
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
194
+ VLLM_DISPATCH_FLOATING_TYPES (
195
+ query.scalar_type (),
196
+ " rotary_embedding" ,
197
+ [&] {
198
+ if (is_neox) {
199
+ vllm::batched_rotary_embedding_kernel<scalar_t , true ><<<grid, block, 0 , stream>>> (
200
+ positions.data_ptr <int64_t >(),
201
+ query.data_ptr <scalar_t >(),
202
+ key.data_ptr <scalar_t >(),
203
+ cos_sin_cache.data_ptr <scalar_t >(),
204
+ cos_sin_cache_offsets.data_ptr <int64_t >(),
205
+ rot_dim,
206
+ query_stride,
207
+ key_stride,
208
+ num_heads,
209
+ num_kv_heads,
210
+ head_size);
211
+ } else {
212
+ vllm::batched_rotary_embedding_kernel<scalar_t , false ><<<grid, block, 0 , stream>>> (
213
+ positions.data_ptr <int64_t >(),
214
+ query.data_ptr <scalar_t >(),
215
+ key.data_ptr <scalar_t >(),
216
+ cos_sin_cache.data_ptr <scalar_t >(),
217
+ cos_sin_cache_offsets.data_ptr <int64_t >(),
218
+ rot_dim,
219
+ query_stride,
220
+ key_stride,
221
+ num_heads,
222
+ num_kv_heads,
223
+ head_size);
224
+ }
225
+ });
226
+ }
0 commit comments