11
11
const int THREADS_X = 32 ; // Block size and thread count along columns in w and out
12
12
const int THREADS_Y = 1 ; // Block size and thread count along rows in x and out
13
13
14
+ const int GROUP_STEP = 32 ; // Assumed group size when block_size_z % groupsize != 0
15
+
14
16
typedef void (*fp_q4_matmul_kernel)
15
17
(
16
18
const half*,
@@ -52,7 +54,7 @@ __global__ void q4_matmul_kernel
52
54
int x_column = block_size_z * blockIdx .z ;
53
55
int x_column_end = min (dim, block_size_z * (blockIdx .z + 1 ));
54
56
55
- int w_column = THREADS_X * blockIdx .x + threadIdx .x ; // assume width of weight matrix divisible by THREADS_X (32)
57
+ int w_column = THREADS_X * blockIdx .x + threadIdx .x ; // assume width of weight matrix divisible by THREADS_X
56
58
int x_row = THREADS_Y * blockIdx .y + threadIdx .y ;
57
59
58
60
int iterations = (x_column_end - x_column) / 8 ;
@@ -109,11 +111,11 @@ __global__ void q4_matmul_kernel
109
111
}
110
112
else
111
113
{
112
- // Otherwise assume groupsize is a multiple of 8 , do 8 columns per iteration and trust the cache
114
+ // Otherwise assume groupsize is a multiple of GROUP_STEP , do GROUP_STEP columns per iteration and trust the cache
113
115
114
- for (int k = x_column; k < x_column + iterations * 8 ; k += 8 )
116
+ for (int k = x_column; k < x_column + iterations * 8 ; k += GROUP_STEP )
115
117
{
116
- for (int i = threadIdx .x ; i < 8 ; i += THREADS_X)
118
+ for (int i = threadIdx .x ; i < GROUP_STEP ; i += THREADS_X)
117
119
{
118
120
if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr (x_row, x_map[k + i]);
119
121
else x_cache_h[i] = *x_.item_ptr (x_row, k + i);
@@ -125,14 +127,14 @@ __global__ void q4_matmul_kernel
125
127
int group = k / groupsize;
126
128
half2 w_scale = w_scales_.item_half2half2 (group, w_column);
127
129
uint32_t w_zero = w_zeros_.item (group, w_column) + 1 ;
128
- acc = dot_product_8 (acc, x_cache, w_, k, w_column, w_scale, w_zero, 1 );
130
+ acc = dot_product_8 (acc, x_cache, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8 );
129
131
}
130
132
else
131
133
{
132
134
int group = k / groupsize;
133
135
half w_scale = w_scales_.item (group, w_column);
134
136
uint32_t w_zero = w_zeros_.item (group, w_column) + 1 ;
135
- acc_h = dot_product_8_h (acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, 1 );
137
+ acc_h = dot_product_8_h (acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8 );
136
138
}
137
139
__syncthreads ();
138
140
}
@@ -224,7 +226,8 @@ void q4_matmul_cuda
224
226
);
225
227
226
228
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick (tuningParams, block_size_z, w->groupsize , x_map);
227
- kernel<<<blocks, threads, w->groupsize * sizeof (half), alt_stream>>> (x_mapped, w->cuda_qweight , out, w->cuda_scales , w->cuda_qzeros , height, dim, width, w->groupsize , block_size_z, x_map, no_zero);
229
+ int shared_mem = (block_size_z % w->groupsize == 0 ? w->groupsize : GROUP_STEP) * sizeof (half);
230
+ kernel<<<blocks, threads, shared_mem, alt_stream>>> (x_mapped, w->cuda_qweight , out, w->cuda_scales , w->cuda_qzeros , height, dim, width, w->groupsize , block_size_z, x_map, no_zero);
228
231
}
229
232
230
233
void q4_matmul_recons_cuda
0 commit comments