Skip to content

Commit 5527c21

Browse files
metascroyfacebook-github-bot
authored andcommitted
Add quant api + python test for shared embedding (#1937)
Summary: Adds shared embedding to quantizer and adds unit test for it in python Reviewed By: digantdesai Differential Revision: D71234378
1 parent ea37ff7 commit 5527c21

File tree

4 files changed

+526
-117
lines changed

4 files changed

+526
-117
lines changed

.github/workflows/torchao_experimental_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
# Install executorch first because it installs its own version
3737
# of torch and torchao, which we do not want to use
3838
pip install executorch
39-
pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu" --force-reinstall
39+
pip install torch==2.7.0.dev20250311 --index-url "https://download.pytorch.org/whl/nightly/cpu" --force-reinstall
4040
pip install numpy
4141
pip install pytest
4242
pip install parameterized

torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,32 @@ class Target(Enum):
3535

3636
# AUTO target will automatically select a packing format
3737
# based on the available hardware.
38-
# TODO: in future, add the ability to specify specific
39-
# hardware targets
4038
AUTO = auto()
39+
UNIVERSAL = auto()
40+
KLEIDIAI = auto()
4141

4242
# ATEN target will use the ATen operator
4343
ATEN = auto()
4444

4545

46+
_TARGET_AND_STR = [
47+
(Target.AUTO, "auto"),
48+
(Target.ATEN, "aten"),
49+
(Target.UNIVERSAL, "universal"),
50+
(Target.KLEIDIAI, "kleidiai"),
51+
]
52+
53+
54+
def target_to_str(target: Target) -> str:
55+
target_to_str = {t: s for t, s in _TARGET_AND_STR}
56+
return target_to_str[target]
57+
58+
4659
def target_from_str(target: str) -> Target:
47-
if target.lower() == "auto":
48-
return Target.AUTO
49-
elif target.lower() == "aten":
50-
return Target.ATEN
51-
else:
52-
raise ValueError(f"Invalid target: {target}")
60+
str_to_target = {s: t for t, s in _TARGET_AND_STR}
61+
if target.lower() in str_to_target:
62+
return str_to_target[target.lower()]
63+
raise ValueError(f"Invalid target: {target}")
5364

5465

5566
class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout):
@@ -146,10 +157,9 @@ def from_plain(
146157
):
147158
assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout)
148159
assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
149-
assert layout.target in {
150-
Target.AUTO,
151-
Target.ATEN,
152-
}, f"Unexpected target: {layout.target}"
160+
assert layout.target in [
161+
t for t, _ in _TARGET_AND_STR
162+
], f"Unexpected target: {layout.target}"
153163

154164
n, k = int_data.shape
155165
if layout.target == Target.ATEN:
@@ -174,7 +184,7 @@ def from_plain(
174184
zero_point.reshape(-1).to(torch.int8) if layout.has_weight_zeros else None,
175185
layout.group_size,
176186
bias if layout.has_bias else None,
177-
None, # target, if not passed a packing format will be chosen on C++ side
187+
target_to_str(layout.target) if layout.target != Target.AUTO else None,
178188
]
179189

180190
packed_weight = getattr(
@@ -223,7 +233,7 @@ def _linear_check(input_tensor, weight_tensor, bias):
223233

224234

225235
def _linear_impl(input_tensor, weight_tensor, bias):
226-
def _impl_2d_auto(input_tensor, weight_tensor):
236+
def _impl_2d_non_aten(input_tensor, weight_tensor):
227237
assert input_tensor.dim() == 2
228238
assert weight_tensor.dim() == 2
229239

@@ -272,8 +282,8 @@ def _impl_2d_aten(input_tensor, weight_tensor):
272282
if target == Target.ATEN:
273283
assert TORCH_VERSION_AT_LEAST_2_6 == 1, "Target.ATEN requires torch >= 2.6.0"
274284
_impl_2d = _impl_2d_aten
275-
elif target == Target.AUTO:
276-
_impl_2d = _impl_2d_auto
285+
else:
286+
_impl_2d = _impl_2d_non_aten
277287

278288
if input_tensor.dim() == 2:
279289
res = _impl_2d(input_tensor, weight_tensor)

0 commit comments

Comments
 (0)