Skip to content

Commit 5dba257

Browse files
authored
Resolve race conditions in Marlin kernel (#11493)
Signed-off-by: wchen61 <[email protected]>
1 parent 187e329 commit 5dba257

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

csrc/quantization/gptq_marlin/gptq_marlin.cu

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,7 @@ __global__ void Marlin(
834834
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
835835
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
836836
int4* sh_s = sh_zp + (stages * zp_sh_stage);
837+
int4* sh_red = sh_s + (stages * s_sh_stage);
837838

838839
// Register storage for double buffer of shared memory reads.
839840
FragA frag_a[2][thread_m_blocks];
@@ -932,11 +933,11 @@ __global__ void Marlin(
932933
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
933934

934935
if constexpr (group_blocks >= thread_k_blocks) {
936+
if (s_sh_wr_pred) {
937+
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
938+
}
935939
// Only fetch scales if this tile starts a new group
936-
if (pipe % (group_blocks / thread_k_blocks) == 0) {
937-
if (s_sh_wr_pred) {
938-
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
939-
}
940+
if ((pipe + 1) % (group_blocks / thread_k_blocks) == 0) {
940941
s_gl_rd += s_gl_rd_delta;
941942
}
942943
} else {
@@ -1038,9 +1039,7 @@ __global__ void Marlin(
10381039
// No act-order case
10391040
if constexpr (group_blocks != -1) {
10401041
if constexpr (group_blocks >= thread_k_blocks) {
1041-
int4* sh_s_stage =
1042-
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
1043-
(pipe / (group_blocks / thread_k_blocks)));
1042+
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
10441043
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
10451044
} else {
10461045
int warp_id = threadIdx.x / 32;
@@ -1339,15 +1338,15 @@ __global__ void Marlin(
13391338
int red_sh_wr =
13401339
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
13411340
if (i < red_off) {
1342-
float* c_rd =
1343-
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
1344-
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
1341+
float* c_rd = reinterpret_cast<float*>(
1342+
&sh_red[red_sh_delta * j + red_sh_rd]);
1343+
float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);
13451344
#pragma unroll
13461345
for (int k = 0; k < 4; k++)
13471346
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
13481347
c_rd[k] + c_wr[k];
13491348
}
1350-
sh[red_sh_wr] =
1349+
sh_red[red_sh_wr] =
13511350
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
13521351
}
13531352
}
@@ -1357,7 +1356,7 @@ __global__ void Marlin(
13571356
#pragma unroll
13581357
for (int i = 0; i < 4 * 2; i++) {
13591358
float* c_rd =
1360-
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
1359+
reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);
13611360
#pragma unroll
13621361
for (int j = 0; j < 4; j++)
13631362
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
@@ -1397,7 +1396,7 @@ __global__ void Marlin(
13971396
#pragma unroll
13981397
for (int i = 0; i < thread_m_blocks * 4; i++) {
13991398
cp_async4_pred(
1400-
&sh[c_sh_wr + c_sh_wr_delta * i],
1399+
&sh_red[c_sh_wr + c_sh_wr_delta * i],
14011400
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
14021401
c_gl_wr_delta_i * (i % 2)],
14031402
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
@@ -1410,7 +1409,7 @@ __global__ void Marlin(
14101409
for (int i = 0; i < thread_m_blocks * 4; i++) {
14111410
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
14121411
if (!first) {
1413-
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
1412+
int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta];
14141413
#pragma unroll
14151414
for (int j = 0; j < 2 * 4; j++) {
14161415
reinterpret_cast<float*>(
@@ -1461,10 +1460,10 @@ __global__ void Marlin(
14611460
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
14621461
#pragma unroll
14631462
for (int k = 0; k < th_size; k++) {
1464-
sh[threadIdx.x] =
1463+
sh_red[threadIdx.x] =
14651464
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
14661465

1467-
float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]);
1466+
float* sh_c_ptr = reinterpret_cast<float*>(&sh_red[threadIdx.x]);
14681467
#pragma unroll
14691468
for (int f = 0; f < 4; f++) {
14701469
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
@@ -1515,7 +1514,7 @@ __global__ void Marlin(
15151514
res = __hmul2(res, s[0]);
15161515
}
15171516

1518-
((scalar_t2*)sh)[idx] = res;
1517+
((scalar_t2*)sh_red)[idx] = res;
15191518
};
15201519

15211520
if (threadIdx.x / 32 < thread_n_blocks / 4) {
@@ -1543,7 +1542,7 @@ __global__ void Marlin(
15431542
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
15441543
i++) {
15451544
if (c_gl_wr < c_gl_wr_end) {
1546-
C[c_gl_wr] = sh[c_sh_rd];
1545+
C[c_gl_wr] = sh_red[c_sh_rd];
15471546
c_gl_wr += c_gl_wr_delta;
15481547
c_sh_rd += c_sh_rd_delta;
15491548
}
@@ -1865,9 +1864,12 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
18651864

18661865
float pipe_size = (a_size + b_size) * pipe_stages;
18671866

1867+
float reduce_size = max(th_config.num_threads * 32 * 4,
1868+
(tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2);
1869+
18681870
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
18691871

1870-
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
1872+
return pipe_size + reduce_size < 0.95f * (max_shared_mem - scales_cache_size);
18711873
}
18721874

18731875
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,

0 commit comments

Comments
 (0)