2
2
import warnings
3
3
from dataclasses import dataclass
4
4
from functools import reduce # Required in Python 3
5
- from typing import Tuple , Optional , List
5
+ from typing import Tuple , Optional , Callable
6
6
from warnings import warn
7
7
8
8
import torch
14
14
def prod (iterable ):
15
15
return reduce (operator .mul , iterable , 1 )
16
16
17
- tensor = torch .Tensor
18
-
19
-
20
17
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
21
18
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
22
19
@@ -56,7 +53,10 @@ def get_current_outlier_idx(self):
56
53
return torch .Tensor (list (self .outliers )).to (torch .int64 )
57
54
58
55
59
- def get_inverse_transform_indices (transform_tile : callable , tile_size : Tuple [int , int ]):
56
+ def get_inverse_transform_indices (
57
+ transform_tile : Callable [[torch .Tensor ], torch .Tensor ],
58
+ tile_size : Tuple [int , int ],
59
+ ):
60
60
"""
61
61
Compute a permutation of indices that invert the specified (tiled) matrix transformation
62
62
@@ -496,7 +496,7 @@ class MatMul4Bit(torch.autograd.Function):
496
496
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
497
497
498
498
@staticmethod
499
- def forward (ctx , A , B , out = None , bias = None , quant_state : F .QuantState = None ):
499
+ def forward (ctx , A , B , out = None , bias = None , quant_state : Optional [ F .QuantState ] = None ):
500
500
# default of pytorch behavior if inputs are empty
501
501
ctx .is_empty = False
502
502
if prod (A .shape ) == 0 :
@@ -549,10 +549,10 @@ def backward(ctx, grad_output):
549
549
550
550
551
551
def matmul (
552
- A : tensor ,
553
- B : tensor ,
554
- out : tensor = None ,
555
- state : MatmulLtState = None ,
552
+ A : torch . Tensor ,
553
+ B : torch . Tensor ,
554
+ out : Optional [ torch . Tensor ] = None ,
555
+ state : Optional [ MatmulLtState ] = None ,
556
556
threshold = 0.0 ,
557
557
bias = None
558
558
):
@@ -562,7 +562,7 @@ def matmul(
562
562
return MatMul8bitLt .apply (A , B , out , bias , state )
563
563
564
564
565
- def matmul_4bit (A : tensor , B : tensor , quant_state : F .QuantState , out : tensor = None , bias = None ):
565
+ def matmul_4bit (A : torch . Tensor , B : torch . Tensor , quant_state : F .QuantState , out : Optional [ torch . Tensor ] = None , bias = None ):
566
566
assert quant_state is not None
567
567
if A .numel () == A .shape [- 1 ] and A .requires_grad == False :
568
568
if A .shape [- 1 ] % quant_state .blocksize != 0 :
0 commit comments