1
- // Adapted from turboderp exllama: https://github.com/turboderp/exllama
1
+ // Adapted from turboderp exllama: https://github.com/turboderp/exllama
2
2
3
3
#include " q4_matmul.cuh"
4
4
#include " column_remap.cuh"
13
13
const int THREADS_X = 32 ; // Block size and thread count along columns in w and out
14
14
const int THREADS_Y = 1 ; // Block size and thread count along rows in x and out
15
15
16
+ const int GROUP_STEP = 32 ; // Assumed group size when block_size_z % groupsize != 0
17
+
16
18
typedef void (*fp_q4_matmul_kernel)
17
19
(
18
20
const half*,
@@ -46,12 +48,15 @@ __global__ void q4_matmul_kernel
46
48
bool no_zero
47
49
)
48
50
{
51
+ extern __shared__ half2 x_cache[];
52
+ half* x_cache_h = (half*)x_cache;
53
+
49
54
// Start of block
50
55
51
56
int x_column = block_size_z * blockIdx .z ;
52
57
int x_column_end = min (dim, block_size_z * (blockIdx .z + 1 ));
53
58
54
- int w_column = THREADS_X * blockIdx .x + threadIdx .x ;
59
+ int w_column = THREADS_X * blockIdx .x + threadIdx .x ; // assume width of weight matrix divisible by THREADS_X
55
60
int x_row = THREADS_Y * blockIdx .y + threadIdx .y ;
56
61
57
62
int iterations = (x_column_end - x_column) / 8 ;
@@ -69,8 +74,8 @@ __global__ void q4_matmul_kernel
69
74
if (!no_zero && blockIdx .z == 0 && (threadIdx .x & 1 ) == 0 )
70
75
{
71
76
*((uint32_t *) out_.item_ptr (x_row, w_column)) = 0 ;
72
- __syncthreads ();
73
77
}
78
+ __syncthreads ();
74
79
75
80
// Loop over part of x row (and w column)
76
81
@@ -84,56 +89,64 @@ __global__ void q4_matmul_kernel
84
89
85
90
for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8 ; group++, k += groupsize)
86
91
{
92
+ for (int i = threadIdx .x ; i < groupsize; i += THREADS_X)
93
+ {
94
+ if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr (x_row, x_map[k + i]);
95
+ else x_cache_h[i] = *x_.item_ptr (x_row, k + i);
96
+ }
97
+ __syncthreads ();
98
+
87
99
if constexpr (use_half2)
88
100
{
89
101
half2 w_scale = w_scales_.item_half2half2 (group, w_column);
90
102
uint32_t w_zero = w_zeros_.item (group, w_column);
91
-
92
- if constexpr (use_x_map) acc = dot_product_8_x_map (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8 , x_map);
93
- else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8 );
103
+ acc = dot_product_8 (acc, x_cache, w_, k, w_column, w_scale, w_zero, groupsize / 8 );
94
104
}
95
105
else
96
106
{
97
107
half w_scale = w_scales_.item (group, w_column);
98
108
uint32_t w_zero = w_zeros_.item (group, w_column);
99
-
100
- if constexpr (use_x_map) acc_h = dot_product_8_x_map_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8 , x_map);
101
- else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8 );
109
+ acc_h = dot_product_8_h (acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, groupsize / 8 );
102
110
}
111
+ __syncthreads ();
103
112
}
104
113
}
105
114
else
106
115
{
107
- // Otherwise assume groupsize is a multiple of 8 , do 8 columns per iteration and trust the cache
116
+ // Otherwise assume groupsize is a multiple of GROUP_STEP , do GROUP_STEP columns per iteration and trust the cache
108
117
109
- for (int k = x_column; k < x_column + iterations * 8 ; k += 8 )
118
+ for (int k = x_column; k < x_column + iterations * 8 ; k += GROUP_STEP )
110
119
{
120
+ for (int i = threadIdx .x ; i < GROUP_STEP; i += THREADS_X)
121
+ {
122
+ if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr (x_row, x_map[k + i]);
123
+ else x_cache_h[i] = *x_.item_ptr (x_row, k + i);
124
+ }
125
+ __syncthreads ();
126
+
111
127
if constexpr (use_half2)
112
128
{
113
129
int group = k / groupsize;
114
130
half2 w_scale = w_scales_.item_half2half2 (group, w_column);
115
131
uint32_t w_zero = w_zeros_.item (group, w_column);
116
-
117
- if constexpr (use_x_map) acc = dot_product_8_x_map (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1 , x_map);
118
- else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1 );
132
+ acc = dot_product_8 (acc, x_cache, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8 );
119
133
}
120
134
else
121
135
{
122
136
int group = k / groupsize;
123
137
half w_scale = w_scales_.item (group, w_column);
124
138
uint32_t w_zero = w_zeros_.item (group, w_column);
125
-
126
- if constexpr (use_x_map) acc_h = dot_product_8_x_map_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1 , x_map);
127
- else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1 );
139
+ acc_h = dot_product_8_h (acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8 );
128
140
}
141
+ __syncthreads ();
129
142
}
130
143
}
131
144
132
145
// Add to block result
133
146
134
147
if constexpr (use_half2)
135
148
{
136
- half result = __hadd (__low2half ( acc), __high2half ( acc) );
149
+ half result = __hadd (acc. x , acc. y );
137
150
atomicAdd (out_.item_ptr (x_row, w_column), result);
138
151
}
139
152
else
@@ -215,8 +228,8 @@ void q4_matmul_cuda
215
228
);
216
229
217
230
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick (tuningParams, block_size_z, w->groupsize , x_map);
218
-
219
- kernel<<<blocks, threads, 0 , alt_stream>>> (x_mapped, w->cuda_qweight , out, w->cuda_scales , w->cuda_qzeros , height, dim, width, w->groupsize , block_size_z, x_map, no_zero);
231
+ int shared_mem = (block_size_z % w-> groupsize == 0 ? w-> groupsize : GROUP_STEP) * sizeof (half);
232
+ kernel<<<blocks, threads, shared_mem , alt_stream>>> (x_mapped, w->cuda_qweight , out, w->cuda_scales , w->cuda_qzeros , height, dim, width, w->groupsize , block_size_z, x_map, no_zero);
220
233
}
221
234
222
235
void q4_matmul_recons_cuda
@@ -240,21 +253,26 @@ void q4_matmul_recons_cuda
240
253
const half* x_mapped = x;
241
254
if (w->cuda_x_map )
242
255
{
243
- TORCH_CHECK (buffers->temp_state_size >= x_height * dim, " The temp_state buffer is too small in the exllama backend. Please call the exllama_set_max_input_length function to increase the buffer size. Example: \n from auto_gptq import exllama_set_max_input_length \n model = exllama_set_max_input_length(model, 4096) " );
256
+ TORCH_CHECK (buffers->temp_state_size >= x_height * dim, " temp_state buffer is too small" );
244
257
column_remap_cuda (x, buffers->temp_state , x_height, dim, w->cuda_x_map );
245
258
x_mapped = buffers->temp_state ;
246
259
}
247
260
248
261
w->reconstruct (buffers->temp_dq );
249
262
250
263
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
264
+
251
265
const float alpha = 1 .0f ;
252
266
const float beta = no_zero ? 1 .0f : 0 .0f ;
253
267
cublasSgemmEx (handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq , CUDA_R_16F, width,
254
268
x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
269
+
255
270
#else
271
+
256
272
const half alpha = __float2half (1 .0f );
257
273
const half beta = no_zero ? __float2half (1 .0f ) : __float2half (0 .0f );
258
274
cublasHgemm (handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq , width, x_mapped, dim, &beta, out, width);
275
+
259
276
#endif
277
+
260
278
}
0 commit comments