-
Notifications
You must be signed in to change notification settings - Fork 274
/
Copy pathrouter_embedding_cohere.py
59 lines (52 loc) · 1.93 KB
/
router_embedding_cohere.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from typing import Callable, List, Optional, TYPE_CHECKING
from mcp_agent.agents.agent import Agent
from mcp_agent.workflows.embedding.embedding_cohere import CohereEmbeddingModel
from mcp_agent.workflows.router.router_embedding import EmbeddingRouter
if TYPE_CHECKING:
from mcp_agent.context import Context
class CohereEmbeddingRouter(EmbeddingRouter):
"""
A router that uses Cohere embedding similarity to route requests to appropriate categories.
This class helps to route an input to a specific MCP server, an Agent (an aggregation of MCP servers),
or a function (any Callable).
"""
def __init__(
self,
server_names: List[str] | None = None,
agents: List[Agent] | None = None,
functions: List[Callable] | None = None,
embedding_model: CohereEmbeddingModel | None = None,
context: Optional["Context"] = None,
**kwargs,
):
embedding_model = embedding_model or CohereEmbeddingModel()
super().__init__(
embedding_model=embedding_model,
server_names=server_names,
agents=agents,
functions=functions,
context=context,
**kwargs,
)
@classmethod
async def create(
cls,
embedding_model: CohereEmbeddingModel | None = None,
server_names: List[str] | None = None,
agents: List[Agent] | None = None,
functions: List[Callable] | None = None,
context: Optional["Context"] = None,
) -> "CohereEmbeddingRouter":
"""
Factory method to create and initialize a router.
Use this instead of constructor since we need async initialization.
"""
instance = cls(
server_names=server_names,
agents=agents,
functions=functions,
embedding_model=embedding_model,
context=context,
)
await instance.initialize()
return instance