Skip to content

Commit 159145b

Browse files
drisspgnishith-fujitsu
authored andcommitted
Signed-off-by: drisspg <[email protected]>
1 parent 647bf2a commit 159145b

File tree

5 files changed

+191
-1
lines changed

5 files changed

+191
-1
lines changed

docs/source/features/quantization/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ int8
1818
fp8
1919
quark
2020
quantized_kvcache
21+
torchao
2122
:::
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# TorchAO
2+
3+
TorchAO is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like torch.compile, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks).
4+
5+
We recommend installing the latest torchao nightly with
6+
7+
```console
8+
# Install the latest TorchAO nightly build
9+
# Choose the CUDA version that matches your system (cu126, cu128, etc.)
10+
pip install --pre torchao>=10.0.0 --index-url https://download.pytorch.org/whl/nightly/cu126
11+
```
12+
13+
## Quantizing HuggingFace Models
14+
You can quantize your own huggingface model with torchao, e.g. [transformers](https://huggingface.co/docs/transformers/main/en/quantization/torchao) and [diffusers](https://huggingface.co/docs/diffusers/en/quantization/torchao), and save the checkpoint to huggingface hub like [this](https://huggingface.co/jerryzh168/llama3-8b-int8wo) with the following example code:
15+
16+
```Python
17+
import torch
18+
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
19+
from torchao.quantization import Int8WeightOnlyConfig
20+
21+
model_name = "meta-llama/Meta-Llama-3-8B"
22+
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
23+
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config)
24+
tokenizer = AutoTokenizer.from_pretrained(model_name)
25+
input_text = "What are we having for dinner?"
26+
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
27+
28+
hub_repo = # YOUR HUB REPO ID
29+
tokenizer.push_to_hub(hub_repo)
30+
quantized_model.push_to_hub(hub_repo, safe_serialization=False)
31+
```
32+
33+
Alternatively, you can use the TorchAO Quantization space for quantizing models with a simple UI.
34+
See: https://huggingface.co/spaces/medmekk/TorchAO_Quantization

tests/quantization/test_torchao.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import importlib.metadata
3+
import importlib.util
4+
5+
import pytest
6+
7+
DTYPE = ["bfloat16"]
8+
9+
TORCHAO_AVAILABLE = importlib.util.find_spec("torchao") is not None
10+
11+
12+
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
13+
def test_pre_quantized_model(vllm_runner):
14+
with vllm_runner("drisspg/float8_dynamic_act_float8_weight-opt-125m",
15+
quantization="torchao",
16+
dtype="bfloat16",
17+
enforce_eager=True) as llm:
18+
output = llm.generate_greedy(["The capital of France is"],
19+
max_tokens=32)
20+
assert output
21+
print(output)
22+
23+
24+
if __name__ == "__main__":
25+
pytest.main([__file__])

vllm/model_executor/layers/quantization/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
"neuron_quant",
3232
"ipex",
3333
"quark",
34-
"moe_wna16"
34+
"moe_wna16",
35+
"torchao",
3536
]
3637

3738
# The customized quantization methods which will be added to this dict.
@@ -103,6 +104,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
103104
from .neuron_quant import NeuronQuantConfig
104105
from .ptpc_fp8 import PTPCFp8Config
105106
from .qqq import QQQConfig
107+
from .torchao import TorchAOConfig
106108
from .tpu_int8 import Int8TpuConfig
107109

