Skip to content

Commit e8195fb

Browse files
committed
Config serde
stack-info: PR: #1806, branch: drisspg/stack/41
1 parent 3210819 commit e8195fb

File tree

3 files changed

+242
-27
lines changed

3 files changed

+242
-27
lines changed
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import json
2+
3+
import pytest
4+
import torch
5+
6+
from torchao.quantization.quant_api import (
7+
Float8WeightOnlyConfig,
8+
FPXWeightOnlyConfig,
9+
GemliteUIntXWeightOnlyConfig,
10+
Int4WeightOnlyConfig,
11+
Int8WeightOnlyConfig,
12+
UIntXWeightOnlyConfig,
13+
)
14+
15+
# Define test configurations as fixtures
16+
configs = [
17+
# Float8DynamicActivationFloat8WeightConfig(),
18+
Float8WeightOnlyConfig(
19+
weight_dtype=torch.float8_e4m3fn,
20+
),
21+
UIntXWeightOnlyConfig(dtype=torch.uint1),
22+
# Int4DynamicActivationInt4WeightConfig(),
23+
Int4WeightOnlyConfig(
24+
group_size=32,
25+
),
26+
# Int8DynamicActivationInt4WeightConfig(
27+
# group_size=64,
28+
# ),
29+
# Int8DynamicActivationInt8WeightConfig(),
30+
Int8WeightOnlyConfig(
31+
group_size=128,
32+
),
33+
UIntXWeightOnlyConfig(
34+
dtype=torch.uint3,
35+
group_size=32,
36+
use_hqq=True,
37+
),
38+
GemliteUIntXWeightOnlyConfig(
39+
group_size=128, # Optional, has default of 64
40+
bit_width=8, # Optional, has default of 4
41+
packing_bitwidth=8, # Optional, has default of 32
42+
contiguous=True, # Optional, has default of None
43+
),
44+
FPXWeightOnlyConfig(ebits=4, mbits=8),
45+
# Float8StaticActivationFloat8WeightConfig(
46+
# activation_dtype=torch.float8_e4m3fn,
47+
# weight_dtype=torch.float8_e4m3fn,
48+
# ),
49+
]
50+
51+
52+
# Create ids for better test naming
53+
def get_config_ids(configs):
54+
return [config.__class__.__name__ for config in configs]
55+
56+
57+
# Parametrized tests
58+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
59+
def test_to_dict_serialization(config):
60+
"""Test that all configs can be serialized to a dictionary."""
61+
# Test to_dict method exists and returns a dict
62+
assert hasattr(
63+
config, "to_dict"
64+
), f"{config.__class__.__name__} missing to_dict method"
65+
result = config.to_dict()
66+
assert isinstance(result, dict)
67+
68+
# Check that all essential attributes are present in the dict
69+
for attr_name in config.__dict__:
70+
if not attr_name.startswith("_"): # Skip private attributes
71+
assert attr_name in result, f"{attr_name} missing in serialized dict"
72+
73+
74+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
75+
def test_to_json_serialization(config):
76+
"""Test that all configs can be serialized to JSON."""
77+
# Test to_json method exists and returns a string
78+
assert hasattr(
79+
config, "to_json"
80+
), f"{config.__class__.__name__} missing to_json method"
81+
json_str = config.to_json()
82+
assert isinstance(json_str, str)
83+
84+
# Verify it's valid JSON
85+
try:
86+
parsed = json.loads(json_str)
87+
assert isinstance(parsed, dict)
88+
except json.JSONDecodeError as e:
89+
pytest.fail(f"Invalid JSON for {config.__class__.__name__}: {e}")
90+
91+
92+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
93+
def test_from_dict_deserialization(config):
94+
"""Test that all configs can be deserialized from a dictionary."""
95+
# Get the class of the instance
96+
cls = config.__class__
97+
98+
# Serialize to dict
99+
data = config.to_dict()
100+
101+
# Test from_dict class method exists
102+
assert hasattr(cls, "from_dict"), f"{cls.__name__} missing from_dict class method"
103+
104+
# Deserialize back to instance
105+
deserialized = cls.from_dict(data)
106+
107+
# Check it's the right class
108+
assert isinstance(deserialized, cls)
109+
110+
# Compare key attributes
111+
for attr_name in config.__dict__:
112+
if not attr_name.startswith("_"): # Skip private attributes
113+
original_value = getattr(config, attr_name)
114+
deserialized_value = getattr(deserialized, attr_name)
115+
116+
# Special handling for torch dtypes
117+
if (
118+
hasattr(original_value, "__module__")
119+
and original_value.__module__ == "torch"
120+
):
121+
assert str(original_value) == str(
122+
deserialized_value
123+
), f"Attribute {attr_name} mismatch for {cls.__name__}"
124+
else:
125+
assert (
126+
original_value == deserialized_value
127+
), f"Attribute {attr_name} mismatch for {cls.__name__}"
128+
129+
130+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
131+
def test_from_json_deserialization(config):
132+
"""Test that all configs can be deserialized from JSON."""
133+
# Get the class of the instance
134+
cls = config.__class__
135+
136+
# Serialize to JSON
137+
json_str = config.to_json()
138+
139+
# Test from_json class method exists
140+
assert hasattr(cls, "from_json"), f"{cls.__name__} missing from_json class method"
141+
142+
# Deserialize back to instance
143+
deserialized = cls.from_json(json_str)
144+
145+
# Check it's the right class
146+
assert isinstance(deserialized, cls)
147+
148+
# Verify the instance is equivalent to the original
149+
# This assumes __eq__ is properly implemented
150+
assert (
151+
config == deserialized
152+
), f"Deserialized instance doesn't match original for {cls.__name__}"
153+
154+
155+
@pytest.mark.parametrize("config", configs, ids=get_config_ids)
156+
def test_round_trip_equivalence(config):
157+
"""Test complete serialization and deserialization round trip."""
158+
# JSON round trip
159+
json_str = config.to_json()
160+
deserialized_from_json = config.__class__.from_json(json_str)
161+
assert (
162+
config == deserialized_from_json
163+
), f"JSON round trip failed for {config.__class__.__name__}"
164+
165+
# Dict round trip
166+
data_dict = config.to_dict()
167+
deserialized_from_dict = config.__class__.from_dict(data_dict)
168+
assert (
169+
config == deserialized_from_dict
170+
), f"Dict round trip failed for {config.__class__.__name__}"

