@@ -834,6 +834,7 @@ __global__ void Marlin(
834
834
int4 * sh_g_idx = sh_b + (stages * b_sh_stage);
835
835
int4 * sh_zp = sh_g_idx + (stages * g_idx_stage);
836
836
int4 * sh_s = sh_zp + (stages * zp_sh_stage);
837
+ int4 * sh_red = sh_s + (stages * s_sh_stage);
837
838
838
839
// Register storage for double buffer of shared memory reads.
839
840
FragA frag_a[2 ][thread_m_blocks];
@@ -932,11 +933,11 @@ __global__ void Marlin(
932
933
int4 * sh_s_stage = sh_s + s_sh_stage * pipe ;
933
934
934
935
if constexpr (group_blocks >= thread_k_blocks) {
936
+ if (s_sh_wr_pred) {
937
+ cp_async4 (&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
938
+ }
935
939
// Only fetch scales if this tile starts a new group
936
- if (pipe % (group_blocks / thread_k_blocks) == 0 ) {
937
- if (s_sh_wr_pred) {
938
- cp_async4 (&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
939
- }
940
+ if ((pipe + 1 ) % (group_blocks / thread_k_blocks) == 0 ) {
940
941
s_gl_rd += s_gl_rd_delta;
941
942
}
942
943
} else {
@@ -1038,9 +1039,7 @@ __global__ void Marlin(
1038
1039
// No act-order case
1039
1040
if constexpr (group_blocks != -1 ) {
1040
1041
if constexpr (group_blocks >= thread_k_blocks) {
1041
- int4 * sh_s_stage =
1042
- sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
1043
- (pipe / (group_blocks / thread_k_blocks)));
1042
+ int4 * sh_s_stage = sh_s + s_sh_stage * pipe ;
1044
1043
reinterpret_cast <int4 *>(&frag_s[k % 2 ])[0 ] = sh_s_stage[s_sh_rd];
1045
1044
} else {
1046
1045
int warp_id = threadIdx .x / 32 ;
@@ -1339,15 +1338,15 @@ __global__ void Marlin(
1339
1338
int red_sh_wr =
1340
1339
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
1341
1340
if (i < red_off) {
1342
- float * c_rd =
1343
- reinterpret_cast < float *>(&sh [red_sh_delta * j + red_sh_rd]);
1344
- float * c_wr = reinterpret_cast <float *>(&sh [red_sh_wr]);
1341
+ float * c_rd = reinterpret_cast < float *>(
1342
+ &sh_red [red_sh_delta * j + red_sh_rd]);
1343
+ float * c_wr = reinterpret_cast <float *>(&sh_red [red_sh_wr]);
1345
1344
#pragma unroll
1346
1345
for (int k = 0 ; k < 4 ; k++)
1347
1346
reinterpret_cast <FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
1348
1347
c_rd[k] + c_wr[k];
1349
1348
}
1350
- sh [red_sh_wr] =
1349
+ sh_red [red_sh_wr] =
1351
1350
reinterpret_cast <int4 *>(&frag_c)[4 * 2 * m_block + j];
1352
1351
}
1353
1352
}
@@ -1357,7 +1356,7 @@ __global__ void Marlin(
1357
1356
#pragma unroll
1358
1357
for (int i = 0 ; i < 4 * 2 ; i++) {
1359
1358
float * c_rd =
1360
- reinterpret_cast <float *>(&sh [red_sh_delta * i + red_sh_rd]);
1359
+ reinterpret_cast <float *>(&sh_red [red_sh_delta * i + red_sh_rd]);
1361
1360
#pragma unroll
1362
1361
for (int j = 0 ; j < 4 ; j++)
1363
1362
reinterpret_cast <FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
@@ -1397,7 +1396,7 @@ __global__ void Marlin(
1397
1396
#pragma unroll
1398
1397
for (int i = 0 ; i < thread_m_blocks * 4 ; i++) {
1399
1398
cp_async4_pred (
1400
- &sh [c_sh_wr + c_sh_wr_delta * i],
1399
+ &sh_red [c_sh_wr + c_sh_wr_delta * i],
1401
1400
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2 ) +
1402
1401
c_gl_wr_delta_i * (i % 2 )],
1403
1402
i < (thread_m_blocks - 1 ) * 4 || 8 * (i / 2 ) + row < prob_m);
@@ -1410,7 +1409,7 @@ __global__ void Marlin(
1410
1409
for (int i = 0 ; i < thread_m_blocks * 4 ; i++) {
1411
1410
if (i < (thread_m_blocks - 1 ) * 4 || 8 * (i / 2 ) + row < prob_m) {
1412
1411
if (!first) {
1413
- int4 c_red = sh [c_sh_wr + i * c_sh_wr_delta];
1412
+ int4 c_red = sh_red [c_sh_wr + i * c_sh_wr_delta];
1414
1413
#pragma unroll
1415
1414
for (int j = 0 ; j < 2 * 4 ; j++) {
1416
1415
reinterpret_cast <float *>(
@@ -1461,10 +1460,10 @@ __global__ void Marlin(
1461
1460
float * frag_c_ptr = reinterpret_cast <float *>(&frag_c);
1462
1461
#pragma unroll
1463
1462
for (int k = 0 ; k < th_size; k++) {
1464
- sh [threadIdx .x ] =
1463
+ sh_red [threadIdx .x ] =
1465
1464
C_tmp[c_cur_offset + active_threads * k + threadIdx .x ];
1466
1465
1467
- float * sh_c_ptr = reinterpret_cast <float *>(&sh [threadIdx .x ]);
1466
+ float * sh_c_ptr = reinterpret_cast <float *>(&sh_red [threadIdx .x ]);
1468
1467
#pragma unroll
1469
1468
for (int f = 0 ; f < 4 ; f++) {
1470
1469
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
@@ -1515,7 +1514,7 @@ __global__ void Marlin(
1515
1514
res = __hmul2 (res, s[0 ]);
1516
1515
}
1517
1516
1518
- ((scalar_t2*)sh )[idx] = res;
1517
+ ((scalar_t2*)sh_red )[idx] = res;
1519
1518
};
1520
1519
1521
1520
if (threadIdx .x / 32 < thread_n_blocks / 4 ) {
@@ -1543,7 +1542,7 @@ __global__ void Marlin(
1543
1542
i < div_ceil (16 * thread_m_blocks, threads / (2 * thread_n_blocks));
1544
1543
i++) {
1545
1544
if (c_gl_wr < c_gl_wr_end) {
1546
- C[c_gl_wr] = sh [c_sh_rd];
1545
+ C[c_gl_wr] = sh_red [c_sh_rd];
1547
1546
c_gl_wr += c_gl_wr_delta;
1548
1547
c_sh_rd += c_sh_rd_delta;
1549
1548
}
@@ -1865,9 +1864,12 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
1865
1864
1866
1865
float pipe_size = (a_size + b_size) * pipe_stages;
1867
1866
1867
+ float reduce_size = max (th_config.num_threads * 32 * 4 ,
1868
+ (tb_n / 64 ) * 32 * (tb_max_m / 16 ) * 4 * 2 * 4 * 2 );
1869
+
1868
1870
TORCH_CHECK (max_shared_mem / 2 > scales_cache_size); // Sanity
1869
1871
1870
- return pipe_size < 0 .95f * (max_shared_mem - scales_cache_size);
1872
+ return pipe_size + reduce_size < 0 .95f * (max_shared_mem - scales_cache_size);
1871
1873
}
1872
1874
1873
1875
bool is_valid_config (thread_config_t const & th_config, int max_m_blocks,
0 commit comments