Skip to content

Commit 09eec2d

Browse files
committed
Update history to task
Add persistent session_id Delete init
1 parent 8773d27 commit 09eec2d

File tree

2 files changed

+22
-16
lines changed

2 files changed

+22
-16
lines changed

Diff for: examples/chatbot-llm/chatbot.py

+21-15
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@
4040
Whether to display the token generation
4141
speed or not [default:
4242
no_show_tokens_per_sec]
43-
--history / --no_history Whether to include history during prompt
44-
generation or not [default: history]
43+
--task TEXT The task to use for the pipeline. Choose any
44+
of `chat`, `codegen`, `text-generation`
45+
[default: chat]
4546
--help Show this message and exit.
4647
47-
4848
Installation: pip install deepsparse[transformers]
4949
Examples:
5050
@@ -61,11 +61,12 @@
6161
6262
4) Disable history
6363
python chatbot.py models/llama/deployment \
64-
--no_history
64+
--task text-generation
6565
"""
6666
import click
6767

6868
from deepsparse import Pipeline
69+
from deepsparse.tasks import SupportedTasks
6970

7071

7172
@click.command(
@@ -103,18 +104,19 @@
103104
help="Whether to display the token generation speed or not",
104105
)
105106
@click.option(
106-
"--history/--no_history",
107-
is_flag=True,
108-
default=True,
109-
help="Whether to include history during prompt generation or not",
107+
"--task",
108+
default="chat",
109+
type=str,
110+
help="The task to use for the pipeline. Choose any of "
111+
"`chat`, `codegen`, `text-generation`",
110112
)
111113
def main(
112114
model_path: str,
113115
sequence_length: int,
114116
sampling_temperature: float,
115117
prompt_sequence_length: int,
116118
show_tokens_per_sec: bool,
117-
history: bool,
119+
task: str,
118120
):
119121
"""
120122
Command Line utility to interact with a text genration LLM in a chatbot style
@@ -123,21 +125,25 @@ def main(
123125
124126
python chatbot.py [OPTIONS] <MODEL_PATH>
125127
"""
126-
# chat pipeline, automatically adds history
127-
task = "chat" if history else "text-generation"
128-
128+
session_ids = "chatbot_cli_session"
129+
129130
pipeline = Pipeline.create(
130-
task=task,
131+
task=task, # let pipeline determine if task is supported
131132
model_path=model_path,
132133
sequence_length=sequence_length,
133134
sampling_temperature=sampling_temperature,
134135
prompt_sequence_length=prompt_sequence_length,
135136
)
136-
137+
137138
# continue prompts until a keyboard interrupt
138139
while True:
139140
input_text = input("User: ")
140-
response = pipeline(**{"sequences": [input_text]})
141+
pipeline_inputs = {"prompt": [input_text]}
142+
143+
if SupportedTasks.is_chat(task):
144+
pipeline_inputs["session_ids"] = session_ids
145+
146+
response = pipeline(**pipeline_inputs)
141147
print("Bot: ", response.generations[0].text)
142148
if show_tokens_per_sec:
143149
times = pipeline.timer_manager.times

Diff for: src/deepsparse/tasks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def is_chat(cls, task: str) -> bool:
223223
:param task: the name of the task to check whether it is a chat task
224224
:return: True if it is a chat task, False otherwise
225225
"""
226-
return any([chat_task.matches(task) for chat_task in cls.chat])
226+
return any(chat_task.matches(task) for chat_task in cls.chat)
227227

228228
@classmethod
229229
def is_text_generation(cls, task: str) -> bool:

0 commit comments

Comments
 (0)