Skip to content

Commit 8d32dc6

Browse files
[Kernel] Support Microsoft Runtime Kernel Lib for our Low Precision Computation - BitBLAS (#6036)
Signed-off-by: xinyuxiao <[email protected]> Co-authored-by: xinyuxiao <[email protected]>
1 parent c4ab9f3 commit 8d32dc6

File tree

15 files changed

+1864
-7
lines changed

15 files changed

+1864
-7
lines changed
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
5+
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
6+
MINIMUM_BITBLAS_VERSION)
7+
8+
try:
9+
import bitblas
10+
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
11+
raise ImportError("bitblas version is wrong. Please "
12+
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
13+
except ImportError as e:
14+
bitblas_import_exception = e
15+
raise ValueError("Trying to use the bitblas backend, but could not import"
16+
f"with the following error: {bitblas_import_exception}. "
17+
"Please install bitblas through the following command: "
18+
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
19+
) from bitblas_import_exception
20+
21+
from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target
22+
23+
from vllm.utils import FlexibleArgumentParser
24+
25+
parser = FlexibleArgumentParser(
26+
description="Benchmark BitBLAS int4 on a specific target.")
27+
28+
# Add arguments to the parser
29+
parser.add_argument(
30+
"--target",
31+
type=str,
32+
default=auto_detect_nvidia_target(),
33+
help="Specify the target device for benchmarking.",
34+
)
35+
parser.add_argument("--group_size",
36+
type=int,
37+
default=None,
38+
help="Group size for grouped quantization.")
39+
parser.add_argument(
40+
"--A_dtype",
41+
type=str,
42+
default="float16",
43+
choices=["float16", "float32", "float64", "int32", "int8"],
44+
help="Data type of activation A.",
45+
)
46+
parser.add_argument(
47+
"--W_dtype",
48+
type=str,
49+
default="int4",
50+
choices=[
51+
"float16",
52+
"float32",
53+
"float64",
54+
"int32",
55+
"int8",
56+
"int4",
57+
"int2",
58+
"int1",
59+
"nf4",
60+
"fp4_e2m1",
61+
],
62+
help="Data type of weight W.",
63+
)
64+
parser.add_argument(
65+
"--accum_dtype",
66+
type=str,
67+
default="float16",
68+
choices=["float16", "int32"],
69+
help="Data type for accumulation.",
70+
)
71+
parser.add_argument(
72+
"--out_dtype",
73+
type=str,
74+
default="float16",
75+
choices=["float16", "float32", "int32", "int8"],
76+
help="Data type for output.",
77+
)
78+
parser.add_argument(
79+
"--layout",
80+
type=str,
81+
default="nt",
82+
choices=["nt", "nn"],
83+
help="Matrix layout, 'nt' for non-transpose A and transpose W.",
84+
)
85+
parser.add_argument("--with_bias",
86+
action="store_true",
87+
help="Include bias in the benchmark.")
88+
parser.add_argument(
89+
"--with_scaling",
90+
action="store_true",
91+
help="Include scaling factor in the quantization.",
92+
)
93+
parser.add_argument("--with_zeros",
94+
action="store_true",
95+
help="Include zeros in the quantization.")
96+
parser.add_argument(
97+
"--zeros_mode",
98+
type=str,
99+
default=None,
100+
choices=["original", "rescale", "quantized"],
101+
help="Specify the mode for calculating zeros.",
102+
)
103+
104+
# Parse the arguments
105+
args = parser.parse_args()
106+
107+
# Assign arguments to variables
108+
target = args.target
109+
A_dtype = args.A_dtype
110+
W_dtype = args.W_dtype
111+
accum_dtype = args.accum_dtype
112+
out_dtype = args.out_dtype
113+
layout = args.layout
114+
with_bias = args.with_bias
115+
group_size = args.group_size
116+
with_scaling = args.with_scaling
117+
with_zeros = args.with_zeros
118+
zeros_mode = args.zeros_mode
119+
120+
# Define a list of shared arguments that repeat in every config
121+
shared_args = [
122+
A_dtype,
123+
W_dtype,
124+
out_dtype,
125+
accum_dtype,
126+
layout,
127+
with_bias,
128+
group_size,
129+
with_scaling,
130+
with_zeros,
131+
zeros_mode,
132+
]
133+
134+
# Define just the (M, K, N) shapes in a more compact list
135+
shapes = [
136+
# square test
137+
(1, 16384, 16384),
138+
# BLOOM-176B
139+
(1, 43008, 14336),
140+
(1, 14336, 14336),
141+
(1, 57344, 14336),
142+
(1, 14336, 57344),
143+
# OPT-65B
144+
(1, 9216, 9216),
145+
(1, 36864, 9216),
146+
(1, 9216, 36864),
147+
(1, 22016, 8192),
148+
# LLAMA-70B/65B
149+
(1, 8192, 22016),
150+
(1, 8192, 8192),
151+
(1, 28672, 8192),
152+
(1, 8192, 28672),
153+
# square test
154+
(16384, 16384, 16384),
155+
# BLOOM-176B
156+
(8192, 43008, 14336),
157+
(8192, 14336, 14336),
158+
(8192, 57344, 14336),
159+
(8192, 14336, 57344),
160+
# OPT-65B
161+
(8192, 9216, 9216),
162+
(8192, 36864, 9216),
163+
(8192, 9216, 36864),
164+
(8192, 22016, 8192),
165+
# LLAMA-70B/65B
166+
(8192, 8192, 22016),
167+
(8192, 8192, 8192),
168+
(8192, 28672, 8192),
169+
(8192, 8192, 28672),
170+
]
171+
172+
# Build test shapes with all the shared arguments
173+
test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args))
174+
for shape in shapes]
175+
176+
benchmark_sets = []
177+
benchmark_sets.extend(test_shapes)
178+
179+
benchmark_results = {}
180+
for config_class, operator, input_args in benchmark_sets:
181+
config = config_class(*input_args)
182+
matmul = operator(config, target=target, enable_tuning=True)
183+
kernel_latency = matmul.profile_latency()
184+
185+
print("Time cost is: {:.3f} ms".format(kernel_latency))
186+
187+
profile_config = {
188+
f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": {
189+
"BitBLAS_top20_latency": kernel_latency,
190+
}
191+
}
192+
193+
benchmark_results.update(profile_config)
194+
195+
# Define headers for the table
196+
headers = [
197+
"PrimFunc",
198+
"Input Arguments",
199+
"BitBLAS Top20 Latency",
200+
]
201+
202+
# Calculate column widths for pretty printing
203+
col_widths = [0, 0, 0]
204+
for config_key, values in benchmark_results.items():
205+
args_split = config_key.split("-")
206+
func_name = args_split[0]
207+
input_args_str = "-".join(args_split[1:])
208+
col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2)
209+
col_widths[1] = max(col_widths[1],
210+
len(input_args_str) + 2,
211+
len(headers[1]) + 2)
212+
col_widths[2] = max(col_widths[2],
213+
len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2,
214+
len(headers[2]) + 2)
215+
# break only if you want to measure widths from a single example;
216+
# otherwise, let it loop over all items.
217+
218+
# Print header
219+
for i, header in enumerate(headers):
220+
headers[i] = header.ljust(col_widths[i])
221+
print("".join(headers))
222+
print("-" * sum(col_widths))
223+
224+
# Print rows
225+
for config_key, values in benchmark_results.items():
226+
args_split = config_key.split("-")
227+
func_name = args_split[0]
228+
input_args_str = "-".join(args_split[1:])
229+
row = [
230+
func_name,
231+
input_args_str,
232+
f"{values['BitBLAS_top20_latency']:.3f} ms",
233+
]
234+
row_str = "".join(
235+
[str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)])
236+
print(row_str)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# BitBLAS
2+
3+
vLLM now supports [BitBLAS](https://github.com/microsoft/BitBLAS) for more efficient and flexible model inference. Compared to other quantization frameworks, BitBLAS provides more precision combinations.
4+
5+
Below are the steps to utilize BitBLAS with vLLM.
6+
7+
```console
8+
pip install bitblas>=0.1.0
9+
```
10+
11+
vLLM reads the model's config file and supports pre-quantized checkpoints.
12+
13+
You can find pre-quantized models on:
14+
15+
- [Hugging Face (BitBLAS)](https://huggingface.co/models?other=bitblas)
16+
- [Hugging Face (GPTQ)](https://huggingface.co/models?other=gptq)
17+
18+
Usually, these repositories have a `quantize_config.json` file that includes a `quantization_config` section.
19+
20+
## Read bitblas format checkpoint
21+
22+
```python
23+
from vllm import LLM
24+
import torch
25+
26+
# "hxbgsyxh/llama-13b-4bit-g-1-bitblas" is a pre-quantized checkpoint.
27+
model_id = "hxbgsyxh/llama-13b-4bit-g-1-bitblas"
28+
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, quantization="bitblas")
29+
```
30+
31+
## Read gptq format checkpoint
32+
33+
```python
34+
from vllm import LLM
35+
import torch
36+
37+
# "hxbgsyxh/llama-13b-4bit-g-1" is a pre-quantized checkpoint.
38+
model_id = "hxbgsyxh/llama-13b-4bit-g-1"
39+
llm = LLM(model=model_id, dtype=torch.float16, trust_remote_code=True, quantization="bitblas", max_model_len=1024)
40+
```

docs/source/features/quantization/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Quantization trades off model precision for smaller memory footprint, allowing l
1111
supported_hardware
1212
auto_awq
1313
bnb
14+
bitblas
1415
gguf
1516
gptqmodel
1617
int4

docs/source/features/quantization/supported_hardware.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,17 @@ The table below shows the compatibility of various quantization implementations
7474
*
7575
*
7676
*
77+
- * BitBLAS (GPTQ)
78+
* ✅︎
79+
* ✅︎
80+
* ✅︎
81+
* ✅︎
82+
* ✅︎
83+
* ✅︎
84+
*
85+
*
86+
*
87+
*
7788
- * AQLM
7889
* ✅︎
7990
* ✅︎

tests/models/test_bitblas.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Compare the outputs of a GPTQ model to a bitblas model.
3+
4+
Note: GPTQ and bitblas do not have bitwise correctness.
5+
As a result, in this test, we just confirm that the top selected tokens of the
6+
bitblas/GPTQ models are in the top 3 selections of each other.
7+
8+
Note: bitblas internally uses locks to synchronize the threads. This can
9+
result in very slight nondeterminism for bitblas. As a result, we re-run the
10+
test up to 3 times to see if we pass.
11+
12+
Run `pytest tests/models/test_bitblas.py`.
13+
"""
14+
from dataclasses import dataclass
15+
16+
import pytest
17+
18+
from .utils import check_logprobs_close
19+
20+
21+
@dataclass
22+
class ModelPair:
23+
model_bitblas: str
24+
model_gptq: str
25+
26+
27+
model_pairs = [
28+
ModelPair(model_bitblas="hxbgsyxh/opt-125m-4bit-128g-bitblas",
29+
model_gptq="hxbgsyxh/opt-125m-4bit-128g"),
30+
]
31+
32+
33+
@pytest.mark.flaky(reruns=2)
34+
@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.")
35+
@pytest.mark.parametrize("model_pair", model_pairs)
36+
@pytest.mark.parametrize("dtype", ["half"])
37+
@pytest.mark.parametrize("max_tokens", [32])
38+
@pytest.mark.parametrize("num_logprobs", [5])
39+
def test_models(
40+
vllm_runner,
41+
example_prompts,
42+
model_pair: ModelPair,
43+
dtype: str,
44+
max_tokens: int,
45+
num_logprobs: int,
46+
) -> None:
47+
with vllm_runner(model_pair.model_bitblas,
48+
dtype=dtype,
49+
quantization="bitblas") as bitblas_model:
50+
bitblas_outputs = bitblas_model.generate_greedy_logprobs(
51+
example_prompts, max_tokens, num_logprobs)
52+
53+
with vllm_runner(model_pair.model_gptq, dtype=dtype,
54+
quantization="gptq") as gptq_model:
55+
gptq_outputs = gptq_model.generate_greedy_logprobs(
56+
example_prompts, max_tokens, num_logprobs)
57+
58+
check_logprobs_close(
59+
outputs_0_lst=gptq_outputs,
60+
outputs_1_lst=bitblas_outputs,
61+
name_0="gptq",
62+
name_1="bitblas",
63+
)

0 commit comments

Comments
 (0)