Skip to content

Commit 890b3bc

Browse files
committed
RoPE in float32 precision
1 parent 791d79d commit 890b3bc

File tree

2 files changed

+18
-23
lines changed

2 files changed

+18
-23
lines changed

csrc/pos_encoding_kernels.cu

+15-15
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ __global__ void rotary_embedding_neox_kernel(
88
const int64_t* __restrict__ positions, // [num_tokens]
99
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
1010
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
11-
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
11+
const float* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
1212
const int rot_dim,
1313
const int query_stride,
1414
const int key_stride,
@@ -18,7 +18,7 @@ __global__ void rotary_embedding_neox_kernel(
1818
// Each thread block is responsible for one token.
1919
const int token_idx = blockIdx.x;
2020
int64_t pos = positions[token_idx];
21-
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
21+
const float* cache_ptr = cos_sin_cache + pos * rot_dim;
2222

2323
const int embed_dim = rot_dim / 2;
2424
const int nq = num_heads * embed_dim;
@@ -33,13 +33,13 @@ __global__ void rotary_embedding_neox_kernel(
3333
const int out_x = token_idx * query_stride + head_idx * head_size + x_index;
3434
const int out_y = token_idx * query_stride + head_idx * head_size + y_index;
3535

36-
const scalar_t cos = __ldg(cache_ptr + x_index);
37-
const scalar_t sin = __ldg(cache_ptr + y_index);
36+
const float cos = __ldg(cache_ptr + x_index);
37+
const float sin = __ldg(cache_ptr + y_index);
3838

39-
const scalar_t q_x = query[token_head + x_index];
40-
const scalar_t q_y = query[token_head + y_index];
41-
query[out_x] = q_x * cos - q_y * sin;
42-
query[out_y] = q_y * cos + q_x * sin;
39+
const float q_x = static_cast<float> (query[token_head + x_index]);
40+
const float q_y = static_cast<float> (query[token_head + y_index]);
41+
query[out_x] = static_cast<scalar_t> (q_x * cos - q_y * sin);
42+
query[out_y] = static_cast<scalar_t> (q_y * cos + q_x * sin);
4343
}
4444

4545
const int nk = num_kv_heads * embed_dim;
@@ -54,13 +54,13 @@ __global__ void rotary_embedding_neox_kernel(
5454
const int out_x = token_idx * key_stride + head_idx * head_size + x_index;
5555
const int out_y = token_idx * key_stride + head_idx * head_size + y_index;
5656

57-
const scalar_t cos = __ldg(cache_ptr + x_index);
58-
const scalar_t sin = __ldg(cache_ptr + y_index);
57+
const float cos = __ldg(cache_ptr + x_index);
58+
const float sin = __ldg(cache_ptr + y_index);
5959

60-
const scalar_t k_x = key[token_head + x_index];
61-
const scalar_t k_y = key[token_head + y_index];
62-
key[out_x] = k_x * cos - k_y * sin;
63-
key[out_y] = k_y * cos + k_x * sin;
60+
const float k_x = static_cast<float> (key[token_head + x_index]);
61+
const float k_y = static_cast<float> (key[token_head + y_index]);
62+
key[out_x] = static_cast<scalar_t> (k_x * cos - k_y * sin);
63+
key[out_y] = static_cast<scalar_t> (k_y * cos + k_x * sin);
6464
}
6565
}
6666

@@ -93,7 +93,7 @@ void rotary_embedding_neox(
9393
positions.data_ptr<int64_t>(),
9494
query.data_ptr<scalar_t>(),
9595
key.data_ptr<scalar_t>(),
96-
cos_sin_cache.data_ptr<scalar_t>(),
96+
cos_sin_cache.data_ptr<float>(),
9797
rot_dim,
9898
query_stride,
9999
key_stride,

vllm/model_executor/layers/attention.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -259,18 +259,13 @@ def __init__(
259259
super().__init__(num_heads, head_size, scale, num_kv_heads)
260260

261261
# Create the cos and sin cache.
262-
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
263-
t = torch.arange(max_position).float()
264-
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
262+
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float32) / rotary_dim))
263+
t = torch.arange(max_position, dtype=torch.float32)
264+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
265265
cos = freqs.cos()
266266
sin = freqs.sin()
267267
cache = torch.cat((cos, sin), dim=-1)
268268

269-
# FIXME(woosuk): This assumes that we configure the default dtype when
270-
# initializing the model.
271-
# TODO(woosuk): Make it more robust.
272-
torch_dtype = torch.get_default_dtype()
273-
cache = cache.to(torch_dtype)
274269
# Embedding size: [max_position, rotary_dim]
275270
self.register_buffer("cos_sin_cache", cache, persistent=False)
276271

0 commit comments

Comments
 (0)