62
62
import json
63
63
import os
64
64
65
- from deepsparse import model_debug_analysis
65
+ from deepsparse import KVCacheParams , model_debug_analysis
66
66
from deepsparse .utils import (
67
+ default_cached_outputs ,
67
68
generate_random_inputs ,
68
69
model_to_path ,
69
70
override_onnx_input_shapes ,
@@ -140,6 +141,29 @@ def parse_args():
140
141
type = str ,
141
142
default = "" ,
142
143
)
144
+ parser .add_argument (
145
+ "--disable-batch-override" ,
146
+ help = "Ignores the batch_size parameter" ,
147
+ action = "store_true" ,
148
+ default = False ,
149
+ )
150
+ parser .add_argument (
151
+ "--use-kvcache" , help = "Enable KVCache" , action = "store_true" , default = False
152
+ )
153
+ parser .add_argument (
154
+ "--kv-cache-prev-num-tokens" ,
155
+ help = "KVCache: The amount of previous tokens that will be read"
156
+ " from the external KV cache on the first inference" ,
157
+ type = int ,
158
+ default = None ,
159
+ )
160
+ parser .add_argument (
161
+ "--kv-cache-num-frozen-tokens" ,
162
+ help = "KVCache: The amount of first tokens that we want to keep"
163
+ " permanently in the KV cache" ,
164
+ type = int ,
165
+ default = None ,
166
+ )
143
167
parser .add_argument (
144
168
"-q" ,
145
169
"--quiet" ,
@@ -186,14 +210,10 @@ def construct_layer_table(result):
186
210
"{: >#08.4f} | {: >#08.4f} | {: >#08.4f} | {:12}"
187
211
)
188
212
for li in result ["layer_info" ]:
189
- table_str += layer_info_to_string (
190
- li ,
191
- "{:28}| " + info_format_base + "\n " ,
192
- )
213
+ table_str += layer_info_to_string (li , "{:28}| " + info_format_base + "\n " )
193
214
for sub_li in li ["sub_layer_info" ]:
194
215
table_str += layer_info_to_string (
195
- sub_li ,
196
- " {:26}| " + info_format_base + "\n " ,
216
+ sub_li , " {:26}| " + info_format_base + "\n "
197
217
)
198
218
199
219
table_str += "Total Time(MS): {:05f}\n " .format (result ["average_total_time" ])
@@ -295,11 +315,39 @@ def main():
295
315
296
316
print ("Analyzing model: {}" .format (orig_model_path ))
297
317
318
+ batch_size = args .batch_size
319
+ if args .disable_batch_override :
320
+ batch_size = None
321
+ os .environ ["NM_DISABLE_BATCH_OVERRIDE" ] = "1"
322
+ print ("Disable batch override: ON" )
323
+
298
324
if input_shapes :
299
325
with override_onnx_input_shapes (model_path , input_shapes ) as tmp_path :
300
- input_list = generate_random_inputs (tmp_path , args . batch_size )
326
+ input_list = generate_random_inputs (tmp_path , batch_size )
301
327
else :
302
- input_list = generate_random_inputs (model_path , args .batch_size )
328
+ input_list = generate_random_inputs (model_path , batch_size )
329
+
330
+ kv_cache_params = None
331
+ if args .use_kvcache :
332
+ kv_cache_prev_num_tokens = 0
333
+ if args .kv_cache_prev_num_tokens is not None :
334
+ kv_cache_prev_num_tokens = args .kv_cache_prev_num_tokens
335
+
336
+ kv_cache_num_frozen_tokens = 0
337
+ if args .kv_cache_num_frozen_tokens is not None :
338
+ kv_cache_num_frozen_tokens = args .kv_cache_num_frozen_tokens
339
+
340
+ kv_cache_params = KVCacheParams (
341
+ default_cached_outputs (model_path ),
342
+ kv_cache_prev_num_tokens ,
343
+ kv_cache_num_frozen_tokens ,
344
+ )
345
+
346
+ print (
347
+ "Enable KVCache: prev_num_tokens = {}, num_frozen_tokens = {}" .format (
348
+ kv_cache_params .prev_num_tokens , kv_cache_params .num_frozen_tokens
349
+ )
350
+ )
303
351
304
352
result = model_debug_analysis (
305
353
model_path ,
@@ -308,9 +356,11 @@ def main():
308
356
num_cores = args .num_cores ,
309
357
num_iterations = args .num_iterations ,
310
358
num_warmup_iterations = args .num_warmup_iterations ,
311
- optimization_level = args .optimization ,
359
+ optimization_level = int (args .optimization ),
360
+ disable_batch_override = args .disable_batch_override ,
312
361
imposed_ks = imposed_kernel_sparsity ,
313
362
input_shapes = input_shapes ,
363
+ kv_cache_params = kv_cache_params ,
314
364
)
315
365
316
366
if not args .quiet :
0 commit comments