Skip to content

Commit 8e37f2d

Browse files
committed
Optimize q4_matmul: optimize non-128 group sizes
1 parent 215a715 commit 8e37f2d

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

exllama_ext/cuda_func/q4_matmul.cu

+10-7
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
const int THREADS_X = 32; // Block size and thread count along columns in w and out
1212
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
1313

14+
const int GROUP_STEP = 32; // Assumed group size when block_size_z % groupsize != 0
15+
1416
typedef void (*fp_q4_matmul_kernel)
1517
(
1618
const half*,
@@ -52,7 +54,7 @@ __global__ void q4_matmul_kernel
5254
int x_column = block_size_z * blockIdx.z;
5355
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
5456

55-
int w_column = THREADS_X * blockIdx.x + threadIdx.x; // assume width of weight matrix divisible by THREADS_X (32)
57+
int w_column = THREADS_X * blockIdx.x + threadIdx.x; // assume width of weight matrix divisible by THREADS_X
5658
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
5759

5860
int iterations = (x_column_end - x_column) / 8;
@@ -109,11 +111,11 @@ __global__ void q4_matmul_kernel
109111
}
110112
else
111113
{
112-
// Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
114+
// Otherwise assume groupsize is a multiple of GROUP_STEP, do GROUP_STEP columns per iteration and trust the cache
113115

114-
for (int k = x_column; k < x_column + iterations * 8; k += 8)
116+
for (int k = x_column; k < x_column + iterations * 8; k += GROUP_STEP)
115117
{
116-
for (int i = threadIdx.x; i < 8; i += THREADS_X)
118+
for (int i = threadIdx.x; i < GROUP_STEP; i += THREADS_X)
117119
{
118120
if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]);
119121
else x_cache_h[i] = *x_.item_ptr(x_row, k + i);
@@ -125,14 +127,14 @@ __global__ void q4_matmul_kernel
125127
int group = k / groupsize;
126128
half2 w_scale = w_scales_.item_half2half2(group, w_column);
127129
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
128-
acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, 1);
130+
acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8);
129131
}
130132
else
131133
{
132134
int group = k / groupsize;
133135
half w_scale = w_scales_.item(group, w_column);
134136
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
135-
acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, 1);
137+
acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8);
136138
}
137139
__syncthreads();
138140
}
@@ -224,7 +226,8 @@ void q4_matmul_cuda
224226
);
225227

226228
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
227-
kernel<<<blocks, threads, w->groupsize * sizeof(half), alt_stream>>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
229+
int shared_mem = (block_size_z % w->groupsize == 0 ? w->groupsize : GROUP_STEP) * sizeof(half);
230+
kernel<<<blocks, threads, shared_mem, alt_stream>>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
228231
}
229232

230233
void q4_matmul_recons_cuda

0 commit comments

Comments
 (0)