Skip to content

Commit bf83bff

Browse files
authored
metal : matrix-matrix multiplication kernel (#2615)
* metal: matrix-matrix multiplication kernel This commit removes MPS and uses custom matrix-matrix multiplication kernels for all quantization types. This commit also adds grouped-query attention to support llama2 70B. * metal: fix performance degradation from gqa Integers are slow on the GPU, and 64-bit divides are extremely slow. In the context of GQA, we introduce a 64-bit divide that cannot be optimized out by the compiler, which results in a decrease of ~8% in inference performance. This commit fixes that issue by calculating a part of the offset with a 32-bit divide. Naturally, this limits the size of a single matrix to ~4GB. However, this limitation should suffice for the near future. * metal: fix bugs for GQA and perplexity test. I mixed up ne02 and nb02 in previous commit.
1 parent b5ffb28 commit bf83bff

File tree

6 files changed

+528
-636
lines changed

6 files changed

+528
-636
lines changed

CMakeLists.txt

-2
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,6 @@ if (LLAMA_METAL)
296296
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
297297
find_library(METAL_FRAMEWORK Metal REQUIRED)
298298
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
299-
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)
300299

301300
set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h)
302301

@@ -313,7 +312,6 @@ if (LLAMA_METAL)
313312
${FOUNDATION_LIBRARY}
314313
${METAL_FRAMEWORK}
315314
${METALKIT_FRAMEWORK}
316-
${METALPERFORMANCE_FRAMEWORK}
317315
)
318316
endif()
319317

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ endif # LLAMA_CLBLAST
283283
ifdef LLAMA_METAL
284284
CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG
285285
CXXFLAGS += -DGGML_USE_METAL
286-
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
286+
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
287287
OBJS += ggml-metal.o
288288
endif # LLAMA_METAL
289289

flake.nix

-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
with pkgs.darwin.apple_sdk_11_0.frameworks; [
1515
Accelerate
1616
MetalKit
17-
MetalPerformanceShaders
18-
MetalPerformanceShadersGraph
1917
]
2018
else if isAarch32 && isDarwin then
2119
with pkgs.darwin.apple_sdk.frameworks; [

0 commit comments

Comments
 (0)