Skip to content

Commit 5b8f494

Browse files
sydnashAlvant
authored andcommitted
[Frontend][Feature] support tool calling for internlm/internlm2_5-7b-chat model (vllm-project#8405)
Signed-off-by: Alvant <[email protected]>
1 parent f7607ce commit 5b8f494

13 files changed

+533
-46
lines changed

docs/requirements-docs.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ torch
1212
py-cpuinfo
1313
transformers
1414
mistral_common >= 1.3.4
15-
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
15+
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
16+
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args

docs/source/serving/openai_compatible_server.md

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,9 @@ vLLM will use guided decoding to ensure the response matches the tool parameter
157157
To enable this feature, you should set the following flags:
158158
* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it
159159
deems appropriate.
160-
* `--tool-call-parser` -- select the tool parser to use - currently either `hermes`, `mistral` or `llama3_json`. Additional tool parsers
161-
will continue to be added in the future.
160+
* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral` or `llama3_json` or `internlm`. Additional tool parsers
161+
will continue to be added in the future, and also can register your own tool parsers in the `--tool-parser-plugin`.
162+
* `--tool-parser-plugin` -- **optional** tool parser plugin used to register user defined tool parsers into vllm, the registered tool parser name can be specified in `--tool-call-parser`.
162163
* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages
163164
that contain previously generated tool calls. Hermes, Mistral and Llama models have tool-compatible chat templates in their
164165
`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat
@@ -218,4 +219,73 @@ it works better with vLLM.
218219

219220
Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja`
220221

222+
#### Internlm Models
223+
Supported models:
224+
* `internlm/internlm2_5-7b-chat` (confirmed)
225+
* Additional internlm2.5 function-calling models are compatible as well
226+
227+
Known issues:
228+
* Although this implementation also supports Internlm2, the tool call results are not stable when testing with the `internlm/internlm2-chat-7b` model.
229+
230+
Recommended flags: `--tool-call-parser internlm --chat-template examples/tool_chat_template_internlm2_tool.jinja`
231+
232+
233+
### How to write a tool parser plugin
234+
235+
A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py.
236+
237+
Here is a summary of a plugin file:
238+
239+
```python
240+
241+
# import the required packages
242+
243+
# define a tool parser and register it to vllm
244+
# the name list in register_module can be used
245+
# in --tool-call-parser. you can define as many
246+
# tool parsers as you want here.
247+
@ToolParserManager.register_module(["example"])
248+
class ExampleToolParser(ToolParser):
249+
def __init__(self, tokenizer: AnyTokenizer):
250+
super().__init__(tokenizer)
251+
252+
# adjust request. e.g.: set skip special tokens
253+
# to False for tool call output.
254+
def adjust_request(
255+
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
256+
return request
257+
258+
# implement the tool call parse for stream call
259+
def extract_tool_calls_streaming(
260+
self,
261+
previous_text: str,
262+
current_text: str,
263+
delta_text: str,
264+
previous_token_ids: Sequence[int],
265+
current_token_ids: Sequence[int],
266+
delta_token_ids: Sequence[int],
267+
request: ChatCompletionRequest,
268+
) -> Union[DeltaMessage, None]:
269+
return delta
270+
271+
# implement the tool parse for non-stream call
272+
def extract_tool_calls(
273+
self,
274+
model_output: str,
275+
request: ChatCompletionRequest,
276+
) -> ExtractedToolCallInformation:
277+
return ExtractedToolCallInformation(tools_called=False,
278+
tool_calls=[],
279+
content=text)
280+
281+
282+
```
283+
284+
Then you can use this plugin in the command line like this.
285+
```
286+
--enable-auto-tool-choice \
287+
--tool-parser-plugin <absolute path of the plugin file>
288+
--tool-call-parser example \
289+
--chat-template <your chat template> \
290+
```
221291

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
{%- if messages[0]["role"] == "system" %}
2+
{%- set system_message = messages[0]["content"] %}
3+
{%- set loop_messages = messages[1:] %}
4+
{%- else %}
5+
{%- set loop_messages = messages %}
6+
{%- endif %}
7+
8+
{%- if not tools is defined %}
9+
{%- set tools = none %}
10+
{%- endif %}
11+
12+
{{- bos_token }}
13+
{%- if system_message is defined %}
14+
{{- "<|im_start|>system\n" + system_message + "<|im_end|>\n" }}
15+
{%- endif %}
16+
17+
{%- if tools is not none %}
18+
{{- "<|im_start|>system name=<|plugin|>\n[" }}
19+
{%- for tool in tools %}
20+
{{- tool.function|tojson }}
21+
{%- if not loop.last %}
22+
{{- ", " }}
23+
{%- else %}
24+
{{- "]" }}
25+
{%- endif %}
26+
{%- endfor %}
27+
{{- "<|im_end|>\n" }}
28+
{%- endif %}
29+
30+
{%- for message in loop_messages %}
31+
{%- if message["role"] == "user" %}
32+
{{- "<|im_start|>user\n" + message["content"] + "<|im_end|>\n"}}
33+
{%- elif message.tool_calls is defined and message.tool_calls is not none %}
34+
{%- set content = message["content"] if message["content"] else "" %}
35+
{{- "<|im_start|>assistant\n" + content }}
36+
{%- for tool_call in message.tool_calls %}
37+
{%- set function=tool_call.function %}
38+
{{- "<|action_start|><|plugin|>\n" }}
39+
{{- '{"name": "' + function.name + '", '}}
40+
{{- '"arguments": ' + function.arguments|tojson + '}' }}
41+
{{- "<|action_end|>" }}
42+
{%- endfor %}
43+
{{- "<|im_end|>\n" }}
44+
{%- elif message["role"] == "assistant" %}
45+
{{- "<|im_start|>assistant\n" + message["content"] + "<|im_end|>\n"}}
46+
{%- elif message["role"] == "tool_results" or message["role"] == "tool" or message["role"] == "function" %}
47+
{%- if message.content is defined and message.content.content is defined %}
48+
{%- set content = message.content.content %}
49+
{%- else %}
50+
{%- set content = message.content %}
51+
{%- endif %}
52+
{{- "<|im_start|>environment name=<|plugin|>\n" + content|string + "<|im_end|>\n" }}
53+
{%- else %}
54+
{{- raise_exception("Only user and assistant and tool_results and tool and function roles are supported, with the exception of an initial optional system message!") }}
55+
{%- endif %}
56+
{%- endfor %}
57+
58+
{%- if add_generation_prompt %}
59+
{{- '<|im_start|>assistant\n' }}
60+
{%- endif %}

tests/tool_use/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,18 @@ def ensure_system_prompt(messages: List[Dict[str, Any]],
8787
"call the tool. Otherwise, answer the user's query directly "
8888
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
8989
"to the user's question - just respond to it normally."
90+
},
91+
"internlm": {
92+
"model":
93+
"internlm/internlm2_5-7b-chat",
94+
"arguments": [
95+
"--tool-call-parser", "internlm", "--chat-template",
96+
str(VLLM_PATH /
97+
"examples/tool_chat_template_internlm2_tool.jinja"),
98+
"--trust_remote_code"
99+
],
100+
"supports_parallel":
101+
False,
90102
}
91103
}
92104

@@ -109,7 +121,7 @@ def ensure_system_prompt(messages: List[Dict[str, Any]],
109121
"type":
110122
"string",
111123
"description":
112-
"the two-letter abbreviation for the state "
124+
"must the two-letter abbreviation for the state "
113125
"that the city is in, e.g. 'CA' which would "
114126
"mean 'California'"
115127
},

vllm/entrypoints/openai/api_server.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from vllm.entrypoints.openai.serving_engine import BaseModelPath
5454
from vllm.entrypoints.openai.serving_tokenization import (
5555
OpenAIServingTokenization)
56+
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
5657
from vllm.logger import init_logger
5758
from vllm.usage.usage_lib import UsageContext
5859
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
@@ -526,6 +527,15 @@ async def run_server(args, **uvicorn_kwargs) -> None:
526527
logger.info("vLLM API server version %s", VLLM_VERSION)
527528
logger.info("args: %s", args)
528529

530+
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
531+
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
532+
533+
valide_tool_parses = ToolParserManager.tool_parsers.keys()
534+
if args.enable_auto_tool_choice \
535+
and args.tool_call_parser not in valide_tool_parses:
536+
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
537+
f"(chose from {{ {','.join(valide_tool_parses)} }})")
538+
529539
temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
530540
temp_socket.bind(("", args.port))
531541

vllm/entrypoints/openai/cli_args.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
1313
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
1414
PromptAdapterPath)
15+
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
1516
from vllm.utils import FlexibleArgumentParser
1617

1718

@@ -190,16 +191,27 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
190191
"Enable auto tool choice for supported models. Use --tool-call-parser"
191192
"to specify which parser to use")
192193

194+
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
193195
parser.add_argument(
194196
"--tool-call-parser",
195197
type=str,
196-
choices=["mistral", "hermes", "llama3_json"],
198+
metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
199+
"--tool-parser-plugin",
197200
default=None,
198201
help=
199202
"Select the tool call parser depending on the model that you're using."
200203
" This is used to parse the model-generated tool call into OpenAI API "
201204
"format. Required for --enable-auto-tool-choice.")
202205

206+
parser.add_argument(
207+
"--tool-parser-plugin",
208+
type=str,
209+
default="",
210+
help=
211+
"Special the tool parser plugin write to parse the model-generated tool"
212+
" into OpenAI API format, the name register in this plugin can be used "
213+
"in --tool-call-parser.")
214+
203215
parser = AsyncEngineArgs.add_cli_args(parser)
204216

205217
parser.add_argument('--max-log-len',

vllm/entrypoints/openai/serving_chat.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,7 @@
2929
OpenAIServing,
3030
PromptAdapterPath,
3131
TextTokensPrompt)
32-
from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser,
33-
Llama3JsonToolParser,
34-
MistralToolParser,
35-
ToolParser)
32+
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
3633
from vllm.inputs import TokensPrompt
3734
from vllm.logger import init_logger
3835
from vllm.outputs import CompletionOutput, RequestOutput
@@ -82,15 +79,13 @@ def __init__(self,
8279

8380
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
8481
if self.enable_auto_tools:
85-
if tool_parser == "mistral":
86-
self.tool_parser = MistralToolParser
87-
elif tool_parser == "hermes":
88-
self.tool_parser = Hermes2ProToolParser
89-
elif tool_parser == "llama3_json":
90-
self.tool_parser = Llama3JsonToolParser
91-
else:
82+
try:
83+
self.tool_parser = ToolParserManager.get_tool_parser(
84+
tool_parser)
85+
except Exception as e:
9286
raise TypeError("Error: --enable-auto-tool-choice requires "
93-
"--tool-call-parser")
87+
f"tool_parser:'{tool_parser}' which has not "
88+
"been registered") from e
9489

9590
async def create_chat_completion(
9691
self,
@@ -187,6 +182,10 @@ async def create_chat_completion(
187182
raw_request.state.request_metadata = request_metadata
188183

189184
try:
185+
if self.enable_auto_tools and self.tool_parser:
186+
request = self.tool_parser(tokenizer).adjust_request(
187+
request=request)
188+
190189
if isinstance(prompt, str):
191190
prompt_inputs = self._tokenize_prompt_input(
192191
request,
@@ -282,11 +281,11 @@ async def chat_completion_stream_generator(
282281
num_choices = 1 if request.n is None else request.n
283282
previous_num_tokens = [0] * num_choices
284283
finish_reason_sent = [False] * num_choices
285-
286284
num_prompt_tokens = 0
287285

288-
tool_parser: Optional[ToolParser] = self.tool_parser(
289-
tokenizer) if self.tool_parser else None
286+
tool_parsers: List[Optional[ToolParser]] = [
287+
self.tool_parser(tokenizer) if self.tool_parser else None
288+
] * num_choices
290289

291290
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
292291
tool_choice_function_name = request.tool_choice.function.name
@@ -324,7 +323,7 @@ async def chat_completion_stream_generator(
324323
# NOTE num_choices defaults to 1 so this usually executes
325324
# once per request
326325
for i in range(num_choices):
327-
326+
tool_parser = tool_parsers[i]
328327
choice_data = ChatCompletionResponseStreamChoice(
329328
index=i,
330329
delta=DeltaMessage(
@@ -399,6 +398,7 @@ async def chat_completion_stream_generator(
399398

400399
for output in res.outputs:
401400
i = output.index
401+
tool_parser = tool_parsers[i]
402402

403403
if finish_reason_sent[i]:
404404
continue
@@ -446,7 +446,8 @@ async def chat_completion_stream_generator(
446446
delta_text=delta_text,
447447
previous_token_ids=previous_token_ids,
448448
current_token_ids=current_token_ids,
449-
delta_token_ids=output.token_ids))
449+
delta_token_ids=output.token_ids,
450+
request=request))
450451

451452
# update the previous values for the next iteration
452453
previous_texts[i] = current_text
@@ -685,7 +686,8 @@ async def chat_completion_full_generator(
685686
and self.tool_parser:
686687

687688
tool_parser = self.tool_parser(tokenizer)
688-
tool_call_info = tool_parser.extract_tool_calls(output.text)
689+
tool_call_info = tool_parser.extract_tool_calls(
690+
output.text, request=request)
689691
tools_called = tool_call_info.tools_called
690692
if tool_call_info.tools_called:
691693
message = ChatMessage(role=role,
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from .abstract_tool_parser import ToolParser
1+
from .abstract_tool_parser import ToolParser, ToolParserManager
22
from .hermes_tool_parser import Hermes2ProToolParser
3+
from .internlm2_tool_parser import Internlm2ToolParser
34
from .llama_tool_parser import Llama3JsonToolParser
45
from .mistral_tool_parser import MistralToolParser
56

67
__all__ = [
7-
"ToolParser", "Hermes2ProToolParser", "MistralToolParser",
8-
"Llama3JsonToolParser"
8+
"ToolParser", "ToolParserManager", "Hermes2ProToolParser",
9+
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser"
910
]

0 commit comments

Comments
 (0)