diff --git a/src/deepsparse/transformers/pipelines/text_generation.py b/src/deepsparse/transformers/pipelines/text_generation.py index 813a7fa700..d5b08728d9 100644 --- a/src/deepsparse/transformers/pipelines/text_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation.py @@ -433,6 +433,7 @@ def autoregressive_inference( :return: The new, generated token and the logits for the new token (with dimensions ['batch_size', 'num_tokens', 'vocab_size']) """ + new_token = tokens[-1] # padding is added to left, so attention mask is 1s from the # right up to the number of total tokens (prompt + generated) @@ -444,7 +445,17 @@ def autoregressive_inference( positions -= 1 input_ids = numpy.array([[new_token]]) causal_mask = create_causal_mask(input_ids, attention_mask) - engine_inputs = [input_ids, attention_mask, positions, causal_mask] + + # filter out the inputs that are not needed by the engine + engine_inputs_map = dict( + input_ids=input_ids, + attention_mask=attention_mask, + causal_mask=causal_mask, + positions=positions, + ) + engine_inputs = [ + engine_inputs_map[name] for name in self.engine.onnx_input_names_no_cache + ] generated_token, generated_logits = self.engine(engine_inputs)