torchao/core/config.py

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,84 @@
1-
import abc
1+
from typing import Any, Dict
22

3+
import torch
4+
from pydantic import BaseModel, field_serializer, model_validator
35

4-
class AOBaseConfig(abc.ABC):
6+
7+
class AOBaseConfig(BaseModel):
58
"""
6-
If a workflow config inherits from this then `quantize_` knows
7-
how to a apply it to a model. For example::
9+
Base configuration class for TorchAO quantization workflows with native Pydantic handling for torch.dtype.
10+
11+
When a workflow configuration inherits from AOBaseConfig, the `quantize_` function can automatically
12+
apply the appropriate transformation to a model based on the configuration type.
813
9-
# user facing code
10-
class WorkflowFooConfig(AOBaseConfig): ...
11-
# configuration for workflow `Foo` is defined here
12-
bar = 'baz'
14+
Usage example:
15+
# 1. Define a configuration class for your workflow
16+
class WorkflowFooConfig(AOBaseConfig):
17+
# Configuration parameters for workflow 'Foo'
18+
bar: str = 'baz'
1319
14-
# non user facing code
20+
# 2. Register a handler for this configuration (internal implementation)
1521
@register_quantize_module_handler(WorkflowFooConfig)
1622
def _transform(
1723
mod: torch.nn.Module,
1824
config: WorkflowFooConfig,
1925
) -> torch.nn.Module:
20-
# the transform is implemented here, usually a tensor sublass
21-
# weight swap or a module swap
26+
# Implementation of the transformation logic
27+
# Typically performs tensor subclass weight swapping or module replacement
2228
...
2329
24-
# then, the user calls `quantize_` with a config, and `_transform` is called
25-
# under the hood by `quantize_.
30+
# 3. Apply the configuration to a model
31+
# The user simply calls `quantize_` with a model and config instance
32+
# The appropriate handler is automatically selected based on the config type
33+
model = ...
34+
quantized_model = quantize_(model, WorkflowFooConfig(bar='custom_value'))
2635
36+
Note on serialization, if you add a new AOBaseConfig and want to support serialization,
37+
please add a test in test/quantization/test_config_serialization.py
2738
"""
2839

29-
pass
40+
model_config = {
41+
"arbitrary_types_allowed": True,
42+
"validate_assignment": True,
43+
"extra": "forbid",
44+
"validate_default": True,
45+
"populate_by_name": True,
46+
}
47+
48+
@field_serializer("*")
49+
def serialize_torch_dtype(self, v, _info):
50+
if isinstance(v, torch.dtype):
51+
return str(v)
52+
return v
53+
54+
@model_validator(mode="before")
55+
@classmethod
56+
def convert_dtypes(cls, data: Any) -> Any:
57+
"""Simple converter for torch dtype strings"""
58+
if isinstance(data, str) and data.startswith("torch."):
59+
dtype_name = data.split("torch.")[1]
60+
if hasattr(torch, dtype_name):
61+
return getattr(torch, dtype_name)
62+
elif isinstance(data, dict):
63+
return {k: cls.convert_dtypes(v) for k, v in data.items()}
64+
elif isinstance(data, list):
65+
return [cls.convert_dtypes(item) for item in data]
66+
return data
67+
68+
def to_dict(self) -> dict:
69+
"""Convert the configuration to a dictionary"""
70+
return self.model_dump()
71+
72+
def to_json(self) -> str:
73+
"""Convert the configuration to a JSON string."""
74+
return self.model_dump_json()
75+
76+
@classmethod
77+
def from_dict(cls, data: Dict[str, Any]) -> "AOBaseConfig":
78+
"""Create a configuration from a dictionary."""
79+
return cls.model_validate(data)
80+
81+
@classmethod
82+
def from_json(cls, json_str: str) -> "AOBaseConfig":
83+
"""Create a configuration from a JSON string."""
84+
return cls.model_validate_json(json_str)

