Skip to content

Commit 719440e

Browse files
Add Int4CPULayout and update int4 woq (#1278)
* Add Int4CPULayout and update int4 woq * Apply automatic Ruff fixes * Fix CI * Remote nightly * Apply automatic Ruff fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 543209b commit 719440e

File tree

14 files changed

+448
-93
lines changed

14 files changed

+448
-93
lines changed

.github/workflows/regression_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ jobs:
7070
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121'
7171
gpu-arch-type: "cuda"
7272
gpu-arch-version: "12.1"
73+
7374
- name: CPU 2.3
7475
runs-on: linux.4xlarge
7576
torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu'

test/dtypes/test_affine_quantized.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
run_tests,
99
)
1010

11-
from torchao.dtypes import SemiSparseLayout
11+
from torchao.dtypes import Int4CPULayout, SemiSparseLayout
1212
from torchao.quantization import (
1313
float8_weight_only,
1414
int4_weight_only,
@@ -17,20 +17,25 @@
1717
int8_weight_only,
1818
)
1919
from torchao.quantization.quant_primitives import MappingType
20-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
20+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6
2121

2222
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
2323

2424

25-
def get_quantization_functions(do_sparse: bool, do_int4: bool):
25+
def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"):
2626
base_functions = [
2727
int8_weight_only(),
2828
int8_dynamic_activation_int4_weight(),
2929
int8_dynamic_activation_int8_weight(),
3030
int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC),
3131
]
3232
if do_int4:
33-
base_functions.append(int4_weight_only(group_size=32))
33+
if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6:
34+
base_functions.append(
35+
int4_weight_only(group_size=32, layout=Int4CPULayout())
36+
)
37+
else:
38+
base_functions.append(int4_weight_only(group_size=32))
3439

3540
if do_sparse:
3641
base_functions.append(
@@ -152,30 +157,28 @@ class TestAffineQuantizedBasic(TestCase):
152157
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
153158
COMMON_DTYPES = [torch.bfloat16]
154159

155-
@common_utils.parametrize("apply_quant", get_quantization_functions(False, True))
156160
@common_utils.parametrize("device", COMMON_DEVICES)
157161
@common_utils.parametrize("dtype", COMMON_DTYPES)
158-
def test_flatten_unflatten(self, apply_quant, device, dtype):
159-
if device == "cpu":
160-
self.skipTest(f"Temporarily skipping for {device}")
161-
162-
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
163-
ql = apply_quant(linear)
164-
lp_tensor = ql.weight
165-
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
166-
tensor_data_dict = {
167-
name: getattr(lp_tensor, name) for name in tensor_data_name_dict
168-
}
169-
outer_size = lp_tensor.size()
170-
outer_stride = lp_tensor.stride()
171-
reconstructed = type(lp_tensor).__tensor_unflatten__(
172-
tensor_data_dict, tensor_attributes, outer_size, outer_stride
173-
)
174-
example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),)
175-
ref = ql(*example_inputs)
176-
ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False)
177-
reconstruct_res = ql(*example_inputs)
178-
self.assertEqual(reconstruct_res, ref)
162+
def test_flatten_unflatten(self, device, dtype):
163+
apply_quant_list = get_quantization_functions(False, True, device)
164+
for apply_quant in apply_quant_list:
165+
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
166+
ql = apply_quant(linear)
167+
lp_tensor = ql.weight
168+
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
169+
tensor_data_dict = {
170+
name: getattr(lp_tensor, name) for name in tensor_data_name_dict
171+
}
172+
outer_size = lp_tensor.size()
173+
outer_stride = lp_tensor.stride()
174+
reconstructed = type(lp_tensor).__tensor_unflatten__(
175+
tensor_data_dict, tensor_attributes, outer_size, outer_stride
176+
)
177+
example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),)
178+
ref = ql(*example_inputs)
179+
ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False)
180+
reconstruct_res = ql(*example_inputs)
181+
self.assertEqual(reconstruct_res, ref)
179182

180183

181184
common_utils.instantiate_parametrized_tests(TestAffineQuantized)

