|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +from typing import Any, Dict, List, Optional |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn.functional as F |
| 6 | +from torch.nn.parameter import Parameter |
| 7 | + |
| 8 | +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase |
| 9 | +from vllm.model_executor.layers.quantization.base_config import ( |
| 10 | + QuantizationConfig) |
| 11 | +from vllm.model_executor.utils import set_weight_attrs |
| 12 | + |
| 13 | + |
| 14 | +class TorchAOConfig(QuantizationConfig): |
| 15 | + """Config class for torchao.""" |
| 16 | + |
| 17 | + def __init__(self, torchao_config) -> None: |
| 18 | + self.torchao_config = torchao_config |
| 19 | + |
| 20 | + def __repr__(self) -> str: |
| 21 | + return f"TorchAOConfig({self.torchao_config})" |
| 22 | + |
| 23 | + def get_name(self) -> str: |
| 24 | + return "torchao" |
| 25 | + |
| 26 | + def get_supported_act_dtypes(self) -> List[torch.dtype]: |
| 27 | + return [torch.float32, torch.float16, torch.bfloat16] |
| 28 | + |
| 29 | + @classmethod |
| 30 | + def get_min_capability(cls) -> int: |
| 31 | + return 75 |
| 32 | + |
| 33 | + @staticmethod |
| 34 | + def get_config_filenames() -> List[str]: |
| 35 | + return ["config.json"] |
| 36 | + |
| 37 | + @classmethod |
| 38 | + def from_config(cls, config: Dict[str, Any]) -> "TorchAOConfig": |
| 39 | + """Create the quant config from an hf model config""" |
| 40 | + try: |
| 41 | + from torchao.core.config import config_from_dict |
| 42 | + except ImportError as err: |
| 43 | + raise ImportError( |
| 44 | + "Please install torchao>=0.10.0 via " |
| 45 | + "`pip install torchao>=0.10.0` to use torchao quantization." |
| 46 | + ) from err |
| 47 | + |
| 48 | + hf_config = cls.get_from_keys_or(config, ["quant_type"], None) |
| 49 | + assert hf_config is not None, "quant_type must be specified" |
| 50 | + assert (len(hf_config) == 1 and "default" in hf_config |
| 51 | + ), "Expected only one key 'default' in quant_type dictionary" |
| 52 | + quant_type = hf_config["default"] |
| 53 | + ao_config = config_from_dict(quant_type) |
| 54 | + return cls(ao_config) |
| 55 | + |
| 56 | + def get_quant_method(self, layer: torch.nn.Module, |
| 57 | + prefix: str) -> Optional["TorchAOLinearMethod"]: |
| 58 | + if isinstance(layer, LinearBase): |
| 59 | + return TorchAOLinearMethod(self) |
| 60 | + return None |
| 61 | + |
| 62 | + def get_scaled_act_names(self) -> List[str]: |
| 63 | + return [] |
| 64 | + |
| 65 | + |
| 66 | +def torchao_quantize_param_data(param: torch.Tensor, |
| 67 | + torchao_config: Any) -> torch.nn.Parameter: |
| 68 | + """Quantize a Tensor with torchao quantization specified by torchao_config |
| 69 | +
|
| 70 | + Args: |
| 71 | + `param`: weight parameter of the linear module |
| 72 | + `torchao_config`: type of quantization and their arguments we want to |
| 73 | + use to quantize the Tensor |
| 74 | + """ |
| 75 | + from torchao.core.config import AOBaseConfig |
| 76 | + from torchao.quantization import quantize_ |
| 77 | + assert isinstance(torchao_config, AOBaseConfig) |
| 78 | + dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False) |
| 79 | + dummy_linear.weight = param |
| 80 | + quantize_(dummy_linear, torchao_config) |
| 81 | + return dummy_linear.weight |
| 82 | + |
| 83 | + |
| 84 | +class TorchAOLinearMethod(LinearMethodBase): |
| 85 | + """Linear method for torchao. |
| 86 | +
|
| 87 | + Args: |
| 88 | + torchao_config: The torchao quantization config, a string |
| 89 | + that encodes the type of quantization and all relevant arguments. |
| 90 | + """ |
| 91 | + |
| 92 | + def __init__(self, quant_config: TorchAOConfig): |
| 93 | + self.quant_config = quant_config |
| 94 | + |
| 95 | + def create_weights( |
| 96 | + self, |
| 97 | + layer: torch.nn.Module, |
| 98 | + input_size_per_partition: int, |
| 99 | + output_partition_sizes: List[int], |
| 100 | + input_size: int, |
| 101 | + output_size: int, |
| 102 | + params_dtype: torch.dtype, |
| 103 | + **extra_weight_attrs, |
| 104 | + ): |
| 105 | + weight = Parameter( |
| 106 | + torch.empty( |
| 107 | + sum(output_partition_sizes), |
| 108 | + input_size_per_partition, |
| 109 | + dtype=params_dtype, |
| 110 | + ), |
| 111 | + requires_grad=False, |
| 112 | + ) |
| 113 | + weight = torchao_quantize_param_data(weight, |
| 114 | + self.quant_config.torchao_config) |
| 115 | + |
| 116 | + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) |
| 117 | + |
| 118 | + layer.register_parameter("weight", weight) |
| 119 | + set_weight_attrs(weight, extra_weight_attrs) |
| 120 | + |
| 121 | + def apply( |
| 122 | + self, |
| 123 | + layer: torch.nn.Module, |
| 124 | + x: torch.Tensor, |
| 125 | + bias: Optional[torch.Tensor] = None, |
| 126 | + ) -> torch.Tensor: |
| 127 | + return F.linear(x, layer.weight, bias) |
0 commit comments