Skip to content

Commit bf49a93

Browse files
committed
move HIPBLAS definitions into ggml-cuda.h
1 parent 540f4e0 commit bf49a93

File tree

2 files changed

+57
-54
lines changed

2 files changed

+57
-54
lines changed

ggml-cuda.cu

+2-54
Original file line numberDiff line numberDiff line change
@@ -5,60 +5,8 @@
55
#include <stdio.h>
66
#include <atomic>
77
#include <assert.h>
8-
#if defined(GGML_USE_HIPBLAS)
9-
#include <hip/hip_runtime.h>
10-
#include <hipblas/hipblas.h>
11-
#include <hip/hip_fp16.h>
12-
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
13-
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
14-
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
15-
#define CUBLAS_OP_N HIPBLAS_OP_N
16-
#define CUBLAS_OP_T HIPBLAS_OP_T
17-
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
18-
#define CUBLAS_TF32_TENSOR_OP_MATH 0
19-
#define CUDA_R_16F HIPBLAS_R_16F
20-
#define CUDA_R_32F HIPBLAS_R_32F
21-
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
22-
#define cublasCreate hipblasCreate
23-
#define cublasGemmEx hipblasGemmEx
24-
#define cublasHandle_t hipblasHandle_t
25-
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
26-
#define cublasSetStream hipblasSetStream
27-
#define cublasSgemm hipblasSgemm
28-
#define cublasStatus_t hipblasStatus_t
29-
#define cudaDeviceProp hipDeviceProp_t
30-
#define cudaDeviceSynchronize hipDeviceSynchronize
31-
#define cudaError_t hipError_t
32-
#define cudaEventCreateWithFlags hipEventCreateWithFlags
33-
#define cudaEventDisableTiming hipEventDisableTiming
34-
#define cudaEventRecord hipEventRecord
35-
#define cudaEvent_t hipEvent_t
36-
#define cudaFree hipFree
37-
#define cudaFreeHost hipHostFree
38-
#define cudaGetDevice hipGetDevice
39-
#define cudaGetDeviceCount hipGetDeviceCount
40-
#define cudaGetDeviceProperties hipGetDeviceProperties
41-
#define cudaGetErrorString hipGetErrorString
42-
#define cudaGetLastError hipGetLastError
43-
#define cudaMalloc hipMalloc
44-
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
45-
#define cudaMemcpy hipMemcpy
46-
#define cudaMemcpy2DAsync hipMemcpy2DAsync
47-
#define cudaMemcpyAsync hipMemcpyAsync
48-
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
49-
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
50-
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
51-
#define cudaMemcpyKind hipMemcpyKind
52-
#define cudaMemset hipMemset
53-
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
54-
#define cudaSetDevice hipSetDevice
55-
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
56-
#define cudaStreamNonBlocking hipStreamNonBlocking
57-
#define cudaStreamSynchronize hipStreamSynchronize
58-
#define cudaStreamWaitEvent hipStreamWaitEvent
59-
#define cudaStream_t hipStream_t
60-
#define cudaSuccess hipSuccess
61-
#else
8+
9+
#ifndef GGML_USE_HIPBLAS
6210
#include <cuda_runtime.h>
6311
#include <cublas_v2.h>
6412
#include <cuda_fp16.h>

ggml-cuda.h

+55
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,61 @@
11
#pragma once
22

33
#include "ggml.h"
4+
#if defined(GGML_USE_HIPBLAS)
5+
#include <hip/hip_runtime.h>
6+
#include <hipblas/hipblas.h>
7+
#include <hip/hip_fp16.h>
8+
9+
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
10+
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
11+
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
12+
#define CUBLAS_OP_N HIPBLAS_OP_N
13+
#define CUBLAS_OP_T HIPBLAS_OP_T
14+
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
15+
#define CUBLAS_TF32_TENSOR_OP_MATH 0
16+
#define CUDA_R_16F HIPBLAS_R_16F
17+
#define CUDA_R_32F HIPBLAS_R_32F
18+
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
19+
#define cublasCreate hipblasCreate
20+
#define cublasGemmEx hipblasGemmEx
21+
#define cublasHandle_t hipblasHandle_t
22+
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
23+
#define cublasSetStream hipblasSetStream
24+
#define cublasSgemm hipblasSgemm
25+
#define cublasStatus_t hipblasStatus_t
26+
#define cudaDeviceProp hipDeviceProp_t
27+
#define cudaDeviceSynchronize hipDeviceSynchronize
28+
#define cudaError_t hipError_t
29+
#define cudaEventCreateWithFlags hipEventCreateWithFlags
30+
#define cudaEventDisableTiming hipEventDisableTiming
31+
#define cudaEventRecord hipEventRecord
32+
#define cudaEvent_t hipEvent_t
33+
#define cudaFree hipFree
34+
#define cudaFreeHost hipHostFree
35+
#define cudaGetDevice hipGetDevice
36+
#define cudaGetDeviceCount hipGetDeviceCount
37+
#define cudaGetDeviceProperties hipGetDeviceProperties
38+
#define cudaGetErrorString hipGetErrorString
39+
#define cudaGetLastError hipGetLastError
40+
#define cudaMalloc hipMalloc
41+
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
42+
#define cudaMemcpy hipMemcpy
43+
#define cudaMemcpy2DAsync hipMemcpy2DAsync
44+
#define cudaMemcpyAsync hipMemcpyAsync
45+
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
46+
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
47+
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
48+
#define cudaMemcpyKind hipMemcpyKind
49+
#define cudaMemset hipMemset
50+
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
51+
#define cudaSetDevice hipSetDevice
52+
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
53+
#define cudaStreamNonBlocking hipStreamNonBlocking
54+
#define cudaStreamSynchronize hipStreamSynchronize
55+
#define cudaStreamWaitEvent hipStreamWaitEvent
56+
#define cudaStream_t hipStream_t
57+
#define cudaSuccess hipSuccess
58+
#endif
459

560
#ifdef __cplusplus
661
extern "C" {

0 commit comments

Comments
 (0)