Skip to content

Commit 82a7012

Browse files
committed
metal : fix support check
ggml-ci
1 parent 41b47e5 commit 82a7012

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

ggml/src/ggml-metal.m

+5-1
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
942942
case GGML_OP_LEAKY_RELU:
943943
return true;
944944
case GGML_OP_FLASH_ATTN_EXT:
945+
if (op->src[1]->type != op->src[2]->type) {
946+
return false;
947+
}
945948
return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
946949
case GGML_OP_SSM_CONV:
947950
case GGML_OP_SSM_SCAN:
@@ -2886,6 +2889,7 @@ static void ggml_metal_encode_node(
28862889
GGML_ASSERT(ne11 % 32 == 0);
28872890

28882891
GGML_ASSERT(src0->type == GGML_TYPE_F32);
2892+
GGML_ASSERT(src1->type == src2->type);
28892893

28902894
GGML_ASSERT(ggml_are_same_shape (src1, src2));
28912895

@@ -3158,7 +3162,7 @@ static void ggml_metal_encode_node(
31583162

31593163
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
31603164
} else {
3161-
// half1x4 kernel
3165+
// half4x4 kernel
31623166
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
31633167
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
31643168

0 commit comments

Comments
 (0)