Skip to content

[Frontend][Feature] support tool calling for internlm/internlm2_5-7b-chat model #8405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
87b6352
[add] add tools call for internlm2
sydnash Sep 12, 2024
5355659
Merge branch 'main' into add-internlm2-for-tool-use
sydnash Sep 12, 2024
68cd89d
[add] add some comments
sydnash Sep 12, 2024
d17f006
[add] add some comments
sydnash Sep 12, 2024
2d7d9d4
[fix] fix internlm2 tool chat template, fix the internlm2 tool call o…
sydnash Sep 13, 2024
12352e7
[add] add tool parser plugin doc
sydnash Sep 13, 2024
11bed0d
[add] add tool parser plugin doc
sydnash Sep 13, 2024
8a8b840
[fix] fix the stream tool call for internlm2
sydnash Sep 13, 2024
00c5da2
[fix] comment
sydnash Sep 13, 2024
882c764
[merge] resolve conflict
sydnash Sep 13, 2024
12b1035
[fix] use metavar to display the help info for --tool-call-parser, ad…
sydnash Sep 14, 2024
ed5b3fd
[add] got valid tool parsers from ToolParserManager
sydnash Sep 14, 2024
ea2c089
[fix] fix build for docs
sydnash Sep 14, 2024
36ad5d0
[fix] internlm's tool call out may arguments or parameters
sydnash Sep 15, 2024
cf981c0
[merge] resolve conflict
sydnash Sep 18, 2024
647db0d
refactor the tool parser to internlm, fix the test case of streamed_args
sydnash Sep 26, 2024
064ca1f
merge main
sydnash Sep 27, 2024
106909c
[fix] fix internlm parallel test, remove vllm/version.py
sydnash Sep 28, 2024
e242501
[format]
sydnash Sep 29, 2024
0a5ddf4
[format]
sydnash Sep 29, 2024
1db530d
[fix] fix the mistral tool call error. recover vllm/version.py and de…
sydnash Sep 29, 2024
dc94a22
[fix] change vocab property to get_vocab method in mistral_tool_parse…
sydnash Sep 29, 2024
3048233
Merge remote-tracking branch 'origin/main' into add-internlm2-for-too…
sydnash Sep 29, 2024
a2f938f
[fix] remove --tokenizer-mode mistral for mistral test. fix the syste…
sydnash Oct 3, 2024
4b619a2
[merge] merge from main
sydnash Oct 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ torch
py-cpuinfo
transformers
mistral_common >= 1.3.4
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
75 changes: 73 additions & 2 deletions docs/source/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,9 @@ vLLM will use guided decoding to ensure the response matches the tool parameter
To enable this feature, you should set the following flags:
* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it
deems appropriate.
* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral`. Additional tool parsers
will continue to be added in the future.
* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral` or `internlm`. Additional tool parsers
will continue to be added in the future, and also can register your own tool parsers in the `--tool-parser-plugin`.
* `--tool-parser-plugin` -- **optional** tool parser plugin used to register user defined tool parsers into vllm, the registered tool parser name can be specified in `--tool-call-parser`.
* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages
that contain previously generated tool calls. Hermes and Mistral models have tool-compatible chat templates in their
`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat
Expand Down Expand Up @@ -197,3 +198,73 @@ when tools are provided, that results in much better reliability when working wi


Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`

#### Internlm Models
Supported models:
* `internlm/internlm2_5-7b-chat` (confirmed)
* Additional internlm2.5 function-calling models are compatible as well

Known issues:
* Although this implementation also supports Internlm2, the tool call results are not stable when testing with the `internlm/internlm2-chat-7b` model.

Recommended flags: `--tool-call-parser internlm --chat-template examples/tool_chat_template_internlm2_tool.jinja`


### How to write a tool parser plugin

A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py.

Here is a summary of a plugin file:

```python

# import the required packages

# define a tool parser and register it to vllm
# the name list in register_module can be used
# in --tool-call-parser. you can define as many
# tool parsers as you want here.
@ToolParserManager.register_module(["example"])
class ExampleToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)

# adjust request. e.g.: set skip special tokens
# to False for tool call output.
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
return request

# implement the tool call parse for stream call
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
return delta

# implement the tool parse for non-stream call
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=text)


```

Then you can use this plugin in the command line like this.
```
--enable-auto-tool-choice \
--tool-parser-plugin <absolute path of the plugin file>
--tool-call-parser example \
--chat-template <your chat template> \
```
60 changes: 60 additions & 0 deletions examples/tool_chat_template_internlm2_tool.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
{%- if messages[0]["role"] == "system" %}
{%- set system_message = messages[0]["content"] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}

{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}

{{- bos_token }}
{%- if system_message is defined %}
{{- "<|im_start|>system\n" + system_message + "<|im_end|>\n" }}
{%- endif %}

