Skip to content

Commit 29a637b

Browse files
authored
Don't crash Python interpreter via assert(false) (#998)
1 parent 706ec24 commit 29a637b

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

bitsandbytes/functional.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1944,7 +1944,10 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
19441944
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
19451945
)
19461946

1947-
if has_error == 1:
1947+
if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
1948+
raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)")
1949+
1950+
if has_error:
19481951
print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
19491952
raise Exception('cublasLt ran into an error!')
19501953

csrc/ops.cu

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include <cassert>
1212
#include <common.h>
1313

14+
#define ERR_NOT_IMPLEMENTED 100
15+
1416

1517
using namespace BinSearch;
1618
using std::cout;
@@ -421,14 +423,7 @@ template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandl
421423
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
422424
{
423425
#ifdef NO_CUBLASLT
424-
cout << "" << endl;
425-
cout << "=============================================" << endl;
426-
cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl;
427-
cout << "=============================================" << endl;
428-
cout << "" << endl;
429-
assert(false);
430-
431-
return 0;
426+
return ERR_NOT_IMPLEMENTED;
432427
#else
433428
int has_error = 0;
434429
cublasLtMatmulDesc_t matmulDesc = NULL;
@@ -484,7 +479,7 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle
484479
printf("error detected");
485480

486481
return has_error;
487-
#endif
482+
#endif // NO_CUBLASLT
488483
}
489484

490485
int fill_up_to_nearest_multiple(int value, int multiple)

0 commit comments

Comments
 (0)