Skip to content

Commit 7686ffe

Browse files
committed
fixed library loading
1 parent 8eb0934 commit 7686ffe

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

bitsandbytes/cuda_setup/env_vars.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def to_be_ignored(env_var: str, value: str) -> bool:
2525

2626

2727
def might_contain_a_path(candidate: str) -> bool:
28-
return "/" in candidate
28+
return os.sep in candidate
2929

3030

3131
def is_active_conda_env(env_var: str) -> bool:

bitsandbytes/cuda_setup/main.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import ctypes as ct
2020
import os
2121
import errno
22+
import platform
2223
import torch
2324
from warnings import warn
2425
from itertools import product
@@ -31,7 +32,13 @@
3132
# libcudart.so is missing by default for a conda install with PyTorch 2.0 and instead
3233
# we have libcudart.so.11.0 which causes a lot of errors before
3334
# not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt
34-
CUDA_RUNTIME_LIBS: list = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0']
35+
system = platform.system()
36+
if system == 'Darwin':
37+
CUDA_RUNTIME_LIBS: list = ["libcuda.dylib", '/usr/local/cuda/lib/libcuda.dylib']
38+
elif system == 'Windows':
39+
CUDA_RUNTIME_LIBS: list = ["nvcuda.dll"]
40+
else: # Linux or other
41+
CUDA_RUNTIME_LIBS: list = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0']
3542

3643
# this is a order list of backup paths to search CUDA in, if it cannot be found in the main environmental paths
3744
backup_paths = []
@@ -194,7 +201,7 @@ def is_cublasLt_compatible(cc):
194201
return has_cublaslt
195202

196203
def extract_candidate_paths(paths_list_candidate: str) -> Set[Path]:
197-
return {Path(ld_path) for ld_path in paths_list_candidate.split(":") if ld_path}
204+
return {Path(ld_path) for ld_path in paths_list_candidate.split(";" if os.sep == "\\" else ":") if ld_path}
198205

199206

200207
def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]:

0 commit comments

Comments
 (0)