{%- if tools is not none %}
{{- "<|im_start|>system name=<|plugin|>\n[" }}
{%- for tool in tools %}
{{- tool.function|tojson }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" }}
{%- endif %}
{%- endfor %}
{{- "<|im_end|>\n" }}
{%- endif %}

{%- for message in loop_messages %}
{%- if message["role"] == "user" %}
{{- "<|im_start|>user\n" + message["content"] + "<|im_end|>\n"}}
{%- elif message.tool_calls is defined and message.tool_calls is not none %}
{%- set content = message["content"] if message["content"] else "" %}
{{- "<|im_start|>assistant\n" + content }}
{%- for tool_call in message.tool_calls %}
{%- set function=tool_call.function %}
{{- "<|action_start|><|plugin|>\n" }}
{{- '{"name": "' + function.name + '", '}}
{{- '"arguments": ' + function.arguments|tojson + '}' }}
{{- "<|action_end|>" }}
{%- endfor %}
{{- "<|im_end|>\n" }}
{%- elif message["role"] == "assistant" %}
{{- "<|im_start|>assistant\n" + message["content"] + "<|im_end|>\n"}}
{%- elif message["role"] == "tool_results" or message["role"] == "tool" or message["role"] == "function" %}
{%- if message.content is defined and message.content.content is defined %}
{%- set content = message.content.content %}
{%- else %}
{%- set content = message.content %}
{%- endif %}
{{- "<|im_start|>environment name=<|plugin|>\n" + content|string + "<|im_end|>\n" }}
{%- else %}
{{- raise_exception("Only user and assistant and tool_results and tool and function roles are supported, with the exception of an initial optional system message!") }}
{%- endif %}
{%- endfor %}

{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}
12 changes: 9 additions & 3 deletions tests/tool_use/test_parallel_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@

from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL,
WEATHER_TOOL)
WEATHER_TOOL, ServerConfig)


# test: getting the model to generate parallel tool calls (streaming/not)
# when requested. NOTE that not all models may support this, so some exclusions
# may be added in the future. e.g. llama 3.1 models are not designed to support
# parallel tool calls.
@pytest.mark.asyncio
async def test_parallel_tool_calls(client: openai.AsyncOpenAI):
async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
server_config: ServerConfig):
models = await client.models.list()
model_name: str = models.data[0].id
if server_config.get("skip_parallel", False):
pytest.skip(f"skip parallel test for {model_name}")
chat_completion = await client.chat.completions.create(
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
temperature=0,
Expand Down Expand Up @@ -136,9 +139,12 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI):
# test: providing parallel tool calls back to the model to get a response
# (streaming/not)
@pytest.mark.asyncio
async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI):
async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI,
server_config: ServerConfig):
models = await client.models.list()
model_name: str = models.data[0].id
if server_config.get("skip_parallel", False):
pytest.skip(f"skip parallel test for {model_name}")
chat_completion = await client.chat.completions.create(
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
temperature=0,
Expand Down
25 changes: 21 additions & 4 deletions tests/tool_use/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List
from typing import Dict, List, Optional

from openai.types.chat import (ChatCompletionMessageParam,
ChatCompletionToolParam)
Expand All @@ -10,6 +10,7 @@
class ServerConfig(TypedDict):
model: str
arguments: List[str]
skip_parallel: Optional[bool]


# universal args for all models go here. also good if you need to test locally
Expand All @@ -23,7 +24,9 @@ class ServerConfig(TypedDict):
"arguments": [
"--tool-call-parser", "hermes", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja")
]
],
"skip_parallel":
False
},
"mistral": {
"model":
Expand All @@ -32,7 +35,21 @@ class ServerConfig(TypedDict):
"--tool-call-parser", "mistral", "--chat-template",
str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja"),
"--ignore-patterns=\"consolidated.safetensors\""
]
],
"skip_parallel":
False
},
"internlm": {
"model":
"internlm/internlm2_5-7b-chat",
"arguments": [
"--tool-call-parser", "internlm", "--chat-template",
str(VLLM_PATH /
"examples/tool_chat_template_internlm2_tool.jinja"),
"--trust_remote_code"
],
"skip_parallel":
True
}
}

Expand All @@ -55,7 +72,7 @@ class ServerConfig(TypedDict):
"type":
"string",
"description":
"the two-letter abbreviation for the state "
"must the two-letter abbreviation for the state "
"that the city is in, e.g. 'CA' which would "
"mean 'California'"
},
Expand Down
10 changes: 10 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
Expand Down Expand Up @@ -526,6 +527,15 @@ async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)

if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)

valide_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \
and args.tool_call_parser not in valide_tool_parses:
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
f"(chose from {{ {','.join(valide_tool_parses)} }})")

temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
temp_socket.bind(("", args.port))

Expand Down
14 changes: 13 additions & 1 deletion vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.utils import FlexibleArgumentParser


Expand Down Expand Up @@ -171,16 +172,27 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"Enable auto tool choice for supported models. Use --tool-call-parser"
"to specify which parser to use")

