Skip to content

Commit 0664d5c

Browse files
committed
Config serde
stack-info: PR: #1806, branch: drisspg/stack/41
1 parent ed64709 commit 0664d5c

File tree

3 files changed

+214
-31
lines changed

3 files changed

+214
-31
lines changed
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import json
2+
3+
import pytest
4+
import torch
5+
6+
from torchao.quantization.quant_api import (
7+
Float8WeightOnlyConfig,
8+
UIntXWeightOnlyConfig,
9+
Int4DynamicActivationInt4WeightConfig,
10+
Int4WeightOnlyConfig,
11+
Int8DynamicActivationInt4WeightConfig,
12+
Int8DynamicActivationInt8WeightConfig,
13+
Int8WeightOnlyConfig
14+
)
15+
16+
# Define test configurations as fixtures
17+
configs = [
18+
# Float8DynamicActivationFloat8WeightConfig(),
19+
Float8WeightOnlyConfig(
20+
weight_dtype=torch.float8_e4m3fn,
21+
),
22+
UIntXWeightOnlyConfig(dtype=torch.uint1),
23+
# Int4DynamicActivationInt4WeightConfig(),
24+
Int4WeightOnlyConfig(
25+
group_size=32,
26+
),
27+
# Int8DynamicActivationInt4WeightConfig(
28+
# group_size=64,
29+
# ),
30+
# Int8DynamicActivationInt8WeightConfig(),
31+
# Int8WeightOnlyConfig(
32+
# group_size=128,
33+
# ),
34+
# UIntXWeightOnlyConfig(
35+
# bit_width=4,
36+
# group_size=32,
37+
# ),
38+
# Float8StaticActivationFloat8WeightConfig(
39+
# activation_dtype=torch.float8_e4m3fn,
40+
# weight_dtype=torch.float8_e4m3fn,
41+
# ),
42+
]
43+
44+
45+
# Create ids for better test naming
46+
def get_config_ids(configs):
47+
return [config.__class__.__name__ for config in configs]
48+
49+
50+
# Parametrized tests
51+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
52+
def test_to_dict_serialization(config):
53+
"""Test that all configs can be serialized to a dictionary."""
54+
# Test to_dict method exists and returns a dict
55+
assert hasattr(
56+
config, "to_dict"
57+
), f"{config.__class__.__name__} missing to_dict method"
58+
result = config.to_dict()
59+
assert isinstance(result, dict)
60+
61+
# Check that all essential attributes are present in the dict
62+
for attr_name in config.__dict__:
63+
if not attr_name.startswith("_"): # Skip private attributes
64+
assert attr_name in result, f"{attr_name} missing in serialized dict"
65+
66+
67+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
68+
def test_to_json_serialization(config):
69+
"""Test that all configs can be serialized to JSON."""
70+
# Test to_json method exists and returns a string
71+
assert hasattr(
72+
config, "to_json"
73+
), f"{config.__class__.__name__} missing to_json method"
74+
json_str = config.to_json()
75+
assert isinstance(json_str, str)
76+
77+
# Verify it's valid JSON
78+
try:
79+
parsed = json.loads(json_str)
80+
assert isinstance(parsed, dict)
81+
except json.JSONDecodeError as e:
82+
pytest.fail(f"Invalid JSON for {config.__class__.__name__}: {e}")
83+
84+
85+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
86+
def test_from_dict_deserialization(config):
87+
"""Test that all configs can be deserialized from a dictionary."""
88+
# Get the class of the instance
89+
cls = config.__class__
90+
91+
# Serialize to dict
92+
data = config.to_dict()
93+
94+
# Test from_dict class method exists
95+
assert hasattr(cls, "from_dict"), f"{cls.__name__} missing from_dict class method"
96+
97+
# Deserialize back to instance
98+
deserialized = cls.from_dict(data)
99+
100+
# Check it's the right class
101+
assert isinstance(deserialized, cls)
102+
103+
# Compare key attributes
104+
for attr_name in config.__dict__:
105+
if not attr_name.startswith("_"): # Skip private attributes
106+
original_value = getattr(config, attr_name)
107+
deserialized_value = getattr(deserialized, attr_name)
108+
109+
# Special handling for torch dtypes
110+
if (
111+
hasattr(original_value, "__module__")
112+
and original_value.__module__ == "torch"
113+
):
114+
assert str(original_value) == str(
115+
deserialized_value
116+
), f"Attribute {attr_name} mismatch for {cls.__name__}"
117+
else:
118+
assert (
119+
original_value == deserialized_value
120+
), f"Attribute {attr_name} mismatch for {cls.__name__}"
121+
122+
123+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
124+
def test_from_json_deserialization(config):
125+
"""Test that all configs can be deserialized from JSON."""
126+
# Get the class of the instance
127+
cls = config.__class__
128+
129+
# Serialize to JSON
130+
json_str = config.to_json()
131+
132+
# Test from_json class method exists
133+
assert hasattr(cls, "from_json"), f"{cls.__name__} missing from_json class method"
134+
135+
# Deserialize back to instance
136+
deserialized = cls.from_json(json_str)
137+
138+
# Check it's the right class
139+
assert isinstance(deserialized, cls)
140+
141+
# Verify the instance is equivalent to the original
142+
# This assumes __eq__ is properly implemented
143+
assert (
144+
config == deserialized
145+
), f"Deserialized instance doesn't match original for {cls.__name__}"
146+
147+
148+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
149+
def test_round_trip_equivalence(config):
150+
"""Test complete serialization and deserialization round trip."""
151+
# JSON round trip
152+
json_str = config.to_json()
153+
deserialized_from_json = config.__class__.from_json(json_str)
154+
assert (
155+
config == deserialized_from_json
156+
), f"JSON round trip failed for {config.__class__.__name__}"
157+
158+
# Dict round trip
159+
data_dict = config.to_dict()
160+
deserialized_from_dict = config.__class__.from_dict(data_dict)
161+
assert (
162+
config == deserialized_from_dict
163+
), f"Dict round trip failed for {config.__class__.__name__}"

