Skip to content

Commit 1117d06

Browse files
authored
opencl : fix element-wise multiplication (#3656)
1 parent cb33f43 commit 1117d06

File tree

1 file changed

+23
-52
lines changed

1 file changed

+23
-52
lines changed

ggml-opencl.cpp

+23-52
Original file line numberDiff line numberDiff line change
@@ -1395,75 +1395,46 @@ static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1,
13951395
const int64_t ne01 = src0->ne[1];
13961396
const int64_t ne02 = src0->ne[2];
13971397
const int64_t ne03 = src0->ne[3];
1398-
const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
13991398
const int64_t ne10 = src1->ne[0];
14001399
const int64_t ne11 = src1->ne[1];
14011400
const int64_t ne12 = src1->ne[2];
14021401
const int64_t ne13 = src1->ne[3];
1403-
const int64_t nb10 = src1->nb[0];
14041402
const int nb2 = dst->nb[2];
14051403
const int nb3 = dst->nb[3];
14061404
size_t x_size;
14071405
size_t d_size;
14081406

1409-
cl_mem d_X = ggml_cl_pool_malloc(ne0 * sizeof(float), &x_size); // src0
1407+
cl_mem d_X = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &x_size); // src0
14101408
cl_mem d_Y = (cl_mem) src1->extra; // src1 is already on device, broadcasted.
1411-
cl_mem d_D = ggml_cl_pool_malloc(ne0 * sizeof(float), &d_size); // dst
1409+
cl_mem d_D = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &d_size); // dst
14121410

14131411

14141412
for (int64_t i03 = 0; i03 < ne03; i03++) {
14151413
for (int64_t i02 = 0; i02 < ne02; i02++) {
1416-
const int i0 = i03*ne02 + i02;
1417-
14181414
cl_event ev;
14191415

14201416
// copy src0 to device
1421-
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, i0, src0, i03, i02, &ev));
1422-
1423-
if (nb10 == sizeof(float)) {
1424-
// Contiguous, avoid overhead from queueing many kernel runs
1425-
const int64_t i13 = i03%ne13;
1426-
const int64_t i12 = i02%ne12;
1427-
const int i1 = i13*ne12*ne11 + i12*ne11;
1428-
1429-
cl_int x_offset = 0;
1430-
cl_int y_offset = i1*ne10;
1431-
cl_int d_offset = 0;
1432-
1433-
size_t global = ne00 * ne01;
1434-
cl_int ky = ne10;
1435-
CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
1436-
CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
1437-
CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
1438-
CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
1439-
CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
1440-
CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
1441-
CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
1442-
CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
1443-
} else {
1444-
for (int64_t i01 = 0; i01 < ne01; i01++) {
1445-
const int64_t i13 = i03%ne13;
1446-
const int64_t i12 = i02%ne12;
1447-
const int64_t i11 = i01%ne11;
1448-
const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
1449-
1450-
cl_int x_offset = i01*ne00;
1451-
cl_int y_offset = i1*ne10;
1452-
cl_int d_offset = i01*ne00;
1453-
1454-
// compute
1455-
size_t global = ne00;
1456-
cl_int ky = ne10;
1457-
CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
1458-
CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
1459-
CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
1460-
CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
1461-
CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
1462-
CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
1463-
CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
1464-
CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
1465-
}
1466-
}
1417+
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, &ev));
1418+
1419+
const int64_t i13 = i03%ne13;
1420+
const int64_t i12 = i02%ne12;
1421+
const int i1 = i13*ne12*ne11 + i12*ne11;
1422+
1423+
cl_int x_offset = 0;
1424+
cl_int y_offset = i1*ne10;
1425+
cl_int d_offset = 0;
1426+
1427+
size_t global = ne00 * ne01;
1428+
cl_int ky = ne10 * ne11;
1429+
1430+
CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
1431+
CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
1432+
CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
1433+
CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
1434+
CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
1435+
CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
1436+
CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
1437+
CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
14671438

14681439
CL_CHECK(clReleaseEvent(ev));
14691440
CL_CHECK(clFinish(queue));

0 commit comments

Comments
 (0)