Skip to content

Commit 1ec9c0f

Browse files
committed
narrow arguments
Signed-off-by: Kyle Sayers <[email protected]>
1 parent e884298 commit 1ec9c0f

File tree

3 files changed

+12
-16
lines changed

3 files changed

+12
-16
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(self, name: str, scheme: TransformScheme, seed: int = 42):
4040

4141
def create_transform(self, module: Module, args: TransformArgs):
4242
assert isinstance(module, Linear)
43-
size = get_matrix_size(module, args)
43+
size = get_matrix_size(module, args.location)
4444
dtype = module.weight.dtype
4545
device = get_offloaded_device(module)
4646

@@ -76,4 +76,4 @@ def forward(self, value: Tensor) -> Tensor:
7676
# if self.permutation is not None:
7777
# weight = apply_permutation(weight, self.permutation)
7878

79-
return apply_transform_weight(weight, value, self.args)
79+
return apply_transform_weight(weight, value, self.args.location)

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional
16-
1715
import torch
1816
from compressed_tensors.transform import TransformArgs, TransformScheme
1917
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
@@ -42,7 +40,7 @@ def __init__(
4240

4341
def create_transform(self, module: Module, args: TransformArgs):
4442
assert isinstance(module, Linear)
45-
size = get_matrix_size(module, args)
43+
size = get_matrix_size(module, args.location)
4644
dtype = module.weight.dtype
4745
device = get_offloaded_device(module)
4846

@@ -68,7 +66,7 @@ def __init__(self, weight: Tensor, args: TransformArgs):
6866
self.args = args
6967

7068
def forward(self, value: Tensor) -> Parameter:
71-
return apply_transform_weight(self.weight, value, self.args)
69+
return apply_transform_weight(self.weight, value, self.args.location)
7270

7371

7472
def high_precision_invert(weight: Tensor) -> Tensor:

src/compressed_tensors/transform/utils/utils.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Literal
16-
1715
import torch
18-
from compressed_tensors.transform import TransformArgs, TransformLocation
16+
from compressed_tensors.transform import TransformLocation
1917

2018

2119
__all__ = ["get_matrix_size", "apply_transform_weight", "apply_permutation"]
2220

2321

24-
def get_matrix_size(module: torch.nn.Module, args: TransformArgs) -> int:
22+
def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int:
2523
assert isinstance(module, torch.nn.Linear)
26-
if args.location in ("input", TransformLocation.WEIGHT_INPUT):
24+
if location in ("input", TransformLocation.WEIGHT_INPUT):
2725
return module.in_features
2826
else:
2927
return module.out_features
@@ -32,7 +30,7 @@ def get_matrix_size(module: torch.nn.Module, args: TransformArgs) -> int:
3230
def apply_transform_weight(
3331
weight: torch.Tensor,
3432
value: torch.Tensor,
35-
args: TransformArgs, # TODO: only pass location
33+
location: TransformLocation,
3634
) -> torch.Tensor:
3735
# let x be input activation
3836
# W be weight,
@@ -57,16 +55,16 @@ def apply_transform_weight(
5755
# = y U
5856
# = yh
5957

60-
if args.location == TransformLocation.INPUT:
58+
if location == TransformLocation.INPUT:
6159
return value @ weight
6260

63-
elif args.location == TransformLocation.WEIGHT_INPUT:
61+
elif location == TransformLocation.WEIGHT_INPUT:
6462
return value @ weight.T
6563

66-
elif args.location == TransformLocation.WEIGHT_OUTPUT:
64+
elif location == TransformLocation.WEIGHT_OUTPUT:
6765
return weight.T @ value
6866

69-
elif args.location == TransformLocation.OUTPUT:
67+
elif location == TransformLocation.OUTPUT:
7068
return value @ weight
7169

7270

0 commit comments

Comments
 (0)