17
17
from deepsparse .transformers .helpers import setup_transformers_pipeline
18
18
from deepsparse .transformers .utils .helpers import process_generation_config
19
19
from deepsparse .utils import split_engine_inputs
20
+ from deepsparse .utils .onnx import default_cached_outputs
20
21
from deepsparse .v2 .pipeline import Pipeline
21
- from deepsparse .v2 .routers import GraphRouter
22
+ from deepsparse .v2 .routers import GraphRouter , LinearRouter
22
23
from deepsparse .v2 .schedulers import OperatorScheduler
23
24
from deepsparse .v2 .text_generation import (
24
25
AutoRegressiveOperatorPreprocess ,
29
30
JoinOutput ,
30
31
KVCacheCreator ,
31
32
MultiEnginePrefill ,
32
- NLEngineOperator ,
33
+ NlEngineOperator ,
34
+ NlEngineOperatorNoCache ,
33
35
PrepareforPrefill ,
34
36
PrepareGeneration ,
35
37
ProcessInputsTextGeneration ,
39
41
from deepsparse .v2 .utils import PipelineState
40
42
41
43
44
+ class TextGenerationPipelineNoCache (Pipeline ):
45
+ def __init__ (
46
+ self ,
47
+ model_path : str ,
48
+ sequence_length : int = 1024 ,
49
+ engine_kwargs : Optional [Dict ] = None ,
50
+ onnx_model_name : Optional [str ] = None ,
51
+ generation_config = None , # TODO: Typing here
52
+ ** kwargs ,
53
+ ):
54
+
55
+ (
56
+ self .model_path ,
57
+ self .config ,
58
+ self .tokenizer ,
59
+ engine_kwargs ,
60
+ ) = setup_transformers_pipeline (
61
+ model_path ,
62
+ sequence_length ,
63
+ onnx_model_name = onnx_model_name ,
64
+ engine_kwargs = engine_kwargs ,
65
+ )
66
+ self .verify_no_kv_cache_present ()
67
+
68
+ token_generator = TokenGeneratorOperator ()
69
+
70
+ ops = [
71
+ ProcessInputsTextGeneration (
72
+ generation_config = process_generation_config (generation_config ),
73
+ sequence_length = sequence_length ,
74
+ tokenizer = self .tokenizer ,
75
+ ),
76
+ NlEngineOperatorNoCache (sequence_length = sequence_length , ** engine_kwargs ),
77
+ PrepareGeneration (
78
+ sequence_length = sequence_length ,
79
+ prompt_sequence_length = 1 ,
80
+ token_generator = token_generator ,
81
+ ),
82
+ GenerateNewTokenOperator (tokenizer = self .tokenizer , force_max_tokens = True ),
83
+ CompileGeneratedTokens (),
84
+ CompileGenerations (),
85
+ JoinOutput (tokenizer = self .tokenizer ),
86
+ ProcessOutputs (tokenizer = self .tokenizer ),
87
+ ]
88
+ router = LinearRouter (end_route = len (ops ))
89
+ scheduler = [OperatorScheduler ()]
90
+ super ().__init__ (
91
+ ops = ops ,
92
+ router = router ,
93
+ schedulers = scheduler ,
94
+ )
95
+
96
+ def run (self , * args , ** kwargs ):
97
+ # we need to set the fixed_sequences_length flag to True
98
+ # for the non-kv cache pipeline
99
+ kwargs .update (dict (fixed_sequences_length = True ))
100
+ return super ().run (* args , ** kwargs )
101
+
102
+ def verify_no_kv_cache_present (self ) -> bool :
103
+ """
104
+ Verifies that the ONNX model does not have
105
+ KV cache inputs/outputs present.
106
+ :return: True if compatible, False otherwise
107
+ """
108
+ is_kv_cache_present = any (default_cached_outputs (self .model_path ))
109
+ if is_kv_cache_present :
110
+ raise ValueError (
111
+ f"The model: { self .model_path } has KV cache inputs/outputs present. "
112
+ "Please use the TextGenerationPipeline instead."
113
+ )
114
+ return not is_kv_cache_present
115
+
116
+
42
117
class TextGenerationPipeline (Pipeline ):
43
118
def __init__ (
44
119
self ,
@@ -65,14 +140,14 @@ def __init__(
65
140
if internal_kv_cache and engine_kwargs .get ("engine_type" ) == "onnxruntime" :
66
141
internal_kv_cache = False
67
142
68
- single_engine_operator = NLEngineOperator (
143
+ single_engine_operator = NlEngineOperator (
69
144
sequence_length = sequence_length ,
70
145
internal_kv_cache = internal_kv_cache ,
71
146
input_ids_length = 1 ,
72
147
** engine_kwargs ,
73
148
)
74
149
75
- multi_engine_operator = NLEngineOperator (
150
+ multi_engine_operator = NlEngineOperator (
76
151
sequence_length = sequence_length ,
77
152
internal_kv_cache = internal_kv_cache ,
78
153
input_ids_length = prompt_sequence_length ,
@@ -194,5 +269,3 @@ def expand_inputs(self, items, batch_size):
194
269
195
270
def condense_inputs (self , * args , ** kwargs ):
196
271
return args [0 ], kwargs
197
-
198
-
0 commit comments