Skip to content

Commit 51c4ee6

Browse files
committed
pipeline runs, but incorrectly
1 parent d1683b4 commit 51c4ee6

File tree

8 files changed

+135
-15
lines changed

8 files changed

+135
-15
lines changed

src/deepsparse/transformers/utils/token_generator.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,16 @@ def generate(self, logits: numpy.ndarray) -> numpy.ndarray:
7777
:param logits: the logits from the model with shape (vocab_size,)
7878
:return: the sampled token
7979
"""
80-
if self.top_k:
81-
logits = self.apply_top_k(logits)
82-
if self.top_p:
83-
logits = self.apply_top_p(logits)
84-
8580
if self.deterministic:
8681
token = numpy.argmax(logits)
8782
self.tokens.append(token)
8883
return token
8984

85+
if self.top_k:
86+
logits = self.apply_top_k(logits)
87+
if self.top_p:
88+
logits = self.apply_top_p(logits)
89+
9090
if self.sampling_temperature != 1.0:
9191
logits /= self.sampling_temperature
9292

src/deepsparse/v2/text_generation/join_output.py

+3
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ def __init__(self, tokenizer):
3333
self.tokenizer = tokenizer
3434

3535
def run(self, inp: List[CompileGenerationsOutput], **kwargs):
36+
37+
if not isinstance(inp, list):
38+
inp = [[inp]]
3639
batch_outputs = [x for x in inp[0]]
3740
generated_tokens = [x.generated_tokens for x in batch_outputs]
3841
generated_logits = [x.generated_logits for x in batch_outputs]

src/deepsparse/v2/text_generation/nl_engine_operator.py

+43-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from pydantic import BaseModel, Field
2020

21+
from deepsparse.transformers.helpers import overwrite_transformer_onnx_model_inputs
2122
from deepsparse.utils.onnx import (
2223
CACHE_INPUT_PREFIX,
2324
overwrite_onnx_model_inputs_for_kv_cache_models,
@@ -29,7 +30,12 @@
2930
)
3031

3132

32-
__all__ = ["NLEngineOperator", "NlEngineInput"]
33+
__all__ = [
34+
"NlEngineOperator",
35+
"NlEngineOperatorNoCache",
36+
"NlEngineInputNoCache",
37+
"NlEngineInput",
38+
]
3339

3440

3541
class NlEngineInput(BaseModel):
@@ -39,7 +45,12 @@ class NlEngineInput(BaseModel):
3945
in_generation: bool = Field(description="in_generation", default=None)
4046

4147

42-
class NLEngineOperator(EngineOperator):
48+
class NlEngineInputNoCache(BaseModel):
49+
input_ids: Any
50+
attention_mask: Any
51+
52+
53+
class NlEngineOperator(EngineOperator):
4354

4455
"""
4556
Operator for the NL Decoder Engine. This Operator inherits from the EngineOperator.
@@ -195,3 +206,33 @@ def output_names(self) -> List[str]:
195206
:return: The output names for the onnx model
196207
"""
197208
return self.engine.output_names
209+
210+
211+
class NlEngineOperatorNoCache(EngineOperator):
212+
213+
input_schema = NlEngineInputNoCache
214+
output_schema = None
215+
216+
def __init__(self, sequence_length, **kwargs):
217+
model_path, *_ = overwrite_transformer_onnx_model_inputs(
218+
path=kwargs.get("model_path"),
219+
max_length=sequence_length,
220+
batch_size=kwargs.get("batch_size", 1),
221+
)
222+
super().__init__(**kwargs)
223+
224+
def run(self, inp: NlEngineInputNoCache, **kwargs) -> Any:
225+
engine_inputs = [inp.input_ids, inp.attention_mask]
226+
logits = (
227+
super()
228+
.run(EngineOperatorInputs(engine_inputs=engine_inputs), **kwargs)
229+
.get("engine_outputs")
230+
)
231+
return {
232+
"logits": logits,
233+
"logits_shape": None,
234+
"deterministic": None,
235+
"kv_cache": None,
236+
"tokens": None,
237+
"sampling_temperature": None,
238+
}, {"prompt_logits": logits}

src/deepsparse/v2/text_generation/pipeline.py

+79-6
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
from deepsparse.transformers.helpers import setup_transformers_pipeline
1818
from deepsparse.transformers.utils.helpers import process_generation_config
1919
from deepsparse.utils import split_engine_inputs
20+
from deepsparse.utils.onnx import default_cached_outputs
2021
from deepsparse.v2.pipeline import Pipeline
21-
from deepsparse.v2.routers import GraphRouter
22+
from deepsparse.v2.routers import GraphRouter, LinearRouter
2223
from deepsparse.v2.schedulers import OperatorScheduler
2324
from deepsparse.v2.text_generation import (
2425
AutoRegressiveOperatorPreprocess,
@@ -29,7 +30,8 @@
2930
JoinOutput,
3031
KVCacheCreator,
3132
MultiEnginePrefill,
32-
NLEngineOperator,
33+
NlEngineOperator,
34+
NlEngineOperatorNoCache,
3335
PrepareforPrefill,
3436
PrepareGeneration,
3537
ProcessInputsTextGeneration,
@@ -39,6 +41,79 @@
3941
from deepsparse.v2.utils import PipelineState
4042

4143

44+
class TextGenerationPipelineNoCache(Pipeline):
45+
def __init__(
46+
self,
47+
model_path: str,
48+
sequence_length: int = 1024,
49+
engine_kwargs: Optional[Dict] = None,
50+
onnx_model_name: Optional[str] = None,
51+
generation_config=None, # TODO: Typing here
52+
**kwargs,
53+
):
54+
55+
(
56+
self.model_path,
57+
self.config,
58+
self.tokenizer,
59+
engine_kwargs,
60+
) = setup_transformers_pipeline(
61+
model_path,
62+
sequence_length,
63+
onnx_model_name=onnx_model_name,
64+
engine_kwargs=engine_kwargs,
65+
)
66+
self.verify_no_kv_cache_present()
67+
68+
token_generator = TokenGeneratorOperator()
69+
70+
ops = [
71+
ProcessInputsTextGeneration(
72+
generation_config=process_generation_config(generation_config),
73+
sequence_length=sequence_length,
74+
tokenizer=self.tokenizer,
75+
),
76+
NlEngineOperatorNoCache(sequence_length=sequence_length, **engine_kwargs),
77+
PrepareGeneration(
78+
sequence_length=sequence_length,
79+
prompt_sequence_length=1,
80+
token_generator=token_generator,
81+
),
82+
GenerateNewTokenOperator(tokenizer=self.tokenizer, force_max_tokens=True),
83+
CompileGeneratedTokens(),
84+
CompileGenerations(),
85+
JoinOutput(tokenizer=self.tokenizer),
86+
ProcessOutputs(tokenizer=self.tokenizer),
87+
]
88+
router = LinearRouter(end_route=len(ops))
89+
scheduler = [OperatorScheduler()]
90+
super().__init__(
91+
ops=ops,
92+
router=router,
93+
schedulers=scheduler,
94+
)
95+
96+
def run(self, *args, **kwargs):
97+
# we need to set the fixed_sequences_length flag to True
98+
# for the non-kv cache pipeline
99+
kwargs.update(dict(fixed_sequences_length=True))
100+
return super().run(*args, **kwargs)
101+
102+
def verify_no_kv_cache_present(self) -> bool:
103+
"""
104+
Verifies that the ONNX model does not have
105+
KV cache inputs/outputs present.
106+
:return: True if compatible, False otherwise
107+
"""
108+
is_kv_cache_present = any(default_cached_outputs(self.model_path))
109+
if is_kv_cache_present:
110+
raise ValueError(
111+
f"The model: {self.model_path} has KV cache inputs/outputs present. "
112+
"Please use the TextGenerationPipeline instead."
113+
)
114+
return not is_kv_cache_present
115+
116+
42117
class TextGenerationPipeline(Pipeline):
43118
def __init__(
44119
self,
@@ -65,14 +140,14 @@ def __init__(
65140
if internal_kv_cache and engine_kwargs.get("engine_type") == "onnxruntime":
66141
internal_kv_cache = False
67142

68-
single_engine_operator = NLEngineOperator(
143+
single_engine_operator = NlEngineOperator(
69144
sequence_length=sequence_length,
70145
internal_kv_cache=internal_kv_cache,
71146
input_ids_length=1,
72147
**engine_kwargs,
73148
)
74149

75-
multi_engine_operator = NLEngineOperator(
150+
multi_engine_operator = NlEngineOperator(
76151
sequence_length=sequence_length,
77152
internal_kv_cache=internal_kv_cache,
78153
input_ids_length=prompt_sequence_length,
@@ -194,5 +269,3 @@ def expand_inputs(self, items, batch_size):
194269

195270
def condense_inputs(self, *args, **kwargs):
196271
return args[0], kwargs
197-
198-

src/deepsparse/v2/text_generation/prep_for_generation.py

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def run(
9191
"token_generator": token_generator,
9292
}
9393
output = {
94+
"logits": prompt_logits,
9495
"tokens": token_generator.tokens,
9596
"kv_cache": kv_cache,
9697
"in_generation": True,

tests/deepsparse/v2/unit/text_generation/conftest.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from deepsparse.v2 import InferenceState, PipelineState
2626
from deepsparse.v2.text_generation import (
2727
GenerationDefaults,
28-
NLEngineOperator,
28+
NlEngineOperator,
2929
TokenGeneratorOperator,
3030
)
3131

@@ -61,7 +61,7 @@ def single_token_engine_no_internal_cache(text_generation_attributes, model_attr
6161
seq_length, _ = text_generation_attributes
6262
_, model_path = model_attributes
6363

64-
nl_engine_operator = NLEngineOperator(
64+
nl_engine_operator = NlEngineOperator(
6565
sequence_length=seq_length, input_ids_length=1, model_path=model_path
6666
)
6767
return nl_engine_operator

tests/testdata/gsm8k-v0-greedy_until

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3b4bf5c7d1504339aa06bcb50212dba05ff761d30de6faf720fdc818b16316ad

tests/testdata/gsm8k-v0-res.json

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"results": {"gsm8k": {"acc": 0.0, "acc_stderr": 0.0}}, "versions": {"gsm8k": 0}}

0 commit comments

Comments
 (0)