valid_tool_parsers = ToolParserManager.tool_parsers.keys()
parser.add_argument(
"--tool-call-parser",
type=str,
choices=["mistral", "hermes"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we still want to specify the choices of ["mistral", "hermes", "internlm2_5"], but make this optional in the event that --enable-auto-tool-choice is called with --tool-parser-plugin.

It would be good for people to know which tool call parsers are available by default, and this makes sure that the expected values get into the auto-generated documentation.

Copy link
Contributor Author

@sydnash sydnash Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if special a choices in the add_argument, user cannot special a --tool-call-parser which register in the --tool-parser-plugin.

maybe we can given the default choices to the help information.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what I'm trying to say is that you could keep the choices of ["mistral", "hermes", "internlm2_5"] and then do something like this in serving_chat.py:

Current state:

if self.enable_auto_tools:
    try:
        self.tool_parser = ToolParserManager.get_tool_parser(tool_parser)
    except Exception as e:
        raise TypeError("Error: --enable-auto-tool-choice requires tool_parser:'{tool_parser}' which has not  been registered") from e

Possible changes:

# if a plugin is not specified; we can do this already
if self.enable_auto_tools and not self.tool_parser_plugin:
    plugin_name = tool_parser # one of the options from the CLI argument, e.g. hermes or mistral

# if a plugin is specified - this may require some refactoring to get the tool parser plugin loaded in serving chat
elif self.enable_auto_tools and self.tool_parser_plugin:
    # get the name of the plugin loaded from `--tool-parser-plugin`
    plugin_name = get_plugin_name_somehow_from_loaded_plugin()

# handle additional cases here
try: 
    self.tool_parser = ToolParserManager.get_tool_parser(plugin_name)
except Exception as e:
    raise TypeError("You must specify a valid value for --tool-call-parser OR a value tool parser plugin"

Copy link
Contributor Author

@sydnash sydnash Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but In my design, a plugin can register any number of tool parsers into vllm, and user can use --tool-call-parser to specify the one he want to use just like the default tool parser write in vllm.

I added some documents in the docs/source/serving/openai_compatible_server.md, maybe you can take a look of that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohhhh, I see. Hmm. I'm not sure what the best pattern would be for the arguments here, then. @DarkLight1337 @mgoin do y'all have any thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can use metavar instead of choices to display help information.

    valid_tool_parsers = ["mistral", "hermes", "internlm2", "internlm2_5"]
    parser.add_argument(
        "--tool-call-parser",
        type=str,
        metavar=
        "{" + ",".join(valid_tool_parsers) +  "} or name registered in "
        "--tool-parser-plugin",
        default=None,
        help=
        "Select the tool call parser depending on the model that you're using."
        " This is used to parse the model-generated tool call into OpenAI API "
        "format. Required for --enable-auto-tool-choice.")

the help will look like this:

--tool-call-parser {mistral,hermes,internlm2,internlm2_5} or name registered in --tool-parser-plugin
                        Select the tool call parser depending on the model
                        that you're using. This is used to parse the model-
                        generated tool call into OpenAI API format. Required
                        for --enable-auto-tool-choice.

and move the plugin import and tool call parser check to run_server to check the invalid tool call parser name quickly.(before the model loads).

if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)
    
    if args.enable_auto_tool_choice:
        if args.tool_call_parser not in ToolParserManager.tool_parsers.keys():
            raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
                        f"chose from {{ {','.join(ToolParserManager.tool_parsers.keys())} }}")

error info look like this:

Traceback (most recent call last):
  File "/LocalRun/jun.dai/conda/envs/vllm_env/bin/vllm", line 8, in <module>
    sys.exit(main())
  File "/LocalRun/jun.dai/code/github/sydnash/vllm/vllm/scripts.py", line 165, in main
    args.dispatch_function(args)
  File "/LocalRun/jun.dai/code/github/sydnash/vllm/vllm/scripts.py", line 37, in serve
    asyncio.run(run_server(args))
  File "/LocalRun/jun.dai/conda/envs/vllm_env/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/LocalRun/jun.dai/conda/envs/vllm_env/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/LocalRun/jun.dai/code/github/sydnash/vllm/vllm/entrypoints/openai/api_server.py", line 505, in run_server
    raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
KeyError: 'invalid tool call parser: internlm3 chose from { hermes,internlm2,internlm2_5,mistral,internlm }'

metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
"--tool-parser-plugin",
default=None,
help=
"Select the tool call parser depending on the model that you're using."
" This is used to parse the model-generated tool call into OpenAI API "
"format. Required for --enable-auto-tool-choice.")

parser.add_argument(
"--tool-parser-plugin",
type=str,
default="",
help=
"Special the tool parser plugin write to parse the model-generated tool"
" into OpenAI API format, the name register in this plugin can be used "
"in --tool-call-parser.")

parser = AsyncEngineArgs.add_cli_args(parser)

parser.add_argument('--max-log-len',
Expand Down
Loading
Loading