Skip to content

Commit 0ae6aa7

Browse files
alexm-redhatmgoin
andauthored
Add kvcache support to debug_analysis.py and engine.py (#1132)
* add kvcache support to debug_analysis.py and engine.py * remove print * make defaults for kvcache run * review comments * Update src/deepsparse/utils/onnx.py --------- Co-authored-by: Michael Goin <[email protected]>
1 parent fb9d131 commit 0ae6aa7

File tree

3 files changed

+186
-63
lines changed

3 files changed

+186
-63
lines changed

src/deepsparse/debug_analysis.py

+60-10
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,9 @@
6262
import json
6363
import os
6464

65-
from deepsparse import model_debug_analysis
65+
from deepsparse import KVCacheParams, model_debug_analysis
6666
from deepsparse.utils import (
67+
default_cached_outputs,
6768
generate_random_inputs,
6869
model_to_path,
6970
override_onnx_input_shapes,
@@ -140,6 +141,29 @@ def parse_args():
140141
type=str,
141142
default="",
142143
)
144+
parser.add_argument(
145+
"--disable-batch-override",
146+
help="Ignores the batch_size parameter",
147+
action="store_true",
148+
default=False,
149+
)
150+
parser.add_argument(
151+
"--use-kvcache", help="Enable KVCache", action="store_true", default=False
152+
)
153+
parser.add_argument(
154+
"--kv-cache-prev-num-tokens",
155+
help="KVCache: The amount of previous tokens that will be read"
156+
" from the external KV cache on the first inference",
157+
type=int,
158+
default=None,
159+
)
160+
parser.add_argument(
161+
"--kv-cache-num-frozen-tokens",
162+
help="KVCache: The amount of first tokens that we want to keep"
163+
" permanently in the KV cache",
164+
type=int,
165+
default=None,
166+
)
143167
parser.add_argument(
144168
"-q",
145169
"--quiet",
@@ -186,14 +210,10 @@ def construct_layer_table(result):
186210
"{: >#08.4f} | {: >#08.4f} | {: >#08.4f} | {:12}"
187211
)
188212
for li in result["layer_info"]:
189-
table_str += layer_info_to_string(
190-
li,
191-
"{:28}| " + info_format_base + "\n",
192-
)
213+
table_str += layer_info_to_string(li, "{:28}| " + info_format_base + "\n")
193214
for sub_li in li["sub_layer_info"]:
194215
table_str += layer_info_to_string(
195-
sub_li,
196-
" {:26}| " + info_format_base + "\n",
216+
sub_li, " {:26}| " + info_format_base + "\n"
197217
)
198218

199219
table_str += "Total Time(MS): {:05f}\n".format(result["average_total_time"])
@@ -295,11 +315,39 @@ def main():
295315

296316
print("Analyzing model: {}".format(orig_model_path))
297317

318+
batch_size = args.batch_size
319+
if args.disable_batch_override:
320+
batch_size = None
321+
os.environ["NM_DISABLE_BATCH_OVERRIDE"] = "1"
322+
print("Disable batch override: ON")
323+
298324
if input_shapes:
299325
with override_onnx_input_shapes(model_path, input_shapes) as tmp_path:
300-
input_list = generate_random_inputs(tmp_path, args.batch_size)
326+
input_list = generate_random_inputs(tmp_path, batch_size)
301327
else:
302-
input_list = generate_random_inputs(model_path, args.batch_size)
328+
input_list = generate_random_inputs(model_path, batch_size)
329+
330+
kv_cache_params = None
331+
if args.use_kvcache:
332+
kv_cache_prev_num_tokens = 0
333+
if args.kv_cache_prev_num_tokens is not None:
334+
kv_cache_prev_num_tokens = args.kv_cache_prev_num_tokens
335+
336+
kv_cache_num_frozen_tokens = 0
337+
if args.kv_cache_num_frozen_tokens is not None:
338+
kv_cache_num_frozen_tokens = args.kv_cache_num_frozen_tokens
339+
340+
kv_cache_params = KVCacheParams(
341+
default_cached_outputs(model_path),
342+
kv_cache_prev_num_tokens,
343+
kv_cache_num_frozen_tokens,
344+
)
345+
346+
print(
347+
"Enable KVCache: prev_num_tokens = {}, num_frozen_tokens = {}".format(
348+
kv_cache_params.prev_num_tokens, kv_cache_params.num_frozen_tokens
349+
)
350+
)
303351

304352
result = model_debug_analysis(
305353
model_path,
@@ -308,9 +356,11 @@ def main():
308356
num_cores=args.num_cores,
309357
num_iterations=args.num_iterations,
310358
num_warmup_iterations=args.num_warmup_iterations,
311-
optimization_level=args.optimization,
359+
optimization_level=int(args.optimization),
360+
disable_batch_override=args.disable_batch_override,
312361
imposed_ks=imposed_kernel_sparsity,
313362
input_shapes=input_shapes,
363+
kv_cache_params=kv_cache_params,
314364
)
315365

316366
if not args.quiet:

0 commit comments

Comments
 (0)