Skip to content

Commit 629ebf2

Browse files
committed
Config serde
stack-info: PR: #1806, branch: drisspg/stack/41
1 parent 3210819 commit 629ebf2

File tree

6 files changed

+296
-60
lines changed

6 files changed

+296
-60
lines changed

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,9 @@ def get_extensions():
327327
version=version + version_suffix,
328328
packages=find_packages(),
329329
include_package_data=True,
330+
install_requires=[
331+
"pydantic>=2",
332+
],
330333
package_data={
331334
"torchao.kernel.configs": ["*.pkl"],
332335
},

test/dtypes/test_affine_quantized.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def get_quantization_functions(
6060
if device == "cuda":
6161
base_functions.append(
6262
int8_dynamic_activation_int4_weight(
63-
group_size=None,
63+
group_size=32,
6464
mapping_type=MappingType.SYMMETRIC,
6565
act_mapping_type=MappingType.SYMMETRIC,
6666
layout=CutlassInt4PackedLayout(),
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import json
2+
3+
import pytest
4+
import torch
5+
6+
from torchao.quantization.quant_api import (
7+
Float8DynamicActivationFloat8WeightConfig,
8+
Float8WeightOnlyConfig,
9+
FPXWeightOnlyConfig,
10+
GemliteUIntXWeightOnlyConfig,
11+
Int4DynamicActivationInt4WeightConfig,
12+
Int4WeightOnlyConfig,
13+
Int8DynamicActivationInt4WeightConfig,
14+
Int8DynamicActivationInt8WeightConfig,
15+
Int8WeightOnlyConfig,
16+
PerRow,
17+
UIntXWeightOnlyConfig,
18+
)
19+
20+
# Define test configurations as fixtures
21+
configs = [
22+
Float8DynamicActivationFloat8WeightConfig(),
23+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
24+
Float8WeightOnlyConfig(
25+
weight_dtype=torch.float8_e4m3fn,
26+
),
27+
UIntXWeightOnlyConfig(dtype=torch.uint1),
28+
Int4DynamicActivationInt4WeightConfig(),
29+
Int4WeightOnlyConfig(
30+
group_size=32,
31+
),
32+
Int8DynamicActivationInt4WeightConfig(
33+
group_size=64,
34+
),
35+
Int8DynamicActivationInt8WeightConfig(),
36+
# Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()),
37+
Int8WeightOnlyConfig(
38+
group_size=128,
39+
),
40+
UIntXWeightOnlyConfig(
41+
dtype=torch.uint3,
42+
group_size=32,
43+
use_hqq=True,
44+
),
45+
GemliteUIntXWeightOnlyConfig(
46+
group_size=128, # Optional, has default of 64
47+
bit_width=8, # Optional, has default of 4
48+
packing_bitwidth=8, # Optional, has default of 32
49+
contiguous=True, # Optional, has default of None
50+
),
51+
FPXWeightOnlyConfig(ebits=4, mbits=8),
52+
]
53+
54+
55+
# Create ids for better test naming
56+
def get_config_ids(configs):
57+
return [config.__class__.__name__ for config in configs]
58+
59+
60+
# Parametrized tests
61+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
62+
def test_to_dict_serialization(config):
63+
"""Test that all configs can be serialized to a dictionary."""
64+
# Test to_dict method exists and returns a dict
65+
assert hasattr(
66+
config, "to_dict"
67+
), f"{config.__class__.__name__} missing to_dict method"
68+
result = config.to_dict()
69+
assert isinstance(result, dict)
70+
71+
# Check that all essential attributes are present in the dict
72+
for attr_name in config.__dict__:
73+
if not attr_name.startswith("_"): # Skip private attributes
74+
assert attr_name in result, f"{attr_name} missing in serialized dict"
75+
76+
77+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
78+
def test_to_json_serialization(config):
79+
"""Test that all configs can be serialized to JSON."""
80+
# Test to_json method exists and returns a string
81+
assert hasattr(
82+
config, "to_json"
83+
), f"{config.__class__.__name__} missing to_json method"
84+
json_str = config.to_json()
85+
assert isinstance(json_str, str)
86+
87+
# Verify it's valid JSON
88+
try:
89+
parsed = json.loads(json_str)
90+
assert isinstance(parsed, dict)
91+
except json.JSONDecodeError as e:
92+
pytest.fail(f"Invalid JSON for {config.__class__.__name__}: {e}")
93+
94+
95+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
96+
def test_from_dict_deserialization(config):
97+
"""Test that all configs can be deserialized from a dictionary."""
98+
# Get the class of the instance
99+
cls = config.__class__
100+
101+
# Serialize to dict
102+
data = config.to_dict()
103+
104+
# Test from_dict class method exists
105+
assert hasattr(cls, "from_dict"), f"{cls.__name__} missing from_dict class method"
106+
107+
# Deserialize back to instance
108+
deserialized = cls.from_dict(data)
109+
110+
# Check it's the right class
111+
assert isinstance(deserialized, cls)
112+
113+
# Compare key attributes
114+
for attr_name in config.__dict__:
115+
if not attr_name.startswith("_"): # Skip private attributes
116+
original_value = getattr(config, attr_name)
117+
deserialized_value = getattr(deserialized, attr_name)
118+
119+
# Special handling for torch dtypes
120+
if (
121+
hasattr(original_value, "__module__")
122+
and original_value.__module__ == "torch"
123+
):
124+
assert str(original_value) == str(
125+
deserialized_value
126+
), f"Attribute {attr_name} mismatch for {cls.__name__}"
127+
else:
128+
assert (
129+
original_value == deserialized_value
130+
), f"Attribute {attr_name} mismatch for {cls.__name__}"
131+
132+
133+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
134+
def test_from_json_deserialization(config):
135+
"""Test that all configs can be deserialized from JSON."""
136+
# Get the class of the instance
137+
cls = config.__class__
138+
139+
# Serialize to JSON
140+
json_str = config.to_json()
141+
142+
# Test from_json class method exists
143+
assert hasattr(cls, "from_json"), f"{cls.__name__} missing from_json class method"
144+
145+
# Deserialize back to instance
146+
deserialized = cls.from_json(json_str)
147+
148+
# Check it's the right class
149+
assert isinstance(deserialized, cls)
150+
151+
# Verify the instance is equivalent to the original
152+
# This assumes __eq__ is properly implemented
153+
assert (
154+
config == deserialized
155+
), f"Deserialized instance doesn't match original for {cls.__name__}"
156+
157+
158+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
159+
def test_round_trip_equivalence(config):
160+
"""Test complete serialization and deserialization round trip."""
161+
# JSON round trip
162+
json_str = config.to_json()
163+
deserialized_from_json = config.__class__.from_json(json_str)
164+
assert (
165+
config == deserialized_from_json
166+
), f"JSON round trip failed for {config.__class__.__name__}"
167+
168+
# Dict round trip
169+
data_dict = config.to_dict()
170+
deserialized_from_dict = config.__class__.from_dict(data_dict)
171+
assert (
172+
config == deserialized_from_dict
173+
), f"Dict round trip failed for {config.__class__.__name__}"
174+
175+
176+
if __name__ == "__main__":
177+
pytest.main([__file__])

torchao/core/config.py

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,84 @@
1-
import abc
1+
from typing import Any, Dict
22

3+
import torch
4+
from pydantic import BaseModel, ConfigDict, field_serializer, model_validator
35

4-
class AOBaseConfig(abc.ABC):
6+
7+
class AOBaseConfig(BaseModel):
58
"""
6-
If a workflow config inherits from this then `quantize_` knows
7-
how to a apply it to a model. For example::
9+
Base configuration class for TorchAO quantization workflows with native Pydantic handling for torch.dtype.
10+
11+
When a workflow configuration inherits from AOBaseConfig, the `quantize_` function can automatically
12+
apply the appropriate transformation to a model based on the configuration type.
813
9-
# user facing code
10-
class WorkflowFooConfig(AOBaseConfig): ...
11-
# configuration for workflow `Foo` is defined here
12-
bar = 'baz'
14+
Usage example:
15+
# 1. Define a configuration class for your workflow
16+
class WorkflowFooConfig(AOBaseConfig):
17+
# Configuration parameters for workflow 'Foo'
18+
bar: str = 'baz'
1319
14-
# non user facing code
20+
# 2. Register a handler for this configuration (internal implementation)
1521
@register_quantize_module_handler(WorkflowFooConfig)
1622
def _transform(
1723
mod: torch.nn.Module,
1824
config: WorkflowFooConfig,
1925
) -> torch.nn.Module:
20-
# the transform is implemented here, usually a tensor sublass
21-
# weight swap or a module swap
26+
# Implementation of the transformation logic
27+
# Typically performs tensor subclass weight swapping or module replacement
2228
...
2329
24-
# then, the user calls `quantize_` with a config, and `_transform` is called
25-
# under the hood by `quantize_.
30+
# 3. Apply the configuration to a model
31+
# The user simply calls `quantize_` with a model and config instance
32+
# The appropriate handler is automatically selected based on the config type
33+
model = ...
34+
quantized_model = quantize_(model, WorkflowFooConfig(bar='custom_value'))
2635
36+
Note on serialization, if you add a new AOBaseConfig and want to support serialization,
37+
please add a test in test/quantization/test_config_serialization.py
2738
"""
2839

29-
pass
40+
model_config = ConfigDict(
41+
arbitrary_types_allowed=True,
42+
validate_assignment=True,
43+
extra="forbid",
44+
validate_default=True,
45+
populate_by_name=True,
46+
)
47+
48+
@field_serializer("*")
49+
def serialize_torch_dtype(self, v, _info):
50+
if isinstance(v, torch.dtype):
51+
return str(v)
52+
return v
53+
54+
@model_validator(mode="before")
55+
@classmethod
56+
def convert_dtypes(cls, data: Any) -> Any:
57+
"""Simple converter for torch dtype strings"""
58+
if isinstance(data, str) and data.startswith("torch."):
59+
dtype_name = data.split("torch.")[1]
60+
if hasattr(torch, dtype_name):
61+
return getattr(torch, dtype_name)
62+
elif isinstance(data, dict):
63+
return {k: cls.convert_dtypes(v) for k, v in data.items()}
64+
elif isinstance(data, list):
65+
return [cls.convert_dtypes(item) for item in data]
66+
return data
67+
68+
def to_dict(self) -> dict:
69+
"""Convert the configuration to a dictionary"""
70+
return self.model_dump()
71+
72+
def to_json(self) -> str:
73+
"""Convert the configuration to a JSON string."""
74+
return self.model_dump_json()
75+
76+
@classmethod
77+
def from_dict(cls, data: Dict[str, Any]) -> "AOBaseConfig":
78+
"""Create a configuration from a dictionary."""
79+
return cls.model_validate(data)
80+
81+
@classmethod
82+
def from_json(cls, json_str: str) -> "AOBaseConfig":
83+
"""Create a configuration from a JSON string."""
84+
return cls.model_validate_json(json_str)

0 commit comments

Comments
 (0)