Skip to content

Commit 3173fce

Browse files
authored
Optimize q4_matmul
turboderp/exllama#275
1 parent 45a17c8 commit 3173fce

File tree

2 files changed

+41
-140
lines changed

2 files changed

+41
-140
lines changed

autogptq_extension/exllama/cuda_func/q4_matmul.cu

+39-21
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
1+
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
22

33
#include "q4_matmul.cuh"
44
#include "column_remap.cuh"
@@ -13,6 +13,8 @@
1313
const int THREADS_X = 32; // Block size and thread count along columns in w and out
1414
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
1515

16+
const int GROUP_STEP = 32; // Assumed group size when block_size_z % groupsize != 0
17+
1618
typedef void (*fp_q4_matmul_kernel)
1719
(
1820
const half*,
@@ -46,12 +48,15 @@ __global__ void q4_matmul_kernel
4648
bool no_zero
4749
)
4850
{
51+
extern __shared__ half2 x_cache[];
52+
half* x_cache_h = (half*)x_cache;
53+
4954
// Start of block
5055

5156
int x_column = block_size_z * blockIdx.z;
5257
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
5358

54-
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
59+
int w_column = THREADS_X * blockIdx.x + threadIdx.x; // assume width of weight matrix divisible by THREADS_X
5560
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
5661

5762
int iterations = (x_column_end - x_column) / 8;
@@ -69,8 +74,8 @@ __global__ void q4_matmul_kernel
6974
if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
7075
{
7176
*((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
72-
__syncthreads();
7377
}
78+
__syncthreads();
7479

7580
// Loop over part of x row (and w column)
7681

@@ -84,56 +89,64 @@ __global__ void q4_matmul_kernel
8489

8590
for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
8691
{
92+
for (int i = threadIdx.x; i < groupsize; i += THREADS_X)
93+
{
94+
if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]);
95+
else x_cache_h[i] = *x_.item_ptr(x_row, k + i);
96+
}
97+
__syncthreads();
98+
8799
if constexpr (use_half2)
88100
{
89101
half2 w_scale = w_scales_.item_half2half2(group, w_column);
90102
uint32_t w_zero = w_zeros_.item(group, w_column);
91-
92-
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
93-
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
103+
acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, groupsize / 8);
94104
}
95105
else
96106
{
97107
half w_scale = w_scales_.item(group, w_column);
98108
uint32_t w_zero = w_zeros_.item(group, w_column);
99-
100-
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
101-
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
109+
acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, groupsize / 8);
102110
}
111+
__syncthreads();
103112
}
104113
}
105114
else
106115
{
107-
// Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
116+
// Otherwise assume groupsize is a multiple of GROUP_STEP, do GROUP_STEP columns per iteration and trust the cache
108117

109-
for (int k = x_column; k < x_column + iterations * 8; k += 8)
118+
for (int k = x_column; k < x_column + iterations * 8; k += GROUP_STEP)
110119
{
120+
for (int i = threadIdx.x; i < GROUP_STEP; i += THREADS_X)
121+
{
122+
if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]);
123+
else x_cache_h[i] = *x_.item_ptr(x_row, k + i);
124+
}
125+
__syncthreads();
126+
111127
if constexpr (use_half2)
112128
{
113129
int group = k / groupsize;
114130
half2 w_scale = w_scales_.item_half2half2(group, w_column);
115131
uint32_t w_zero = w_zeros_.item(group, w_column);
116-
117-
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
118-
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
132+
acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8);
119133
}
120134
else
121135
{
122136
int group = k / groupsize;
123137
half w_scale = w_scales_.item(group, w_column);
124138
uint32_t w_zero = w_zeros_.item(group, w_column);
125-
126-
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
127-
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
139+
acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8);
128140
}
141+
__syncthreads();
129142
}
130143
}
131144

132145
// Add to block result
133146

134147
if constexpr (use_half2)
135148
{
136-
half result = __hadd(__low2half(acc), __high2half(acc));
149+
half result = __hadd(acc.x, acc.y);
137150
atomicAdd(out_.item_ptr(x_row, w_column), result);
138151
}
139152
else
@@ -215,8 +228,8 @@ void q4_matmul_cuda
215228
);
216229

217230
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
218-
219-
kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
231+
int shared_mem = (block_size_z % w->groupsize == 0 ? w->groupsize : GROUP_STEP) * sizeof(half);
232+
kernel<<<blocks, threads, shared_mem, alt_stream>>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
220233
}
221234

222235
void q4_matmul_recons_cuda
@@ -240,21 +253,26 @@ void q4_matmul_recons_cuda
240253
const half* x_mapped = x;
241254
if (w->cuda_x_map)
242255
{
243-
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "The temp_state buffer is too small in the exllama backend. Please call the exllama_set_max_input_length function to increase the buffer size. Example:\nfrom auto_gptq import exllama_set_max_input_length\nmodel = exllama_set_max_input_length(model, 4096)");
256+
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small");
244257
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
245258
x_mapped = buffers->temp_state;
246259
}
247260

248261
w->reconstruct(buffers->temp_dq);
249262

