Skip to content

Commit 6dea2e7

Browse files
committed
Config serde
stack-info: PR: #1806, branch: drisspg/stack/41
1 parent eecf59b commit 6dea2e7

File tree

9 files changed

+435
-82
lines changed

9 files changed

+435
-82
lines changed

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,9 @@ def bool_to_on_off(value):
425425
version=version + version_suffix,
426426
packages=find_packages(),
427427
include_package_data=True,
428+
install_requires=[
429+
"pydantic>=2",
430+
],
428431
package_data={
429432
"torchao.kernel.configs": ["*.pkl"],
430433
},

test/dtypes/test_affine_quantized.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def get_quantization_functions(
6161
if device == "cuda" and not is_ROCM():
6262
base_functions.append(
6363
int8_dynamic_activation_int4_weight(
64-
group_size=None,
64+
group_size=32,
6565
mapping_type=MappingType.SYMMETRIC,
6666
act_mapping_type=MappingType.SYMMETRIC,
6767
layout=CutlassInt4PackedLayout(),
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
import json
2+
import os
3+
import tempfile
4+
5+
import pytest
6+
import torch
7+
8+
from torchao.core.config import reconstruct_from_dict, to_reconstructable_dict
9+
from torchao.quantization.quant_api import (
10+
Float8DynamicActivationFloat8WeightConfig,
11+
Float8WeightOnlyConfig,
12+
FPXWeightOnlyConfig,
13+
GemliteUIntXWeightOnlyConfig,
14+
Int4DynamicActivationInt4WeightConfig,
15+
Int4WeightOnlyConfig,
16+
Int8DynamicActivationInt4WeightConfig,
17+
Int8DynamicActivationInt8WeightConfig,
18+
Int8WeightOnlyConfig,
19+
PerRow,
20+
UIntXWeightOnlyConfig,
21+
)
22+
23+
# Define test configurations as fixtures
24+
configs = [
25+
Float8DynamicActivationFloat8WeightConfig(),
26+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
27+
Float8WeightOnlyConfig(
28+
weight_dtype=torch.float8_e4m3fn,
29+
),
30+
UIntXWeightOnlyConfig(dtype=torch.uint1),
31+
Int4DynamicActivationInt4WeightConfig(),
32+
Int4WeightOnlyConfig(
33+
group_size=32,
34+
),
35+
Int8DynamicActivationInt4WeightConfig(
36+
group_size=64,
37+
),
38+
Int8DynamicActivationInt8WeightConfig(),
39+
# Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()),
40+
Int8WeightOnlyConfig(
41+
group_size=128,
42+
),
43+
UIntXWeightOnlyConfig(
44+
dtype=torch.uint3,
45+
group_size=32,
46+
use_hqq=True,
47+
),
48+
GemliteUIntXWeightOnlyConfig(
49+
group_size=128, # Optional, has default of 64
50+
bit_width=8, # Optional, has default of 4
51+
packing_bitwidth=8, # Optional, has default of 32
52+
contiguous=True, # Optional, has default of None
53+
),
54+
FPXWeightOnlyConfig(ebits=4, mbits=8),
55+
]
56+
57+
58+
# Create ids for better test naming
59+
def get_config_ids(configs):
60+
return [config.__class__.__name__ for config in configs]
61+
62+
63+
# Parametrized tests
64+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
65+
def test_to_dict_serialization(config):
66+
"""Test that all configs can be serialized to a dictionary."""
67+
# Test to_dict method exists and returns a dict
68+
assert hasattr(
69+
config, "to_dict"
70+
), f"{config.__class__.__name__} missing to_dict method"
71+
result = config.to_dict()
72+
assert isinstance(result, dict)
73+
74+
# Check that all essential attributes are present in the dict
75+
for attr_name in config.__dict__:
76+
if not attr_name.startswith("_"): # Skip private attributes
77+
assert attr_name in result, f"{attr_name} missing in serialized dict"
78+
79+
80+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
81+
def test_to_json_serialization(config):
82+
"""Test that all configs can be serialized to JSON."""
83+
# Test to_json method exists and returns a string
84+
assert hasattr(
85+
config, "to_json"
86+
), f"{config.__class__.__name__} missing to_json method"
87+
json_str = config.to_json()
88+
assert isinstance(json_str, str)
89+
90+
# Verify it's valid JSON
91+
try:
92+
parsed = json.loads(json_str)
93+
assert isinstance(parsed, dict)
94+
except json.JSONDecodeError as e:
95+
pytest.fail(f"Invalid JSON for {config.__class__.__name__}: {e}")
96+
97+
98+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
99+
def test_from_dict_deserialization(config):
100+
"""Test that all configs can be deserialized from a dictionary."""
101+
# Get the class of the instance
102+
cls = config.__class__
103+
104+
# Serialize to dict
105+
data = config.to_dict()
106+
107+
# Test from_dict class method exists
108+
assert hasattr(cls, "from_dict"), f"{cls.__name__} missing from_dict class method"
109+
110+
# Deserialize back to instance
111+
deserialized = cls.from_dict(data)
112+
113+
# Check it's the right class
114+
assert isinstance(deserialized, cls)
115+
116+
# Compare key attributes
117+
for attr_name in config.__dict__:
118+
if not attr_name.startswith("_"): # Skip private attributes
119+
original_value = getattr(config, attr_name)
120+
deserialized_value = getattr(deserialized, attr_name)
121+
122+
# Special handling for torch dtypes
123+
if (
124+
hasattr(original_value, "__module__")
125+
and original_value.__module__ == "torch"
126+
):
127+
assert str(original_value) == str(
128+
deserialized_value
129+
), f"Attribute {attr_name} mismatch for {cls.__name__}"
130+
else:
131+
assert (
132+
original_value == deserialized_value
133+
), f"Attribute {attr_name} mismatch for {cls.__name__}"
134+
135+
136+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
137+
def test_from_json_deserialization(config):
138+
"""Test that all configs can be deserialized from JSON."""
139+
# Get the class of the instance
140+
cls = config.__class__
141+
142+
# Serialize to JSON
143+
json_str = config.to_json()
144+
145+
# Test from_json class method exists
146+
assert hasattr(cls, "from_json"), f"{cls.__name__} missing from_json class method"
147+
148+
# Deserialize back to instance
149+
deserialized = cls.from_json(json_str)
150+
151+
# Check it's the right class
152+
assert isinstance(deserialized, cls)
153+
154+
# Verify the instance is equivalent to the original
155+
# This assumes __eq__ is properly implemented
156+
assert (
157+
config == deserialized
158+
), f"Deserialized instance doesn't match original for {cls.__name__}"
159+
160+
161+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
162+
def test_round_trip_equivalence(config):
163+
"""Test complete serialization and deserialization round trip."""
164+
# JSON round trip
165+
json_str = config.to_json()
166+
deserialized_from_json = config.__class__.from_json(json_str)
167+
assert (
168+
config == deserialized_from_json
169+
), f"JSON round trip failed for {config.__class__.__name__}"
170+
171+
# Dict round trip
172+
data_dict = config.to_dict()
173+
deserialized_from_dict = config.__class__.from_dict(data_dict)
174+
assert (
175+
config == deserialized_from_dict
176+
), f"Dict round trip failed for {config.__class__.__name__}"
177+
178+
179+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
180+
def test_reconstructable_dict_file_round_trip(config):
181+
"""Test saving and loading reconstructable dicts to/from JSON files."""
182+
# Get a reconstructable dict
183+
reconstructable = to_reconstructable_dict(config)
184+
185+
# Create a temporary file to save the JSON
186+
with tempfile.NamedTemporaryFile(
187+
mode="w+", suffix=".json", delete=False
188+
) as temp_file:
189+
# Write the reconstructable dict as JSON
190+
json.dump(reconstructable, temp_file)
191+
temp_file_path = temp_file.name
192+
193+
try:
194+
# Read back the JSON file
195+
with open(temp_file_path, "r") as file:
196+
loaded_dict = json.load(file)
197+
198+
# Reconstruct from the loaded dict
199+
reconstructed = reconstruct_from_dict(loaded_dict)
200+
201+
# Check it's the right class
202+
assert isinstance(reconstructed, config.__class__)
203+
204+
# Verify attributes match
205+
for attr_name in config.__dict__:
206+
if not attr_name.startswith("_"): # Skip private attributes
207+
original_value = getattr(config, attr_name)
208+
reconstructed_value = getattr(reconstructed, attr_name)
209+
210+
# Special handling for torch dtypes
211+
if (
212+
hasattr(original_value, "__module__")
213+
and original_value.__module__ == "torch"
214+
):
215+
assert (
216+
str(original_value) == str(reconstructed_value)
217+
), f"Attribute {attr_name} mismatch after file round trip for {config.__class__.__name__}"
218+
else:
219+
assert (
220+
original_value == reconstructed_value
221+
), f"Attribute {attr_name} mismatch after file round trip for {config.__class__.__name__}"
222+
223+
finally:
224+
# Clean up the temporary file
225+
if os.path.exists(temp_file_path):
226+
os.unlink(temp_file_path)
227+
228+
229+
if __name__ == "__main__":
230+
pytest.main([__file__])

test/quantization/test_observer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def test_block_size_calc_success(self):
9090
obs = AffineQuantizedMinMaxObserver(
9191
MappingType.SYMMETRIC,
9292
torch.float8_e4m3fn,
93-
granularity=PerAxis(1),
93+
granularity=PerAxis(axis=1),
9494
eps=torch.finfo(torch.float32).eps,
9595
scale_dtype=torch.float,
9696
zero_point_dtype=torch.int,
@@ -105,7 +105,7 @@ def test_block_size_row_errors(self):
105105
obs = AffineQuantizedMinMaxObserver(
106106
MappingType.SYMMETRIC,
107107
torch.float8_e4m3fn,
108-
granularity=PerAxis(0),
108+
granularity=PerAxis(axis=0),
109109
eps=torch.finfo(torch.float32).eps,
110110
scale_dtype=torch.float,
111111
zero_point_dtype=torch.int,
@@ -124,7 +124,7 @@ def test_block_size_row_errors(self):
124124
obs = AffineQuantizedMinMaxObserver(
125125
MappingType.SYMMETRIC,
126126
torch.float8_e4m3fn,
127-
granularity=PerAxis(1),
127+
granularity=PerAxis(axis=1),
128128
eps=torch.finfo(torch.float32).eps,
129129
scale_dtype=torch.float,
130130
zero_point_dtype=torch.int,

test/quantization/test_qat.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -798,15 +798,15 @@ def test_fake_quantize_config_granularity(self):
798798
self.assertIsInstance(per_token_config2.granularity, PerToken)
799799

800800
# per channel
801-
per_channel_config1 = FakeQuantizeConfig(torch.int8, PerAxis(0))
801+
per_channel_config1 = FakeQuantizeConfig(torch.int8, PerAxis(axis=0))
802802
per_channel_config2 = FakeQuantizeConfig(torch.int8, "per_channel")
803803
self.assertIsInstance(per_channel_config1.granularity, PerAxis)
804804
self.assertIsInstance(per_channel_config2.granularity, PerAxis)
805805
self.assertEqual(per_channel_config1.granularity.axis, 0)
806806
self.assertEqual(per_channel_config2.granularity.axis, 0)
807807

808808
# per group
809-
per_group_config1 = FakeQuantizeConfig(torch.int8, PerGroup(32))
809+
per_group_config1 = FakeQuantizeConfig(torch.int8, PerGroup(group_size=32))
810810
per_group_config2 = FakeQuantizeConfig(torch.int8, "per_group", group_size=32)
811811
per_group_config3 = FakeQuantizeConfig(torch.int8, group_size=32)
812812
self.assertIsInstance(per_group_config1.granularity, PerGroup)
@@ -842,7 +842,7 @@ def test_fake_quantize_config_granularity_error_cases(self):
842842
with self.assertRaisesRegex(ValueError, msg):
843843
FakeQuantizeConfig(torch.int8, PerToken(), group_size=32)
844844
with self.assertRaisesRegex(ValueError, msg):
845-
FakeQuantizeConfig(torch.int8, PerGroup(64), group_size=32)
845+
FakeQuantizeConfig(torch.int8, PerGroup(group_size=64), group_size=32)
846846
with self.assertRaisesRegex(ValueError, msg):
847847
FakeQuantizeConfig(torch.int8, "per_token", group_size=32)
848848

@@ -855,7 +855,7 @@ def test_fake_quantize_config_granularity_error_cases(self):
855855
with self.assertRaisesRegex(ValueError, "not supported"):
856856
FakeQuantizeConfig(torch.int8, PerRow())
857857
with self.assertRaisesRegex(ValueError, "Only axis=0 is supported"):
858-
FakeQuantizeConfig(torch.int8, PerAxis(1))
858+
FakeQuantizeConfig(torch.int8, PerAxis(axis=1))
859859
with self.assertRaisesRegex(ValueError, "Unexpected granularity"):
860860
FakeQuantizeConfig(torch.int8, "blah")
861861
with self.assertRaisesRegex(ValueError, "unexpected type"):
@@ -1240,7 +1240,9 @@ def test_quantize_api_standalone(self):
12401240
weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
12411241
quantize_(
12421242
m,
1243-
intx_quantization_aware_training(activation_config, weight_config),
1243+
intx_quantization_aware_training(
1244+
activation_config=activation_config, weight_config=weight_config
1245+
),
12441246
)
12451247
quantize_(
12461248
m,
@@ -1273,15 +1275,19 @@ def test_quantize_api_errors(self):
12731275
):
12741276
quantize_(
12751277
m,
1276-
intx_quantization_aware_training(my_config, my_config),
1278+
intx_quantization_aware_training(
1279+
activation_config=my_config, weight_config=my_config
1280+
),
12771281
lambda m, _: isinstance(m, torch.nn.Embedding),
12781282
)
12791283

12801284
# Only linear and embedding are supported currently
12811285
with self.assertRaisesRegex(ValueError, "does not have QAT support"):
12821286
quantize_(
12831287
m,
1284-
intx_quantization_aware_training(my_config, my_config),
1288+
intx_quantization_aware_training(
1289+
activation_config=my_config, weight_config=my_config
1290+
),
12851291
lambda m, _: isinstance(m, torch.nn.ReLU),
12861292
)
12871293

@@ -1320,7 +1326,9 @@ def test_quantize_api_convert_path(self):
13201326
weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
13211327
quantize_(
13221328
m,
1323-
intx_quantization_aware_training(activation_config, weight_config),
1329+
intx_quantization_aware_training(
1330+
activation_config=activation_config, weight_config=weight_config
1331+
),
13241332
)
13251333

13261334
# Compare prepared values
@@ -1395,7 +1403,9 @@ def test_qat_linear_bias(self):
13951403
weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=32)
13961404
quantize_(
13971405
m,
1398-
intx_quantization_aware_training(activation_config, weight_config),
1406+
intx_quantization_aware_training(
1407+
activation_config=activation_config, weight_config=weight_config
1408+
),
13991409
)
14001410
example_inputs = m.example_inputs()
14011411
m(*example_inputs)

0 commit comments

Comments
 (0)