Skip to content

Commit 1e0e873

Browse files
authored
CLBlast: Fix matrix-vector multiplication (#3544)
1 parent 370359e commit 1e0e873

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

Diff for: ggml-opencl.cpp

+17-15
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#pragma warning(disable: 4244 4267) // possible loss of data
2020
#endif
2121

22-
#define CL_DMMV_BLOCK_SIZE 32
22+
#define CL_DMMV_LOCAL_SIZE 32
2323

2424
#ifndef K_QUANTS_PER_ITERATION
2525
#define K_QUANTS_PER_ITERATION 1
@@ -338,7 +338,7 @@ __kernel void dequantize_mul_mat_vec_q2_K(__global const struct block_q2_K * xx,
338338
const int row = get_group_id(0);
339339

340340
const int num_blocks_per_row = ncols / QK_K;
341-
const int ib0 = row*num_blocks_per_row;
341+
const int ib0 = row*num_blocks_per_row + get_global_offset(0);
342342

343343
__global const struct block_q2_K * x = xx + ib0;
344344

@@ -413,7 +413,7 @@ __kernel void dequantize_mul_mat_vec_q3_K(__global const struct block_q3_K * xx,
413413
const int row = get_group_id(0);
414414

415415
const int num_blocks_per_row = ncols / QK_K;
416-
const int ib0 = row*num_blocks_per_row;
416+
const int ib0 = row*num_blocks_per_row + get_global_offset(0);
417417

418418
__global const struct block_q3_K * x = xx + ib0;
419419

@@ -489,7 +489,7 @@ __kernel void dequantize_mul_mat_vec_q4_K(__global const struct block_q4_K * xx,
489489

490490
const int row = get_group_id(0);
491491
const int num_blocks_per_row = ncols / QK_K;
492-
const int ib0 = row*num_blocks_per_row;
492+
const int ib0 = row*num_blocks_per_row + get_global_offset(0);
493493

494494
const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...15
495495
const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION;
@@ -562,7 +562,7 @@ __kernel void dequantize_mul_mat_vec_q5_K(__global const struct block_q5_K * xx,
562562

563563
const int row = get_group_id(0);
564564
const int num_blocks_per_row = ncols / QK_K;
565-
const int ib0 = row*num_blocks_per_row;
565+
const int ib0 = row*num_blocks_per_row + get_global_offset(0);
566566

567567
const int tid = get_local_id(0)/2; // 0...15
568568
const int ix = get_local_id(0)%2;
@@ -641,7 +641,7 @@ __kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx,
641641
const int row = get_group_id(0);
642642

643643
const int num_blocks_per_row = ncols / QK_K;
644-
const int ib0 = row*num_blocks_per_row;
644+
const int ib0 = row*num_blocks_per_row + get_global_offset(0);
645645

646646
__global const struct block_q6_K * x = xx + ib0;
647647

@@ -745,19 +745,21 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) {
745745

746746
std::string dequant_mul_mat_vec_template = MULTILINE_QUOTE(
747747
__kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
748-
const int block_size = get_local_size(0);
748+
const int local_size = get_local_size(0);
749749
const int row = get_group_id(0);
750750
const int tid = get_local_id(0);
751751

752752
const uint qk = QUANT_K;
753753
const uint qr = QUANT_R;
754754

755+
const int col_step = local_size * 2;
755756
const int y_offset = qr == 1 ? 1 : qk/2;
756757

758+
x += get_global_offset(0);
759+
757760
tmp[tid] = 0;
758761

759-
for (int i = 0; i < ncols/block_size; i += 2) {
760-
const int col = i*block_size + 2*tid;
762+
for (int col = tid*2; col < ncols; col += col_step) {
761763
const int ib = (row*ncols + col)/qk; // block index
762764
const int iqs = (col%qk)/qr; // quant index
763765
const int iybs = col - col%qk; // y block start index
@@ -773,7 +775,7 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
773775

774776
// sum up partial sums and write back result
775777
barrier(CLK_LOCAL_MEM_FENCE);
776-
for (int s=block_size/2; s>0; s>>=1) {
778+
for (int s=local_size/2; s>0; s>>=1) {
777779
if (tid < s) {
778780
tmp[tid] += tmp[tid + s];
779781
}
@@ -1704,7 +1706,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
17041706
const int nb2 = dst->nb[2];
17051707
const int nb3 = dst->nb[3];
17061708
const ggml_type type = src0->type;
1707-
const bool mul_mat_vec = ne11 == 1;
1709+
const bool mul_mat_vec = ne11 == 1 && ne00%2 == 0;
17081710

17091711
const int64_t r2 = ne12 / ne02;
17101712
const int64_t r3 = ne13 / ne03;
@@ -1737,7 +1739,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
17371739
GGML_ASSERT(to_fp32_cl != nullptr);
17381740

17391741
const size_t global_denom = ggml_cl_global_denom(type);
1740-
const size_t local = ggml_cl_local_size(type);
1742+
const size_t local = mul_mat_vec ? CL_DMMV_LOCAL_SIZE : ggml_cl_local_size(type);
17411743

17421744
size_t ev_idx = 0;
17431745
std::vector<cl_event> events;
@@ -1770,16 +1772,16 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
17701772
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, events.data() + ev_idx++));
17711773

17721774
// compute
1773-
const size_t global = ne01 * CL_DMMV_BLOCK_SIZE;
1774-
const size_t local = CL_DMMV_BLOCK_SIZE;
1775+
const size_t global = ne01 * local;
1776+
const size_t offset = src0->backend == GGML_BACKEND_GPU ? (i03 * ne02 + i02) * x_bps : 0;
17751777
const cl_int ncols = ne00;
17761778
events.emplace_back();
17771779
CL_CHECK(clSetKernelArg(*dmmv, 0, sizeof(cl_mem), &d_Q));
17781780
CL_CHECK(clSetKernelArg(*dmmv, 1, sizeof(float) * local, NULL));
17791781
CL_CHECK(clSetKernelArg(*dmmv, 2, sizeof(cl_mem), &d_Y));
17801782
CL_CHECK(clSetKernelArg(*dmmv, 3, sizeof(cl_mem), &d_D));
17811783
CL_CHECK(clSetKernelArg(*dmmv, 4, sizeof(cl_int), &ncols));
1782-
CL_CHECK(clEnqueueNDRangeKernel(queue, *dmmv, 1, NULL, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++));
1784+
CL_CHECK(clEnqueueNDRangeKernel(queue, *dmmv, 1, &offset, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++));
17831785
} else { // general dequantization kernel + CLBlast matrix matrix multiplication
17841786
// convert src0 to fp32 on device
17851787
const size_t global = x_ne / global_denom;

0 commit comments

Comments
 (0)