250263
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
264+
251265
const float alpha = 1.0f;
252266
const float beta = no_zero ? 1.0f : 0.0f;
253267
cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
254268
x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
269+
255270
#else
271+
256272
const half alpha = __float2half(1.0f);
257273
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
258274
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
275+
259276
#endif
277+
260278
}

autogptq_extension/exllama/matrix.cuh

+2-119
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,7 @@ public:
8787
__device__ __forceinline__ half2 dot_product_8
8888
(
8989
const half2 acc,
90-
MatrixView_half& h_,
91-
const int h_row,
92-
const int h_column, // divisible by 8
90+
const half2* h_ptr,
9391
MatrixView_q4_column& v_,
9492
const int v_row, // divisible by 8
9593
const int v_column,
@@ -98,7 +96,6 @@ __device__ __forceinline__ half2 dot_product_8
9896
const int count
9997
)
10098
{
101-
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
10299
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
103100
half2 result = acc;
104101

@@ -138,9 +135,7 @@ __device__ __forceinline__ half2 dot_product_8
138135
__device__ __forceinline__ half dot_product_8_h
139136
(
140137
const half acc,
141-
MatrixView_half& h_,
142-
const int h_row,
143-
const int h_column, // divisible by 8
138+
const half* h_ptr,
144139
MatrixView_q4_column& v_,
145140
const int v_row, // divisible by 8
146141
const int v_column,
@@ -149,7 +144,6 @@ __device__ __forceinline__ half dot_product_8_h
149144
const int count
150145
)
151146
{
152-
const half* h_ptr = h_.item_ptr(h_row, h_column);
153147
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
154148
half result = acc;
155149

@@ -180,115 +174,4 @@ __device__ __forceinline__ half dot_product_8_h
180174
return result;
181175
}
182176

183-
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
184-
185-
__device__ __forceinline__ half2 dot_product_8_x_map
186-
(
187-
const half2 acc,
188-
MatrixView_half& h_,
189-
const int h_row,
190-
const int h_column, // divisible by 8
191-
MatrixView_q4_column& v_,
192-
const int v_row, // divisible by 8
193-
const int v_column,
194-
const half2 v_scale_2,
195-
const uint32_t v_zero,
196-
const int count,
197-
const uint32_t* x_map
198-
)
199-
{
200-
const half* h_ptr = h_.item_ptr(h_row, 0);
201-
const uint32_t* x_map_ptr = x_map + h_column;
202-
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
203-
half2 result = acc;
204-
205-
for (int i = 0; i < count; i++)
206-
{
207-
uint32_t v_read = *v_ptr; v_ptr += v_.width;
208-
209-
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
210-
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
211-
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
212-
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
213-
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
214-
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
215-
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
216-
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
217-
218-
half2 v_01 = __halves2half2(v_0, v_1);
219-
half2 v_23 = __halves2half2(v_2, v_3);
220-
half2 v_45 = __halves2half2(v_4, v_5);
221-
half2 v_67 = __halves2half2(v_6, v_7);
222-
223-
half h_0 = h_ptr[*x_map_ptr++];
224-
half h_1 = h_ptr[*x_map_ptr++];
225-
half h_2 = h_ptr[*x_map_ptr++];
226-
half h_3 = h_ptr[*x_map_ptr++];
227-
half h_4 = h_ptr[*x_map_ptr++];
228-
half h_5 = h_ptr[*x_map_ptr++];
229-
half h_6 = h_ptr[*x_map_ptr++];
230-
half h_7 = h_ptr[*x_map_ptr++];
231-
232-
half2 h_01 = __halves2half2(h_0, h_1);
233-
half2 h_23 = __halves2half2(h_2, h_3);
234-
half2 h_45 = __halves2half2(h_4, h_5);
235-
half2 h_67 = __halves2half2(h_6, h_7);
236-
237-
half2 tmp = __hmul2(h_01, v_01);
238-
tmp = __hfma2(h_23, v_23, tmp);
239-
tmp = __hfma2(h_45, v_45, tmp);
240-
tmp = __hfma2(h_67, v_67, tmp);
241-
result = __hfma2(v_scale_2, tmp, result);
242-
}
243-
244-
return result;
245-
}
246-
247-
__device__ __forceinline__ half dot_product_8_x_map_h
248-
(
249-
const half acc,
250-
MatrixView_half& h_,
251-
const int h_row,
252-
const int h_column, // divisible by 8
253-
MatrixView_q4_column& v_,
254-
const int v_row, // divisible by 8
255-
const int v_column,
256-
const half v_scale,
257-
const uint32_t v_zero,
258-
const int count,
259-
const uint32_t* x_map
260-
)
261-
{
262-
const half* h_ptr = h_.item_ptr(h_row, 0);
263-
const uint32_t* x_map_ptr = x_map + h_column;
264-
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
265-
half result = acc;
266-
267-
for (int i = 0; i < count; i++)
268-
{
269-
uint32_t v_read = *v_ptr; v_ptr += v_.width;
270-
271-
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
272-
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
273-
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
274-
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
275-
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
276-
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
277-
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
278-
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
279-
280-
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
281-
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
282-
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
283-
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
284-
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
285-
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
286-
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
287-
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
288-
result = __hfma(v_scale, tmp, result);
289-
}
290-
291-
return result;
292-
}
293-
294177
#endif

0 commit comments

Comments
 (0)