Skip to content

Commit a49ab47

Browse files
authored
[TextGeneration] max token refactor (#1217)
* Update TextGeneration Inputs and Constructor - Remove max_generated_tokens from the constructor and add it to the TextGenerationInput Schema - Add num_generated_predictions to the TextGenerationInput which if > 1, repeats the input sequence and turns off deterministic prediction. If a sequence is already provided multiple times, the sequence is not repeated. * update description * quality * rebase * add comment explaining num_generated_predictions * update and add tests * rebase * move to helpers * remove extra check for multiple prompts * group generated outputs if num_generated_predictions > 1 * update docstring * facepalm: updated input, not output * update tests
1 parent fc418b3 commit a49ab47

File tree

3 files changed

+94
-25
lines changed

3 files changed

+94
-25
lines changed

Diff for: src/deepsparse/transformers/pipelines/text_generation.py

+52-22
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from deepsparse.transformers.utils.helpers import (
4141
create_causal_mask,
4242
pad_to_fixed_length,
43+
repeat_inputs,
4344
)
4445
from deepsparse.transformers.utils.timings import TextGenerationTimings
4546
from deepsparse.utils.onnx import default_cached_outputs
@@ -57,6 +58,18 @@ class Config:
5758
sequences: Union[str, List[str]] = Field(
5859
description="The input sequences to generate the text from.",
5960
)
61+
num_generated_predictions: int = Field(
62+
default=1,
63+
description="The number of text generations to create from a single prompt. If "
64+
"the same sequence is given as an input multiple times, the number of generated"
65+
"the number of generated predictins is equivalent to the number of times the "
66+
"the sequence is repeated.",
67+
)
68+
max_tokens: int = Field(
69+
default=1024,
70+
description="Maximum number of tokens to generate per output sequence. If no "
71+
"value is provided, will default to 1024.",
72+
)
6073
return_logits: bool = Field(
6174
default=False,
6275
description="A flag that indicates whether to return "
@@ -110,7 +123,7 @@ class Config:
110123

111124

112125
class TextGenerationOutput(BaseModel):
113-
sequences: Union[str, List[str]] = Field(
126+
sequences: Union[str, List[str], List[List[str]]] = Field(
114127
description="The generated text sequences.",
115128
)
116129
logits: Optional[Any] = Field( # numpy array, set to Any for FastAPI compatibility
@@ -143,11 +156,6 @@ class TextGenerationPipeline(TransformersPipeline):
143156
from the probability distribution computed from the logits.
144157
Higher values will result in more random samples. Should
145158
be greater than 0.0.
146-
:param max_generated_tokens: the maximum number of tokens to generate
147-
given the input sequence. If None, the model will generate
148-
tokens until the end of the sequence is reached.
149-
Otherwise, it will generate up to the maximum number of tokens or end of
150-
sequence is reached.
151159
:param sequence_length: sequence length to compile model and tokenizer for.
152160
This controls the maximum context length of the pipeline. Default is 512
153161
:param prompt_sequence_length: For large prompts, the prompt is
@@ -164,7 +172,6 @@ def __init__(
164172
self,
165173
deterministic: bool = True,
166174
sampling_temperature: float = 1.0,
167-
max_generated_tokens: Optional[int] = 1024,
168175
prompt_sequence_length: int = 64,
169176
sequence_length: int = 512,
170177
force_max_tokens: bool = False,
@@ -203,16 +210,8 @@ def __init__(
203210
if "WAND_OPT_FLAGS" not in os.environ:
204211
os.environ["WAND_OPT_FLAGS"] = "default,~pyramids"
205212

206-
if not self.cache_support_enabled and max_generated_tokens > 1:
207-
raise ValueError(
208-
"The model used for inference does not support kv cache. It is "
209-
"assumed that it maps from the token sequence to predicted logits."
210-
"Set `max_generated_tokens` to 1 to support that scenario."
211-
)
212-
213213
self.deterministic = deterministic
214214
self.sampling_temperature = sampling_temperature
215-
self.max_generated_tokens = max_generated_tokens
216215
self.prompt_sequence_length = prompt_sequence_length
217216
self.force_max_tokens = force_max_tokens
218217
self.internal_kv_cache = internal_kv_cache
@@ -369,6 +368,26 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
369368
:param inputs: the input schema for the pipeline
370369
:return: the inputs for the engine
371370
"""
371+
if not self.cache_support_enabled and inputs.max_tokens > 1:
372+
raise ValueError(
373+
"The model used for inference does not support kv cache. It is "
374+
"assumed that it maps from the token sequence to predicted logits."
375+
"Set `max_tokens` to 1 to support that scenario."
376+
)
377+
378+
# If the num_generated_predictions > 1, repeat the prompt
379+
# num_generated_predictions times. Also, update the engine so that deterministic
380+
# is set to False.
381+
if inputs.num_generated_predictions > 1:
382+
if isinstance(inputs.sequences, str):
383+
inputs.sequences = [inputs.sequences]
384+
inputs.sequences = repeat_inputs(
385+
inputs.sequences, inputs.num_generated_predictions
386+
)
387+
if self.engine:
388+
self.engine.deterministic = False
389+
if self.multitoken_engine:
390+
self.multitoken_engine.deterministic = False
372391

373392
if inputs.fixed_sequences_length or not self.cache_support_enabled:
374393
# to enforce a fixed sequence length, we need to
@@ -414,14 +433,16 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
414433
self.engine.session_id = inputs.session_id
415434
self.multitoken_engine.session_id = inputs.session_id
416435

417-
postprocessing_kwargs = dict(
436+
context = dict(
437+
num_generated_predictions=inputs.num_generated_predictions,
418438
return_logits=inputs.return_logits,
419439
streamer=inputs.streamer,
420440
include_prompt_logits=inputs.include_prompt_logits,
421441
callback=inputs.callback,
422442
stop=inputs.stop,
443+
max_tokens=inputs.max_tokens,
423444
)
424-
return engine_input, postprocessing_kwargs
445+
return engine_input, context
425446

426447
def process_engine_outputs(
427448
self, engine_outputs: List[numpy.ndarray], **kwargs
@@ -436,6 +457,18 @@ def process_engine_outputs(
436457
sequences = self.tokenizer.batch_decode(
437458
generated_tokens, skip_special_tokens=True
438459
)
460+
num_preds = kwargs.get("num_generated_predictions", 1)
461+
# If the num_generated_predictions > 1, group the generated sequences and return
462+
# the sequences as a list of lists where each list consists of the generated
463+
# predictions for a given prompt, and all the lists are in the order matching
464+
# the order that the prompts were given as inputs.
465+
if num_preds > 1:
466+
grouped_seq = [
467+
sequences[n : n + num_preds]
468+
for n in range(0, len(sequences), num_preds)
469+
]
470+
sequences = grouped_seq
471+
439472
logits = generated_logits if kwargs.get("return_logits") else None
440473

441474
return TextGenerationOutput(sequences=sequences, logits=logits)
@@ -472,11 +505,8 @@ def engine_forward(
472505
streamer.put(numpy.array(tokens))
473506

474507
# create the generated output
475-
max_tokens = (
476-
self.max_generated_tokens
477-
if self.max_generated_tokens and self.max_generated_tokens > 0
478-
else 100 * self.sequence_length
479-
) # set safety for absolute max generation
508+
max_tokens = context.get("max_tokens", 0)
509+
max_tokens = max_tokens if max_tokens > 0 else (100 * self.sequence_length)
480510

481511
# last prompt token is the first generated token
482512
# add it to generated tokens, and the logits

Diff for: src/deepsparse/transformers/utils/helpers.py

+20
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"generate_session_id",
2323
"pad_to_fixed_length",
2424
"create_causal_mask",
25+
"repeat_inputs",
2526
]
2627

2728
_LOGGER = logging.getLogger(__name__)
@@ -36,6 +37,25 @@ def generate_session_id() -> str:
3637
return session_id
3738

3839

40+
def repeat_inputs(
41+
input_sequences: List[str], num_generated_predictions: int
42+
) -> List[str]:
43+
"""
44+
:param input_sequences: List of input sequences to repeat
45+
:param num_generated_predictions: number of times to repeat each sequence
46+
47+
:return: a list of input sequences, where sequences have been repeated
48+
num_generated_predictions times if the sequence appears in input_sequences just
49+
once. If the sequence appears multiple times in input_sequences, the
50+
num_generated_predictions for the sequence is ignored.
51+
"""
52+
repeated_seq = []
53+
54+
for seq in input_sequences:
55+
repeated_seq.extend(numpy.repeat([seq], num_generated_predictions))
56+
return repeated_seq
57+
58+
3959
def pad_to_fixed_length(
4060
array: numpy.ndarray, max_len: int, axis: int = 0, value: int = 0
4161
) -> numpy.ndarray:

Diff for: tests/deepsparse/transformers/pipelines/test_text_generation.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ def setup(self, model_stub, model_name, uses_bos_token, internal_kv_cache):
9393
model_path=model_stub,
9494
sequence_length=32,
9595
prompt_sequence_length=4,
96-
max_generated_tokens=self.max_generated_tokens,
9796
internal_kv_cache=self.internal_kv_cache,
9897
)
9998
short_prompt = "this"
@@ -126,7 +125,9 @@ def test_model_output_sequences(self, setup):
126125
# test model output against sources of truth
127126
pipeline, model_name, _, short_prompt, long_prompt = setup
128127

129-
output_sequences = pipeline(sequences=[short_prompt, long_prompt])
128+
output_sequences = pipeline(
129+
sequences=[short_prompt, long_prompt], max_tokens=self.max_generated_tokens
130+
)
130131

131132
# test against huggingface model
132133
output_hugging_face = self._get_output_huggingface(
@@ -135,6 +136,23 @@ def test_model_output_sequences(self, setup):
135136
assert short_prompt + output_sequences.sequences[0] == output_hugging_face[0]
136137
assert long_prompt + output_sequences.sequences[1] == output_hugging_face[1]
137138

139+
def test_num_generated_predictions(self, setup):
140+
pipeline = setup[0]
141+
short_prompt = setup[3]
142+
143+
output_sequences = pipeline(
144+
sequences=[short_prompt], num_generated_predictions=2
145+
)
146+
147+
assert len(output_sequences.sequences[0]) == 2
148+
149+
output_sequences = pipeline(
150+
sequences=[short_prompt, short_prompt], num_generated_predictions=2
151+
)
152+
assert len(output_sequences.sequences) == 2
153+
for sequences in output_sequences.sequences:
154+
assert len(sequences) == 2
155+
138156
def test_model_output_cache(self, setup):
139157
pipeline, model_name, _, short_prompt, long_prompt = setup
140158
if self.internal_kv_cache:
@@ -158,6 +176,7 @@ def dummy_callback(token):
158176
"sequences": "def fib(a, b, accumulator=0)",
159177
"callback": dummy_callback,
160178
"return_logits": True,
179+
"max_tokens": self.max_generated_tokens,
161180
}
162181

163182
outs = pipeline(**inputs)
@@ -167,7 +186,7 @@ def _test_cache_state(self, prompt, pipeline, model_name):
167186
# make sure that the cache state after running a prompt
168187
# is correct
169188

170-
pipeline(sequences=prompt)
189+
pipeline(sequences=prompt, max_tokens=self.max_generated_tokens)
171190
cache_state_dict = pipeline.engine.kv_cache.cached_inputs
172191
cache_state_list = [cache_state_dict[key] for key in cache_state_dict.keys()]
173192

0 commit comments

Comments
 (0)