Skip to content

Move config out of experimental #1954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 25, 2025
Merged

Conversation

metascroy
Copy link
Contributor

Moves q_dq_layout and packed_linear_int8_dynamic_activation_intx_weight_layout.py out of experimental.

Copy link

pytorch-bot bot commented Mar 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1954

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures

As of commit c8d7871 with merge base f3ff2e5 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 25, 2025
@metascroy metascroy requested review from jerryzh168 and drisspg March 25, 2025 21:52
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@metascroy
Copy link
Contributor Author

maybe move https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#int8_dynamic_activation_intx_weight-quantization to the supported section as well

This PR doesn't move the config/quant api yet. It just moved the layouts. The config is still in torchao/experimental and will be moved in follow-up PR.

@metascroy metascroy merged commit 40a6867 into main Mar 25, 2025
16 of 19 checks passed
@nikhil-arm
Copy link
Contributor

Hello @metascroy , where is the documentation to use the new to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight api?
How are we supposed to use the aten kleidiai kernels after this refactoring?

@metascroy
Copy link
Contributor Author

metascroy commented Apr 8, 2025

Hello @metascroy , where is the documentation to use the new to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight api? How are we supposed to use the aten kleidiai kernels after this refactoring?

That API is not intended for use by most people. The aten kleidiai kernels can be used from the quantize_ config as they could before (https://github.com/pytorch/ao/blob/main/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py#L177-L186):

from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
from torchao.experimental.quant_api import int8_dynamic_activation_intx_weight
quantize_(
      model,
      int8_dynamic_activation_intx_weight(
          weight_dtype=torch.int4,
          granularity=PerGroup(32),
          has_weight_zeros=False,
          layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="aten"),
          round_weight_scale_to_bf16=True,
      ),
  )

But note that we're also moving int8_dynamic_activation_intx_weight out of experimental and changing its API slightly to align with torchao's QAT routines (#1968). After that, the call site will be:

from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import (
    Int8DynamicActivationIntxWeightConfig,
    MappingType,
    ZeroPointDomain,
    quantize_,
)

quantize_(
    model,
    Int8DynamicActivationIntxWeightConfig(
            weight_dtype=torch.int4,
            weight_granularity=PerGroup(32),
            # Using ASYMMETRIC also works
            weight_mapping_type=MappingType.SYMMETRIC,
            # aten target does not have zeros
            weight_zero_point_domain=ZeroPointDomain.NONE,
            # aten target wants bfloat16 for groupwise
            weight_scale_dtype=torch.bfloat16,
            layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="aten"),
    ),
)

Also note that users can specify target as "auto" (default), "kleidiai", or "universal". When KleidiAI is chosen, it dispatches to KleidiAI kernels in torchao (https://github.com/pytorch/ao/blob/main/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h#L187). The nice thing about "auto" is it will choose KleidiAI when supported (torch.int4, ZeroPointDomain.NONE), but fall back to neondot GEMV kernels when a quantization option that KleidiAI does not support is chosen.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants