We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 3db70b5 commit 1f6294dCopy full SHA for 1f6294d
ggml-cuda.cu
@@ -10,6 +10,7 @@
10
#include <hip/hip_runtime.h>
11
#include <hipblas/hipblas.h>
12
#include <hip/hip_fp16.h>
13
+#include "rocblas/rocblas.h"
14
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
15
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
16
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
@@ -2531,6 +2532,10 @@ void ggml_init_cublas() {
2531
2532
static bool initialized = false;
2533
2534
if (!initialized) {
2535
+#ifdef GGML_USE_HIPBLAS
2536
+ rocblas_initialize();
2537
+ hipDeviceSynchronize();
2538
+#endif
2539
CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
2540
GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
2541
int64_t total_vram = 0;
0 commit comments