Skip to content

Commit 4762127

Browse files
committed
Optimize caclulation of shared memory size for reduction
Signed-off-by: wchen61 <[email protected]>
1 parent aa2f07a commit 4762127

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

csrc/quantization/gptq_marlin/gptq_marlin.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1864,7 +1864,7 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
18641864

18651865
float pipe_size = (a_size + b_size) * pipe_stages;
18661866

1867-
float reduce_size = max(th_config.num_threads * 2 * 4 * 4,
1867+
float reduce_size = max(th_config.num_threads * 32 * 4,
18681868
(tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2);
18691869

18701870
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity

0 commit comments

Comments
 (0)