diff --git a/src/deepsparse/transformers/engines/nl_decoder_engine.py b/src/deepsparse/transformers/engines/nl_decoder_engine.py index 223d4f0a60..5ec5001c2f 100644 --- a/src/deepsparse/transformers/engines/nl_decoder_engine.py +++ b/src/deepsparse/transformers/engines/nl_decoder_engine.py @@ -11,21 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import numpy -from transformers import AutoTokenizer from deepsparse.engine import Context from deepsparse.pipeline import DEEPSPARSE_ENGINE, create_engine from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache -from deepsparse.transformers.utils.helpers import generate_session_id from deepsparse.transformers.utils.timings import TextGenerationTimings from deepsparse.utils import TimerManager from deepsparse.utils.onnx import ( CACHE_INPUT_PREFIX, - CACHE_OUTPUT_PREFIX, overwrite_onnx_model_inputs_for_kv_cache_models, ) @@ -37,9 +35,9 @@ class NLDecoderEngine: """ - The NLDecoderEngine (NaturalLanguageDecoderEngine) handles the + The NLDecoderEngine (Natural Language Decoder Engine) handles the logic around the inference for Natural Language pipeline, - including batching and kv cache logic. + including batching and kv cache manipulation logic. :param onnx_file_path: The path to the onnx model file :param engine_type: The type of engine to use for the inference @@ -47,10 +45,6 @@ class NLDecoderEngine: :param sequence_length: The maximum sequence length to run the engine for :param input_ids_length: The maximum input ids length to run the engine for :param engine_context: The context to run the engine in - :param sampling_temperature: The temperature to use for sampling - :param deterministic: Whether to use deterministic sampling - :param tokenizer: The tokenizer to used for engine inputs - :param engine_context: The context to run the engine in :param internal_kv_cache: Whether to use the deepsparse kv cache in the DecoderKVCache object or not """ @@ -62,9 +56,6 @@ def __init__( engine_args: Dict[str, Any], sequence_length: int, input_ids_length: int, - tokenizer: AutoTokenizer, - sampling_temperature: float = 1.0, - deterministic: bool = True, engine_context: Optional[Context] = None, internal_kv_cache=False, timer_manager: TimerManager = None, @@ -82,9 +73,7 @@ def __init__( input_ids_length=input_ids_length, ) - kv_cache_enabled = False if any(output_indices_to_be_cached): - kv_cache_enabled = True self.kv_cache_data_type = kv_cache_data_type if internal_kv_cache and engine_type == DEEPSPARSE_ENGINE: # inform the engine, that are using the kv cache @@ -98,30 +87,10 @@ def __init__( ) self.timer_manager = timer_manager or TimerManager() self.sequence_length = sequence_length - self.sampling_temperature = sampling_temperature - self.deterministic = deterministic self.input_ids_length = input_ids_length self.cache_length = sequence_length - input_ids_length - self.kv_cache_enabled = kv_cache_enabled - self.kv_cache = DecoderKVCache(internal_kv_cache) if kv_cache_enabled else None - self._freeze_first_position = self._should_freeze_first_position(tokenizer) - self._session_id = generate_session_id() self._engine_type = engine_type - @property - def session_id(self) -> str: - """ - :return: The session id for the kv_cache if enabled - """ - return self._session_id - - @session_id.setter - def session_id(self, session_id: str): - """ - :param session_id: The session id to set for the kv_cache - """ - self._session_id = session_id - @property def onnx_input_names_no_cache(self) -> List[str]: """ @@ -135,25 +104,43 @@ def onnx_input_names_no_cache(self) -> List[str]: ] @property - def num_non_blank_cache_entries(self) -> int: + def onnx_input_names_cached(self) -> List[str]: + """ + :return: The cached input names for the onnx model + """ + return [ + name + for name in self.engine.input_names + if name.startswith(CACHE_INPUT_PREFIX) + ] + + @property + def cache_shape(self) -> Tuple[int, int, int, int]: """ - :return A number of non-blank entries in the - kv cache + :return: The shape of the kv cache inputs + for the onnx model. The shape is + (batch_size, num_heads, sequence_length, hidden_size) """ - return self.kv_cache.num_non_blank_entries + cache_engine_input_index = next( + i + for i, name in enumerate(self.engine.input_names) + if CACHE_INPUT_PREFIX in name + ) + return self.engine.input_shapes[cache_engine_input_index] @property - def internal_cache_active(self) -> bool: + def output_names(self) -> List[str]: """ - :return: Whether the internal kv cache is active + :return: The output names for the onnx model """ - return self.kv_cache_enabled and self.kv_cache.engine_internal_cache is not None + return self.engine.output_names - def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray]: + def run( + self, inputs: List[numpy.ndarray], val_inp: bool, kv_cache: DecoderKVCache + ) -> List[numpy.ndarray]: """ Run the engine with the given inputs. - - If the self.internal_cache_active=True, the internal + If the kv_cache.engine_internal_cache=True, the internal deepsparse kv cache management is enabled. In this case the LIB.kv_cache class object will be passed to the engine call as well. In this scenario also the inputs will not be @@ -163,10 +150,11 @@ def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray] :param inputs: The inputs to run the engine with :param val_inp: Whether the input is for validation or not + :param kv_cache: The kv cache object to use for the inference + :return: The output of the engine """ - - if self.internal_cache_active: + if bool(kv_cache.engine_internal_cache): # conventionally, before dispatching # inputs to the engine, we validate them # if val_inp=True. However, in this case @@ -174,7 +162,7 @@ def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray] # (batch_size=0) to the engine. Therefore, # we skip the validation return self.engine._eng_net.execute_list_out( - inputs, self.kv_cache.engine_internal_cache + inputs, kv_cache.engine_internal_cache ) # run the engine without the LIB.kv_cache object return self.engine.run(inputs, val_inp) @@ -182,6 +170,7 @@ def run(self, inputs: List[numpy.ndarray], val_inp: bool) -> List[numpy.ndarray] def __call__( self, inp: List[numpy.ndarray], + kv_cache: Optional[DecoderKVCache] = None, val_inp: bool = True, ) -> numpy.ndarray: """ @@ -190,23 +179,28 @@ def __call__( :param inp: The input to run the engine with. We expect a list of numpy arrays that contain the input ids, attention mask, and position ids (optionally) + :param kv_cache: The DecoderKVCache object that contains + the kv cache state :param val_inp: Whether the input is for validation or not + :return: The generated token and corresponding logits """ timer = self.timer_manager.current - if self.kv_cache: + if kv_cache: # if model has kv cache enabled, we need # to add the kv cache state to the input - inp = self.add_kv_cache_to_input(inp) + inp = self.add_kv_cache_to_input(inp, kv_cache) with timer.time(f"EXECUTE_ENGINE_SEQ_LEN_{self.sequence_length}"): - out = self.run(inp, val_inp) + out = self.run(inp, val_inp, kv_cache) - if self.kv_cache: + if kv_cache: with timer.time(TextGenerationTimings.KV_CACHE_UPDATE): logits, *kv_cache_state = out self.update_kv_cache( - kv_cache_state=kv_cache_state, input_ids_len=self.input_ids_length + kv_cache_state=kv_cache_state, + input_ids_len=self.input_ids_length, + kv_cache=kv_cache, ) else: logits = out[0] @@ -219,36 +213,11 @@ def __str__(self): def __repr__(self): return str(self) - def transfer_cache_state(self, cache: DecoderKVCache): + def add_kv_cache_to_input( + self, inp: List[numpy.ndarray], kv_cache: DecoderKVCache + ) -> List[numpy.ndarray]: """ - Transfers the kv cache state and the number of tokens processed - information from another NLDecoderEngine. Call this method when - you want to transfer the kv cache state from one engine to another. - - This method will also automatically set the kv cache capacity to - the appropriate value for the new engine. - - :param cache: The `DecoderKVCache` object to transfer to the engine - from - """ - cache.set_capacity(self.cache_length) - self.kv_cache = cache - - def reset_kv_cache(self): - """ - Resets the kv cache state. - """ - kv_cache_state = self._initialize_kv_cache_state(self.cache_length) - self.kv_cache.setup( - session_id=self._session_id, - state=kv_cache_state, - num_processed_tokens=0, - freeze_first_position=self._freeze_first_position, - ) - - def add_kv_cache_to_input(self, inp: List[numpy.ndarray]) -> List[numpy.ndarray]: - """ - Takes the input and adds the past kv cache state to it. + Takes the input and adds the kv cache state to it. If the internal kv cache is enabled, the kv cache state will always be an empty array. This is just to make sure @@ -262,17 +231,11 @@ def add_kv_cache_to_input(self, inp: List[numpy.ndarray]) -> List[numpy.ndarray] :param inp: The input to the model + :param kv_cache: The kv cache object + :return The input with the kv cache state added to it """ - if self.internal_cache_active: - kv_cache_state = self._initialize_kv_cache_state( - self.cache_length, empty=True - ) - else: - kv_cache_state = self.kv_cache.cached_inputs - if kv_cache_state is None: - self.reset_kv_cache() - kv_cache_state = self.kv_cache.cached_inputs + kv_cache_state = copy.copy(kv_cache.cached_inputs) for idx, input_name in enumerate(self.onnx_input_names_no_cache): kv_cache_state[input_name] = inp[idx] @@ -284,75 +247,29 @@ def update_kv_cache( self, kv_cache_state: List[numpy.ndarray], input_ids_len: int, + kv_cache: DecoderKVCache, ): """ - Updates the state of the kv cache + Updates the kv cache using the new kv cache state. If the internal kv cache is enabled, we refrain from updating the kv cache state as it is being tracked internally inside the engine. We only update the number of tokens processed. - :param kv_cache_state: The state of the kv cache storage + :param kv_cache_state: The new state of the kv cache storage :param input_ids_len: The length of input_ids + :param kv_cache: The kv cache object to update """ - if self.internal_cache_active: - self.kv_cache.total_num_processed_tokens += input_ids_len + if bool(kv_cache.engine_internal_cache): + kv_cache.total_num_processed_tokens += input_ids_len return - cache_onnx_names = [ - name - for name in self.engine.input_names - if name.startswith(CACHE_INPUT_PREFIX) - ] kv_cache_state = { - name: array for name, array in zip(cache_onnx_names, kv_cache_state) + name: array + for name, array in zip(self.onnx_input_names_cached, kv_cache_state) } - self.kv_cache.update( + kv_cache.update( state=kv_cache_state, input_ids_len=input_ids_len, ) - - def _initialize_kv_cache_state( - self, length: int, empty: bool = False - ) -> Dict[str, numpy.ndarray]: - # initialize empty kv cache of size - # (batch_size, num_attention_heads, length, hidden_dims) - # if empty is True, we initialize empty kv_cache - # and set the batch_size to 0 - - cache_engine_input_index = next( - i - for i, name in enumerate(self.engine.input_names) - if CACHE_INPUT_PREFIX in name - ) - batch_size, num_attention_heads, _, hidden_dims = self.engine.input_shapes[ - cache_engine_input_index - ] - - empty_kv_cache_tensor = numpy.zeros( - ( - batch_size if not empty else 0, - num_attention_heads, - length, - hidden_dims, - ), - dtype=self.kv_cache_data_type, - ) - - cache_keys = [ - output_name.replace(CACHE_OUTPUT_PREFIX, CACHE_INPUT_PREFIX) - for output_name in self.engine.output_names - if output_name.startswith(CACHE_OUTPUT_PREFIX) - ] - return {key: empty_kv_cache_tensor for key in cache_keys} - - @staticmethod - def _should_freeze_first_position(tokenizer) -> bool: - # use tokenizer to find out whether we should freeze the first position - # (True if tokenizer has a prefix for a BOS token) - if tokenizer is None: - return False - if hasattr(tokenizer, "add_bos_token"): - return True - return False diff --git a/src/deepsparse/transformers/pipelines/__init__.py b/src/deepsparse/transformers/pipelines/__init__.py index 3e2e88381d..fdc32b17ee 100644 --- a/src/deepsparse/transformers/pipelines/__init__.py +++ b/src/deepsparse/transformers/pipelines/__init__.py @@ -19,5 +19,6 @@ from .question_answering import * from .text_classification import * from .token_classification import * +from .text_generation import * from .zero_shot_text_classification import * from .embedding_extraction import * diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 37fcbc249b..20e0d546ad 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -39,9 +39,12 @@ from deepsparse.pipeline import DEEPSPARSE_ENGINE from deepsparse.transformers.engines import NLDecoderEngine from deepsparse.transformers.pipelines import TransformersPipeline +from deepsparse.transformers.utils import DecoderKVCache from deepsparse.transformers.utils.helpers import ( create_causal_mask, + initialize_kv_cache_state, pad_to_fixed_length, + prepends_bos_token, repeat_inputs, ) from deepsparse.transformers.utils.timings import TextGenerationTimings @@ -93,13 +96,7 @@ class Config: "Note: This flag is only applicable when return_logits " "is `True`.", ) - session_id: Optional[str] = Field( - default=None, - description="A user may set a string identifier " - "for the kv cache session. If None, " - "and the model is using kv cache, it " - "will be set to a random uuid.", - ) + fixed_sequences_length: bool = Field( default=False, description="A flag that indicates whether to modify " @@ -183,12 +180,10 @@ class TextGenerationOutput(BaseModel): "prompt provided. If streamng is enabled, the next generated token is returned." "Otherwise, the full generated sequence is returned." ) - session_id: Optional[str] = Field( - default=None, description="A string identifier for the kv cache session." - ) class Config: arbitrary_types_allowed = True + extra = "allow" @Pipeline.register( @@ -274,6 +269,9 @@ def __init__( self.engine, self.multitoken_engine = self.initialize_engines() + # auxiliary flag for devs to enable debug mode for the pipeline + self._debug = False + def initialize_engines( self, ) -> Tuple[Optional[NLDecoderEngine], Optional[NLDecoderEngine]]: @@ -349,11 +347,8 @@ def initialize_engines( engine_type=self.engine_type, engine_args=self.engine_args, engine_context=self.context, - sampling_temperature=self.sampling_temperature, - deterministic=self.deterministic, sequence_length=self.sequence_length, input_ids_length=input_ids_length, - tokenizer=self.tokenizer, internal_kv_cache=self.internal_kv_cache, timer_manager=self.timer_manager, ) @@ -364,11 +359,8 @@ def initialize_engines( engine_type=self.engine_type, engine_args=self.engine_args, engine_context=self.context, - sampling_temperature=self.sampling_temperature, - deterministic=self.deterministic, sequence_length=self.sequence_length, input_ids_length=1, - tokenizer=self.tokenizer, internal_kv_cache=self.internal_kv_cache, timer_manager=self.timer_manager, ) @@ -479,11 +471,6 @@ def process_inputs(self, inputs: TextGenerationInput) -> List[numpy.ndarray]: ) engine_input = self.tokens_to_engine_input(input_tokens, onnx_input_names) - if inputs.session_id is not None: - # if session_id is provided, we need to set it in engines - self.engine.session_id = inputs.session_id - self.multitoken_engine.session_id = inputs.session_id - context = dict( prompts=original_inputs, num_generated_predictions=inputs.num_generated_predictions, @@ -510,7 +497,7 @@ def process_engine_outputs( :param engine_outputs: the outputs from the engine :return: the output schema for the pipeline """ - generated_tokens, generated_logits, finished_reason = engine_outputs + generated_tokens, generated_logits, finished_reason, *debug = engine_outputs finished_reason = [f[0] for f in finished_reason] sequences = self.tokenizer.batch_decode( @@ -560,13 +547,26 @@ def _create_generated_text_output( ] generations = grouped_generations - return TextGenerationOutput( + outputs = dict( created=datetime.datetime.now(), prompts=prompts, generations=generations ) + if debug: + kv_cache_state, total_num_processed_tokens = debug + debug_params = dict( + kv_cache_state=kv_cache_state, + total_num_processed_tokens=total_num_processed_tokens, + ) + outputs.update(debug_params) + + return TextGenerationOutput(**outputs) + def engine_forward( self, engine_inputs: List[numpy.ndarray], context: Dict - ) -> Tuple[numpy.ndarray, numpy.ndarray, List[FinishReason]]: + ) -> Union[ + Tuple[numpy.ndarray, numpy.ndarray, List[FinishReason]], + Tuple[numpy.ndarray, numpy.ndarray, List[FinishReason], DecoderKVCache], + ]: """ Run the forward pass on the engine. @@ -599,7 +599,7 @@ def engine_forward( else: # run the prompt through with timer.time(TextGenerationTimings.PROMPT_PREFILL): - prompt_logits = self.prompt_inference(engine_inputs) + prompt_logits, session = self.prompt_inference(engine_inputs) tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist() token_generator = TokenGenerator( @@ -632,7 +632,7 @@ def engine_forward( while len(generated_tokens) < max_tokens: with timer.time(TextGenerationTimings.TOKEN_GENERATION_SINGLE): logits = self.autoregressive_inference( - tokens=token_generator.tokens + tokens=token_generator.tokens, kv_cache=session ) token = token_generator.generate(logits=logits[0, -1, :]) generated_tokens.append(token) @@ -670,16 +670,21 @@ def engine_forward( if streamer is not None: streamer.end() - return ( + returns = ( numpy.array([generated_tokens]), numpy.concatenate(generated_logits, axis=1), finished_reason, ) + if self._debug is True: + return *returns, session + + return returns + def prompt_inference( self, engine_inputs: List[numpy.ndarray], - ) -> Tuple[List[int], List[numpy.ndarray]]: + ) -> Tuple[List[numpy.ndarray], DecoderKVCache]: """ An inference run that processes the prompt through the model to generate the new token and logits @@ -687,9 +692,9 @@ def prompt_inference( :param engine_inputs: the prompt (context) represented by a list of numpy inputs to the engine :return: A tuple of: - - The list of prompt tokens plus the new, generated token - The logits generated from the prompt (with dimensions ['batch_size', 'num_tokens', 'vocab_size']) + - The kv cache session for this inference run """ # get tokens by attention mask tokens = engine_inputs[0][engine_inputs[1].nonzero()].tolist() @@ -697,18 +702,17 @@ def prompt_inference( prompt_logits = [] num_tokens_processed = 0 + session = self.get_kv_cache_decoder() + if len(tokens) > self.prompt_sequence_length and self.enable_multitoken_prefill: - self.multitoken_engine.reset_kv_cache() - for engine_inputs in self.engine_inputs_for_prefill(tokens): - new_logits = self.multitoken_engine(engine_inputs) + for engine_inputs in self.engine_inputs_for_prefill( + tokens, kv_cache=session + ): + new_logits = self.multitoken_engine(engine_inputs, kv_cache=session) num_tokens_processed += self.prompt_sequence_length prompt_logits.append(new_logits) - if num_tokens_processed: - # transfer the cache state from the multi-token engine to the main engine - self.engine.transfer_cache_state(cache=self.multitoken_engine.kv_cache) - else: - self.engine.reset_kv_cache() + session.set_capacity(self.sequence_length - 1) # prompt size is small, run autoregressive inference to populate kv cache run_tokens = [] if num_tokens_processed == 0 else tokens[:num_tokens_processed] @@ -718,15 +722,16 @@ def prompt_inference( with self.timer_manager.current.time( TextGenerationTimings.PROMPT_PREFILL_SINGLE ): - new_logits = self.autoregressive_inference(run_tokens) + new_logits = self.autoregressive_inference(run_tokens, session) prompt_logits.append(new_logits) - return prompt_logits + return prompt_logits, session def autoregressive_inference( self, tokens: List[int], + kv_cache: DecoderKVCache, ) -> Tuple[int, numpy.ndarray]: """ An inference run that processes the last token to generate @@ -737,14 +742,16 @@ def autoregressive_inference( (with dimensions ['batch_size', 'num_tokens', 'vocab_size']) """ + num_total_processed_tokens = kv_cache.total_num_processed_tokens new_token = tokens[-1] # padding is added to left, so attention mask is 1s from the # right up to the number of total tokens (prompt + generated) attention_mask = numpy.zeros((1, self.sequence_length), dtype=numpy.int64) - num_tokens_processed = min(len(tokens), self.sequence_length) # cap by seq len - attention_mask[:, -num_tokens_processed:] = 1 - positions = numpy.array([[len(tokens)]], dtype=numpy.int64) - positions -= 1 + num_attention_entries_to_unmask = min( + num_total_processed_tokens + 1, self.sequence_length + ) # cap by seq len + attention_mask[:, -num_attention_entries_to_unmask:] = 1 + positions = numpy.array([[num_total_processed_tokens]], dtype=numpy.int64) input_ids = numpy.array([[new_token]]) causal_mask = create_causal_mask(input_ids, attention_mask) @@ -758,13 +765,12 @@ def autoregressive_inference( engine_inputs = [ engine_inputs_map[name] for name in self.engine.onnx_input_names_no_cache ] - - generated_logits = self.engine(engine_inputs) + generated_logits = self.engine(engine_inputs, kv_cache) return generated_logits def engine_inputs_for_prefill( - self, tokens: List[int] + self, tokens: List[int], kv_cache: DecoderKVCache ) -> Generator[List[numpy.ndarray], None, None]: """ Takes a list of tokens and creates a generator @@ -784,11 +790,14 @@ def engine_inputs_for_prefill( the sum of: a) the number of tokens in the batch (self.prompt_sequence_length) - b) the number of non-blank cache entries - (num_non_blank_cache_entries) + b) the number of processed tokens so far + (num_total_processed_tokens) so that the attention_mask properly attends to the current input tokens, as well as the previous cache entries. + Note: the aformentioned sum must be capped + by the sequence length, as the maximum shape of the + attention mask is [batch_size, sequence_length]. - positions: derived directly from the input_ids @@ -797,7 +806,6 @@ def engine_inputs_for_prefill( :param tokens: the list of tokens to process :return: a generator of engine inputs """ - num_batches = len(tokens) // self.prompt_sequence_length token_batches = [ @@ -808,8 +816,8 @@ def engine_inputs_for_prefill( ] for idx, token_batch in enumerate(token_batches): + num_total_processed_tokens = kv_cache.total_num_processed_tokens engine_inputs = [] - num_cached_entries = self.multitoken_engine.num_non_blank_cache_entries for name in self.multitoken_engine.onnx_input_names_no_cache: if name == "input_ids": engine_input = numpy.array([token_batch]) @@ -819,25 +827,21 @@ def engine_inputs_for_prefill( engine_input = numpy.zeros( (1, self.sequence_length), dtype=numpy.int64 ) - # fill it out with 1s (from the right), so that the number - # of unmasked entries is equal to the sum of: - engine_input[ - :, - -( - # ...the number of current input tokens... - self.prompt_sequence_length - # ...and the number of the previous cache entries - + num_cached_entries - ) :, - ] = 1 + # calculate the number of entries in attention mask + # that should be set to 1 + num_attention_entries_to_unmask = min( + num_total_processed_tokens + self.prompt_sequence_length, + self.sequence_length, + ) + engine_input[:, -num_attention_entries_to_unmask:] = 1 elif name == "causal_mask": # delay creation of the causal mask continue elif name == "positions": engine_input = ( numpy.arange( - num_cached_entries, - num_cached_entries + self.prompt_sequence_length, + num_total_processed_tokens, + num_total_processed_tokens + self.prompt_sequence_length, ) .reshape(1, -1) .astype(numpy.int64) @@ -877,7 +881,7 @@ def join_engine_outputs( :param orig_batch_size: The original batch size :return: A list of joined outputs """ - tokens, logits, finish_reason = zip(*batch_outputs) + tokens, logits, finish_reason, *debug = zip(*batch_outputs) if self.cache_support_enabled: # if the model has kv cache, we need to account for # the fact that the predicted outputs may have @@ -909,6 +913,15 @@ def join_engine_outputs( tokens = numpy.concatenate(tokens, axis=0) logits = numpy.concatenate(logits, axis=0) + if debug: + sessions = debug[0] + kv_cache_state = numpy.stack(session.cached_inputs for session in sessions) + num_processed_tokens = numpy.stack( + session.total_num_processed_tokens for session in sessions + ) + + return [tokens, logits, finish_reason, kv_cache_state, num_processed_tokens] + return [tokens, logits, finish_reason] @staticmethod @@ -933,6 +946,29 @@ def causal_mask_input_present(model_path: str) -> bool: return is_causal_mask_input + def get_kv_cache_decoder(self) -> DecoderKVCache: + """ + Initialize the kv cache decoder for the inference + + :return: the initialized kv cache decoder + """ + engine = self.multitoken_engine or self.engine + + kv_cache_state = initialize_kv_cache_state( + cache_shape=engine.cache_shape, + kv_cache_data_type=engine.kv_cache_data_type, + output_names=engine.output_names, + length=self.sequence_length - self.prompt_sequence_length, + empty=bool(self.internal_kv_cache), + ) + + kv_cache = DecoderKVCache(self.internal_kv_cache) + kv_cache.setup( + state=kv_cache_state, + freeze_first_position=prepends_bos_token(self.tokenizer), + ) + return kv_cache + def _stop_token_generated( self, token, stop_tokens: Union[None, str, Sequence[str]] ) -> bool: diff --git a/src/deepsparse/transformers/utils/decoder_kv_cache.py b/src/deepsparse/transformers/utils/decoder_kv_cache.py index bc182f4931..08a03394e2 100644 --- a/src/deepsparse/transformers/utils/decoder_kv_cache.py +++ b/src/deepsparse/transformers/utils/decoder_kv_cache.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List +from typing import Any, Dict import numpy @@ -42,14 +42,12 @@ def __init__(self, internal_kv_cache: bool = False): # [batch_size, num_heads, sequence_length, hidden_size] self._sequence_len_axis = SEQUENCE_LENGTH_AXIS self._internal_kv_cache = internal_kv_cache - self._session_id = None self._freeze_first_position = None self._state = None self.engine_internal_cache = None def setup( self, - session_id: str, state: Dict[str, Any], num_processed_tokens: int = 0, freeze_first_position: bool = False, @@ -58,8 +56,6 @@ def setup( Setup the session - a level of abstraction that allocates the resources to store and manipulate the kv cache. - :param session_id: The session id to use for the current - session. Used to identify the kv cache state :param state: The state of the cache. This is a dictionary that maps the name of the cache array to the cache array. The cache tensor is a numpy array of shape @@ -74,7 +70,6 @@ def setup( that corresponds to the BOS token in the sequence. By default, is set to False. """ - self._session_id = session_id self._state = state self._freeze_first_position = freeze_first_position self.total_num_processed_tokens = num_processed_tokens @@ -90,6 +85,7 @@ def update( self, state: Dict[str, Any], input_ids_len: int, + increment_total_num_processed_tokens: int = True, ): """ Updating the session is identical with taking the kv cache @@ -103,8 +99,12 @@ def update( :param input_ids_len: The number of input ids in the current input batch: (batch_size, length). Corresponds to `input_ids.shape[1]` + :param increment_total_num_processed_tokens: If set to True, + the total number of processed tokens will be incremented + by the input_ids_len. """ - self.total_num_processed_tokens += input_ids_len + if increment_total_num_processed_tokens: + self.total_num_processed_tokens += input_ids_len input_state_capacity = state[list(state.keys())[0]].shape[ self._sequence_len_axis @@ -131,7 +131,7 @@ def update( ) if num_non_padded_entries_to_delete: cache_array = self.remove_non_padded_entries( - cache_array, num_entries_to_delete + cache_array, num_non_padded_entries_to_delete ) state[name] = numpy.ascontiguousarray(cache_array) @@ -167,7 +167,7 @@ def remove_non_padded_entries( new_cache_array = cache_array[ :, :, - bool(self._freeze_first_position) + num_non_padded_entries_to_delete :, + int(self._freeze_first_position) + num_non_padded_entries_to_delete :, :, ] if self._freeze_first_position: @@ -198,50 +198,32 @@ def set_capacity(self, capacity: int): state = self.cached_inputs if capacity_difference > 0: - raise NotImplementedError( - "The scenario when capacity" - "needs to be expanded is not yet" - "supported." + self.update( + state, + input_ids_len=capacity_difference, + increment_total_num_processed_tokens=False, ) elif capacity_difference < 0: - indices = [0] * abs(capacity_difference) - state = self._add_entries(state, indices=indices) - + state = self._expand_capacity( + state, num_additional_entries=abs(capacity_difference) + ) + self._state = state else: - return + pass - self._state = state + return - def _add_entries( - self, state: Dict[str, Any], indices: List[int], padding_value: int = 0 + def _expand_capacity( + self, state: Dict[str, Any], num_additional_entries: int, padding_value: int = 0 ) -> Dict[str, Any]: for key, value in state.items(): - # required to make sure that both - # quantized and non quantized caches - # are supported - state_dtype = value.dtype - # change padding_value dtype to match the state dtype - padding_value = numpy.array(padding_value, dtype=state_dtype) - - state[key] = numpy.insert( - value, indices, padding_value, axis=self._sequence_len_axis + zeros = numpy.zeros_like( + (value[:, :, :num_additional_entries, :]), dtype=value.dtype ) + state[key] = numpy.concatenate([zeros, value], axis=self._sequence_len_axis) return state - @property - def id(self): - if self._session_id is None: - raise ValueError("Attempted to access session_id before setting up session") - return self._session_id - - @property - def num_non_blank_entries(self): - """ - :return: the number of non-blank entries in the kv cache - """ - return min(self.capacity, self.total_num_processed_tokens) - @property def capacity(self) -> int: """ @@ -256,10 +238,6 @@ def capacity(self) -> int: self._sequence_len_axis ] - @id.setter - def id(self, session_id: str): - self._session_id = session_id - @property def cached_inputs(self): return self._state diff --git a/src/deepsparse/transformers/utils/helpers.py b/src/deepsparse/transformers/utils/helpers.py index b904570492..57d09e309e 100644 --- a/src/deepsparse/transformers/utils/helpers.py +++ b/src/deepsparse/transformers/utils/helpers.py @@ -13,9 +13,12 @@ # limitations under the License. import logging import uuid -from typing import List, Union +from typing import Dict, List, Optional, Tuple, Union import numpy +from transformers import AutoTokenizer + +from deepsparse.utils.onnx import CACHE_INPUT_PREFIX, CACHE_OUTPUT_PREFIX __all__ = [ @@ -23,11 +26,79 @@ "pad_to_fixed_length", "create_causal_mask", "repeat_inputs", + "initialize_kv_cache_state", + "prepends_bos_token", ] _LOGGER = logging.getLogger(__name__) +def prepends_bos_token(tokenizer: AutoTokenizer) -> bool: + """ + Check whether the tokenizer prepends a BOS token to the input sequence. + + :param tokenizer: tokenizer to check + :return: True if the tokenizer prepends a BOS token to the input sequence, + False otherwise + """ + if hasattr(tokenizer, "add_bos_token"): + return bool(tokenizer.add_bos_token) + return False + + +def initialize_kv_cache_state( + cache_shape: Tuple[int, int, int, int], + kv_cache_data_type: numpy.dtype, + output_names: List[str], + length: Optional[int] = None, + empty: bool = False, +) -> Dict[str, numpy.ndarray]: + """ + Initialize the kv cache state for the given set of arguments. + + :param cache_shape: shape of the kv cache tensor. Should be + (batch_size, num_attention_heads, length, hidden_dims) + :param kv_cache_data_type: data type of the kv cache tensor + :param output_names: list of output names from the engine + :param length: length of the input sequence. If None, the length + is taken from the cache_shape + :param empty: if True, initialize an empty kv cache tensor + with batch_size set to 0. Otherwise, initialize a kv cache + tensor with zeros + + :return: dictionary of kv cache tensors, where the keys are the + output names of the kv cache tensors and the values are the + kv cache tensors of shape + (batch_size, num_attention_heads, length, hidden_dims) + """ + batch_size, num_attention_heads, length_, hidden_dims = cache_shape + + # new kv cache tensor is either + # - non-empty tensor of zeros with shape + # (batch_size, num_attention_heads, length, hidden_dims), + # required for the external kv cache management + # or + # - empty tensor with shape + # (0, num_attention_heads, length, hidden_dims) + # required for the internal kv cache management + kv_cache_tensor = numpy.zeros( + ( + batch_size if not empty else 0, + num_attention_heads, + length if length is not None else length_, + hidden_dims, + ), + dtype=kv_cache_data_type, + ) + + cache_keys = [ + output_name.replace(CACHE_OUTPUT_PREFIX, CACHE_INPUT_PREFIX) + for output_name in output_names + if output_name.startswith(CACHE_OUTPUT_PREFIX) + ] + return {key: kv_cache_tensor for key in cache_keys} + + def generate_session_id() -> str: """ Generate uuid for session id. This is used to diff --git a/tests/deepsparse/transformers/engine/test_nl_decoder_engine.py b/tests/deepsparse/transformers/engine/test_nl_decoder_engine.py index 7d80aa6ada..ff5b67aa75 100644 --- a/tests/deepsparse/transformers/engine/test_nl_decoder_engine.py +++ b/tests/deepsparse/transformers/engine/test_nl_decoder_engine.py @@ -47,9 +47,10 @@ def test_add_kv_cache_to_input(): with patch.object(NLDecoderEngine, "__init__", lambda x, y, z: None): nl_decoder_engine = NLDecoderEngine(None, None) nl_decoder_engine.engine = DummyEngine() - nl_decoder_engine.kv_cache = DummyKVCacheDecoder() nl_decoder_engine.kv_cache_enabled = True - result = nl_decoder_engine.add_kv_cache_to_input(inp) + result = nl_decoder_engine.add_kv_cache_to_input( + inp, kv_cache=DummyKVCacheDecoder + ) for (x, y) in zip(result, expected_result): assert np.array_equal(x, y) diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index 9eeacfb4ce..d8c4fde2a1 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -18,7 +18,7 @@ import pytest from deepsparse import Pipeline -from deepsparse.transformers.utils.decoder_kv_cache import DecoderKVCache +from deepsparse.transformers.utils.helpers import prepends_bos_token from tests.deepsparse.transformers.pipelines.helpers import TorchGroundTruthSource @@ -153,7 +153,7 @@ def test_freeze_first_position(self, setup): # the kv cache is full _, uses_bos_token, _ = setup pipeline = self.get_pipeline() - assert pipeline.engine._freeze_first_position == uses_bos_token + assert prepends_bos_token(pipeline.tokenizer) == uses_bos_token def test_ort_single_token_prefill(self, setup): # Test the pipeline that uses ORT engine. The test covers the @@ -172,20 +172,18 @@ def test_ort_single_token_prefill(self, setup): model_path=self.model_stub, sequence_length=self.sequence_length, prompt_sequence_length=1, - force_max_tokens=True, engine_type="onnxruntime", ) + pipeline._debug = True output = pipeline( sequences=self.prompt, return_logits=True, include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline.engine.kv_cache - assert cache_session.total_num_processed_tokens < self.sequence_length + assert output.total_num_processed_tokens[0] < self.sequence_length self._test_output( output=output, - cache_session=cache_session, torch_ground_truth=torch_ground_truth, ) @@ -206,20 +204,19 @@ def test_ort_multi_token_prefill(self, setup): model_path=self.model_stub, sequence_length=self.sequence_length, prompt_sequence_length=self.prompt_sequence_length, - force_max_tokens=True, engine_type="onnxruntime", ) + pipeline._debug = True output = pipeline( sequences=self.prompt, return_logits=True, include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline.engine.kv_cache - assert cache_session.total_num_processed_tokens < self.sequence_length + + assert output.total_num_processed_tokens[0] < self.sequence_length self._test_output( output=output, - cache_session=cache_session, torch_ground_truth=torch_ground_truth, ) @@ -243,14 +240,15 @@ def test_ort_generation_after_kv_cache_has_been_filled(self, setup): force_max_tokens=True, engine_type="onnxruntime", ) + pipeline._debug = True output = pipeline( sequences=self.prompt, return_logits=True, include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline.engine.kv_cache - assert cache_session.total_num_processed_tokens > self.sequence_length_short, ( + + assert output.total_num_processed_tokens[0] > self.sequence_length_short, ( "for this scenario, the kv cache should be full: " "the total number of processed tokens should be " "greater than the sequence length" @@ -258,7 +256,6 @@ def test_ort_generation_after_kv_cache_has_been_filled(self, setup): self._test_output( output=output, - cache_session=cache_session, torch_ground_truth=torch_ground_truth, max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 ) @@ -276,20 +273,19 @@ def test_deepsparse_single_token_prefill(self, setup): model_path=self.model_stub, sequence_length=self.sequence_length, prompt_sequence_length=1, - force_max_tokens=True, internal_kv_cache=self.internal_kv_cache, ) + pipeline._debug = True output = pipeline( sequences=self.prompt, return_logits=True, include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline.engine.kv_cache - assert cache_session.total_num_processed_tokens < self.sequence_length + + assert output.total_num_processed_tokens[0] < self.sequence_length self._test_output( output=output, - cache_session=cache_session, torch_ground_truth=torch_ground_truth, run_cache_validation=not self.internal_kv_cache, ) @@ -307,20 +303,20 @@ def test_deepsparse_multi_token_prefill(self, setup): model_path=self.model_stub, sequence_length=self.sequence_length, prompt_sequence_length=self.prompt_sequence_length, - force_max_tokens=True, internal_kv_cache=self.internal_kv_cache, ) + pipeline._debug = True + output = pipeline( sequences=self.prompt, return_logits=True, include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline.engine.kv_cache - assert cache_session.total_num_processed_tokens < self.sequence_length + + assert output.total_num_processed_tokens[0] < self.sequence_length self._test_output( output=output, - cache_session=cache_session, torch_ground_truth=torch_ground_truth, run_cache_validation=not self.internal_kv_cache, ) @@ -338,17 +334,17 @@ def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): model_path=self.model_stub, sequence_length=self.sequence_length_short, prompt_sequence_length=self.prompt_sequence_length, - force_max_tokens=True, internal_kv_cache=self.internal_kv_cache, ) + pipeline._debug = True output = pipeline( sequences=self.prompt, return_logits=True, include_prompt_logits=True, max_tokens=self.num_tokens_generate, ) - cache_session = pipeline.engine.kv_cache - assert cache_session.total_num_processed_tokens > self.sequence_length_short, ( + + assert output.total_num_processed_tokens[0] > self.sequence_length_short, ( "for this scenario, the kv cache should be full: " "the total number of processed tokens should be " "greater than the sequence length" @@ -356,7 +352,6 @@ def test_deepsparse_generation_after_kv_cache_has_been_filled(self, setup): self._test_output( output=output, - cache_session=cache_session, torch_ground_truth=torch_ground_truth, run_cache_validation=not self.internal_kv_cache, max_logits_difference_threshold=self.logits_max_diff_kv_cache_has_been_filled, # noqa E501 @@ -428,13 +423,10 @@ def test_num_generated_predictions(self, setup): def _test_output( self, output: "TextGenerationOutput", # noqa F821 - cache_session: DecoderKVCache, torch_ground_truth: Tuple[numpy.ndarray, ...], max_logits_difference_threshold: Optional[float] = None, run_cache_validation: bool = True, ): - # extract numpy arrays from cached_inputs - kv_cache_array = list(cache_session.cached_inputs.values()) ( generated_logits, @@ -466,10 +458,13 @@ def _test_output( assert self.prompt + output.generations[0].text == generated_text if run_cache_validation: + # extract numpy arrays from cached_inputs + kv_cache_array = list(output.kv_cache_state[0].values()) + total_num_processed_tokens = output.total_num_processed_tokens[0] self._test_kv_cache_state( expected_cache=kv_cache_array, target_cache=torch_ground_truth[2], - total_num_processed_tokens=cache_session.total_num_processed_tokens, + total_num_processed_tokens=total_num_processed_tokens, ) @staticmethod diff --git a/tests/deepsparse/transformers/utils/test_decoder_kv_cache.py b/tests/deepsparse/transformers/utils/test_decoder_kv_cache.py index 8fcc40dd87..e93890548e 100644 --- a/tests/deepsparse/transformers/utils/test_decoder_kv_cache.py +++ b/tests/deepsparse/transformers/utils/test_decoder_kv_cache.py @@ -81,7 +81,6 @@ def setup( # initialize a session session.setup( - session_id="dummy_id", state=state, num_processed_tokens=num_processed_tokens, freeze_first_position=freeze_first_position, @@ -93,11 +92,10 @@ def test_session_attributes(self, setup): # check if the session attributes are set correctly state = session.cached_inputs - assert session.num_non_blank_entries == np.count_nonzero( + assert session.total_num_processed_tokens == np.count_nonzero( state["dummy_cache_name"].flatten() ) assert session.capacity == state["dummy_cache_name"].shape[2] - assert session.id == "dummy_id" def test_set_capacity(self, setup): session, _, _ = setup @@ -134,7 +132,19 @@ def _test_increase_capacity(session_): @staticmethod def _test_decrease_capacity(session_): - pass + session = copy.deepcopy(session_) + capacity = session.capacity + # decrease capacity by 3 + session.set_capacity(capacity - 3) + kv_cache_state = session.cached_inputs + # check if the capacity has been decreased by 3 + if session_._freeze_first_position: + bos_token = session_.cached_inputs["dummy_cache_name"][:, :, :1, :] + new_array = session_.cached_inputs["dummy_cache_name"][:, :, 4:, :] + target_array = np.concatenate([bos_token, new_array], axis=2) + else: + target_array = session_.cached_inputs["dummy_cache_name"][:, :, 3:, :] + assert np.array_equal(target_array, kv_cache_state["dummy_cache_name"]) @staticmethod def _test_constant_capacity(session_): diff --git a/tests/deepsparse/transformers/utils/test_helpers.py b/tests/deepsparse/transformers/utils/test_helpers.py index c96a004d6b..bc0fabb3d4 100644 --- a/tests/deepsparse/transformers/utils/test_helpers.py +++ b/tests/deepsparse/transformers/utils/test_helpers.py @@ -15,7 +15,10 @@ import numpy import pytest -from deepsparse.transformers.utils.helpers import create_causal_mask +from deepsparse.transformers.utils.helpers import ( + create_causal_mask, + initialize_kv_cache_state, +) @pytest.mark.parametrize( @@ -85,3 +88,44 @@ def test_create_causal_mask(input_ids, attention_mask, expected_causal_mask): causal_mask = create_causal_mask(input_ids, attention_mask) assert numpy.array_equal(causal_mask, expected_causal_mask[None, None, ...]) + + +@pytest.mark.parametrize( + "cache_shape, kv_cache_data_type, output_names, length, empty, expected_result", + [ + ( + (1, 2, 3, 4), + numpy.float32, + ["present.1", "present.2", "present.3"], + None, + False, + { + "past_key_values.1": numpy.zeros((1, 2, 3, 4)), + "past_key_values.2": numpy.zeros((1, 2, 3, 4)), + "past_key_values.3": numpy.zeros((1, 2, 3, 4)), + }, + ), + ( + (5, 6, 7, 8), + numpy.int8, + ["present.1", "present.2"], + 10, + True, + { + "past_key_values.1": numpy.zeros((0, 6, 10, 8), dtype=numpy.int8), + "past_key_values.2": numpy.zeros((0, 6, 10, 8), dtype=numpy.int8), + }, + ), + ], +) +def test_initialize_kv_cache_state( + cache_shape, kv_cache_data_type, output_names, length, empty, expected_result +): + # make sure that resulting Dict[str, numpy.ndarray] is the same + # as the expected_result + result = initialize_kv_cache_state( + cache_shape, kv_cache_data_type, output_names, length, empty + ) + assert result.keys() == expected_result.keys() + for key in result.keys(): + assert numpy.array_equal(result[key], expected_result[key])