1
+ import asyncio
1
2
import re
2
- from typing import Iterable , Literal
3
+ from typing import Literal
4
+
5
+ from raggy .utilities .collections import distinct
3
6
4
7
try :
5
8
from chromadb import Client , CloudClient , HttpClient
6
9
from chromadb .api import ClientAPI
7
10
from chromadb .api .models .Collection import Collection
8
- from chromadb .api .types import Include , QueryResult
11
+ from chromadb .api .models .Collection import Document as ChromaDocument
12
+ from chromadb .api .types import QueryResult
13
+ from chromadb .utils .batch_utils import create_batches
9
14
except ImportError :
10
15
raise ImportError (
11
16
"You must have `chromadb` installed to use the Chroma vector store. "
12
17
"Install it with `pip install 'raggy[chroma]'`."
13
18
)
14
19
15
- from raggy .documents import Document , get_distinct_documents
20
+ from raggy .documents import Document as RaggyDocument
16
21
from raggy .settings import settings
17
22
from raggy .utilities .asyncutils import run_sync_in_worker_thread
18
23
from raggy .utilities .embeddings import create_openai_embeddings
@@ -69,7 +74,7 @@ async def delete(
69
74
self ,
70
75
ids : list [str ] | None = None ,
71
76
where : dict | None = None ,
72
- where_document : Document | None = None ,
77
+ where_document : ChromaDocument | None = None ,
73
78
):
74
79
await run_sync_in_worker_thread (
75
80
self .collection .delete ,
@@ -78,27 +83,36 @@ async def delete(
78
83
where_document = where_document ,
79
84
)
80
85
81
- async def add (self , documents : list [Document ]) -> Iterable [Document ]:
82
- documents = list (get_distinct_documents (documents ))
83
- kwargs = dict (
84
- ids = [document .id for document in documents ],
85
- documents = [document .text for document in documents ],
86
- metadatas = [
87
- document .metadata .model_dump (exclude_none = True ) or None
88
- for document in documents
89
- ],
90
- embeddings = await create_openai_embeddings (
91
- [document .text for document in documents ]
92
- ),
93
- )
86
+ async def add (self , documents : list [RaggyDocument ]) -> list [ChromaDocument ]:
87
+ unique_documents = list (distinct (documents , key = lambda doc : doc .text ))
94
88
95
- await run_sync_in_worker_thread (self .collection .add , ** kwargs )
89
+ ids = [doc .id for doc in unique_documents ]
90
+ texts = [doc .text for doc in unique_documents ]
91
+ metadatas = [
92
+ doc .metadata .model_dump (exclude_none = True ) for doc in unique_documents
93
+ ]
96
94
97
- get_result = await run_sync_in_worker_thread (
98
- self .collection .get , ids = kwargs ["ids" ]
95
+ embeddings = await create_openai_embeddings (texts )
96
+
97
+ data = {
98
+ "ids" : ids ,
99
+ "documents" : texts ,
100
+ "metadatas" : metadatas ,
101
+ "embeddings" : embeddings ,
102
+ }
103
+
104
+ batched_data : list [tuple ] = create_batches (
105
+ get_client (self .client_type ),
106
+ ** data ,
107
+ )
108
+
109
+ await asyncio .gather (
110
+ * (asyncio .to_thread (self .collection .add , * batch ) for batch in batched_data )
99
111
)
100
112
101
- return get_result .get ("documents" )
113
+ get_result = await asyncio .to_thread (self .collection .get , ids = ids )
114
+
115
+ return get_result .get ("documents" ) or []
102
116
103
117
async def query (
104
118
self ,
@@ -107,7 +121,7 @@ async def query(
107
121
n_results : int = 10 ,
108
122
where : dict | None = None ,
109
123
where_document : dict | None = None ,
110
- include : "Include" = ["metadatas" ],
124
+ include : list [ str ] = ["metadatas" ],
111
125
** kwargs ,
112
126
) -> "QueryResult" :
113
127
return await run_sync_in_worker_thread (
@@ -124,8 +138,8 @@ async def query(
124
138
async def count (self ) -> int :
125
139
return await run_sync_in_worker_thread (self .collection .count )
126
140
127
- async def upsert (self , documents : list [Document ]) :
128
- documents = list (get_distinct_documents (documents ))
141
+ async def upsert (self , documents : list [RaggyDocument ]) -> list [ ChromaDocument ] :
142
+ documents = list (distinct (documents , key = lambda doc : hash ( doc . text ) ))
129
143
kwargs = dict (
130
144
ids = [document .id for document in documents ],
131
145
documents = [document .text for document in documents ],
@@ -143,7 +157,7 @@ async def upsert(self, documents: list[Document]):
143
157
self .collection .get , ids = kwargs ["ids" ]
144
158
)
145
159
146
- return get_result .get ("documents" )
160
+ return get_result .get ("documents" ) or []
147
161
148
162
async def reset_collection (self ):
149
163
client = get_client (self .client_type )
@@ -160,7 +174,7 @@ async def reset_collection(self):
160
174
161
175
def ok (self ) -> bool :
162
176
try :
163
- version = self .client .get_version ()
177
+ version = get_client ( self .client_type ) .get_version ()
164
178
except Exception as e :
165
179
self .logger .error_kv ("Connection error" , f"Cannot connect to Chroma: { e } " )
166
180
if re .match (r"^\d+\.\d+\.\d+$" , version ):
@@ -177,6 +191,7 @@ async def query_collection(
177
191
where : dict | None = None ,
178
192
where_document : dict | None = None ,
179
193
max_tokens : int = 500 ,
194
+ client_type : ChromaClientType = "base" ,
180
195
) -> str :
181
196
"""Query a Chroma collection.
182
197
@@ -194,7 +209,9 @@ async def query_collection(
194
209
print(await query_collection("How to create a flow in Prefect?"))
195
210
```
196
211
"""
197
- async with Chroma (collection_name = collection_name ) as chroma :
212
+ async with Chroma (
213
+ collection_name = collection_name , client_type = client_type
214
+ ) as chroma :
198
215
query_embedding = query_embedding or await create_openai_embeddings (query_text )
199
216
200
217
query_result = await chroma .query (
@@ -205,8 +222,8 @@ async def query_collection(
205
222
include = ["documents" ],
206
223
)
207
224
208
- concatenated_result = " \n " . join (
209
- doc for doc in query_result .get ("documents" , [] )
210
- )
225
+ assert (
226
+ result := query_result .get ("documents" )
227
+ ) is not None , "No documents found"
211
228
212
- return slice_tokens (concatenated_result , max_tokens )
229
+ return slice_tokens (" \n " . join ( result [ 0 ]) , max_tokens )
0 commit comments