Skip to content

Commit f252489

Browse files
Merge pull request #12 from qdrant/feature/embedding-providers
Abstract the embedding providers
2 parents 3e88bae + 406725c commit f252489

File tree

11 files changed

+200
-29
lines changed

11 files changed

+200
-29
lines changed

.gitignore

+4-1
Original file line numberDiff line numberDiff line change
@@ -159,4 +159,7 @@ cython_debug/
159159
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160160
# and can be added to the global gitignore or merged into this file. For a more nuclear
161161
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
162-
#.idea/
162+
.idea/
163+
164+
# Project-specific settings
165+
.aider*

README.md

+23-3
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ uv run mcp-server-qdrant \
3838
--qdrant-url "http://localhost:6333" \
3939
--qdrant-api-key "your_api_key" \
4040
--collection-name "my_collection" \
41-
--fastembed-model-name "sentence-transformers/all-MiniLM-L6-v2"
41+
--embedding-model "sentence-transformers/all-MiniLM-L6-v2"
4242
```
4343

4444
### Installing via Smithery
@@ -78,7 +78,7 @@ This MCP server will automatically create a collection with the specified name i
7878

7979
By default, the server will use the `sentence-transformers/all-MiniLM-L6-v2` embedding model to encode memories.
8080
For the time being, only [FastEmbed](https://qdrant.github.io/fastembed/) models are supported, and you can change it
81-
by passing the `--fastembed-model-name` argument to the server.
81+
by passing the `--embedding-model` argument to the server.
8282

8383
### Using the local mode of Qdrant
8484

@@ -108,11 +108,31 @@ The configuration of the server can be also done using environment variables:
108108
- `QDRANT_URL`: URL of the Qdrant server, e.g. `http://localhost:6333`
109109
- `QDRANT_API_KEY`: API key for the Qdrant server
110110
- `COLLECTION_NAME`: Name of the collection to use
111-
- `FASTEMBED_MODEL_NAME`: Name of the FastEmbed model to use
111+
- `EMBEDDING_MODEL`: Name of the embedding model to use
112+
- `EMBEDDING_PROVIDER`: Embedding provider to use (currently only "fastembed" is supported)
112113
- `QDRANT_LOCAL_PATH`: Path to the local Qdrant database
113114

114115
You cannot provide `QDRANT_URL` and `QDRANT_LOCAL_PATH` at the same time.
115116

