Skip to content

Commit 1a8d409

Browse files
committed
use glob(), search CUDA_PATH
1 parent efadf2c commit 1a8d409

File tree

1 file changed

+18
-27
lines changed

1 file changed

+18
-27
lines changed

bitsandbytes/__main__.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,32 +11,18 @@
1111

1212
HEADER_WIDTH = 60
1313

14-
def execute_and_return(command_string: str) -> Tuple[str, str]:
15-
def _decode(subprocess_err_out_tuple):
16-
return tuple(
17-
to_decode.decode("UTF-8").strip()
18-
for to_decode in subprocess_err_out_tuple
19-
)
20-
21-
def execute_and_return_decoded_std_streams(command_string):
22-
return _decode(
23-
subprocess.Popen(
24-
shlex.split(command_string),
25-
stdout=subprocess.PIPE,
26-
stderr=subprocess.PIPE,
27-
).communicate()
28-
)
29-
30-
std_out, std_err = execute_and_return_decoded_std_streams(command_string)
31-
return std_out, std_err
3214

3315
def find_file_recursive(folder, filename):
34-
cmd = f'find {folder} -name {filename}'
35-
out, err = execute_and_return(cmd)
36-
if len(err) > 0:
37-
raise RuntimeError('Something when wrong when trying to find file. Maybe you do not have a linux system?')
16+
import glob
17+
outs = []
18+
try:
19+
for ext in ["so", "dll", "dylib"]:
20+
out = glob.glob(os.path.join(folder, "**", filename + ext))
21+
outs.extend(out)
22+
except Exception as e:
23+
raise RuntimeError('Error: Something when wrong when trying to find file. {e}')
3824

39-
return out
25+
return outs
4026

4127

4228
def generate_bug_report_information():
@@ -46,26 +32,31 @@ def generate_bug_report_information():
4632
print('')
4733

4834
if 'CONDA_PREFIX' in os.environ:
49-
paths = find_file_recursive(os.environ['CONDA_PREFIX'], '*cuda*so')
35+
paths = find_file_recursive(os.environ['CONDA_PREFIX'], '*cuda*')
5036
print_header("ANACONDA CUDA PATHS")
5137
print(paths)
5238
print('')
5339
if isdir('/usr/local/'):
54-
paths = find_file_recursive('/usr/local', '*cuda*so')
40+
paths = find_file_recursive('/usr/local', '*cuda*')
5541
print_header("/usr/local CUDA PATHS")
5642
print(paths)
5743
print('')
44+
if 'CUDA_PATH' in os.environ and isdir(os.environ['CUDA_PATH']):
45+
paths = find_file_recursive(os.environ['CUDA_PATH'], '*cuda*')
46+
print_header("CUDA PATHS")
47+
print(paths)
48+
print('')
5849

5950
if isdir(os.getcwd()):
60-
paths = find_file_recursive(os.getcwd(), '*cuda*so')
51+
paths = find_file_recursive(os.getcwd(), '*cuda*')
6152
print_header("WORKING DIRECTORY CUDA PATHS")
6253
print(paths)
6354
print('')
6455

6556
print_header("LD_LIBRARY CUDA PATHS")
6657
if 'LD_LIBRARY_PATH' in os.environ:
6758
lib_path = os.environ['LD_LIBRARY_PATH'].strip()
68-
for path in set(lib_path.split(':')):
59+
for path in set(lib_path.split(':' if not os.sep == "\\" else ";")):
6960
try:
7061
if isdir(path):
7162
print_header(f"{path} CUDA PATHS")

0 commit comments

Comments
 (0)