Skip to content

Commit e1f9581

Browse files
authored
Add hip def for cuda v2
1 parent 3bff5c0 commit e1f9581

File tree

1 file changed

+57
-1
lines changed

1 file changed

+57
-1
lines changed

otherarch/ggml_v2-cuda.cu

+57-1
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,66 @@
44
#include <stdio.h>
55
#include <atomic>
66

7+
#if defined(GGML_USE_HIPBLAS)
8+
#include <hip/hip_runtime.h>
9+
#include <hipblas/hipblas.h>
10+
#include <hip/hip_fp16.h>
11+
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
12+
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
13+
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
14+
#define CUBLAS_OP_N HIPBLAS_OP_N
15+
#define CUBLAS_OP_T HIPBLAS_OP_T
16+
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
17+
#define CUBLAS_TF32_TENSOR_OP_MATH 0
18+
#define CUDA_R_16F HIPBLAS_R_16F
19+
#define CUDA_R_32F HIPBLAS_R_32F
20+
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
21+
#define cublasCreate hipblasCreate
22+
#define cublasGemmEx hipblasGemmEx
23+
#define cublasHandle_t hipblasHandle_t
24+
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
25+
#define cublasSetStream hipblasSetStream
26+
#define cublasSgemm hipblasSgemm
27+
#define cublasStatus_t hipblasStatus_t
28+
#define cudaDeviceProp hipDeviceProp_t
29+
#define cudaDeviceSynchronize hipDeviceSynchronize
30+
#define cudaError_t hipError_t
31+
#define cudaEventCreateWithFlags hipEventCreateWithFlags
32+
#define cudaEventDisableTiming hipEventDisableTiming
33+
#define cudaEventRecord hipEventRecord
34+
#define cudaEvent_t hipEvent_t
35+
#define cudaFree hipFree
36+
#define cudaFreeHost hipHostFree
37+
#define cudaGetDevice hipGetDevice
38+
#define cudaGetDeviceCount hipGetDeviceCount
39+
#define cudaGetDeviceProperties hipGetDeviceProperties
40+
#define cudaGetErrorString hipGetErrorString
41+
#define cudaGetLastError hipGetLastError
42+
#define cudaMalloc hipMalloc
43+
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
44+
#define cudaMemcpy hipMemcpy
45+
#define cudaMemcpy2DAsync hipMemcpy2DAsync
46+
#define cudaMemcpyAsync hipMemcpyAsync
47+
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
48+
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
49+
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
50+
#define cudaMemcpyKind hipMemcpyKind
51+
#define cudaMemset hipMemset
52+
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
53+
#define cudaSetDevice hipSetDevice
54+
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
55+
#define cudaStreamNonBlocking hipStreamNonBlocking
56+
#define cudaStreamSynchronize hipStreamSynchronize
57+
#define cudaStreamWaitEvent hipStreamWaitEvent
58+
#define cudaStream_t hipStream_t
59+
#define cudaSuccess hipSuccess
60+
#else
761
#include <cuda_runtime.h>
862
#include <cublas_v2.h>
963
#include <cuda_fp16.h>
1064

65+
#endif
66+
1167
#include "ggml_v2-cuda.h"
1268
#include "ggml_v2.h"
1369

@@ -807,4 +863,4 @@ void ggml_v2_cuda_transform_tensor(ggml_v2_tensor * tensor) {
807863

808864
tensor->data = d_Q;
809865
tensor->backend = GGML_V2_BACKEND_CUDA;
810-
}
866+
}

0 commit comments

Comments
 (0)