torchao/core/config.py

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,56 @@
1-
import abc
1+
from typing import Any, Dict
22

3+
import torch
4+
from pydantic import BaseModel, field_serializer, model_validator
35

4-
class AOBaseConfig(abc.ABC):
5-
"""
6-
If a workflow config inherits from this then `quantize_` knows
7-
how to a apply it to a model. For example::
8-
9-
# user facing code
10-
class WorkflowFooConfig(AOBaseConfig): ...
11-
# configuration for workflow `Foo` is defined here
12-
bar = 'baz'
13-
14-
# non user facing code
15-
@register_quantize_module_handler(WorkflowFooConfig)
16-
def _transform(
17-
mod: torch.nn.Module,
18-
config: WorkflowFooConfig,
19-
) -> torch.nn.Module:
20-
# the transform is implemented here, usually a tensor sublass
21-
# weight swap or a module swap
22-
...
23-
24-
# then, the user calls `quantize_` with a config, and `_transform` is called
25-
# under the hood by `quantize_.
266

7+
class AOBaseConfig(BaseModel):
8+
"""
9+
Base configuration class with native Pydantic handling for torch.dtype.
2710
"""
2811

29-
pass
12+
model_config = {
13+
"arbitrary_types_allowed": True,
14+
"validate_assignment": True,
15+
"extra": "forbid",
16+
"validate_default": True,
17+
"populate_by_name": True,
18+
}
19+
20+
@field_serializer("*")
21+
def serialize_torch_dtype(self, v, _info):
22+
if isinstance(v, torch.dtype):
23+
return str(v)
24+
return v
25+
26+
@model_validator(mode="before")
27+
@classmethod
28+
def convert_dtypes(cls, data: Any) -> Any:
29+
"""Simple converter for torch dtype strings"""
30+
if isinstance(data, str) and data.startswith("torch."):
31+
dtype_name = data.split("torch.")[1]
32+
if hasattr(torch, dtype_name):
33+
return getattr(torch, dtype_name)
34+
elif isinstance(data, dict):
35+
return {k: cls.convert_dtypes(v) for k, v in data.items()}
36+
elif isinstance(data, list):
37+
return [cls.convert_dtypes(item) for item in data]
38+
return data
39+
40+
def to_dict(self) -> dict:
41+
"""Convert the configuration to a dictionary"""
42+
return self.model_dump()
43+
44+
def to_json(self) -> str:
45+
"""Convert the configuration to a JSON string."""
46+
return self.model_dump_json()
47+
48+
@classmethod
49+
def from_dict(cls, data: Dict[str, Any]) -> "AOBaseConfig":
50+
"""Create a configuration from a dictionary."""
51+
return cls.model_validate(data)
52+
53+
@classmethod
54+
def from_json(cls, json_str: str) -> "AOBaseConfig":
55+
"""Create a configuration from a JSON string."""
56+
return cls.model_validate_json(json_str)

torchao/quantization/quant_api.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,6 @@ def _int8_symm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
600600
)
601601

602602

603-
@dataclass
604603
class Int8DynamicActivationInt4WeightConfig(AOBaseConfig):
605604
"""Configuration for applying int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear
606605
This is used to produce a model for executorch backend, but currently executorch did not
@@ -681,7 +680,6 @@ def _int8_dynamic_activation_int4_weight_transform(
681680
return module
682681

683682

684-
@dataclass
685683
class Int4DynamicActivationInt4WeightConfig(AOBaseConfig):
686684
"""Applies int4 dynamic per token symmetric activation quantization and int4 per row weight symmetric quantization to linear
687685
@@ -787,7 +785,6 @@ def _gemlite_uintx_weight_only_transform(
787785
return module
788786

789787

790-
@dataclass
791788
class Int4WeightOnlyConfig(AOBaseConfig):
792789
"""
793790
Configuration for applying uint4 weight-only asymmetric per-group quantization to linear layers, using
@@ -893,7 +890,6 @@ def _int4_weight_only_transform(
893890
return module
894891

895892

896-
@dataclass
897893
class Int8WeightOnlyConfig(AOBaseConfig):
898894
"""
899895
Configuration for applying int8 weight-only symmetric per-channel quantization to linear layers.
@@ -1007,7 +1003,6 @@ def _int4_symm_per_token_quant_cutlass(x: torch.Tensor) -> torch.Tensor:
10071003
)
10081004

10091005

1010-
@dataclass
10111006
class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
10121007
"""
10131008
Configuration for applying int8 dynamic symmetric per-token activation and int8 per-channel weight
@@ -1092,7 +1087,6 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
10921087
return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
10931088

10941089

1095-
@dataclass
10961090
class Float8WeightOnlyConfig(AOBaseConfig):
10971091
"""
10981092
Configuration for applying float8 weight-only symmetric per-channel quantization to linear layers.
@@ -1408,7 +1402,6 @@ def _float8_static_activation_float8_weight_transform(
14081402
return module
14091403

14101404

1411-
@dataclass
14121405
class UIntXWeightOnlyConfig(AOBaseConfig):
14131406
"""
14141407
Configuration for applying uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where

0 commit comments

Comments
 (0)