Skip to content

Commit a8c9dfa

Browse files
authored
Fix some issues found by Mypy (#995)
* Fix erroneous type aliasing * Fix `Optional` typings (see PEP 484) * Add Mypy ignores * Fix Mypy complaints for method tables * Fix type for get_ptr * Fix various Mypy errors * Fix missed call to is_triton_available
1 parent 32be289 commit a8c9dfa

File tree

7 files changed

+168
-117
lines changed

7 files changed

+168
-117
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import warnings
33
from dataclasses import dataclass
44
from functools import reduce # Required in Python 3
5-
from typing import Tuple, Optional, List
5+
from typing import Tuple, Optional, Callable
66
from warnings import warn
77

88
import torch
@@ -14,9 +14,6 @@
1414
def prod(iterable):
1515
return reduce(operator.mul, iterable, 1)
1616

17-
tensor = torch.Tensor
18-
19-
2017
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
2118
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
2219

@@ -56,7 +53,10 @@ def get_current_outlier_idx(self):
5653
return torch.Tensor(list(self.outliers)).to(torch.int64)
5754

5855

59-
def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int, int]):
56+
def get_inverse_transform_indices(
57+
transform_tile: Callable[[torch.Tensor], torch.Tensor],
58+
tile_size: Tuple[int, int],
59+
):
6060
"""
6161
Compute a permutation of indices that invert the specified (tiled) matrix transformation
6262
@@ -496,7 +496,7 @@ class MatMul4Bit(torch.autograd.Function):
496496
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
497497

498498
@staticmethod
499-
def forward(ctx, A, B, out=None, bias=None, quant_state: F.QuantState = None):
499+
def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None):
500500
# default of pytorch behavior if inputs are empty
501501
ctx.is_empty = False
502502
if prod(A.shape) == 0:
@@ -549,10 +549,10 @@ def backward(ctx, grad_output):
549549

550550

551551
def matmul(
552-
A: tensor,
553-
B: tensor,
554-
out: tensor = None,
555-
state: MatmulLtState = None,
552+
A: torch.Tensor,
553+
B: torch.Tensor,
554+
out: Optional[torch.Tensor] = None,
555+
state: Optional[MatmulLtState] = None,
556556
threshold=0.0,
557557
bias=None
558558
):
@@ -562,7 +562,7 @@ def matmul(
562562
return MatMul8bitLt.apply(A, B, out, bias, state)
563563

564564

565-
def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = None, bias=None):
565+
def matmul_4bit(A: torch.Tensor, B: torch.Tensor, quant_state: F.QuantState, out: Optional[torch.Tensor] = None, bias=None):
566566
assert quant_state is not None
567567
if A.numel() == A.shape[-1] and A.requires_grad == False:
568568
if A.shape[-1] % quant_state.blocksize != 0:

bitsandbytes/cuda_setup/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
# not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt
3535
system = platform.system()
3636
if system == 'Windows':
37-
CUDA_RUNTIME_LIBS: list = ["nvcuda.dll"]
37+
CUDA_RUNTIME_LIBS = ["nvcuda.dll"]
3838
else: # Linux or other
39-
CUDA_RUNTIME_LIBS: list = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0', 'libcudart.so.12.1', 'libcudart.so.12.2']
39+
CUDA_RUNTIME_LIBS = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0', 'libcudart.so.12.1', 'libcudart.so.12.2']
4040

4141
# this is a order list of backup paths to search CUDA in, if it cannot be found in the main environmental paths
4242
backup_paths = []

0 commit comments

Comments
 (0)