Skip to content

Commit 8d0405e

Browse files
authored
[ENH] URI Data Loader (#1294)
## Description of changes This PR adds URIs and DataLoaders into Chroma. - `DataLoader` works like `EmbeddingFunction`, except it takes a `URIs` and outputs the specified datatype. - A `DataLoader` using `pillow` for image file loading. - Adds `uris` as a field on `add`, `query`, as well as an `include` field - Adds `data` as an `include` field URIs specify a place where data can be loaded from, and can be used to load data for embedding, or as the result of retrieval. This makes multimodal retrieval with data stored externally as files seamless and extensible. This PR is stacked on #1293 ## Test Integration tests pass. A new unit test for data loaders: https://github.com/chroma-core/chroma/blob/c71c3efa15d2a9252db26470b9ff1cdb10a2b681/chromadb/test/data_loader/test_data_loader.py Try the notebook: https://github.com/chroma-core/chroma/blob/c71c3efa15d2a9252db26470b9ff1cdb10a2b681/examples/multimodal/multimodal_retrieval.ipynb ## Documentation Documentation for this and #1293 chroma-core/docs#157 ## TODOs - [x] Concurrent Loading - [x] Tests - [x] Wiring through FastAPI - [x] Documentation
1 parent d2b3bdd commit 8d0405e

File tree

14 files changed

+1096
-217
lines changed

14 files changed

+1096
-217
lines changed

chromadb/api/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010
Documents,
1111
Embeddable,
1212
EmbeddingFunction,
13+
DataLoader,
1314
Embeddings,
1415
IDs,
1516
Include,
17+
Loadable,
1618
Metadatas,
19+
URIs,
1720
Where,
1821
QueryResult,
1922
GetResult,
@@ -62,6 +65,7 @@ def create_collection(
6265
embedding_function: Optional[
6366
EmbeddingFunction[Embeddable]
6467
] = ef.DefaultEmbeddingFunction(), # type: ignore
68+
data_loader: Optional[DataLoader[Loadable]] = None,
6569
get_or_create: bool = False,
6670
) -> Collection:
6771
"""Create a new collection with the given name and metadata.
@@ -98,6 +102,7 @@ def get_collection(
98102
embedding_function: Optional[
99103
EmbeddingFunction[Embeddable]
100104
] = ef.DefaultEmbeddingFunction(), # type: ignore
105+
data_loader: Optional[DataLoader[Loadable]] = None,
101106
) -> Collection:
102107
"""Get a collection with the given name.
103108
Args:
@@ -127,6 +132,7 @@ def get_or_create_collection(
127132
embedding_function: Optional[
128133
EmbeddingFunction[Embeddable]
129134
] = ef.DefaultEmbeddingFunction(), # type: ignore
135+
data_loader: Optional[DataLoader[Loadable]] = None,
130136
) -> Collection:
131137
"""Get or create a collection with the given name and metadata.
132138
Args:
@@ -193,6 +199,7 @@ def _add(
193199
embeddings: Embeddings,
194200
metadatas: Optional[Metadatas] = None,
195201
documents: Optional[Documents] = None,
202+
uris: Optional[URIs] = None,
196203
) -> bool:
197204
"""[Internal] Add embeddings to a collection specified by UUID.
198205
If (some) ids already exist, only the new embeddings will be added.
@@ -203,6 +210,7 @@ def _add(
203210
embedding: The sequence of embeddings to add.
204211
metadata: The metadata to associate with the embeddings. Defaults to None.
205212
documents: The documents to associate with the embeddings. Defaults to None.
213+
uris: URIs of data sources for each embedding. Defaults to None.
206214
207215
Returns:
208216
True if the embeddings were added successfully.
@@ -217,6 +225,7 @@ def _update(
217225
embeddings: Optional[Embeddings] = None,
218226
metadatas: Optional[Metadatas] = None,
219227
documents: Optional[Documents] = None,
228+
uris: Optional[URIs] = None,
220229
) -> bool:
221230
"""[Internal] Update entries in a collection specified by UUID.
222231
@@ -226,7 +235,7 @@ def _update(
226235
embeddings: The sequence of embeddings to update. Defaults to None.
227236
metadatas: The metadata to associate with the embeddings. Defaults to None.
228237
documents: The documents to associate with the embeddings. Defaults to None.
229-
238+
uris: URIs of data sources for each embedding. Defaults to None.
230239
Returns:
231240
True if the embeddings were updated successfully.
232241
"""
@@ -240,6 +249,7 @@ def _upsert(
240249
embeddings: Embeddings,
241250
metadatas: Optional[Metadatas] = None,
242251
documents: Optional[Documents] = None,
252+
uris: Optional[URIs] = None,
243253
) -> bool:
244254
"""[Internal] Add or update entries in the a collection specified by UUID.
245255
If an entry with the same id already exists, it will be updated,
@@ -251,6 +261,7 @@ def _upsert(
251261
embeddings: The sequence of embeddings to add
252262
metadatas: The metadata to associate with the embeddings. Defaults to None.
253263
documents: The documents to associate with the embeddings. Defaults to None.
264+
uris: URIs of data sources for each embedding. Defaults to None.
254265
"""
255266
pass
256267

@@ -496,6 +507,7 @@ def create_collection(
496507
embedding_function: Optional[
497508
EmbeddingFunction[Embeddable]
498509
] = ef.DefaultEmbeddingFunction(), # type: ignore
510+
data_loader: Optional[DataLoader[Loadable]] = None,
499511
get_or_create: bool = False,
500512
tenant: str = DEFAULT_TENANT,
501513
database: str = DEFAULT_DATABASE,
@@ -511,6 +523,7 @@ def get_collection(
511523
embedding_function: Optional[
512524
EmbeddingFunction[Embeddable]
513525
] = ef.DefaultEmbeddingFunction(), # type: ignore
526+
data_loader: Optional[DataLoader[Loadable]] = None,
514527
tenant: str = DEFAULT_TENANT,
515528
database: str = DEFAULT_DATABASE,
516529
) -> Collection:
@@ -525,6 +538,7 @@ def get_or_create_collection(
525538
embedding_function: Optional[
526539
EmbeddingFunction[Embeddable]
527540
] = ef.DefaultEmbeddingFunction(), # type: ignore
541+
data_loader: Optional[DataLoader[Loadable]] = None,
528542
tenant: str = DEFAULT_TENANT,
529543
database: str = DEFAULT_DATABASE,
530544
) -> Collection:

chromadb/api/client.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,18 @@
55
from chromadb.api import AdminAPI, ClientAPI, ServerAPI
66
from chromadb.api.types import (
77
CollectionMetadata,
8+
DataLoader,
89
Documents,
10+
Embeddable,
911
EmbeddingFunction,
1012
Embeddings,
1113
GetResult,
1214
IDs,
1315
Include,
16+
Loadable,
1417
Metadatas,
1518
QueryResult,
19+
URIs,
1620
)
1721
from chromadb.config import Settings, System
1822
from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE
@@ -173,13 +177,17 @@ def create_collection(
173177
self,
174178
name: str,
175179
metadata: Optional[CollectionMetadata] = None,
176-
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
180+
embedding_function: Optional[
181+
EmbeddingFunction[Embeddable]
182+
] = ef.DefaultEmbeddingFunction(), # type: ignore
183+
data_loader: Optional[DataLoader[Loadable]] = None,
177184
get_or_create: bool = False,
178185
) -> Collection:
179186
return self._server.create_collection(
180187
name=name,
181188
metadata=metadata,
182189
embedding_function=embedding_function,
190+
data_loader=data_loader,
183191
tenant=self.tenant,
184192
database=self.database,
185193
get_or_create=get_or_create,
@@ -188,14 +196,18 @@ def create_collection(
188196
@override
189197
def get_collection(
190198
self,
191-
name: Optional[str] = None,
199+
name: str,
192200
id: Optional[UUID] = None,
193-
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
201+
embedding_function: Optional[
202+
EmbeddingFunction[Embeddable]
203+
] = ef.DefaultEmbeddingFunction(), # type: ignore
204+
data_loader: Optional[DataLoader[Loadable]] = None,
194205
) -> Collection:
195206
return self._server.get_collection(
196207
id=id,
197208
name=name,
198209
embedding_function=embedding_function,
210+
data_loader=data_loader,
199211
tenant=self.tenant,
200212
database=self.database,
201213
)
@@ -205,12 +217,16 @@ def get_or_create_collection(
205217
self,
206218
name: str,
207219
metadata: Optional[CollectionMetadata] = None,
208-
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
220+
embedding_function: Optional[
221+
EmbeddingFunction[Embeddable]
222+
] = ef.DefaultEmbeddingFunction(), # type: ignore
223+
data_loader: Optional[DataLoader[Loadable]] = None,
209224
) -> Collection:
210225
return self._server.get_or_create_collection(
211226
name=name,
212227
metadata=metadata,
213228
embedding_function=embedding_function,
229+
data_loader=data_loader,
214230
tenant=self.tenant,
215231
database=self.database,
216232
)
@@ -251,13 +267,15 @@ def _add(
251267
embeddings: Embeddings,
252268
metadatas: Optional[Metadatas] = None,
253269
documents: Optional[Documents] = None,
270+
uris: Optional[URIs] = None,
254271
) -> bool:
255272
return self._server._add(
256273
ids=ids,
257274
collection_id=collection_id,
258275
embeddings=embeddings,
259276
metadatas=metadatas,
260277
documents=documents,
278+
uris=uris,
261279
)
262280

263281
@override
@@ -268,13 +286,15 @@ def _update(
268286
embeddings: Optional[Embeddings] = None,
269287
metadatas: Optional[Metadatas] = None,
270288
documents: Optional[Documents] = None,
289+
uris: Optional[URIs] = None,
271290
) -> bool:
272291
return self._server._update(
273292
collection_id=collection_id,
274293
ids=ids,
275294
embeddings=embeddings,
276295
metadatas=metadatas,
277296
documents=documents,
297+
uris=uris,
278298
)
279299

280300
@override
@@ -285,13 +305,15 @@ def _upsert(
285305
embeddings: Embeddings,
286306
metadatas: Optional[Metadatas] = None,
287307
documents: Optional[Documents] = None,
308+
uris: Optional[URIs] = None,
288309
) -> bool:
289310
return self._server._upsert(
290311
collection_id=collection_id,
291312
ids=ids,
292313
embeddings=embeddings,
293314
metadatas=metadatas,
294315
documents=documents,
316+
uris=uris,
295317
)
296318

297319
@override

chromadb/api/fastapi.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
from chromadb.api import ServerAPI
1414
from chromadb.api.models.Collection import Collection
1515
from chromadb.api.types import (
16+
DataLoader,
1617
Documents,
1718
Embeddable,
1819
Embeddings,
1920
EmbeddingFunction,
2021
IDs,
2122
Include,
23+
Loadable,
2224
Metadatas,
25+
URIs,
2326
Where,
2427
WhereDocument,
2528
GetResult,
@@ -223,6 +226,7 @@ def create_collection(
223226
embedding_function: Optional[
224227
EmbeddingFunction[Embeddable]
225228
] = ef.DefaultEmbeddingFunction(), # type: ignore
229+
data_loader: Optional[DataLoader[Loadable]] = None,
226230
get_or_create: bool = False,
227231
tenant: str = DEFAULT_TENANT,
228232
database: str = DEFAULT_DATABASE,
@@ -246,6 +250,7 @@ def create_collection(
246250
id=resp_json["id"],
247251
name=resp_json["name"],
248252
embedding_function=embedding_function,
253+
data_loader=data_loader,
249254
metadata=resp_json["metadata"],
250255
)
251256

@@ -255,7 +260,10 @@ def get_collection(
255260
self,
256261
name: str,
257262
id: Optional[UUID] = None,
258-
embedding_function: Optional[EmbeddingFunction[Embeddable]] = ef.DefaultEmbeddingFunction(), # type: ignore
263+
embedding_function: Optional[
264+
EmbeddingFunction[Embeddable]
265+
] = ef.DefaultEmbeddingFunction(), # type: ignore
266+
data_loader: Optional[DataLoader[Loadable]] = None,
259267
tenant: str = DEFAULT_TENANT,
260268
database: str = DEFAULT_DATABASE,
261269
) -> Collection:
@@ -276,6 +284,7 @@ def get_collection(
276284
name=resp_json["name"],
277285
id=resp_json["id"],
278286
embedding_function=embedding_function,
287+
data_loader=data_loader,
279288
metadata=resp_json["metadata"],
280289
)
281290

@@ -287,16 +296,20 @@ def get_or_create_collection(
287296
self,
288297
name: str,
289298
metadata: Optional[CollectionMetadata] = None,
290-
embedding_function: Optional[EmbeddingFunction[Embeddable]] = ef.DefaultEmbeddingFunction(), # type: ignore
299+
embedding_function: Optional[
300+
EmbeddingFunction[Embeddable]
301+
] = ef.DefaultEmbeddingFunction(), # type: ignore
302+
data_loader: Optional[DataLoader[Loadable]] = None,
291303
tenant: str = DEFAULT_TENANT,
292304
database: str = DEFAULT_DATABASE,
293305
) -> Collection:
294306
return cast(
295307
Collection,
296308
self.create_collection(
297-
name,
298-
metadata,
299-
embedding_function,
309+
name=name,
310+
metadata=metadata,
311+
embedding_function=embedding_function,
312+
data_loader=data_loader,
300313
get_or_create=True,
301314
tenant=tenant,
302315
database=database,
@@ -403,6 +416,8 @@ def _get(
403416
embeddings=body.get("embeddings", None),
404417
metadatas=body.get("metadatas", None),
405418
documents=body.get("documents", None),
419+
data=None,
420+
uris=body.get("uris", None),
406421
)
407422

408423
@trace_method("FastAPI._delete", OpenTelemetryGranularity.OPERATION)
@@ -429,7 +444,11 @@ def _delete(
429444
def _submit_batch(
430445
self,
431446
batch: Tuple[
432-
IDs, Optional[Embeddings], Optional[Metadatas], Optional[Documents]
447+
IDs,
448+
Optional[Embeddings],
449+
Optional[Metadatas],
450+
Optional[Documents],
451+
Optional[URIs],
433452
],
434453
url: str,
435454
) -> requests.Response:
@@ -444,6 +463,7 @@ def _submit_batch(
444463
"embeddings": batch[1],
445464
"metadatas": batch[2],
446465
"documents": batch[3],
466+
"uris": batch[4],
447467
}
448468
),
449469
)
@@ -458,12 +478,13 @@ def _add(
458478
embeddings: Embeddings,
459479
metadatas: Optional[Metadatas] = None,
460480
documents: Optional[Documents] = None,
481+
uris: Optional[URIs] = None,
461482
) -> bool:
462483
"""
463484
Adds a batch of embeddings to the database
464485
- pass in column oriented data lists
465486
"""
466-
batch = (ids, embeddings, metadatas, documents)
487+
batch = (ids, embeddings, metadatas, documents, uris)
467488
validate_batch(batch, {"max_batch_size": self.max_batch_size})
468489
resp = self._submit_batch(batch, "/collections/" + str(collection_id) + "/add")
469490
raise_chroma_error(resp)
@@ -478,12 +499,13 @@ def _update(
478499
embeddings: Optional[Embeddings] = None,
479500
metadatas: Optional[Metadatas] = None,
480501
documents: Optional[Documents] = None,
502+
uris: Optional[URIs] = None,
481503
) -> bool:
482504
"""
483505
Updates a batch of embeddings in the database
484506
- pass in column oriented data lists
485507
"""
486-
batch = (ids, embeddings, metadatas, documents)
508+
batch = (ids, embeddings, metadatas, documents, uris)
487509
validate_batch(batch, {"max_batch_size": self.max_batch_size})
488510
resp = self._submit_batch(
489511
batch, "/collections/" + str(collection_id) + "/update"
@@ -500,12 +522,13 @@ def _upsert(
500522
embeddings: Embeddings,
501523
metadatas: Optional[Metadatas] = None,
502524
documents: Optional[Documents] = None,
525+
uris: Optional[URIs] = None,
503526
) -> bool:
504527
"""
505528
Upserts a batch of embeddings in the database
506529
- pass in column oriented data lists
507530
"""
508-
batch = (ids, embeddings, metadatas, documents)
531+
batch = (ids, embeddings, metadatas, documents, uris)
509532
validate_batch(batch, {"max_batch_size": self.max_batch_size})
510533
resp = self._submit_batch(
511534
batch, "/collections/" + str(collection_id) + "/upsert"
@@ -547,6 +570,8 @@ def _query(
547570
embeddings=body.get("embeddings", None),
548571
metadatas=body.get("metadatas", None),
549572
documents=body.get("documents", None),
573+
uris=body.get("uris", None),
574+
data=None,
550575
)
551576

552577
@trace_method("FastAPI.reset", OpenTelemetryGranularity.ALL)

0 commit comments

Comments
 (0)