7
7
import os
8
8
from safetensors .torch import safe_open
9
9
import torch
10
- from typing import List , Optional , Tuple
10
+ from typing import Any , Callable , Dict , Iterable , List , Optional , Tuple
11
11
12
12
13
13
# Adapted from vllm/model_executor/weight_utils.py
@@ -90,12 +90,25 @@ def _hf_tensorfile_iterator(filename: str, load_format: str,
90
90
torch .cuda .empty_cache ()
91
91
92
92
93
- def main (args ):
94
- rank_tensors_map = {}
95
- hf_tensor_files , use_safetensors = _prepare_hf_weights (args .quantized_model , args .load_format )
96
- # Matches the number immediately after this keyword in the tensor filename to
97
- # determine the TP rank corresponding to said tensor file
98
- rank_keyword = "rank"
93
+ def _kv_scales_extractor (hf_tensor_files : Iterable [str ],
94
+ use_safetensors : bool ,
95
+ rank_keyword : str = "rank" ,
96
+ expected_tp_size : Optional [int ] = None ) -> Dict [int , Dict [int , float ]]:
97
+ """
98
+ Given a list of files containing tensor data, attempt to extract KV cache scales from
99
+ these files. Intended as a helper function taking in the output from _prepare_hf_weights.
100
+ Args:
101
+ rank_keyword Matches the number immediately after this keyword in the tensor
102
+ filename to determine the TP rank corresponding to said tensor file
103
+ expected_tp_size If specified, the TP size of the tensor files is checked against
104
+ this and an error is raised if they do not match.
105
+ Returns a dictionary mapping TP ranks to their relevant KV cache scaling factors. The
106
+ per-rank scaling factors are themselves represented as a dictionary of layer indices to the
107
+ respective per-layer scaling factor.
108
+ """
109
+ for char in rank_keyword :
110
+ assert not char .isdecimal (), f"Rank keyword { rank_keyword } contains a numeric character!"
111
+ rank_scales_map = {}
99
112
for tensor_file in hf_tensor_files :
100
113
try :
101
114
rank_idx = tensor_file .find (rank_keyword )
@@ -118,9 +131,9 @@ def main(args):
118
131
f"corresponding to file '{ tensor_file } '" )
119
132
raise
120
133
121
- if rank not in rank_tensors_map :
134
+ if rank not in rank_scales_map :
122
135
layer_scales_map = {}
123
- rank_tensors_map [rank ] = layer_scales_map
136
+ rank_scales_map [rank ] = layer_scales_map
124
137
else :
125
138
raise RuntimeError (f"Tensor file '{ tensor_file } ' shares TP rank { rank } "
126
139
"with another tensor file." )
@@ -138,34 +151,138 @@ def main(args):
138
151
layer_scales_map [layer_idx ] = param .item ()
139
152
except RuntimeError :
140
153
print ("This utility supports only per-tensor scalar scale factors "
141
- f"for now. The tensor\n { name } = { param } is an invalid "
154
+ f"for now. The tensor\n { name } = { param } \n is an invalid "
142
155
"scale factor." )
143
156
raise
144
157
158
+ if all (len (layer_scales_map ) == 0 for layer_scales_map in rank_scales_map .values ()):
159
+ # Note: this is true even if the rank_scales_map is empty
160
+ print ("WARNING: No KV cache scale factors found. No output saved." )
161
+ return None
162
+ empirical_tp_world_size = max (rank_scales_map .keys ()) + 1
163
+ if expected_tp_size is not None :
164
+ assert expected_tp_size == empirical_tp_world_size , "User expected TP world size = " \
165
+ f"{ expected_tp_size } from model but tool is expecting TP world size = " \
166
+ f"{ empirical_tp_world_size } from model instead."
167
+ for i in range (empirical_tp_world_size ):
168
+ assert i in rank_scales_map , f"Expected TP world size = { empirical_tp_world_size } " \
169
+ "but did not find KV cache scaling factors " \
170
+ f"for TP rank { i } "
171
+ print (f"Found TP world size = { empirical_tp_world_size } when extracting KV cache scales!" )
172
+ return rank_scales_map
173
+
174
+
175
+ def _metadata_extractor (quantized_model_dir : str ,
176
+ metadata_extract_fns : Dict [str , Callable [[Dict [str , Any ]], Any ]]) \
177
+ -> Dict [str , Any ]:
178
+ """
179
+ Given a directory containing quantized model files, this function aims to extract metadata
180
+ from the JSON files within this directory. Each JSON file is expected to represent a
181
+ dictionary in JSON format (referred to as a "JSON-dictionary"). Metadata extraction is
182
+ defined by a dictionary called metadata_extract_fns, where each metadata field name is
183
+ mapped to an extraction function.
184
+
185
+ These extraction functions are designed to take a JSON-dictionary as their only argument
186
+ and return the corresponding metadata. While extraction functions are permitted to raise
187
+ exceptions, they should only raise a KeyError or ValueError if the metadata field cannot
188
+ be extracted from the current JSON-dictionary, yet there's a possibility of finding it in
189
+ another JSON-dictionary.
190
+
191
+ The function returns a dictionary that maps metadata fields to their extracted data. The
192
+ keys of this dictionary correspond exactly to those in metadata_extract_fns. If any fields
193
+ fail to be extracted, their corresponding values are set to None, and a warning is printed.
194
+ """
195
+ if not os .path .isdir (quantized_model_dir ):
196
+ raise FileNotFoundError (f"The quantized model directory `{ quantized_model_dir } ` "
197
+ "does not exist." )
198
+ metadata_files = glob .glob (os .path .join (quantized_model_dir , "*.json" ))
199
+
200
+ result = {}
201
+ for file in metadata_files :
202
+ with open (file ) as f :
203
+ try :
204
+ metadata = json .load (f )
205
+ except json .JSONDecodeError :
206
+ print (f"Could not parse `{ file } ` as a valid metadata file, skipping it." )
207
+ continue
208
+ if not isinstance (metadata , dict ):
209
+ print (f"The file `{ file } ` does not correspond to a JSON-serialized "
210
+ "dictionary, skipping it." )
211
+ continue
212
+ for metadata_name , extract_fn in metadata_extract_fns .items ():
213
+ try :
214
+ metadata_info = extract_fn (metadata )
215
+ if metadata_name not in result :
216
+ result [metadata_name ] = metadata_info
217
+ elif metadata_info != result [metadata_name ]:
218
+ raise RuntimeError ("Metadata mismatch! Originally found "
219
+ f"{ metadata_name } = { result [metadata_name ]} but "
220
+ f"now found { metadata_name } = { metadata_info } in "
221
+ f"`{ file } `" )
222
+ except KeyError :
223
+ # It is possible that a given file does not contain some of our selected
224
+ # metadata as it could be located in some other metadata file.
225
+ # 'EFINAE': extract_fn failure is not an error.
226
+ pass
227
+ except ValueError :
228
+ # See above.
229
+ pass
230
+
231
+ # Warn if we cannot find any of the requested metadata
232
+ for metadata_name in metadata_extract_fns :
233
+ if metadata_name not in result :
234
+ print (f"WARNING: Unable to find requested metadata field `{ metadata_name } `, "
235
+ "setting it to None." )
236
+ result [metadata_name ] = None
237
+
238
+ return result
239
+
240
+
241
+ def main (args ):
242
+ metadata_extract_fns = {
243
+ "model_type" : lambda json_dict : json_dict ["layers" ][0 ]["decoder_type" ],
244
+ "tp_size" : lambda json_dict : int (json_dict ["tensor_parallel" ]),
245
+ "model_dtype" : lambda json_dict : json_dict ["dtype" ]
246
+ }
247
+ recovered_metadata = _metadata_extractor (args .quantized_model , metadata_extract_fns )
248
+ if args .tp_size is not None :
249
+ metadata_tp_size = recovered_metadata ["tp_size" ]
250
+ if metadata_tp_size is not None :
251
+ assert args .tp_size == metadata_tp_size , "User expected TP world size = " \
252
+ f"{ args .tp_size } but found TP world size = { metadata_tp_size } from metadata!"
253
+ expected_tp_size = args .tp_size or recovered_metadata ["tp_size" ]
254
+ rank_keyword = "rank"
255
+ hf_tensor_files , use_safetensors = _prepare_hf_weights (args .quantized_model , args .load_format )
256
+ rank_scales_map = _kv_scales_extractor (hf_tensor_files , use_safetensors ,
257
+ rank_keyword , expected_tp_size )
258
+ # Postprocess: formatting to the current schema. Consider pulling it out into a dedicated
259
+ # function should it ever become more complicated.
260
+ rank_scales_map = { rank_keyword + str (rank ) :
261
+ {k : scale [k ] for k in sorted (scale .keys ())}
262
+ for rank , scale in rank_scales_map .items () }
263
+
264
+ # Consider generalizing and formalizing this into its own class (and other necessary
265
+ # subclasses) in the future
266
+ schema = { "model_type" : recovered_metadata ["model_type" ],
267
+ "kv_cache" : {
268
+ "dtype" : "float8_e4m3fn" if len (rank_scales_map ) > 0 \
269
+ else recovered_metadata ["model_dtype" ],
270
+ "scaling_factor" : rank_scales_map
271
+ },
272
+ # TODO: Expand this with activation and weights scaling factors when they
273
+ # are used in the future
274
+ }
275
+
145
276
if args .output_dir is None :
146
277
output_file = os .path .join (args .quantized_model , args .output_name )
147
278
else :
148
- output_file = os .path .join (args .output_dir , args .output_name )
149
279
if not os .path .isdir (args .output_dir ):
150
280
os .makedirs (args .output_dir , exist_ok = True )
151
-
152
- if all (len (layer_scales_map ) == 0 for layer_scales_map in rank_tensors_map .values ()):
153
- # Note: this is true even if the rank_tensors_map is empty
154
- print ("WARNING: No KV cache scale factors found. No output saved." )
155
- else :
156
- empirical_tp_world_size = max (rank_tensors_map .keys ()) + 1
157
- if args .tp_size is not None :
158
- assert args .tp_size == empirical_tp_world_size , "User expected TP world size = " \
159
- f"{ args .tp_size } from model but tool is expecting TP world size = " \
160
- f"{ empirical_tp_world_size } from model instead."
161
- for i in range (empirical_tp_world_size ):
162
- assert i in rank_tensors_map , f"Expected TP world size = { empirical_tp_world_size } " \
163
- "but did not find KV cache scaling factors " \
164
- f"for TP rank { i } "
165
- with open (output_file , 'w' ) as f :
166
- json .dump (rank_tensors_map , f , sort_keys = True , indent = 4 )
167
- print (f"Completed! Found TP world size = { empirical_tp_world_size } ." ,
168
- f"KV cache scaling factors saved to { output_file } " )
281
+ output_file = os .path .join (args .output_dir , args .output_name )
282
+
283
+ with open (output_file , 'w' ) as f :
284
+ json .dump (schema , f , indent = 4 )
285
+ print (f"Completed! KV cache scaling factors saved to { output_file } " )
169
286
170
287
171
288
if __name__ == "__main__" :
@@ -174,7 +291,7 @@ def main(args):
174
291
"and saves them to a JSON file compatible with later "
175
292
"use by vLLM (pass this file to the appropriate "
176
293
"runtime typically using the argument "
177
- "--kv_cache_scales_path <filename>). This is only used "
294
+ "--scales-path <filename>). This is only used "
178
295
"if the KV cache dtype is FP8 and on ROCm (AMD GPU)." )
179
296
parser .add_argument ("--quantized_model" ,
180
297
help = "Specify the directory containing a single quantized HF model. "
@@ -193,7 +310,8 @@ def main(args):
193
310
default = None )
194
311
parser .add_argument ("--output_name" ,
195
312
help = "Optionally specify the output filename." ,
196
- default = "kv_cache_scales.json" )
313
+ # TODO: Change this once additional scaling factors are enabled
314
+ default = "kv_cache_scales.json" )
197
315
parser .add_argument ("--tp_size" ,
198
316
help = "Optionally specify the tensor-parallel (TP) size that the "
199
317
"quantized model should correspond to. If specified, during KV "
0 commit comments