Skip to content

Commit 7b72159

Browse files
authored
Merge pull request vllm-project#16 from ROCm/fp8_ingest_stage1_model
Generalizing KV scales JSON to updated schema
2 parents 12f7650 + 52df603 commit 7b72159

File tree

11 files changed

+259
-116
lines changed

11 files changed

+259
-116
lines changed

3rdparty/quantizer/extract_scales.py

Lines changed: 149 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import os
88
from safetensors.torch import safe_open
99
import torch
10-
from typing import List, Optional, Tuple
10+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
1111

1212

1313
# Adapted from vllm/model_executor/weight_utils.py
@@ -90,12 +90,25 @@ def _hf_tensorfile_iterator(filename: str, load_format: str,
9090
torch.cuda.empty_cache()
9191

9292

93-
def main(args):
94-
rank_tensors_map = {}
95-
hf_tensor_files, use_safetensors = _prepare_hf_weights(args.quantized_model, args.load_format)
96-
# Matches the number immediately after this keyword in the tensor filename to
97-
# determine the TP rank corresponding to said tensor file
98-
rank_keyword = "rank"
93+
def _kv_scales_extractor(hf_tensor_files: Iterable[str],
94+
use_safetensors: bool,
95+
rank_keyword: str = "rank",
96+
expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]:
97+
"""
98+
Given a list of files containing tensor data, attempt to extract KV cache scales from
99+
these files. Intended as a helper function taking in the output from _prepare_hf_weights.
100+
Args:
101+
rank_keyword Matches the number immediately after this keyword in the tensor
102+
filename to determine the TP rank corresponding to said tensor file
103+
expected_tp_size If specified, the TP size of the tensor files is checked against
104+
this and an error is raised if they do not match.
105+
Returns a dictionary mapping TP ranks to their relevant KV cache scaling factors. The
106+
per-rank scaling factors are themselves represented as a dictionary of layer indices to the
107+
respective per-layer scaling factor.
108+
"""
109+
for char in rank_keyword:
110+
assert not char.isdecimal(), f"Rank keyword {rank_keyword} contains a numeric character!"
111+
rank_scales_map = {}
99112
for tensor_file in hf_tensor_files:
100113
try:
101114
rank_idx = tensor_file.find(rank_keyword)
@@ -118,9 +131,9 @@ def main(args):
118131
f"corresponding to file '{tensor_file}'")
119132
raise
120133

121-
if rank not in rank_tensors_map:
134+
if rank not in rank_scales_map:
122135
layer_scales_map = {}
123-
rank_tensors_map[rank] = layer_scales_map
136+
rank_scales_map[rank] = layer_scales_map
124137
else:
125138
raise RuntimeError(f"Tensor file '{tensor_file}' shares TP rank {rank} "
126139
"with another tensor file.")
@@ -138,34 +151,138 @@ def main(args):
138151
layer_scales_map[layer_idx] = param.item()
139152
except RuntimeError:
140153
print("This utility supports only per-tensor scalar scale factors "
141-
f"for now. The tensor\n {name} = {param} is an invalid "
154+
f"for now. The tensor\n {name} = {param} \nis an invalid "
142155
"scale factor.")
143156
raise
144157

