Skip to content

Commit f008edc

Browse files
committed
Integrate Torchao into VLLM
Signed-off-by: drisspg <[email protected]>
1 parent 27df519 commit f008edc

File tree

7 files changed

+212
-1
lines changed

7 files changed

+212
-1
lines changed

docs/source/features/quantization/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@ int4
1717
int8
1818
fp8
1919
quantized_kvcache
20+
torchao
2021
:::
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
(torchao)=
2+
3+
# TorchAO
4+
5+
TorchAO is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like torch.compile, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks).
6+
7+
We recommend installing the latest torchao nightly with
8+
9+
```console
10+
pip install --pre torchao>=10.0.0 --index-url https://download.pytorch.org/whl/nightly/cu126 # or other cuda versions like cu128
11+
12+
You can quantize your own huggingface model with torchao, e.g. [transformers](https://huggingface.co/docs/transformers/main/en/quantization/torchao) and [diffusers](https://huggingface.co/docs/diffusers/en/quantization/torchao), and save the checkpoint to huggingface hub like [this](https://huggingface.co/jerryzh168/llama3-8b-int8wo) either with the following example code:
13+
14+
```python
15+
import torch
16+
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
17+
from torchao.quantization import Int8WeightOnlyConfig
18+
19+
model_name = "meta-llama/Meta-Llama-3-8B"
20+
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
21+
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config)
22+
tokenizer = AutoTokenizer.from_pretrained(model_name)
23+
input_text = "What are we having for dinner?"
24+
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
25+
26+
hub_repo = # YOUR HUB REPO ID
27+
tokenizer.push_to_hub(hub_repo)
28+
quantized_model.push_to_hub(hub_repo, safe_serialization=False)
29+
```
30+
31+
or with [spaces](https://huggingface.co/spaces/medmekk/TorchAO_Quantization).

tests/quantization/test_torchao.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Test model set-up and inference for TorchAO quantized HF models supported
3+
on the CPU/GPU backend.
4+
5+
Validating the configuration and printing results for manual checking.
6+
7+
Run `pytest tests/quantization/test_torchao_quant.py`.
8+
"""
9+
import pytest
10+
11+
from vllm.config import CompilationLevel
12+
13+
DTYPE = ["bfloat16"]
14+
15+
16+
def test_pre_quantized_model(vllm_runner):
17+
with vllm_runner("drisspg/float8_dynamic_act_float8_weight-opt-125m",
18+
quantization="torchao",
19+
dtype="bfloat16",
20+
compilation_config=CompilationLevel.PIECEWISE) as llm:
21+
output = llm.generate_greedy(["The capital of France is"],
22+
max_tokens=32)
23+
assert output
24+
print(output)
25+
26+
27+
if __name__ == "__main__":
28+
pytest.main([__file__])

vllm/engine/arg_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ class EngineArgs:
139139
hf_overrides: Optional[HfOverrides] = None
140140
tokenizer_revision: Optional[str] = None
141141
quantization: Optional[str] = None
142+
142143
enforce_eager: Optional[bool] = None
143144
max_seq_len_to_capture: int = 8192
144145
disable_custom_all_reduce: bool = False

vllm/model_executor/layers/quantization/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
"neuron_quant",
3232
"ipex",
3333
"quark",
34-
"moe_wna16"
34+
"moe_wna16",
35+
"torchao",
3536
]
3637

3738
# The customized quantization methods which will be added to this dict.
@@ -103,6 +104,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
103104
from .neuron_quant import NeuronQuantConfig
104105
from .ptpc_fp8 import PTPCFp8Config
105106
from .qqq import QQQConfig
107+
from .torchao import TorchAOConfig
106108
from .tpu_int8 import Int8TpuConfig
107109

108110
method_to_config: Dict[str, Type[QuantizationConfig]] = {
@@ -132,6 +134,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
132134
"ipex": IPEXConfig,
133135
"quark": QuarkConfig,
134136
"moe_wna16": MoeWNA16Config,
137+
"torchao": TorchAOConfig,
135138
}
136139
# Update the `method_to_config` with customized quantization methods.
137140
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import importlib.metadata
3+
from typing import Any, Dict, List, Optional
4+
5+
import torch
6+
import torch.nn.functional as F
7+
from packaging import version
8+
from torch.nn.parameter import Parameter
9+
from torchao.core.config import AOBaseConfig
10+
11+
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
12+
from vllm.model_executor.layers.quantization.base_config import (
13+
QuantizationConfig)
14+
from vllm.model_executor.layers.quantization.torchao_utils import (
15+
torchao_quantize_param_data)
16+
from vllm.model_executor.utils import set_weight_attrs
17+
18+
19+
class TorchAOConfig(QuantizationConfig):
20+
"""Config class for torchao."""
21+
22+
torchao_config: AOBaseConfig
23+
24+
def __init__(self, torchao_config: AOBaseConfig) -> None:
25+
self.torchao_config = torchao_config
26+
27+
# Determine based on torchao version
28+
torchao_version = version.parse(importlib.metadata.version("torchao"))
29+
assert torchao_version >= version.parse(
30+
"0.9.0"
31+
), "torchao version must be >= 0.9.0 to load quantized models"
32+
33+
def __repr__(self) -> str:
34+
return f"TorchAOConfig({self.torchao_config})"
35+
36+
def get_name(self) -> str:
37+
return "torchao"
38+
39+
def get_supported_act_dtypes(self) -> List[torch.dtype]:
40+
return [torch.float32, torch.float16, torch.bfloat16]
41+
42+
@classmethod
43+
def get_min_capability(cls) -> int:
44+
return 75
45+
46+
@staticmethod
47+
def get_config_filenames() -> List[str]:
48+
return ["config.json"]
49+
50+
@classmethod
51+
def from_config(cls, config: Dict[str, Any]) -> "TorchAOConfig":
52+
"""Create the quant config from an hf model config"""
53+
from torchao.core.config import config_from_dict
54+
hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
55+
assert hf_config is not None, "quant_type must be specified"
56+
assert (len(hf_config) == 1 and "default" in hf_config
57+
), "Expected only one key 'default' in quant_type dictionary"
58+
quant_type = hf_config["default"]
59+
ao_config = config_from_dict(quant_type)
60+
return cls(ao_config)
61+
62+
def get_quant_method(self, layer: torch.nn.Module,
63+
prefix: str) -> Optional["TorchAOLinearMethod"]:
64+
if isinstance(layer, LinearBase):
65+
return TorchAOLinearMethod(self)
66+
return None
67+
68+
def get_scaled_act_names(self) -> List[str]:
69+
return []
70+
71+
72+
class TorchAOLinearMethod(LinearMethodBase):
73+
"""Linear method for torchao.
74+
75+
Args:
76+
torchao_config: The torchao quantization config, a string
77+
that encodes the type of quantization and all relevant arguments.
78+
"""
79+
80+
def __init__(self, quant_config: TorchAOConfig):
81+
self.quant_config = quant_config
82+
83+
def create_weights(
84+
self,
85+
layer: torch.nn.Module,
86+
input_size_per_partition: int,
87+
output_partition_sizes: List[int],
88+
input_size: int,
89+
output_size: int,
90+
params_dtype: torch.dtype,
91+
**extra_weight_attrs,
92+
):
93+
weight = Parameter(
94+
torch.empty(
95+
sum(output_partition_sizes),
96+
input_size_per_partition,
97+
dtype=params_dtype,
98+
),
99+
requires_grad=False,
100+
)
101+
weight = torchao_quantize_param_data(weight,
102+
self.quant_config.torchao_config)
103+
104+
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
105+
106+
layer.register_parameter("weight", weight)
107+
set_weight_attrs(weight, extra_weight_attrs)
108+
109+
def apply(
110+
self,
111+
layer: torch.nn.Module,
112+
x: torch.Tensor,
113+
bias: Optional[torch.Tensor] = None,
114+
) -> torch.Tensor:
115+
return F.linear(x, layer.weight, bias)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Common utilities for torchao.
4+
"""
5+
6+
from typing import Any
7+
8+
import torch
9+
10+
11+
def torchao_quantize_param_data(param: torch.Tensor,
12+
torchao_config: Any) -> torch.nn.Parameter:
13+
"""Quantize a Tensor with torchao quantization specified by torchao_config
14+
15+
Args:
16+
`param`: weight parameter of the linear module
17+
`torchao_config`: type of quantization and their arguments we want to
18+
use to quantize the Tensor
19+
"""
20+
try:
21+
from torchao.core.config import AOBaseConfig
22+
from torchao.quantization import quantize_
23+
except ImportError as err:
24+
raise ImportError(
25+
"torchao required for this quantization method, "
26+
"Please install torchao with `pip install torchao>=0.10.0`."
27+
) from err
28+
assert isinstance(torchao_config, AOBaseConfig)
29+
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
30+
dummy_linear.weight = param
31+
quantize_(dummy_linear, torchao_config)
32+
return dummy_linear.weight

0 commit comments

Comments
 (0)