torchao/quantization/quant_api.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import logging
1919
import types
2020
import warnings
21-
from dataclasses import dataclass
2221
from typing import Any, Callable, Optional, Tuple, Union
2322

2423
import torch
@@ -600,7 +599,6 @@ def _int8_symm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
600599
)
601600

602601

603-
@dataclass
604602
class Int8DynamicActivationInt4WeightConfig(AOBaseConfig):
605603
"""Configuration for applying int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear
606604
This is used to produce a model for executorch backend, but currently executorch did not
@@ -681,7 +679,6 @@ def _int8_dynamic_activation_int4_weight_transform(
681679
return module
682680

683681

684-
@dataclass
685682
class Int4DynamicActivationInt4WeightConfig(AOBaseConfig):
686683
"""Applies int4 dynamic per token symmetric activation quantization and int4 per row weight symmetric quantization to linear
687684
@@ -738,7 +735,6 @@ def _int4_dynamic_activation_int4_weight_transform(
738735
return module
739736

740737

741-
@dataclass
742738
class GemliteUIntXWeightOnlyConfig(AOBaseConfig):
743739
"""
744740
applies weight only 4 or 8 bit integer quantization and utilizes the gemlite triton kernel and its associated weight packing format.
@@ -787,7 +783,6 @@ def _gemlite_uintx_weight_only_transform(
787783
return module
788784

789785

790-
@dataclass
791786
class Int4WeightOnlyConfig(AOBaseConfig):
792787
"""
793788
Configuration for applying uint4 weight-only asymmetric per-group quantization to linear layers, using
@@ -811,7 +806,9 @@ class Int4WeightOnlyConfig(AOBaseConfig):
811806
"""
812807

813808
group_size: int = 128
814-
layout: Optional[TensorCoreTiledLayout] = TensorCoreTiledLayout(inner_k_tiles=8)
809+
layout: Optional[Union[TensorCoreTiledLayout, Int4CPULayout]] = (
810+
TensorCoreTiledLayout(inner_k_tiles=8)
811+
)
815812
use_hqq: bool = False
816813
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE
817814

@@ -893,7 +890,6 @@ def _int4_weight_only_transform(
893890
return module
894891

895892

896-
@dataclass
897893
class Int8WeightOnlyConfig(AOBaseConfig):
898894
"""
899895
Configuration for applying int8 weight-only symmetric per-channel quantization to linear layers.
@@ -1007,7 +1003,6 @@ def _int4_symm_per_token_quant_cutlass(x: torch.Tensor) -> torch.Tensor:
10071003
)
10081004

10091005

1010-
@dataclass
10111006
class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
10121007
"""
10131008
Configuration for applying int8 dynamic symmetric per-token activation and int8 per-channel weight
@@ -1092,7 +1087,6 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
10921087
return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
10931088

10941089

1095-
@dataclass
10961090
class Float8WeightOnlyConfig(AOBaseConfig):
10971091
"""
10981092
Configuration for applying float8 weight-only symmetric per-channel quantization to linear layers.
@@ -1245,7 +1239,6 @@ def _fp8_mm_compat(weight: torch.Tensor) -> bool:
12451239
return is_compatible
12461240

12471241

1248-
@dataclass
12491242
class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
12501243
"""
12511244
Configuration for applying float8 dynamic symmetric quantization to both activations and weights of linear layers.
@@ -1327,7 +1320,6 @@ def _float8_dynamic_activation_float8_weight_transform(
13271320
return module
13281321

13291322

1330-
@dataclass
13311323
class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
13321324
"""
13331325
Configuration for applying float8 static symmetric quantization to
@@ -1408,7 +1400,6 @@ def _float8_static_activation_float8_weight_transform(
14081400
return module
14091401

14101402

1411-
@dataclass
14121403
class UIntXWeightOnlyConfig(AOBaseConfig):
14131404
"""
14141405
Configuration for applying uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
@@ -1499,7 +1490,6 @@ def _uintx_weight_only_transform(
14991490
return module
15001491

15011492

1502-
@dataclass
15031493
class FPXWeightOnlyConfig(AOBaseConfig):
15041494
"""Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits
15051495
e.g. fp6_e3_m2, fp6_e2_m3, ...

0 commit comments

Comments
 (0)