40
40
Whether to display the token generation
41
41
speed or not [default:
42
42
no_show_tokens_per_sec]
43
- --history / --no_history Whether to include history during prompt
44
- generation or not [default: history]
43
+ --task TEXT The task to use for the pipeline. Choose any
44
+ of `chat`, `codegen`, `text-generation`
45
+ [default: chat]
45
46
--help Show this message and exit.
46
47
47
-
48
48
Installation: pip install deepsparse[transformers]
49
49
Examples:
50
50
61
61
62
62
4) Disable history
63
63
python chatbot.py models/llama/deployment \
64
- --no_history
64
+ --task text-generation
65
65
"""
66
66
import click
67
67
68
68
from deepsparse import Pipeline
69
+ from deepsparse .tasks import SupportedTasks
69
70
70
71
71
72
@click .command (
103
104
help = "Whether to display the token generation speed or not" ,
104
105
)
105
106
@click .option (
106
- "--history/--no_history" ,
107
- is_flag = True ,
108
- default = True ,
109
- help = "Whether to include history during prompt generation or not" ,
107
+ "--task" ,
108
+ default = "chat" ,
109
+ type = str ,
110
+ help = "The task to use for the pipeline. Choose any of "
111
+ "`chat`, `codegen`, `text-generation`" ,
110
112
)
111
113
def main (
112
114
model_path : str ,
113
115
sequence_length : int ,
114
116
sampling_temperature : float ,
115
117
prompt_sequence_length : int ,
116
118
show_tokens_per_sec : bool ,
117
- history : bool ,
119
+ task : str ,
118
120
):
119
121
"""
120
122
Command Line utility to interact with a text genration LLM in a chatbot style
@@ -123,21 +125,25 @@ def main(
123
125
124
126
python chatbot.py [OPTIONS] <MODEL_PATH>
125
127
"""
126
- # chat pipeline, automatically adds history
127
- task = "chat" if history else "text-generation"
128
-
128
+ session_ids = "chatbot_cli_session"
129
+
129
130
pipeline = Pipeline .create (
130
- task = task ,
131
+ task = task , # let pipeline determine if task is supported
131
132
model_path = model_path ,
132
133
sequence_length = sequence_length ,
133
134
sampling_temperature = sampling_temperature ,
134
135
prompt_sequence_length = prompt_sequence_length ,
135
136
)
136
-
137
+
137
138
# continue prompts until a keyboard interrupt
138
139
while True :
139
140
input_text = input ("User: " )
140
- response = pipeline (** {"sequences" : [input_text ]})
141
+ pipeline_inputs = {"prompt" : [input_text ]}
142
+
143
+ if SupportedTasks .is_chat (task ):
144
+ pipeline_inputs ["session_ids" ] = session_ids
145
+
146
+ response = pipeline (** pipeline_inputs )
141
147
print ("Bot: " , response .generations [0 ].text )
142
148
if show_tokens_per_sec :
143
149
times = pipeline .timer_manager .times
0 commit comments