Skip to content

Commit c3b313c

Browse files
authored
Productionize Chat demo (#1235)
* Add chatbot example rebased from master * Update history to task Add persistent session_id Delete init * Delete init.py Update fire condition * Move to src Add cli callable
1 parent 17758e9 commit c3b313c

File tree

3 files changed

+164
-1
lines changed

3 files changed

+164
-1
lines changed

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def _setup_entry_points() -> Dict:
298298
"console_scripts": [
299299
f"deepsparse.transformers.run_inference={data_api_entrypoint}",
300300
f"deepsparse.transformers.eval_downstream={eval_downstream}",
301+
"deepsparse.infer=deepsparse.transformers.infer:main",
301302
"deepsparse.debug_analysis=deepsparse.debug_analysis:main",
302303
"deepsparse.analyze=deepsparse.analyze:main",
303304
"deepsparse.check_hardware=deepsparse.cpu:print_hardware_capability",

src/deepsparse/tasks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def is_chat(cls, task: str) -> bool:
230230
:param task: the name of the task to check whether it is a chat task
231231
:return: True if it is a chat task, False otherwise
232232
"""
233-
return any([chat_task.matches(task) for chat_task in cls.chat])
233+
return any(chat_task.matches(task) for chat_task in cls.chat)
234234

235235
@classmethod
236236
def is_text_generation(cls, task: str) -> bool:

src/deepsparse/transformers/infer.py

+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Usage: deepsparse.infer [OPTIONS] MODEL_PATH
17+
18+
Command Line utility to interact with a text genration LLM in a chatbot
19+
style
20+
21+
Example usage:
22+
23+
deepsparse.infer [OPTIONS] <MODEL_PATH>
24+
25+
Options:
26+
--sequence_length INTEGER Sequence length to compile model and
27+
tokenizer for.This controls the maximum
28+
context length of the pipeline. [default:
29+
512]
30+
--sampling_temperature FLOAT The temperature to use when samplingfrom the
31+
probability distribution computed from the
32+
logits.Higher values will result in more
33+
random samples. Shouldbe greater than 0.0.
34+
[default: 1.0]
35+
--prompt_sequence_length INTEGER
36+
Processed prompt in chunks of this length.
37+
This is to maximize the inference speed
38+
[default: 64]
39+
--show_tokens_per_sec / --no_show_tokens_per_sec
40+
Whether to display the token generation
41+
speed or not [default:
42+
no_show_tokens_per_sec]
43+
--task TEXT The task to use for the pipeline. Choose any
44+
of `chat`, `codegen`, `text-generation`
45+
[default: chat]
46+
--help Show this message and exit.
47+
48+
Installation: pip install deepsparse[transformers]
49+
Examples:
50+
51+
1) Use a local deployment directory
52+
deepsparse.infer models/llama/deployment
53+
54+
2) Use a SparseZoo stub
55+
deepsparse.infer \
56+
zoo:nlg/text_generation/codegen_mono-350m/pytorch/huggingface/bigpython_bigquery_thepile/base-none # noqa: E501
57+
58+
3) Display token generation speed
59+
deepsparse.infer models/llama/deployment \
60+
--show_tokens_per_sec
61+
62+
4) Disable history
63+
deepsparse.infer models/llama/deployment \
64+
--task text-generation
65+
"""
66+
import click
67+
68+
from deepsparse import Pipeline
69+
from deepsparse.tasks import SupportedTasks
70+
71+
72+
@click.command(
73+
context_settings=dict(
74+
token_normalize_func=lambda x: x.replace("-", "_"), show_default=True
75+
)
76+
)
77+
@click.argument("model_path", type=str)
78+
@click.option(
79+
"--sequence_length",
80+
type=int,
81+
default=512,
82+
help="Sequence length to compile model and tokenizer for."
83+
"This controls the maximum context length of the pipeline.",
84+
)
85+
@click.option(
86+
"--sampling_temperature",
87+
type=float,
88+
default=1.0,
89+
help="The temperature to use when sampling"
90+
"from the probability distribution computed from the logits."
91+
"Higher values will result in more random samples. Should"
92+
"be greater than 0.0.",
93+
)
94+
@click.option(
95+
"--prompt_sequence_length",
96+
type=int,
97+
default=64,
98+
help="Processed prompt in chunks of this length. "
99+
"This is to maximize the inference speed",
100+
)
101+
@click.option(
102+
"--show_tokens_per_sec/--no_show_tokens_per_sec",
103+
default=False,
104+
help="Whether to display the token generation speed or not",
105+
)
106+
@click.option(
107+
"--task",
108+
default="chat",
109+
type=str,
110+
help="The task to use for the pipeline. Choose any of "
111+
"`chat`, `codegen`, `text-generation`",
112+
)
113+
def main(
114+
model_path: str,
115+
sequence_length: int,
116+
sampling_temperature: float,
117+
prompt_sequence_length: int,
118+
show_tokens_per_sec: bool,
119+
task: str,
120+
):
121+
"""
122+
Command Line utility to interact with a text genration LLM in a chatbot style
123+
124+
Example usage:
125+
126+
deepsparse.infer [OPTIONS] <MODEL_PATH>
127+
"""
128+
session_ids = "chatbot_cli_session"
129+
130+
pipeline = Pipeline.create(
131+
task=task, # let pipeline determine if task is supported
132+
model_path=model_path,
133+
sequence_length=sequence_length,
134+
sampling_temperature=sampling_temperature,
135+
prompt_sequence_length=prompt_sequence_length,
136+
)
137+
138+
# continue prompts until a keyboard interrupt
139+
while True:
140+
input_text = input("User: ")
141+
pipeline_inputs = {"prompt": [input_text]}
142+
143+
if SupportedTasks.is_chat(task):
144+
pipeline_inputs["session_ids"] = session_ids
145+
146+
response = pipeline(**pipeline_inputs)
147+
print("Bot: ", response.generations[0].text)
148+
if show_tokens_per_sec:
149+
times = pipeline.timer_manager.times
150+
prefill_speed = (
151+
1.0 * prompt_sequence_length / times["engine_prompt_prefill_single"]
152+
)
153+
generation_speed = 1.0 / times["engine_token_generation_single"]
154+
print(
155+
f"[prefill: {prefill_speed:.2f} tokens/sec]",
156+
f"[decode: {generation_speed:.2f} tokens/sec]",
157+
sep="\n",
158+
)
159+
160+
161+
if __name__ == "__main__":
162+
main()

0 commit comments

Comments
 (0)