diff --git a/src/deepsparse/transformers/pipelines/text_generation/nl_engine_operator.py b/src/deepsparse/transformers/pipelines/text_generation/nl_engine_operator.py index f78d4306d1..1f631573ae 100644 --- a/src/deepsparse/transformers/pipelines/text_generation/nl_engine_operator.py +++ b/src/deepsparse/transformers/pipelines/text_generation/nl_engine_operator.py @@ -208,7 +208,9 @@ def run(self, inp: NLEngineInputs, **kwargs) -> NLEngineOutputs: inputs = list(map(self._add_kv_cache_to_input, engine_input, kv_cache)) inputs = join_engine_outputs(inputs, len(inputs)) - if bool(kv_cache[0].engine_internal_cache): + internal_kv_cache_present = bool(kv_cache[0].engine_internal_cache) + + if internal_kv_cache_present: # conventionally, before dispatching # inputs to the engine, we validate them # if val_inp=True. However, in this case @@ -235,18 +237,21 @@ def run(self, inp: NLEngineInputs, **kwargs) -> NLEngineOutputs: ) # logits should be stacked along batch dim - # kv_cache_state should be a list where each dim 0 is batch_size + # kv_cache_state should be a list where each item has dim 0 as batch_size logits, *kv_cache_state = out - kv_cache_state, _ = split_engine_inputs(kv_cache_state, 1) - if len(kv_cache_state) > 0: + if not internal_kv_cache_present: + # split along batch sizes; will give a list of lists where number of lists + # is equal to batch_size + kv_cache_state, _ = split_engine_inputs(kv_cache_state, 1) for i in range(len(kv_cache)): + # pass in a list and kv_cache object per _update_kv_cache call self._update_kv_cache( kv_cache_state=kv_cache_state[i], kv_cache=kv_cache[i] ) else: - # internal kv cache case - self._update_kv_cache(kv_cache=kv_cache[0]) + for i in range(len(kv_cache)): + self._update_kv_cache(kv_cache=kv_cache[i]) output = { "engine_outputs": logits,