@@ -98,7 +98,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi
98
98
for (int m = 0 ; m < size<1 >(tOgO); ++m) {
99
99
const int row = get<0 >(tOcO (0 , m, 0 ));
100
100
if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1 >(tOcO (0 , m, 0 )) == 0 ) {
101
- gLSE (row) = INFINITY ;
101
+ gLSE (row) = std::numeric_limits<ElementAccum>:: infinity () ;
102
102
}
103
103
}
104
104
return ;
@@ -499,7 +499,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons
499
499
for (int m = 0 ; m < size<1 >(tOgOaccum); ++m) {
500
500
const int row = get<0 >(tOcO (0 , m, 0 ));
501
501
if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1 >(tOcO (0 , m, 0 )) == 0 ) {
502
- gLSEaccum (row) = Split ? -INFINITY : INFINITY ;
502
+ gLSEaccum (row) = Split ? -std::numeric_limits<ElementAccum>:: infinity () : std::numeric_limits<ElementAccum>:: infinity () ;
503
503
}
504
504
}
505
505
return ;
@@ -1061,7 +1061,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) {
1061
1061
for (int l = 0 ; l < kNLsePerThread ; ++l) {
1062
1062
const int row = l * kRowsPerLoadLSE + tidx / kBlockM ;
1063
1063
const int col = tidx % kBlockM ;
1064
- ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM ) ? gLSEaccum (row, col) : -INFINITY ;
1064
+ ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM ) ? gLSEaccum (row, col) : -std::numeric_limits<ElementAccum>:: infinity () ;
1065
1065
if (row < kMaxSplits ) {
1066
1066
sLSE [row][col] = lse;
1067
1067
}
@@ -1082,7 +1082,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) {
1082
1082
for (int l = 0 ; l < kNLsePerThread ; ++l) {
1083
1083
const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose ;
1084
1084
const int col = tidx / kRowsPerLoadTranspose ;
1085
- lse_accum (l) = (row < kMaxSplits && col < kBlockM ) ? sLSE [row][col] : -INFINITY ;
1085
+ lse_accum (l) = (row < kMaxSplits && col < kBlockM ) ? sLSE [row][col] : -std::numeric_limits<ElementAccum>:: infinity () ;
1086
1086
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); }
1087
1087
}
1088
1088
@@ -1094,7 +1094,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) {
1094
1094
}
1095
1095
MaxOp<float > max_op;
1096
1096
lse_max = Allreduce<kRowsPerLoadTranspose >::run (lse_max, max_op);
1097
- lse_max = lse_max == -INFINITY ? 0 .0f : lse_max; // In case all local LSEs are -inf
1097
+ lse_max = lse_max == -std::numeric_limits<ElementAccum>:: infinity () ? 0 .0f : lse_max; // In case all local LSEs are -inf
1098
1098
float lse_sum = expf (lse_accum (0 ) - lse_max);
1099
1099
#pragma unroll
1100
1100
for (int l = 1 ; l < kNLsePerThread ; ++l) {
@@ -1104,7 +1104,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params& params) {
1104
1104
lse_sum = Allreduce<kRowsPerLoadTranspose >::run (lse_sum, sum_op);
1105
1105
// For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise
1106
1106
// lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
1107
- ElementAccum lse_logsum = (lse_sum == 0 .f || lse_sum != lse_sum) ? INFINITY : logf (lse_sum) + lse_max;
1107
+ ElementAccum lse_logsum = (lse_sum == 0 .f || lse_sum != lse_sum) ? std::numeric_limits<ElementAccum>:: infinity () : logf (lse_sum) + lse_max;
1108
1108
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
1109
1109
if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM ) {
1110
1110
gLSE (tidx / kRowsPerLoadTranspose ) = lse_logsum;
0 commit comments