Skip to content

Commit 9f00c61

Browse files
a-r-r-o-wsayakpaulstevhliu
authored
[core] TorchAO Quantizer (#10009)
* torchao quantizer --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Steven Liu <[email protected]>
1 parent aafed3f commit 9f00c61

File tree

16 files changed

+1374
-16
lines changed

16 files changed

+1374
-16
lines changed

docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@
157157
title: Getting Started
158158
- local: quantization/bitsandbytes
159159
title: bitsandbytes
160+
- local: quantization/torchao
161+
title: torchao
160162
title: Quantization Methods
161163
- sections:
162164
- local: optimization/fp16

docs/source/en/api/quantization.md

+4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
2828

2929
[[autodoc]] BitsAndBytesConfig
3030

31+
## TorchAoConfig
32+
33+
[[autodoc]] TorchAoConfig
34+
3135
## DiffusersQuantizer
3236

3337
[[autodoc]] quantizers.base.DiffusersQuantizer

docs/source/en/quantization/overview.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ If you are new to the quantization field, we recommend you to check out these be
3232

3333
## When to use what?
3434

35-
This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
35+
Diffusers supports [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index) and [torchao](https://github.com/pytorch/ao). Refer to this [table](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) to help you determine which quantization backend to use.
+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# torchao
13+
14+
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more.
15+
16+
Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed.
17+
18+
```bash
19+
pip install -U torch torchao
20+
```
21+
22+
23+
Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
24+
25+
The example below only quantizes the weights to int8.
26+
27+
```python
28+
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
29+
30+
model_id = "black-forest-labs/Flux.1-Dev"
31+
dtype = torch.bfloat16
32+
33+
quantization_config = TorchAoConfig("int8wo")
34+
transformer = FluxTransformer2DModel.from_pretrained(
35+
model_id,
36+
subfolder="transformer",
37+
quantization_config=quantization_config,
38+
torch_dtype=dtype,
39+
)
40+
pipe = FluxPipeline.from_pretrained(
41+
model_id,
42+
transformer=transformer,
43+
torch_dtype=dtype,
44+
)
45+
pipe.to("cuda")
46+
47+
prompt = "A cat holding a sign that says hello world"
48+
image = pipe(prompt, num_inference_steps=28, guidance_scale=0.0).images[0]
49+
image.save("output.png")
50+
```
51+
52+
TorchAO is fully compatible with [torch.compile](./optimization/torch2.0#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code.
53+
54+
```python
55+
# In the above code, add the following after initializing the transformer
56+
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
57+
```
58+
59+
For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware.
60+
61+
torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future.
62+
63+
The `TorchAoConfig` class accepts three parameters:
64+
- `quant_type`: A string value mentioning one of the quantization types below.
65+
- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`.
66+
- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.
67+
68+
## Supported quantization types
69+
70+
torchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7.
71+
72+
Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation.
73+
74+
Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.
75+
76+
The quantization methods supported are as follows:
77+
78+
| **Category** | **Full Function Names** | **Shorthands** |
79+
|--------------|-------------------------|----------------|
80+
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
81+
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row` |
82+
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
83+
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |
84+
85+
Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.
86+
87+
Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
88+
89+
## Resources
90+
91+
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
92+
- [Diffusers-TorchAO examples](https://github.com/sayakpaul/diffusers-torchao)

src/diffusers/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"loaders": ["FromOriginalModelMixin"],
3232
"models": [],
3333
"pipelines": [],
34-
"quantizers.quantization_config": ["BitsAndBytesConfig"],
34+
"quantizers.quantization_config": ["BitsAndBytesConfig", "TorchAoConfig"],
3535
"schedulers": [],
3636
"utils": [
3737
"OptionalDependencyNotAvailable",
@@ -569,7 +569,7 @@
569569

570570
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
571571
from .configuration_utils import ConfigMixin
572-
from .quantizers.quantization_config import BitsAndBytesConfig
572+
from .quantizers.quantization_config import BitsAndBytesConfig, TorchAoConfig
573573

574574
try:
575575
if not is_onnx_available():

src/diffusers/models/model_loading_utils.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import torch
2626
from huggingface_hub.utils import EntryNotFoundError
2727

28-
from ..quantizers.quantization_config import QuantizationMethod
2928
from ..utils import (
3029
SAFE_WEIGHTS_INDEX_NAME,
3130
SAFETENSORS_FILE_EXTENSION,
@@ -182,7 +181,6 @@ def load_model_dict_into_meta(
182181
device = device or torch.device("cpu")
183182
dtype = dtype or torch.float32
184183
is_quantized = hf_quantizer is not None
185-
is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
186184

187185
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
188186
empty_state_dict = model.state_dict()
@@ -215,12 +213,12 @@ def load_model_dict_into_meta(
215213
# bnb params are flattened.
216214
if empty_state_dict[param_name].shape != param.shape:
217215
if (
218-
is_quant_method_bnb
216+
is_quantized
219217
and hf_quantizer.pre_quantized
220218
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
221219
):
222220
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
223-
elif not is_quant_method_bnb:
221+
else:
224222
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
225223
raise ValueError(
226224
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."

src/diffusers/models/modeling_utils.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -700,10 +700,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
700700
hf_quantizer = None
701701

702702
if hf_quantizer is not None:
703-
if device_map is not None:
703+
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
704+
if is_bnb_quantization_method and device_map is not None:
704705
raise NotImplementedError(
705-
"Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future."
706+
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
706707
)
708+
707709
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
708710
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
709711

@@ -858,13 +860,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
858860
if device_map is None and not is_sharded:
859861
# `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
860862
# It would error out during the `validate_environment()` call above in the absence of cuda.
861-
is_quant_method_bnb = (
862-
getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
863-
)
864863
if hf_quantizer is None:
865864
param_device = "cpu"
866865
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
867-
elif is_quant_method_bnb:
866+
else:
868867
param_device = torch.device(torch.cuda.current_device())
869868
state_dict = load_state_dict(model_file, variant=variant)
870869
model._convert_deprecated_attention_blocks(state_dict)

src/diffusers/quantizers/auto.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,20 @@
1919
from typing import Dict, Optional, Union
2020

2121
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
22-
from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod
22+
from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod, TorchAoConfig
23+
from .torchao import TorchAoHfQuantizer
2324

2425

2526
AUTO_QUANTIZER_MAPPING = {
2627
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
2728
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
29+
"torchao": TorchAoHfQuantizer,
2830
}
2931

3032
AUTO_QUANTIZATION_CONFIG_MAPPING = {
3133
"bitsandbytes_4bit": BitsAndBytesConfig,
3234
"bitsandbytes_8bit": BitsAndBytesConfig,
35+
"torchao": TorchAoConfig,
3336
}
3437

3538

0 commit comments

Comments
 (0)