@@ -8,7 +8,7 @@ __global__ void rotary_embedding_neox_kernel(
8
8
const int64_t * __restrict__ positions, // [num_tokens]
9
9
scalar_t * __restrict__ query, // [num_tokens, num_heads, head_size]
10
10
scalar_t * __restrict__ key, // [num_tokens, num_kv_heads, head_size]
11
- const scalar_t * __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
11
+ const float * __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
12
12
const int rot_dim,
13
13
const int query_stride,
14
14
const int key_stride,
@@ -18,7 +18,7 @@ __global__ void rotary_embedding_neox_kernel(
18
18
// Each thread block is responsible for one token.
19
19
const int token_idx = blockIdx .x ;
20
20
int64_t pos = positions[token_idx];
21
- const scalar_t * cache_ptr = cos_sin_cache + pos * rot_dim;
21
+ const float * cache_ptr = cos_sin_cache + pos * rot_dim;
22
22
23
23
const int embed_dim = rot_dim / 2 ;
24
24
const int nq = num_heads * embed_dim;
@@ -33,13 +33,13 @@ __global__ void rotary_embedding_neox_kernel(
33
33
const int out_x = token_idx * query_stride + head_idx * head_size + x_index;
34
34
const int out_y = token_idx * query_stride + head_idx * head_size + y_index;
35
35
36
- const scalar_t cos = __ldg (cache_ptr + x_index);
37
- const scalar_t sin = __ldg (cache_ptr + y_index);
36
+ const float cos = __ldg (cache_ptr + x_index);
37
+ const float sin = __ldg (cache_ptr + y_index);
38
38
39
- const scalar_t q_x = query[token_head + x_index];
40
- const scalar_t q_y = query[token_head + y_index];
41
- query[out_x] = q_x * cos - q_y * sin ;
42
- query[out_y] = q_y * cos + q_x * sin ;
39
+ const float q_x = static_cast < float > ( query[token_head + x_index]) ;
40
+ const float q_y = static_cast < float > ( query[token_head + y_index]) ;
41
+ query[out_x] = static_cast < scalar_t > ( q_x * cos - q_y * sin ) ;
42
+ query[out_y] = static_cast < scalar_t > ( q_y * cos + q_x * sin ) ;
43
43
}
44
44
45
45
const int nk = num_kv_heads * embed_dim;
@@ -54,13 +54,13 @@ __global__ void rotary_embedding_neox_kernel(
54
54
const int out_x = token_idx * key_stride + head_idx * head_size + x_index;
55
55
const int out_y = token_idx * key_stride + head_idx * head_size + y_index;
56
56
57
- const scalar_t cos = __ldg (cache_ptr + x_index);
58
- const scalar_t sin = __ldg (cache_ptr + y_index);
57
+ const float cos = __ldg (cache_ptr + x_index);
58
+ const float sin = __ldg (cache_ptr + y_index);
59
59
60
- const scalar_t k_x = key[token_head + x_index];
61
- const scalar_t k_y = key[token_head + y_index];
62
- key[out_x] = k_x * cos - k_y * sin ;
63
- key[out_y] = k_y * cos + k_x * sin ;
60
+ const float k_x = static_cast < float > ( key[token_head + x_index]) ;
61
+ const float k_y = static_cast < float > ( key[token_head + y_index]) ;
62
+ key[out_x] = static_cast < scalar_t > ( k_x * cos - k_y * sin ) ;
63
+ key[out_y] = static_cast < scalar_t > ( k_y * cos + k_x * sin ) ;
64
64
}
65
65
}
66
66
@@ -93,7 +93,7 @@ void rotary_embedding_neox(
93
93
positions.data_ptr <int64_t >(),
94
94
query.data_ptr <scalar_t >(),
95
95
key.data_ptr <scalar_t >(),
96
- cos_sin_cache.data_ptr <scalar_t >(),
96
+ cos_sin_cache.data_ptr <float >(),
97
97
rot_dim,
98
98
query_stride,
99
99
key_stride,
0 commit comments