diff --git a/setup.py b/setup.py index d61c1fa312..6e7ca1bb96 100644 --- a/setup.py +++ b/setup.py @@ -298,7 +298,7 @@ def _setup_entry_points() -> Dict: "console_scripts": [ f"deepsparse.transformers.run_inference={data_api_entrypoint}", f"deepsparse.transformers.eval_downstream={eval_downstream}", - "deepsparse.infer=deepsparse.transformers.infer:main", + "deepsparse.infer=deepsparse.transformers.inference.infer:main", "deepsparse.debug_analysis=deepsparse.debug_analysis:main", "deepsparse.analyze=deepsparse.analyze:main", "deepsparse.check_hardware=deepsparse.cpu:print_hardware_capability", diff --git a/src/deepsparse/transformers/inference/__init__.py b/src/deepsparse/transformers/inference/__init__.py new file mode 100644 index 0000000000..0c44f887a4 --- /dev/null +++ b/src/deepsparse/transformers/inference/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/deepsparse/transformers/infer.py b/src/deepsparse/transformers/inference/infer.py similarity index 66% rename from src/deepsparse/transformers/infer.py rename to src/deepsparse/transformers/inference/infer.py index e4f8ad26f3..460b8499c4 100644 --- a/src/deepsparse/transformers/infer.py +++ b/src/deepsparse/transformers/inference/infer.py @@ -63,10 +63,14 @@ deepsparse.infer models/llama/deployment \ --task text-generation """ + +from typing import Optional + import click from deepsparse import Pipeline from deepsparse.tasks import SupportedTasks +from deepsparse.transformers.inference.prompt_parser import PromptParser @click.command( @@ -75,6 +79,14 @@ ) ) @click.argument("model_path", type=str) +@click.option( + "--data", + type=str, + default=None, + help="Path to .txt, .csv, .json, or .jsonl file to load data from" + "If provided, runs inference over the entire dataset. If not provided " + "runs an interactive inference session in the console. Default None.", +) @click.option( "--sequence_length", type=int, @@ -112,6 +124,7 @@ ) def main( model_path: str, + data: Optional[str], sequence_length: int, sampling_temperature: float, prompt_sequence_length: int, @@ -128,34 +141,76 @@ def main( session_ids = "chatbot_cli_session" pipeline = Pipeline.create( - task=task, # let pipeline determine if task is supported + task=task, # let the pipeline determine if task is supported model_path=model_path, sequence_length=sequence_length, - sampling_temperature=sampling_temperature, prompt_sequence_length=prompt_sequence_length, ) - # continue prompts until a keyboard interrupt - while True: - input_text = input("User: ") - pipeline_inputs = {"prompt": [input_text]} - - if SupportedTasks.is_chat(task): - pipeline_inputs["session_ids"] = session_ids - - response = pipeline(**pipeline_inputs) - print("Bot: ", response.generations[0].text) - if show_tokens_per_sec: - times = pipeline.timer_manager.times - prefill_speed = ( - 1.0 * prompt_sequence_length / times["engine_prompt_prefill_single"] - ) - generation_speed = 1.0 / times["engine_token_generation_single"] - print( - f"[prefill: {prefill_speed:.2f} tokens/sec]", - f"[decode: {generation_speed:.2f} tokens/sec]", - sep="\n", + if data: + prompt_parser = PromptParser(data) + default_prompt_kwargs = { + "sequence_length": sequence_length, + "sampling_temperature": sampling_temperature, + "prompt_sequence_length": prompt_sequence_length, + "show_tokens_per_sec": show_tokens_per_sec, + } + + for prompt_kwargs in prompt_parser.parse_as_iterable(**default_prompt_kwargs): + _run_inference( + task=task, + pipeline=pipeline, + session_ids=session_ids, + **prompt_kwargs, ) + return + + # continue prompts until a keyboard interrupt + while data is None: # always True in interactive Mode + prompt = input(">>> ") + _run_inference( + pipeline, + sampling_temperature, + task, + session_ids, + show_tokens_per_sec, + prompt_sequence_length, + prompt, + ) + + +def _run_inference( + pipeline, + sampling_temperature, + task, + session_ids, + show_tokens_per_sec, + prompt_sequence_length, + prompt, + **kwargs, +): + pipeline_inputs = dict( + prompt=[prompt], + temperature=sampling_temperature, + **kwargs, + ) + if SupportedTasks.is_chat(task): + pipeline_inputs["session_ids"] = session_ids + + response = pipeline(**pipeline_inputs) + print("\n", response.generations[0].text) + + if show_tokens_per_sec: + times = pipeline.timer_manager.times + prefill_speed = ( + 1.0 * prompt_sequence_length / times["engine_prompt_prefill_single"] + ) + generation_speed = 1.0 / times["engine_token_generation_single"] + print( + f"[prefill: {prefill_speed:.2f} tokens/sec]", + f"[decode: {generation_speed:.2f} tokens/sec]", + sep="\n", + ) if __name__ == "__main__": diff --git a/src/deepsparse/transformers/inference/prompt_parser.py b/src/deepsparse/transformers/inference/prompt_parser.py new file mode 100644 index 0000000000..35c433b11f --- /dev/null +++ b/src/deepsparse/transformers/inference/prompt_parser.py @@ -0,0 +1,108 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import csv +import json +import os +from enum import Enum +from typing import Iterator + + +class InvalidPromptSourceDirectoryException(Exception): + pass + + +class UnableToParseExtentsonException(Exception): + pass + + +def parse_value_to_appropriate_type(value: str): + if value.isdigit(): + return int(value) + if "." in str(value) and all(part.isdigit() for part in value.split(".", 1)): + return float(value) + if value.lower() == "true": + return True + if value.lower() == "false": + return False + return value + + +class PromptParser: + class Extensions(Enum): + TEXT = ".txt" + CSV = ".csv" + JSON = ".json" + JSONL = ".jsonl" + + def __init__(self, filename: str): + self.extention: self.Extensions = self._validate_and_return_extention(filename) + self.filename: str = filename + + def parse_as_iterable(self, **kwargs) -> Iterator: + if self.extention == self.Extensions.TEXT: + return self._parse_text(**kwargs) + if self.extention == self.Extensions.CSV: + return self._parse_csv(**kwargs) + if self.extention == self.Extensions.JSON: + return self._parse_json_list(**kwargs) + if self.extention == self.Extensions.JSONL: + return self._parse_jsonl(**kwargs) + + raise UnableToParseExtentsonException( + f"Parser for {self.extention} does not exist" + ) + + def _parse_text(self, **kwargs): + with open(self.filename, "r") as file: + for line in file: + kwargs["prompt"] = line.strip() + yield kwargs + + def _parse_csv(self, **kwargs): + with open(self.filename, "r", newline="", encoding="utf-8-sig") as file: + reader = csv.DictReader(file) + for row in reader: + for key, value in row.items(): + kwargs.update({key: parse_value_to_appropriate_type(value)}) + yield kwargs + + def _parse_json_list(self, **kwargs): + with open(self.filename, "r") as file: + json_list = json.load(file) + for json_object in json_list: + kwargs.update(json_object) + yield kwargs + + def _parse_jsonl(self, **kwargs): + with open(self.filename, "r") as file: + for jsonl in file: + jsonl_object = json.loads(jsonl) + kwargs.update(jsonl_object) + yield kwargs + + def _validate_and_return_extention(self, filename: str): + if os.path.exists(filename): + + for extention in self.Extensions: + if filename.endswith(extention.value): + return extention + + raise InvalidPromptSourceDirectoryException( + f"{filename} is not compatible. Select file that has " + "extension from " + f"{[key.name for key in self.Extensions]}" + ) + raise FileNotFoundError