@@ -147,9 +147,41 @@ inline static void* ggml_aligned_malloc(size_t size) {
147
147
#include <Accelerate/Accelerate.h>
148
148
#elif defined(GGML_USE_OPENBLAS )
149
149
#include <cblas.h>
150
- #elif defined(GGML_USE_CUBLAS )
150
+ #elif defined(GGML_USE_CUBLAS ) || defined(GGML_USE_HIPBLAS )
151
+
152
+ #if defined(GGML_USE_HIPBLAS )
153
+ #include "hipblas/hipblas.h"
154
+ #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
155
+ #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
156
+ #define CUBLAS_OP_N HIPBLAS_OP_N
157
+ #define CUBLAS_OP_T HIPBLAS_OP_T
158
+ #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
159
+ #define cublasCreate hipblasCreate
160
+ #define cublasGemmEx hipblasGemmEx
161
+ #define cublasHandle_t hipblasHandle_t
162
+ #define cublasSetStream hipblasSetStream
163
+ #define cublasSgemm hipblasSgemm
164
+ #define cublasStatus_t hipblasStatus_t
165
+ #define CUDA_R_16F HIPBLAS_R_16F
166
+ #define CUDA_R_32F HIPBLAS_R_32F
167
+ #define cudaError_t hipError_t
168
+ #define cudaFree hipFree
169
+ #define cudaGetErrorString hipGetErrorString
170
+ #define cudaGetLastError hipGetLastError
171
+ #define cudaMalloc hipMalloc
172
+ #define cudaMemcpyAsync hipMemcpyAsync
173
+ #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
174
+ #define cudaMemcpyHostToDevice hipMemcpyHostToDevice
175
+ #define cudaStream_t hipStream_t
176
+ #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
177
+ #define cudaStreamNonBlocking hipStreamNonBlocking
178
+ #define cudaStreamSynchronize hipStreamSynchronize
179
+ #define cudaSuccess hipSuccess
180
+ #define GGML_USE_CUBLAS
181
+ #else
151
182
#include <cublas_v2.h>
152
183
#include <cuda_runtime.h>
184
+ #endif
153
185
#include "ggml-cuda.h"
154
186
155
187
#define CUDA_CHECK (err ) \
@@ -8040,9 +8072,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
8040
8072
else if (type == GGML_TYPE_Q4_2 ) {
8041
8073
dequantize_row_q_cuda = dequantize_row_q4_2_cuda ;
8042
8074
}
8043
- else if (type == GGML_TYPE_Q4_3 ) {
8044
- dequantize_row_q_cuda = dequantize_row_q4_3_cuda ;
8045
- }
8046
8075
else {
8047
8076
GGML_ASSERT (false);
8048
8077
}
@@ -8076,7 +8105,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
8076
8105
const float * x = wdata ;
8077
8106
#endif
8078
8107
8079
-
8080
8108
#if defined(GGML_USE_CUBLAS )
8081
8109
// copy data to device
8082
8110
CUDA_CHECK (cudaMemcpyAsync (d_Y , y , sizeof (float ) * y_ne , cudaMemcpyHostToDevice , cudaStream ));
0 commit comments