Skip to content

Commit be403f0

Browse files
committed
Make get_tools sync, add separate async initialize.
1 parent 94fb24e commit be403f0

File tree

5 files changed

+82
-35
lines changed

5 files changed

+82
-35
lines changed

pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ dev = [
2929
"ruff~=0.8.0",
3030
"mypy~=1.13.0",
3131
"typing-extensions~=4.12.2",
32+
"langchain-groq~=0.2.1",
3233
]
3334

3435
[project.urls]
@@ -81,4 +82,4 @@ warn_unused_ignores = true
8182
strict_equality = true
8283
no_implicit_optional = true
8384
show_error_codes = true
84-
files = "src/**/*.py"
85+
files = ["src/**/*.py", "tests/demo.py"]

src/langchain_mcp/toolkit.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pydantic_core
1010
import typing_extensions as t
1111
from langchain_core.tools.base import BaseTool, BaseToolkit, ToolException
12-
from mcp import ClientSession
12+
from mcp import ClientSession, ListToolsResult
1313

1414

1515
class MCPToolkit(BaseToolkit):
@@ -20,25 +20,30 @@ class MCPToolkit(BaseToolkit):
2020
session: ClientSession
2121
"""The MCP session used to obtain the tools"""
2222

23-
_initialized: bool = False
23+
_tools: ListToolsResult | None = None
2424

2525
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
2626

27-
@t.override
28-
async def get_tools(self) -> list[BaseTool]: # type: ignore[override]
29-
if not self._initialized:
27+
async def initialize(self) -> None:
28+
"""Initialize the session and retrieve tools list"""
29+
if self._tools is None:
3030
await self.session.initialize()
31-
self._initialized = True
31+
self._tools = await self.session.list_tools()
32+
33+
@t.override
34+
def get_tools(self) -> list[BaseTool]:
35+
if self._tools is None:
36+
raise RuntimeError("Must initialize the toolkit first")
3237

3338
return [
3439
MCPTool(
35-
toolkit=self,
40+
session=self.session,
3641
name=tool.name,
3742
description=tool.description or "",
3843
args_schema=create_schema_model(tool.inputSchema),
3944
)
4045
# list_tools returns a PaginatedResult, but I don't see a way to pass the cursor to retrieve more tools
41-
for tool in (await self.session.list_tools()).tools
46+
for tool in self._tools.tools
4247
]
4348

4449

@@ -67,19 +72,20 @@ class MCPTool(BaseTool):
6772
MCP server tool
6873
"""
6974

70-
toolkit: MCPToolkit
75+
session: ClientSession
7176
handle_tool_error: bool | str | Callable[[ToolException], str] | None = True
7277

7378
@t.override
7479
def _run(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
7580
warnings.warn(
76-
"Invoke this tool asynchronousely using `ainvoke`. This method exists only to satisfy tests.", stacklevel=1
81+
"Invoke this tool asynchronousely using `ainvoke`. This method exists only to satisfy standard tests.",
82+
stacklevel=1,
7783
)
7884
return asyncio.run(self._arun(*args, **kwargs))
7985

8086
@t.override
8187
async def _arun(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
82-
result = await self.toolkit.session.call_tool(self.name, arguments=kwargs)
88+
result = await self.session.call_tool(self.name, arguments=kwargs)
8389
content = pydantic_core.to_json(result.content).decode()
8490
if result.isError:
8591
raise ToolException(content)

tests/conftest.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def mcptoolkit(request):
4646

4747
@pytest.fixture(scope="class")
4848
async def mcptool(request, mcptoolkit):
49-
tool = (await mcptoolkit.get_tools())[0]
49+
await mcptoolkit.initialize()
50+
tool = mcptoolkit.get_tools()[0]
5051
request.cls.tool = tool
5152
yield tool

tests/demo.py

+20-22
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,46 @@
11
# Copyright (C) 2024 Andrew Wason
22
# SPDX-License-Identifier: MIT
33

4-
# /// script
5-
# requires-python = ">=3.10"
6-
# dependencies = [
7-
# "langchain-mcp",
8-
# "langchain-groq",
9-
# ]
10-
# ///
11-
12-
134
import asyncio
145
import pathlib
156
import sys
7+
import typing as t
168

17-
from langchain_core.messages import HumanMessage
9+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
1810
from langchain_core.output_parsers import StrOutputParser
11+
from langchain_core.tools import BaseTool
1912
from langchain_groq import ChatGroq
2013
from mcp import ClientSession, StdioServerParameters
2114
from mcp.client.stdio import stdio_client
2215

2316
from langchain_mcp import MCPToolkit
2417

2518

19+
async def run(tools: list[BaseTool], prompt: str) -> str:
20+
model = ChatGroq(model="llama-3.1-8b-instant", stop_sequences=None) # requires GROQ_API_KEY
21+
tools_map = {tool.name: tool for tool in tools}
22+
tools_model = model.bind_tools(tools)
23+
messages: list[BaseMessage] = [HumanMessage(prompt)]
24+
ai_message = t.cast(AIMessage, await tools_model.ainvoke(messages))
25+
messages.append(ai_message)
26+
for tool_call in ai_message.tool_calls:
27+
selected_tool = tools_map[tool_call["name"].lower()]
28+
tool_msg = await selected_tool.ainvoke(tool_call)
29+
messages.append(tool_msg)
30+
return await (tools_model | StrOutputParser()).ainvoke(messages)
31+
32+
2633
async def main(prompt: str) -> None:
27-
model = ChatGroq(model="llama-3.1-8b-instant") # requires GROQ_API_KEY
2834
server_params = StdioServerParameters(
2935
command="npx",
3036
args=["-y", "@modelcontextprotocol/server-filesystem", str(pathlib.Path(__file__).parent.parent)],
3137
)
3238
async with stdio_client(server_params) as (read, write):
3339
async with ClientSession(read, write) as session:
3440
toolkit = MCPToolkit(session=session)
35-
tools = await toolkit.get_tools()
36-
tools_map = {tool.name: tool for tool in tools}
37-
tools_model = model.bind_tools(tools)
38-
messages = [HumanMessage(prompt)]
39-
messages.append(await tools_model.ainvoke(messages))
40-
for tool_call in messages[-1].tool_calls:
41-
selected_tool = tools_map[tool_call["name"].lower()]
42-
tool_msg = await selected_tool.ainvoke(tool_call)
43-
messages.append(tool_msg)
44-
result = await (tools_model | StrOutputParser()).ainvoke(messages)
45-
print(result)
41+
await toolkit.initialize()
42+
response = await run(toolkit.get_tools(), prompt)
43+
print(response)
4644

4745

4846
if __name__ == "__main__":

uv.lock

+41
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)