Skip to content

Commit dadbed9

Browse files
authored
metal : fix synchronization in new matrix multiplication kernel (#2686)
1 parent cb1c072 commit dadbed9

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

ggml-metal.metal

+2-1
Original file line numberDiff line numberDiff line change
@@ -1898,10 +1898,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
18981898
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
18991899
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
19001900
for (int i = 0; i < 8; i++) {
1901+
threadgroup_barrier(mem_flags::mem_device);
19011902
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
19021903
}
19031904

1904-
threadgroup_barrier(mem_flags::mem_threadgroup);
1905+
threadgroup_barrier(mem_flags::mem_device);
19051906
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
19061907
if (sgitg==0) {
19071908
for (int i = 0; i < n_rows; i++) {

0 commit comments

Comments
 (0)