test/integration/test_integration.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torchao.quantization.dynamic_quant import (
2020
DynamicallyPerAxisQuantizedLinear,
2121
)
22-
from torchao.dtypes import TensorCoreTiledLayout
22+
from torchao.dtypes import TensorCoreTiledLayout, Int4CPULayout
2323
from torchao.quantization.quant_api import (
2424
int4_weight_only,
2525
int8_weight_only,
@@ -93,6 +93,7 @@
9393
is_fbcode,
9494
benchmark_model
9595
)
96+
from torchao.dtypes.utils import is_device
9697

9798
logger = logging.getLogger("INFO")
9899

@@ -133,7 +134,10 @@ def _int8da_int8w_api(mod):
133134
change_linear_weights_to_int8_dqtensors(mod)
134135

135136
def _int4wo_api(mod):
136-
if TORCH_VERSION_AT_LEAST_2_4:
137+
if is_device(next(mod.parameters()).device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
138+
quantize_(mod, int4_weight_only(layout=Int4CPULayout()), set_inductor_config=False)
139+
unwrap_tensor_subclass(mod)
140+
elif TORCH_VERSION_AT_LEAST_2_4:
137141
quantize_(mod, int4_weight_only(), set_inductor_config=False)
138142
if not TORCH_VERSION_AT_LEAST_2_5:
139143
unwrap_tensor_subclass(mod)
@@ -935,10 +939,16 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
935939
self.skipTest(f"Temporarily skipping for {device}")
936940
if dtype != torch.bfloat16:
937941
self.skipTest(f"Fails for {dtype}")
942+
layout_list = []
943+
if device == 'cpu' and TORCH_VERSION_AT_LEAST_2_6:
944+
layout_list.append(Int4CPULayout())
945+
else:
946+
for inner_k_tiles in [4, 2]:
947+
layout_list.append(TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles))
938948
for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
939949
for groupsize in [64, 32]:
940-
for inner_k_tiles in [4, 2]:
941-
kwargs = {"groupsize": groupsize, "layout": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)}
950+
for layout in layout_list:
951+
kwargs = {"groupsize": groupsize, "layout": layout}
942952

943953
def api(mod):
944954
kwargs_copy = kwargs.copy()

test/quantization/test_quant_primitives.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
TORCH_VERSION_AT_LEAST_2_6,
3434
is_fbcode,
3535
)
36+
from torchao.dtypes.utils import is_device
3637

3738
_SEED = 1234
3839
torch.manual_seed(_SEED)
@@ -102,7 +103,8 @@ def _groupwise_affine_quantize_tensor_from_qparams(
102103
.reshape_as(w)
103104
)
104105
if TORCH_VERSION_AT_LEAST_2_5:
105-
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
106+
if not (is_device(w.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
107+
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
106108

107109
return w_int4x8
108110

@@ -524,8 +526,10 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
524526
groupsize = 128
525527

526528
if TORCH_VERSION_AT_LEAST_2_5:
527-
input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
528-
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize)
529+
input_tmp = input
530+
if not (is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
531+
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
532+
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_tmp, scales, zeros, n_bit, groupsize)
529533
else:
530534
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
531535
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)

torchao/dtypes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .nf4tensor import NF4Tensor, to_nf4
1717
from .uintx import (
1818
BlockSparseLayout,
19+
Int4CPULayout,
1920
MarlinQQQLayout,
2021
MarlinSparseLayout,
2122
SemiSparseLayout,
@@ -48,4 +49,5 @@
4849
"UintxLayout",
4950
"MarlinQQQTensor",
5051
"MarlinQQQLayout",
52+
"Int4CPULayout",
5153
]

torchao/dtypes/uintx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
SemiSparseLayout,
1212
)
1313
from .tensor_core_tiled_layout import (
14+
Int4CPULayout,
1415
TensorCoreTiledLayout,
1516
)
1617
from .uintx_layout import (
@@ -23,5 +24,6 @@
2324
"MarlinSparseLayout",
2425
"SemiSparseLayout",
2526
"TensorCoreTiledLayout",
27+
"Int4CPULayout",
2628
"MarlinQQQLayout",
2729
]

0 commit comments

Comments
 (0)