108110
method_to_config: Dict[str, Type[QuantizationConfig]] = {
@@ -132,6 +134,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
132134
"ipex": IPEXConfig,
133135
"quark": QuarkConfig,
134136
"moe_wna16": MoeWNA16Config,
137+
"torchao": TorchAOConfig,
135138
}
136139
# Update the `method_to_config` with customized quantization methods.
137140
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from typing import Any, Dict, List, Optional
3+
4+
import torch
5+
import torch.nn.functional as F
6+
from torch.nn.parameter import Parameter
7+
8+
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
9+
from vllm.model_executor.layers.quantization.base_config import (
10+
QuantizationConfig)
11+
from vllm.model_executor.utils import set_weight_attrs
12+
13+
14+
class TorchAOConfig(QuantizationConfig):
15+
"""Config class for torchao."""
16+
17+
def __init__(self, torchao_config) -> None:
18+
self.torchao_config = torchao_config
19+
20+
def __repr__(self) -> str:
21+
return f"TorchAOConfig({self.torchao_config})"
22+
23+
def get_name(self) -> str:
24+
return "torchao"
25+
26+
def get_supported_act_dtypes(self) -> List[torch.dtype]:
27+
return [torch.float32, torch.float16, torch.bfloat16]
28+
29+
@classmethod
30+
def get_min_capability(cls) -> int:
31+
return 75
32+
33+
@staticmethod
34+
def get_config_filenames() -> List[str]:
35+
return ["config.json"]
36+
37+
@classmethod
38+
def from_config(cls, config: Dict[str, Any]) -> "TorchAOConfig":
39+
"""Create the quant config from an hf model config"""
40+
try:
41+
from torchao.core.config import config_from_dict
42+
except ImportError as err:
43+
raise ImportError(
44+
"Please install torchao>=0.10.0 via "
45+
"`pip install torchao>=0.10.0` to use torchao quantization."
46+
) from err
47+
48+
hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
49+
assert hf_config is not None, "quant_type must be specified"
50+
assert (len(hf_config) == 1 and "default" in hf_config
51+
), "Expected only one key 'default' in quant_type dictionary"
52+
quant_type = hf_config["default"]
53+
ao_config = config_from_dict(quant_type)
54+
return cls(ao_config)
55+
56+
def get_quant_method(self, layer: torch.nn.Module,
57+
prefix: str) -> Optional["TorchAOLinearMethod"]:
58+
if isinstance(layer, LinearBase):
59+
return TorchAOLinearMethod(self)
60+
return None
61+
62+
def get_scaled_act_names(self) -> List[str]:
63+
return []
64+
65+
66+
def torchao_quantize_param_data(param: torch.Tensor,
67+
torchao_config: Any) -> torch.nn.Parameter:
68+
"""Quantize a Tensor with torchao quantization specified by torchao_config
69+
70+
Args:
71+
`param`: weight parameter of the linear module
72+
`torchao_config`: type of quantization and their arguments we want to
73+
use to quantize the Tensor
74+
"""
75+
from torchao.core.config import AOBaseConfig
76+
from torchao.quantization import quantize_
77+
assert isinstance(torchao_config, AOBaseConfig)
78+
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
79+
dummy_linear.weight = param
80+
quantize_(dummy_linear, torchao_config)
81+
return dummy_linear.weight
82+
83+
84+
class TorchAOLinearMethod(LinearMethodBase):
85+
"""Linear method for torchao.
86+
87+
Args:
88+
torchao_config: The torchao quantization config, a string
89+
that encodes the type of quantization and all relevant arguments.
90+
"""
91+
92+
def __init__(self, quant_config: TorchAOConfig):
93+
self.quant_config = quant_config
94+
95+
def create_weights(
96+
self,
97+
layer: torch.nn.Module,
98+
input_size_per_partition: int,
99+
output_partition_sizes: List[int],
100+
input_size: int,
101+
output_size: int,
102+
params_dtype: torch.dtype,
103+
**extra_weight_attrs,
104+
):
105+
weight = Parameter(
106+
torch.empty(
107+
sum(output_partition_sizes),
108+
input_size_per_partition,
109+
dtype=params_dtype,
110+
),
111+
requires_grad=False,
112+
)
113+
weight = torchao_quantize_param_data(weight,
114+
self.quant_config.torchao_config)
115+
116+
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
117+
118+
layer.register_parameter("weight", weight)
119+
set_weight_attrs(weight, extra_weight_attrs)
120+
121+
def apply(
122+
self,
123+
layer: torch.nn.Module,
124+
x: torch.Tensor,
125+
bias: Optional[torch.Tensor] = None,
126+
) -> torch.Tensor:
127+
return F.linear(x, layer.weight, bias)

0 commit comments

Comments
 (0)