Skip to content

[Codegen][ORT][Static Seq Length] TextGenerationPipeline #946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
2 changes: 2 additions & 0 deletions src/deepsparse/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class SupportedTasks:
"token_classification",
"zero_shot_text_classification",
"transformers_embedding_extraction",
"text_generation",
],
)(
question_answering=AliasedTask("question_answering", ["qa"]),
Expand All @@ -93,6 +94,7 @@ class SupportedTasks:
transformers_embedding_extraction=AliasedTask(
"transformers_embedding_extraction", []
),
text_generation=AliasedTask("text_generation", ["codegen"]),
)

image_classification = namedtuple("image_classification", ["image_classification"])(
Expand Down
6 changes: 4 additions & 2 deletions src/deepsparse/transformers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,10 @@ def overwrite_transformer_onnx_model_inputs(
]
input_names = []
for external_input in external_inputs:
external_input.type.tensor_type.shape.dim[0].dim_value = batch_size
external_input.type.tensor_type.shape.dim[1].dim_value = max_length
# Commenting this out for now, as it is not needed for the ORT backend
# Will be crucial for DeepSparse backend
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not to be merged to main right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be here. Let me update the status of this PR.

# external_input.type.tensor_type.shape.dim[0].dim_value = batch_size
# external_input.type.tensor_type.shape.dim[1].dim_value = max_length
input_names.append(external_input.name)

# Save modified model
Expand Down
1 change: 1 addition & 0 deletions src/deepsparse/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .token_classification import *
from .zero_shot_text_classification import *
from .embedding_extraction import *
from .text_generation import *
14 changes: 9 additions & 5 deletions src/deepsparse/transformers/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def setup_onnx_file_path(self) -> str:
config_path, finetuning_task=self.task if hasattr(self, "task") else None
)
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, model_max_length=self.sequence_length
tokenizer_path,
model_max_length=self.sequence_length,
)
self.config_path = os.path.join(config_path, "config.json")
self.tokenizer_config_path = os.path.join(tokenizer_path, "tokenizer.json")
Expand All @@ -126,19 +127,22 @@ def setup_onnx_file_path(self) -> str:
return onnx_path

def tokens_to_engine_input(
self, tokens: Mapping[Any, numpy.ndarray]
self,
tokens: Mapping[Any, numpy.ndarray],
onnx_input_names: Optional[List[str]] = None,
) -> List[numpy.ndarray]:
"""
:param tokens: outputs of the pipeline tokenizer
:return: list of numpy arrays in expected order for model input
"""
if not all(name in tokens for name in self.onnx_input_names):
onnx_input_names = onnx_input_names or self.onnx_input_names
if not all(name in tokens for name in onnx_input_names):
raise ValueError(
f"pipeline expected arrays with names {self.onnx_input_names}, "
f"pipeline expected arrays with names {onnx_input_names}, "
f"received inputs: {list(tokens.keys())}"
)

return [tokens[name] for name in self.onnx_input_names]
return [tokens[name] for name in onnx_input_names]

@staticmethod
def should_bucket(*args, **kwargs) -> bool:
Expand Down
Loading