117+
## Contributing
118+
119+
If you have suggestions for how mcp-server-qdrant could be improved, or want to report a bug, open an issue!
120+
We'd love all and any contributions.
121+
122+
### Testing `mcp-server-qdrant` locally
123+
124+
The [MCP inspector](https://github.com/modelcontextprotocol/inspector) is a developer tool for testing and debugging MCP
125+
servers. It runs both a client UI (default port 5173) and an MCP proxy server (default port 3000). Open the client UI in
126+
your browser to use the inspector.
127+
128+
```shell
129+
npx @modelcontextprotocol/inspector uv run mcp-server-qdrant \
130+
--collection-name test \
131+
--qdrant-local-path /tmp/qdrant-local-test
132+
```
133+
134+
Once started, open your browser to http://localhost:5173 to access the inspector interface.
135+
116136
## License
117137

118138
This MCP server is licensed under the MIT License. This means you are free to use, modify, and distribute the software,

pyproject.toml

+8-1
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,15 @@ dev-dependencies = [
1818
"pre-commit>=4.1.0",
1919
"pyright>=1.1.389",
2020
"pytest>=8.3.3",
21-
"ruff>=0.8.0",
21+
"pytest-asyncio>=0.23.0",
22+
"ruff>=0.8.0"
2223
]
2324

2425
[project.scripts]
2526
mcp-server-qdrant = "mcp_server_qdrant:main"
27+
28+
[tool.pytest.ini_options]
29+
testpaths = ["tests"]
30+
python_files = "test_*.py"
31+
python_functions = "test_*"
32+
asyncio_mode = "auto"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .base import EmbeddingProvider
2+
from .factory import create_embedding_provider
3+
from .fastembed import FastEmbedProvider
4+
5+
__all__ = ["EmbeddingProvider", "FastEmbedProvider", "create_embedding_provider"]
+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from abc import ABC, abstractmethod
2+
from typing import List
3+
4+
5+
class EmbeddingProvider(ABC):
6+
"""Abstract base class for embedding providers."""
7+
8+
@abstractmethod
9+
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
10+
"""Embed a list of documents into vectors."""
11+
pass
12+
13+
@abstractmethod
14+
async def embed_query(self, query: str) -> List[float]:
15+
"""Embed a query into a vector."""
16+
pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from mcp_server_qdrant.embeddings import EmbeddingProvider
2+
3+
4+
def create_embedding_provider(provider_type: str, **kwargs) -> EmbeddingProvider:
5+
"""
6+
Create an embedding provider based on the specified type.
7+
8+
:param provider_type: The type of embedding provider to create.
9+
:param kwargs: Additional arguments to pass to the provider constructor.
10+
:return: An instance of the specified embedding provider.
11+
"""
12+
if provider_type.lower() == "fastembed":
13+
from .fastembed import FastEmbedProvider
14+
15+
model_name = kwargs.get("model_name", "sentence-transformers/all-MiniLM-L6-v2")
16+
return FastEmbedProvider(model_name)
17+
else:
18+
raise ValueError(f"Unsupported embedding provider: {provider_type}")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import asyncio
2+
from typing import List
3+
4+
from fastembed import TextEmbedding
5+
6+
from .base import EmbeddingProvider
7+
8+
9+
class FastEmbedProvider(EmbeddingProvider):
10+
"""FastEmbed implementation of the embedding provider."""
11+
12+
def __init__(self, model_name: str):
13+
"""
14+
Initialize the FastEmbed provider.
15+
16+
:param model_name: The name of the FastEmbed model to use.
17+
"""
18+
self.model_name = model_name
19+
self.embedding_model = TextEmbedding(model_name)
20+
21+
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
22+
"""Embed a list of documents into vectors."""
23+
# Run in a thread pool since FastEmbed is synchronous
24+
loop = asyncio.get_event_loop()
25+
embeddings = await loop.run_in_executor(
26+
None, lambda: list(self.embedding_model.passage_embed(documents))
27+
)
28+
return [embedding.tolist() for embedding in embeddings]
29+
30+
async def embed_query(self, query: str) -> List[float]:
31+
"""Embed a query into a vector."""
32+
# Run in a thread pool since FastEmbed is synchronous
33+
loop = asyncio.get_event_loop()
34+
embeddings = await loop.run_in_executor(
35+
None, lambda: list(self.embedding_model.query_embed([query]))
36+
)
37+
return embeddings[0].tolist()

src/mcp_server_qdrant/qdrant.py

+47-15
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import Optional
22

3-
from qdrant_client import AsyncQdrantClient
3+
from qdrant_client import AsyncQdrantClient, models
4+
5+
from .embeddings.base import EmbeddingProvider
46

57

68
class QdrantConnector:
@@ -9,7 +11,7 @@ class QdrantConnector:
911
:param qdrant_url: The URL of the Qdrant server.
1012
:param qdrant_api_key: The API key to use for the Qdrant server.
1113
:param collection_name: The name of the collection to use.
12-
:param fastembed_model_name: The name of the FastEmbed model to use.
14+
:param embedding_provider: The embedding provider to use.
1315
:param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used.
1416
"""
1517

@@ -18,29 +20,54 @@ def __init__(
1820
qdrant_url: Optional[str],
1921
qdrant_api_key: Optional[str],
2022
collection_name: str,
21-
fastembed_model_name: str,
23+
embedding_provider: EmbeddingProvider,
2224
qdrant_local_path: Optional[str] = None,
2325
):
2426
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
2527
self._qdrant_api_key = qdrant_api_key
2628
self._collection_name = collection_name
27-
self._fastembed_model_name = fastembed_model_name
28-
# For the time being, FastEmbed models are the only supported ones.
29-
# A list of all available models can be found here:
30-
# https://qdrant.github.io/fastembed/examples/Supported_Models/
29+
self._embedding_provider = embedding_provider
3130
self._client = AsyncQdrantClient(
3231
location=qdrant_url, api_key=qdrant_api_key, path=qdrant_local_path
3332
)
34-
self._client.set_model(fastembed_model_name)
33+
34+
async def _ensure_collection_exists(self):
35+
"""Ensure that the collection exists, creating it if necessary."""
36+
collection_exists = await self._client.collection_exists(self._collection_name)
37+
if not collection_exists:
38+
# Create the collection with the appropriate vector size
39+
# We'll get the vector size by embedding a sample text
40+
sample_vector = await self._embedding_provider.embed_query("sample text")
41+
vector_size = len(sample_vector)
42+
43+
await self._client.create_collection(
44+
collection_name=self._collection_name,
45+
vectors_config=models.VectorParams(
46+
size=vector_size,
47+
distance=models.Distance.COSINE,
48+
),
49+
)
3550

3651
async def store_memory(self, information: str):
3752
"""
3853
Store a memory in the Qdrant collection.
3954
:param information: The information to store.
4055
"""
41-
await self._client.add(
42-
self._collection_name,
43-
documents=[information],
56+
await self._ensure_collection_exists()
57+
58+
# Embed the document
59+
embeddings = await self._embedding_provider.embed_documents([information])
60+
61+
# Add to Qdrant
62+
await self._client.upsert(
63+
collection_name=self._collection_name,
64+
points=[
65+
models.PointStruct(
66+
id=hash(information), # Simple hash as ID
67+
vector=embeddings[0],
68+
payload={"document": information},
69+
)
70+
],
4471
)
4572

4673
async def find_memories(self, query: str) -> list[str]:
@@ -53,9 +80,14 @@ async def find_memories(self, query: str) -> list[str]:
5380
if not collection_exists:
5481
return []
5582

56-
search_results = await self._client.query(
57-
self._collection_name,
58-
query_text=query,
83+
# Embed the query
84+
query_vector = await self._embedding_provider.embed_query(query)
85+
86+
# Search in Qdrant
87+
search_results = await self._client.search(
88+
collection_name=self._collection_name,
89+
query_vector=query_vector,
5990
limit=10,
6091
)
61-
return [result.document for result in search_results]
92+
93+
return [result.payload["document"] for result in search_results]

src/mcp_server_qdrant/server.py

+27-9
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,39 @@
77
from mcp.server import NotificationOptions, Server
88
from mcp.server.models import InitializationOptions
99

10+
from .embeddings.factory import create_embedding_provider
1011
from .qdrant import QdrantConnector
1112

1213

1314
def serve(
1415
qdrant_url: Optional[str],
1516
qdrant_api_key: Optional[str],
1617
collection_name: str,
17-
fastembed_model_name: str,
18+
embedding_provider_type: str,
19+
embedding_model_name: str,
1820
qdrant_local_path: Optional[str] = None,
1921
) -> Server:
2022
"""
2123
Instantiate the server and configure tools to store and find memories in Qdrant.
2224
:param qdrant_url: The URL of the Qdrant server.
2325
:param qdrant_api_key: The API key to use for the Qdrant server.
2426
:param collection_name: The name of the collection to use.
25-
:param fastembed_model_name: The name of the FastEmbed model to use.
27+
:param embedding_provider_type: The type of embedding provider to use.
28+
:param embedding_model_name: The name of the embedding model to use.
2629
:param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used.
2730
"""
2831
server = Server("qdrant")
2932

33+
# Create the embedding provider
34+
embedding_provider = create_embedding_provider(
35+
embedding_provider_type, model_name=embedding_model_name
36+
)
37+
3038
qdrant = QdrantConnector(
3139
qdrant_url,
3240
qdrant_api_key,
3341
collection_name,
34-
fastembed_model_name,
42+
embedding_provider,
3543
qdrant_local_path,
3644
)
3745

@@ -133,10 +141,18 @@ async def handle_tool_call(
133141
help="Collection name",
134142
)
135143
@click.option(
136-
"--fastembed-model-name",
137-
envvar="FASTEMBED_MODEL_NAME",
138-
required=True,
139-
help="FastEmbed model name",
144+
"--embedding-provider",
145+
envvar="EMBEDDING_PROVIDER",
146+
required=False,
147+
help="Embedding provider to use",
148+
default="fastembed",
149+
type=click.Choice(["fastembed"], case_sensitive=False),
150+
)
151+
@click.option(
152+
"--embedding-model",
153+
envvar="EMBEDDING_MODEL",
154+
required=False,
155+
help="Embedding model name",
140156
default="sentence-transformers/all-MiniLM-L6-v2",
141157
)
142158
@click.option(
@@ -149,7 +165,8 @@ def main(
149165
qdrant_url: Optional[str],
150166
qdrant_api_key: str,
151167
collection_name: Optional[str],
152-
fastembed_model_name: str,
168+
embedding_provider: str,
169+
embedding_model: str,
153170
qdrant_local_path: Optional[str],
154171
):
155172
# XOR of url and local path, since we accept only one of them
@@ -164,7 +181,8 @@ async def _run():
164181
qdrant_url,
165182
qdrant_api_key,
166183
collection_name,
167-
fastembed_model_name,
184+
embedding_provider,
185+
embedding_model,
168186
qdrant_local_path,
169187
)
170188
await server.run(

tests/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# This file can be empty, it just marks the directory as a Python package

uv.lock

+14
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)