Skip to content

Commit 3949441

Browse files
committed
use hipblas based on cublas
1 parent 66aab46 commit 3949441

File tree

4 files changed

+69
-5
lines changed

4 files changed

+69
-5
lines changed

CMakeLists.txt

+26
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ endif()
6767
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
6868
option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF)
6969
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
70+
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
7071

7172
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
7273
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
@@ -168,6 +169,31 @@ if (LLAMA_CUBLAS)
168169
endif()
169170
endif()
170171

172+
if (LLAMA_HIPBLAS)
173+
cmake_minimum_required(VERSION 3.21)
174+
175+
find_package(hip)
176+
find_package(hipblas)
177+
178+
if (hipblas_FOUND)
179+
message(STATUS "hipBLAS found")
180+
181+
set(LLAMA_HIPBLAS_PLATFORM "AMD" CACHE STRING "hip device type" FORCE)
182+
set_property(CACHE LLAMA_HIPBLAS_PLATFORM PROPERTY STRINGS "AMD" "NVIDIA")
183+
184+
add_compile_definitions(GGML_USE_HIPBLAS "__HIP_PLATFORM_${LLAMA_HIPBLAS_PLATFORM}__")
185+
186+
add_library(ggml-hip OBJECT ggml-cuda.cu)
187+
set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
188+
target_link_libraries(ggml-hip hip::device)
189+
190+
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} hip::host roc::hipblas ggml-hip)
191+
192+
else()
193+
message(WARNING "hipBLAS not found")
194+
endif()
195+
endif()
196+
171197
if (LLAMA_ALL_WARNINGS)
172198
if (NOT MSVC)
173199
set(c_flags

Makefile

+4
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ ifdef LLAMA_CUBLAS
107107
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
108108
nvcc -arch=native -c -o $@ $<
109109
endif
110+
ifdef LLAMA_HIPBLAS
111+
CFLAGS += -DGGML_USE_HIPBLAS -D__HIP_PLATFORM_AMD__ -I/opt/rocm/include
112+
LDFLAGS += -lhipblas -lamdhip64 -L/opt/rocm/lib
113+
endif
110114
ifdef LLAMA_GPROF
111115
CFLAGS += -pg
112116
CXXFLAGS += -pg

ggml-cuda.cu

+6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
#include <stdint.h>
2+
#if defined(__HIP_PLATFORM_AMD__)
3+
#include "hip/hip_runtime.h"
4+
#define cudaStream_t hipStream_t
5+
#define __half _Float16
6+
#else
27
#include <cuda_fp16.h>
8+
#endif
39
#include "ggml-cuda.h"
410

511
typedef uint16_t ggml_fp16_t;

ggml.c

+33-5
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,41 @@ inline static void* ggml_aligned_malloc(size_t size) {
147147
#include <Accelerate/Accelerate.h>
148148
#elif defined(GGML_USE_OPENBLAS)
149149
#include <cblas.h>
150-
#elif defined(GGML_USE_CUBLAS)
150+
#elif defined(GGML_USE_CUBLAS) || defined(GGML_USE_HIPBLAS)
151+
152+
#if defined(GGML_USE_HIPBLAS)
153+
#include "hipblas/hipblas.h"
154+
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
155+
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
156+
#define CUBLAS_OP_N HIPBLAS_OP_N
157+
#define CUBLAS_OP_T HIPBLAS_OP_T
158+
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
159+
#define cublasCreate hipblasCreate
160+
#define cublasGemmEx hipblasGemmEx
161+
#define cublasHandle_t hipblasHandle_t
162+
#define cublasSetStream hipblasSetStream
163+
#define cublasSgemm hipblasSgemm
164+
#define cublasStatus_t hipblasStatus_t
165+
#define CUDA_R_16F HIPBLAS_R_16F
166+
#define CUDA_R_32F HIPBLAS_R_32F
167+
#define cudaError_t hipError_t
168+
#define cudaFree hipFree
169+
#define cudaGetErrorString hipGetErrorString
170+
#define cudaGetLastError hipGetLastError
171+
#define cudaMalloc hipMalloc
172+
#define cudaMemcpyAsync hipMemcpyAsync
173+
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
174+
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
175+
#define cudaStream_t hipStream_t
176+
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
177+
#define cudaStreamNonBlocking hipStreamNonBlocking
178+
#define cudaStreamSynchronize hipStreamSynchronize
179+
#define cudaSuccess hipSuccess
180+
#define GGML_USE_CUBLAS
181+
#else
151182
#include <cublas_v2.h>
152183
#include <cuda_runtime.h>
184+
#endif
153185
#include "ggml-cuda.h"
154186

155187
#define CUDA_CHECK(err) \
@@ -8040,9 +8072,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
80408072
else if (type == GGML_TYPE_Q4_2) {
80418073
dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
80428074
}
8043-
else if (type == GGML_TYPE_Q4_3) {
8044-
dequantize_row_q_cuda = dequantize_row_q4_3_cuda;
8045-
}
80468075
else {
80478076
GGML_ASSERT(false);
80488077
}
@@ -8076,7 +8105,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
80768105
const float * x = wdata;
80778106
#endif
80788107

8079-
80808108
#if defined(GGML_USE_CUBLAS)
80818109
// copy data to device
80828110
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));

0 commit comments

Comments
 (0)