158+
if all(len(layer_scales_map) == 0 for layer_scales_map in rank_scales_map.values()):
159+
# Note: this is true even if the rank_scales_map is empty
160+
print("WARNING: No KV cache scale factors found. No output saved.")
161+
return None
162+
empirical_tp_world_size = max(rank_scales_map.keys()) + 1
163+
if expected_tp_size is not None:
164+
assert expected_tp_size == empirical_tp_world_size, "User expected TP world size = " \
165+
f"{expected_tp_size} from model but tool is expecting TP world size = " \
166+
f"{empirical_tp_world_size} from model instead."
167+
for i in range(empirical_tp_world_size):
168+
assert i in rank_scales_map, f"Expected TP world size = {empirical_tp_world_size} " \
169+
"but did not find KV cache scaling factors " \
170+
f"for TP rank {i}"
171+
print(f"Found TP world size = {empirical_tp_world_size} when extracting KV cache scales!")
172+
return rank_scales_map
173+
174+
175+
def _metadata_extractor(quantized_model_dir: str,
176+
metadata_extract_fns: Dict[str, Callable[[Dict[str, Any]], Any]]) \
177+
-> Dict[str, Any]:
178+
"""
179+
Given a directory containing quantized model files, this function aims to extract metadata
180+
from the JSON files within this directory. Each JSON file is expected to represent a
181+
dictionary in JSON format (referred to as a "JSON-dictionary"). Metadata extraction is
182+
defined by a dictionary called metadata_extract_fns, where each metadata field name is
183+
mapped to an extraction function.
184+
185+
These extraction functions are designed to take a JSON-dictionary as their only argument
186+
and return the corresponding metadata. While extraction functions are permitted to raise
187+
exceptions, they should only raise a KeyError or ValueError if the metadata field cannot
188+
be extracted from the current JSON-dictionary, yet there's a possibility of finding it in
189+
another JSON-dictionary.
190+
191+
The function returns a dictionary that maps metadata fields to their extracted data. The
192+
keys of this dictionary correspond exactly to those in metadata_extract_fns. If any fields
193+
fail to be extracted, their corresponding values are set to None, and a warning is printed.
194+
"""
195+
if not os.path.isdir(quantized_model_dir):
196+
raise FileNotFoundError(f"The quantized model directory `{quantized_model_dir}` "
197+
"does not exist.")
198+
metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json"))
199+
200+
result = {}
201+
for file in metadata_files:
202+
with open(file) as f:
203+
try:
204+
metadata = json.load(f)
205+
except json.JSONDecodeError:
206+
print(f"Could not parse `{file}` as a valid metadata file, skipping it.")
207+
continue
208+
if not isinstance(metadata, dict):
209+
print(f"The file `{file}` does not correspond to a JSON-serialized "
210+
"dictionary, skipping it.")
211+
continue
212+
for metadata_name, extract_fn in metadata_extract_fns.items():
213+
try:
214+
metadata_info = extract_fn(metadata)
215+
if metadata_name not in result:
216+
result[metadata_name] = metadata_info
217+
elif metadata_info != result[metadata_name]:
218+
raise RuntimeError("Metadata mismatch! Originally found "
219+
f"{metadata_name} = {result[metadata_name]} but "
220+
f"now found {metadata_name} = {metadata_info} in "
221+
f"`{file}`")
222+
except KeyError:
223+
# It is possible that a given file does not contain some of our selected
224+
# metadata as it could be located in some other metadata file.
225+
# 'EFINAE': extract_fn failure is not an error.
226+
pass
227+
except ValueError:
228+
# See above.
229+
pass
230+
231+
# Warn if we cannot find any of the requested metadata
232+
for metadata_name in metadata_extract_fns:
233+
if metadata_name not in result:
234+
print(f"WARNING: Unable to find requested metadata field `{metadata_name}`, "
235+
"setting it to None.")
236+
result[metadata_name] = None
237+
238+
return result
239+
240+
241+
def main(args):
242+
metadata_extract_fns = {
243+
"model_type": lambda json_dict: json_dict["layers"][0]["decoder_type"],
244+
"tp_size": lambda json_dict: int(json_dict["tensor_parallel"]),
245+
"model_dtype": lambda json_dict: json_dict["dtype"]
246+
}
247+
recovered_metadata = _metadata_extractor(args.quantized_model, metadata_extract_fns)
248+
if args.tp_size is not None:
249+
metadata_tp_size = recovered_metadata["tp_size"]
250+
if metadata_tp_size is not None:
251+
assert args.tp_size == metadata_tp_size, "User expected TP world size = " \
252+
f"{args.tp_size} but found TP world size = {metadata_tp_size} from metadata!"
253+
expected_tp_size = args.tp_size or recovered_metadata["tp_size"]
254+
rank_keyword = "rank"
255+
hf_tensor_files, use_safetensors = _prepare_hf_weights(args.quantized_model, args.load_format)
256+
rank_scales_map = _kv_scales_extractor(hf_tensor_files, use_safetensors,
257+
rank_keyword, expected_tp_size)
258+
# Postprocess: formatting to the current schema. Consider pulling it out into a dedicated
259+
# function should it ever become more complicated.
260+
rank_scales_map = { rank_keyword + str(rank) :
261+
{k: scale[k] for k in sorted(scale.keys())}
262+
for rank, scale in rank_scales_map.items() }
263+
264+
# Consider generalizing and formalizing this into its own class (and other necessary
265+
# subclasses) in the future
266+
schema = { "model_type": recovered_metadata["model_type"],
267+
"kv_cache": {
268+
"dtype": "float8_e4m3fn" if len(rank_scales_map) > 0 \
269+
else recovered_metadata["model_dtype"],
270+
"scaling_factor": rank_scales_map
271+
},
272+
# TODO: Expand this with activation and weights scaling factors when they
273+
# are used in the future
274+
}
275+
145276
if args.output_dir is None:
146277
output_file = os.path.join(args.quantized_model, args.output_name)
147278
else:
148-
output_file = os.path.join(args.output_dir, args.output_name)
149279
if not os.path.isdir(args.output_dir):
150280
os.makedirs(args.output_dir, exist_ok=True)
151-
152-
if all(len(layer_scales_map) == 0 for layer_scales_map in rank_tensors_map.values()):
153-
# Note: this is true even if the rank_tensors_map is empty
154-
print("WARNING: No KV cache scale factors found. No output saved.")
155-
else:
156-
empirical_tp_world_size = max(rank_tensors_map.keys()) + 1
157-
if args.tp_size is not None:
158-
assert args.tp_size == empirical_tp_world_size, "User expected TP world size = " \
159-
f"{args.tp_size} from model but tool is expecting TP world size = " \
160-
f"{empirical_tp_world_size} from model instead."
161-
for i in range(empirical_tp_world_size):
162-
assert i in rank_tensors_map, f"Expected TP world size = {empirical_tp_world_size} " \
163-
"but did not find KV cache scaling factors " \
164-
f"for TP rank {i}"
165-
with open(output_file, 'w') as f:
166-
json.dump(rank_tensors_map, f, sort_keys=True, indent=4)
167-
print(f"Completed! Found TP world size = {empirical_tp_world_size}.",
168-
f"KV cache scaling factors saved to {output_file}")
281+
output_file = os.path.join(args.output_dir, args.output_name)
282+
283+
with open(output_file, 'w') as f:
284+
json.dump(schema, f, indent=4)
285+
print(f"Completed! KV cache scaling factors saved to {output_file}")
169286

