Skip to content

[Text Generation] Support for causal masks, internal KV cache, and initial testing framework #1172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 111 additions & 36 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Dict, Generator, List, Optional, Tuple, Type, Union

import numpy
import onnx
from pydantic import BaseModel, Field
from transformers import TextStreamer

Expand All @@ -31,6 +32,7 @@
create_causal_mask,
pad_to_fixed_length,
)
from deepsparse.utils.onnx import default_cached_outputs


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -164,62 +166,114 @@ def __init__(
super().__init__(
**kwargs, _delay_engine_initialize=True, _delay_overwriting_inputs=True
)
self.enable_multitoken_prefill = self.causal_mask_input_present(
model_path=self.onnx_file_path
)
self.cache_support_enabled = self.is_cache_support_enabled()

if self.engine_type == DEEPSPARSE_ENGINE:
if "WAND_OPT_FLAGS" not in os.environ:
os.environ["WAND_OPT_FLAGS"] = "default,~pyramids"

if not self.cache_support_enabled and self.max_generated_tokens > 1:
raise ValueError(
"The model used for inference does not support kv cache. It is "
"assumed that it maps from the token sequence to predicted logits."
"Set `max_generated_tokens` to 1 to support that scenario."
)

self.deterministic = deterministic
self.sampling_temperature = sampling_temperature
self.max_generated_tokens = max_generated_tokens
self.prompt_processing_sequence_length = prompt_processing_sequence_length
self.force_max_tokens = force_max_tokens
self.use_deepsparse_cache = use_deepsparse_cache

# override tokenizer to pad to left
self.tokenizer.padding_side = "left"
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token

self.engine = None

self.multitoken_engine = NLDecoderEngine(
onnx_file_path=self.onnx_file_path,
engine_type=self.engine_type,
engine_args=self.engine_args,
engine_context=self.context,
sampling_temperature=self.sampling_temperature,
deterministic=self.deterministic,
sequence_length=self.sequence_length,
input_ids_length=prompt_processing_sequence_length,
tokenizer=self.tokenizer,
use_deepsparse_cache=use_deepsparse_cache,
)
self.engine, self.multitoken_engine = self.initialize_engines()

def initialize_engines(
self,
) -> Tuple[Optional[NLDecoderEngine], Optional[NLDecoderEngine]]:
"""
Inititalizes a pair of engines for the pipeline.
The first engine (`engine`) is used for processing the tokens token-by-token
(in the autoregressive fashion).
The second engine (`multitoken_engine`) is used for processing the tokens
in a single pass (in the multitoken fashion).

There are several cases of how the engines are initialized:
- if the model does not support kv cache, then only the
`multitoken_engine` is initialized. The `engine` is set to None.
- if the model supports kv cache but does not support
multitoken prefill scenario (i.e. self.enable_multitoken_prefill = False),
then only the `engine` is initialized. The `multitoken_engine`
is set to None.

:return: a pair of engines (`engine`, `multitoken_engine`)
Note: that depending on the scenario one of the engines may be None
"""

if self.multitoken_engine.kv_cache_enabled:
# unless kv cache is enabled, we don't
# need to initialize the single token engine
self.engine = NLDecoderEngine(
engine, multitoken_engine = None, None

if self.cache_support_enabled:
# emit the appropriate user message depending whether we are
# instantiation the multitoken engine or not
if not self.enable_multitoken_prefill:
warnings.warn(
"This ONNX graph does not support processing the prompt in "
"with processing length > 1. Creation of an auxiliary engine for "
"processing the prompt at a larger processing length is disabled. "
"The prompt will be processed in with processing length 1."
)
else:
_LOGGER.info(
"Compiling an auxiliary engine to process a prompt with a "
"larger processing length. This improves performance, but "
"may result in additional memory consumption."
)

if (
self.cache_support_enabled and self.enable_multitoken_prefill
) or not self.cache_support_enabled:

multitoken_engine = NLDecoderEngine(
onnx_file_path=self.onnx_file_path,
engine_type=self.engine_type,
engine_args=self.engine_args,
engine_context=self.context,
sampling_temperature=self.sampling_temperature,
deterministic=self.deterministic,
sequence_length=self.sequence_length,
input_ids_length=1,
input_ids_length=self.prompt_processing_sequence_length,
tokenizer=self.tokenizer,
use_deepsparse_cache=use_deepsparse_cache,
use_deepsparse_cache=self.use_deepsparse_cache,
)
if (
not self.multitoken_engine.kv_cache_enabled
and self.max_generated_tokens > 1
):
raise ValueError(
"The model used for inference does not support kv cache. It is "
"assumed that it maps from the token sequence to predicted logits."
"Set `max_generated_tokens` to 1 to support that scenario."

if self.cache_support_enabled:

engine = NLDecoderEngine(
onnx_file_path=self.onnx_file_path,
engine_type=self.engine_type,
engine_args=self.engine_args,
engine_context=self.context,
sampling_temperature=self.sampling_temperature,
deterministic=self.deterministic,
sequence_length=self.sequence_length,
input_ids_length=1,
tokenizer=self.tokenizer,
use_deepsparse_cache=self.use_deepsparse_cache,
)

assert (engine is not None) or (
multitoken_engine is not None
), "At least one of the engines must be initialized for the pipeline!"
return engine, multitoken_engine

@staticmethod
def route_input_to_bucket(
*args, input_schema: BaseModel, pipelines: List[Pipeline], **kwargs
Expand Down Expand Up @@ -293,7 +347,11 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]:
input_tokens = dict(
**input_tokens, positions=positions, causal_mask=causal_mask
)
onnx_input_names = self.multitoken_engine.onnx_input_names_no_cache
onnx_input_names = (
self.multitoken_engine.onnx_input_names_no_cache
if self.multitoken_engine
else self.engine.onnx_input_names_no_cache
)
engine_input = self.tokens_to_engine_input(input_tokens, onnx_input_names)

if inputs.session_id is not None:
Expand Down Expand Up @@ -342,7 +400,7 @@ def engine_forward(
with self.timer_manager.new_timer_context(total_inference=False) as timer:
streamer = context.get("streamer")

if not self.multitoken_engine.kv_cache_enabled:
if not self.cache_support_enabled:
tokens, prompt_logits = self.multitoken_engine(engine_inputs)
return numpy.array([tokens]), prompt_logits

Expand Down Expand Up @@ -414,7 +472,10 @@ def prompt_inference(
# to refrain from resetting if session id is being passed
self._reset_engines_cache()

if len(tokens) > self.prompt_processing_sequence_length:
if (
len(tokens) > self.prompt_processing_sequence_length
and self.enable_multitoken_prefill
):
for engine_inputs in self.engine_inputs_for_prefill(tokens):
new_token, new_logits = self.multitoken_engine(engine_inputs)
num_tokens_processed += self.prompt_processing_sequence_length
Expand Down Expand Up @@ -580,14 +641,13 @@ def engine_inputs_for_prefill(

yield engine_inputs

@property
def has_cache(self) -> bool:
def is_cache_support_enabled(self) -> bool:
"""
Returns whether the ran model has kv cache or not

:return: True if the model has kv cache, False otherwise
"""
return self.multitoken_engine.kv_cache_enabled
return any(default_cached_outputs(self.onnx_file_path))

def join_engine_outputs(
self, batch_outputs: List[List[numpy.ndarray]], orig_batch_size: int
Expand All @@ -603,7 +663,7 @@ def join_engine_outputs(
:return: A list of joined outputs
"""
tokens, logits = zip(*batch_outputs)
if self.has_cache:
if self.cache_support_enabled:
# if the model has kv cache, we need to account for
# the fact that the predicted outputs may have
# different lengths
Expand Down Expand Up @@ -636,6 +696,21 @@ def join_engine_outputs(

return [tokens, logits]

@staticmethod
def causal_mask_input_present(model_path: str) -> bool:
"""
Check whether the model has causal_mask input present or not.
In general, the absence of causal_mask input means that the model
cannot be run through the multitoken engine.

:param model_path: path to the model
:return: True if causal_mask input is present, False otherwise
"""
return any(
inp.name == "causal_mask"
for inp in onnx.load(model_path, load_external_data=False).graph.input
)

def _reset_engines_cache(self):
self.engine.reset_kv_cache()
self.multitoken_engine.reset_kv_cache()
self.multitoken_engine.reset_kv_cache() if self.multitoken_engine else None
24 changes: 12 additions & 12 deletions src/deepsparse/utils/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,13 @@
"truncate_onnx_model",
"truncate_onnx_embedding_model",
"default_cached_outputs",
"CACHE_OUTPUT_PREFIX",
]

_LOGGER = logging.getLogger(__name__)

CACHE_OUTPUT_PREFIX = "present"


@contextlib.contextmanager
def save_onnx_to_temp_files(model: onnx.ModelProto, with_external_data=False) -> str:
Expand Down Expand Up @@ -477,18 +480,15 @@ def truncate_onnx_embedding_model(

def default_cached_outputs(model_path: str) -> List[bool]:
"""
:param model_path: Path to a model
:return A list of bools that indicates caching of all outputs except the first one.
"""
Get a list of bools that indicate which outputs should be cached.
The elements that are set to True correspond to cached outputs,
the rest are set to False.

outputs = get_output_names(model_path)
assert len(outputs) > 0

# Create a boolean list of every output of the
# model [logits, key0, value0, key1, value1, ..., keyN, valueN]
cached_outputs = [True for i in range(len(outputs))]
:param model_path: Path to the model.
:return A list of bools that indicate which outputs should be cached.
"""

# Assume first input is logits and logits ought not to be cached
cached_outputs[0] = False
output_names = get_output_names(model_path)
assert len(output_names) > 0

return cached_outputs
return [name.startswith(CACHE_OUTPUT_PREFIX) for name in output_names]