1
1
from typing import Optional
2
2
3
- from qdrant_client import AsyncQdrantClient
3
+ from qdrant_client import AsyncQdrantClient , models
4
+
5
+ from .embeddings .base import EmbeddingProvider
4
6
5
7
6
8
class QdrantConnector :
@@ -9,7 +11,7 @@ class QdrantConnector:
9
11
:param qdrant_url: The URL of the Qdrant server.
10
12
:param qdrant_api_key: The API key to use for the Qdrant server.
11
13
:param collection_name: The name of the collection to use.
12
- :param fastembed_model_name : The name of the FastEmbed model to use.
14
+ :param embedding_provider : The embedding provider to use.
13
15
:param qdrant_local_path: The path to the storage directory for the Qdrant client, if local mode is used.
14
16
"""
15
17
@@ -18,29 +20,54 @@ def __init__(
18
20
qdrant_url : Optional [str ],
19
21
qdrant_api_key : Optional [str ],
20
22
collection_name : str ,
21
- fastembed_model_name : str ,
23
+ embedding_provider : EmbeddingProvider ,
22
24
qdrant_local_path : Optional [str ] = None ,
23
25
):
24
26
self ._qdrant_url = qdrant_url .rstrip ("/" ) if qdrant_url else None
25
27
self ._qdrant_api_key = qdrant_api_key
26
28
self ._collection_name = collection_name
27
- self ._fastembed_model_name = fastembed_model_name
28
- # For the time being, FastEmbed models are the only supported ones.
29
- # A list of all available models can be found here:
30
- # https://qdrant.github.io/fastembed/examples/Supported_Models/
29
+ self ._embedding_provider = embedding_provider
31
30
self ._client = AsyncQdrantClient (
32
31
location = qdrant_url , api_key = qdrant_api_key , path = qdrant_local_path
33
32
)
34
- self ._client .set_model (fastembed_model_name )
33
+
34
+ async def _ensure_collection_exists (self ):
35
+ """Ensure that the collection exists, creating it if necessary."""
36
+ collection_exists = await self ._client .collection_exists (self ._collection_name )
37
+ if not collection_exists :
38
+ # Create the collection with the appropriate vector size
39
+ # We'll get the vector size by embedding a sample text
40
+ sample_vector = await self ._embedding_provider .embed_query ("sample text" )
41
+ vector_size = len (sample_vector )
42
+
43
+ await self ._client .create_collection (
44
+ collection_name = self ._collection_name ,
45
+ vectors_config = models .VectorParams (
46
+ size = vector_size ,
47
+ distance = models .Distance .COSINE ,
48
+ ),
49
+ )
35
50
36
51
async def store_memory (self , information : str ):
37
52
"""
38
53
Store a memory in the Qdrant collection.
39
54
:param information: The information to store.
40
55
"""
41
- await self ._client .add (
42
- self ._collection_name ,
43
- documents = [information ],
56
+ await self ._ensure_collection_exists ()
57
+
58
+ # Embed the document
59
+ embeddings = await self ._embedding_provider .embed_documents ([information ])
60
+
61
+ # Add to Qdrant
62
+ await self ._client .upsert (
63
+ collection_name = self ._collection_name ,
64
+ points = [
65
+ models .PointStruct (
66
+ id = hash (information ), # Simple hash as ID
67
+ vector = embeddings [0 ],
68
+ payload = {"document" : information },
69
+ )
70
+ ],
44
71
)
45
72
46
73
async def find_memories (self , query : str ) -> list [str ]:
@@ -53,9 +80,14 @@ async def find_memories(self, query: str) -> list[str]:
53
80
if not collection_exists :
54
81
return []
55
82
56
- search_results = await self ._client .query (
57
- self ._collection_name ,
58
- query_text = query ,
83
+ # Embed the query
84
+ query_vector = await self ._embedding_provider .embed_query (query )
85
+
86
+ # Search in Qdrant
87
+ search_results = await self ._client .search (
88
+ collection_name = self ._collection_name ,
89
+ query_vector = query_vector ,
59
90
limit = 10 ,
60
91
)
61
- return [result .document for result in search_results ]
92
+
93
+ return [result .payload ["document" ] for result in search_results ]
0 commit comments