170287

171288
if __name__ == "__main__":
@@ -174,7 +291,7 @@ def main(args):
174291
"and saves them to a JSON file compatible with later "
175292
"use by vLLM (pass this file to the appropriate "
176293
"runtime typically using the argument "
177-
"--kv_cache_scales_path <filename>). This is only used "
294+
"--scales-path <filename>). This is only used "
178295
"if the KV cache dtype is FP8 and on ROCm (AMD GPU).")
179296
parser.add_argument("--quantized_model",
180297
help="Specify the directory containing a single quantized HF model. "
@@ -193,7 +310,8 @@ def main(args):
193310
default=None)
194311
parser.add_argument("--output_name",
195312
help="Optionally specify the output filename.",
196-
default="kv_cache_scales.json")
313+
# TODO: Change this once additional scaling factors are enabled
314+
default="kv_cache_scales.json")
197315
parser.add_argument("--tp_size",
198316
help="Optionally specify the tensor-parallel (TP) size that the "
199317
"quantized model should correspond to. If specified, during KV "

benchmarks/benchmark_latency.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def main(args: argparse.Namespace):
2525
dtype=args.dtype,
2626
enforce_eager=args.enforce_eager,
2727
kv_cache_dtype=args.kv_cache_dtype,
28-
kv_cache_scales_path=args.kv_cache_scales_path,
28+
scales_path=args.scales_path,
2929
)
3030

3131
sampling_params = SamplingParams(
@@ -128,10 +128,10 @@ def run_to_completion(profile_dir: Optional[str] = None):
128128
'FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. '
129129
'On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria.')
130130
parser.add_argument(
131-
'--kv-cache-scales-path',
131+
'--scales-path',
132132
type=str,
133133
default=None,
134-
help='Path to the JSON files containing the KV cache scaling factors. '
134+
help='Path to the JSON file containing the KV cache scaling factors. '
135135
'This should generally be supplied, when KV cache dtype is FP8. Otherwise, '
136136
'KV cache scaling factors default to 1.0, which may cause accuracy issues. '
137137
'FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. '

benchmarks/benchmark_throughput.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def run_vllm(
7272
max_model_len: Optional[int],
7373
enforce_eager: bool,
7474
kv_cache_dtype: str,
75-
kv_cache_scales_path: Optional[str],
75+
scales_path: Optional[str],
7676
) -> float:
7777
from vllm import LLM, SamplingParams
7878
llm = LLM(
@@ -86,7 +86,7 @@ def run_vllm(
8686
max_model_len=max_model_len,
8787
enforce_eager=enforce_eager,
8888
kv_cache_dtype=kv_cache_dtype,
89-
kv_cache_scales_path=kv_cache_scales_path,
89+
scales_path=scales_path,
9090
)
9191

9292
# Add the requests to the engine.
@@ -211,7 +211,7 @@ def main(args: argparse.Namespace):
211211
args.seed, args.n, args.use_beam_search,
212212
args.trust_remote_code, args.dtype,
213213
args.max_model_len, args.enforce_eager,
214-
args.kv_cache_dtype, args.kv_cache_scales_path)
214+
args.kv_cache_dtype, args.scales_path)
215215
elif args.backend == "hf":
216216
assert args.tensor_parallel_size == 1
217217
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@@ -298,10 +298,10 @@ def main(args: argparse.Namespace):
298298
'FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. '
299299
'On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria.')
300300
parser.add_argument(
301-
'--kv-cache-scales-path',
301+
'--scales-path',
302302
type=str,
303303
default=None,
304-
help='Path to the JSON files containing the KV cache scaling factors. '
304+
help='Path to the JSON file containing the KV cache scaling factors. '
305305
'This should generally be supplied, when KV cache dtype is FP8. Otherwise, '
306306
'KV cache scaling factors default to 1.0, which may cause accuracy issues. '
307307
'FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. '

0 commit comments

Comments
 (0)