Skip to content

Commit 78e0d9b

Browse files
committed
fix lint
1 parent 5f6b4e9 commit 78e0d9b

File tree

6 files changed

+90
-51
lines changed

6 files changed

+90
-51
lines changed

examples/chat_with_X/repo.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ async def upsert_documents(documents: list[Document]):
3434

3535
@flow(flow_run_name="{repo}")
3636
async def ingest_repo(repo: str):
37-
"""repo should be in the format 'owner/repo'"""
3837
documents = await gather_documents(repo)
3938
await upsert_documents(documents)
4039

@@ -69,4 +68,6 @@ async def chat_with_repo(initial_message: str | None = None, clean_up: bool = Tr
6968

7069
if __name__ == "__main__":
7170
warnings.filterwarnings("ignore", category=UserWarning)
72-
run_coro_as_sync(chat_with_repo("lets chat about zzstoatzz/prefect-bot"))
71+
run_coro_as_sync(
72+
chat_with_repo("lets chat about zzstoatzz/prefect-bot - please ingest it")
73+
)

examples/refresh_chroma/refresh_collection.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Literal
33

44
from bs4 import BeautifulSoup
5+
from chromadb.api.models.Collection import Document as ChromaDocument
56
from prefect import flow, task
67
from prefect.tasks import task_input_hash
78

@@ -50,6 +51,18 @@ async def run_loader(loader: Loader) -> list[Document]:
5051
return await loader.load()
5152

5253

54+
@task
55+
async def add_documents(
56+
chroma: Chroma, documents: list[Document], mode: Literal["upsert", "reset"]
57+
) -> list[ChromaDocument]:
58+
if mode == "reset":
59+
await chroma.reset_collection()
60+
docs = await chroma.add(documents)
61+
elif mode == "upsert":
62+
docs = await chroma.upsert(documents)
63+
return docs
64+
65+
5366
@flow(name="Update Knowledge", log_prints=True)
5467
async def refresh_chroma(
5568
collection_name: str = "default",
@@ -68,13 +81,7 @@ async def refresh_chroma(
6881
async with Chroma(
6982
collection_name=collection_name, client_type=chroma_client_type
7083
) as chroma:
71-
if mode == "reset":
72-
await chroma.reset_collection()
73-
docs = await chroma.add(documents)
74-
elif mode == "upsert":
75-
docs = await task(chroma.upsert)(documents)
76-
else:
77-
raise ValueError(f"Unknown mode: {mode!r} (expected 'upsert' or 'reset')")
84+
docs = await add_documents(chroma, documents, mode)
7885

7986
print(f"Added {len(docs)} documents to the {collection_name} collection.") # type: ignore
8087

@@ -83,5 +90,5 @@ async def refresh_chroma(
8390
import asyncio
8491

8592
asyncio.run(
86-
refresh_chroma(collection_name="docs", chroma_client_type="cloud", mode="reset")
93+
refresh_chroma(collection_name="test", chroma_client_type="cloud", mode="reset") # type: ignore
8794
)
+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
chromadb
2-
git+https://github.com/prefecthq/prefect.git@main
2+
prefect
33
trafilatura

src/raggy/documents.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import asyncio
22
import inspect
33
from functools import partial
4-
from typing import Annotated, Iterable
4+
from typing import Annotated
55

66
from jinja2 import Environment, Template
77
from pydantic import BaseModel, ConfigDict, Field, model_validator
88

9-
from raggy.utilities.collections import distinct
109
from raggy.utilities.ids import generate_prefixed_uuid
1110
from raggy.utilities.text import count_tokens, extract_keywords, hash_text, split_text
1211

@@ -39,7 +38,7 @@ class Document(BaseModel):
3938
keywords: list[str] = Field(default_factory=list)
4039

4140
@model_validator(mode="after")
42-
def validate_tokens(self):
41+
def ensure_tokens(self):
4342
if self.tokens is None:
4443
self.tokens = count_tokens(self.text)
4544
return self
@@ -51,7 +50,7 @@ def __hash__(self) -> int:
5150

5251
EXCERPT_TEMPLATE = jinja_env.from_string(
5352
inspect.cleandoc(
54-
"""The following is an excerpt from a document
53+
"""This is an excerpt from a document
5554
{% if document.metadata %}\n\n# Document metadata
5655
{{ document.metadata }}
5756
{% endif %}
@@ -126,8 +125,3 @@ async def _create_excerpt(
126125
metadata=document.metadata if document.metadata else {},
127126
tokens=count_tokens(excerpt_text),
128127
)
129-
130-
131-
def get_distinct_documents(documents: Iterable["Document"]) -> Iterable["Document"]:
132-
"""Return a list of distinct documents."""
133-
return distinct(documents, key=lambda doc: hash(doc.text))

src/raggy/utilities/embeddings.py

+20
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,30 @@
1+
from typing import overload
2+
13
from openai import APIConnectionError, AsyncOpenAI
24
from openai.types import CreateEmbeddingResponse
35
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
46

57
import raggy
68

79

10+
@overload
11+
async def create_openai_embeddings(
12+
input_: str,
13+
timeout: int = 60,
14+
model: str = raggy.settings.openai_embeddings_model,
15+
) -> list[float]:
16+
...
17+
18+
19+
@overload
20+
async def create_openai_embeddings(
21+
input_: list[str],
22+
timeout: int = 60,
23+
model: str = raggy.settings.openai_embeddings_model,
24+
) -> list[list[float]]:
25+
...
26+
27+
828
@retry(
929
retry=retry_if_exception_type(APIConnectionError),
1030
stop=stop_after_attempt(3),

src/raggy/vectorstores/chroma.py

+48-31
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
1+
import asyncio
12
import re
2-
from typing import Iterable, Literal
3+
from typing import Literal
4+
5+
from raggy.utilities.collections import distinct
36

47
try:
58
from chromadb import Client, CloudClient, HttpClient
69
from chromadb.api import ClientAPI
710
from chromadb.api.models.Collection import Collection
8-
from chromadb.api.types import Include, QueryResult
11+
from chromadb.api.models.Collection import Document as ChromaDocument
12+
from chromadb.api.types import QueryResult
13+
from chromadb.utils.batch_utils import create_batches
914
except ImportError:
1015
raise ImportError(
1116
"You must have `chromadb` installed to use the Chroma vector store. "
1217
"Install it with `pip install 'raggy[chroma]'`."
1318
)
1419

15-
from raggy.documents import Document, get_distinct_documents
20+
from raggy.documents import Document as RaggyDocument
1621
from raggy.settings import settings
1722
from raggy.utilities.asyncutils import run_sync_in_worker_thread
1823
from raggy.utilities.embeddings import create_openai_embeddings
@@ -69,7 +74,7 @@ async def delete(
6974
self,
7075
ids: list[str] | None = None,
7176
where: dict | None = None,
72-
where_document: Document | None = None,
77+
where_document: ChromaDocument | None = None,
7378
):
7479
await run_sync_in_worker_thread(
7580
self.collection.delete,
@@ -78,27 +83,36 @@ async def delete(
7883
where_document=where_document,
7984
)
8085

81-
async def add(self, documents: list[Document]) -> Iterable[Document]:
82-
documents = list(get_distinct_documents(documents))
83-
kwargs = dict(
84-
ids=[document.id for document in documents],
85-
documents=[document.text for document in documents],
86-
metadatas=[
87-
document.metadata.model_dump(exclude_none=True) or None
88-
for document in documents
89-
],
90-
embeddings=await create_openai_embeddings(
91-
[document.text for document in documents]
92-
),
93-
)
86+
async def add(self, documents: list[RaggyDocument]) -> list[ChromaDocument]:
87+
unique_documents = list(distinct(documents, key=lambda doc: doc.text))
9488

95-
await run_sync_in_worker_thread(self.collection.add, **kwargs)
89+
ids = [doc.id for doc in unique_documents]
90+
texts = [doc.text for doc in unique_documents]
91+
metadatas = [
92+
doc.metadata.model_dump(exclude_none=True) for doc in unique_documents
93+
]
9694

97-
get_result = await run_sync_in_worker_thread(
98-
self.collection.get, ids=kwargs["ids"]
95+
embeddings = await create_openai_embeddings(texts)
96+
97+
data = {
98+
"ids": ids,
99+
"documents": texts,
100+
"metadatas": metadatas,
101+
"embeddings": embeddings,
102+
}
103+
104+
batched_data: list[tuple] = create_batches(
105+
get_client(self.client_type),
106+
**data,
107+
)
108+
109+
await asyncio.gather(
110+
*(asyncio.to_thread(self.collection.add, *batch) for batch in batched_data)
99111
)
100112

101-
return get_result.get("documents")
113+
get_result = await asyncio.to_thread(self.collection.get, ids=ids)
114+
115+
return get_result.get("documents") or []
102116

103117
async def query(
104118
self,
@@ -107,7 +121,7 @@ async def query(
107121
n_results: int = 10,
108122
where: dict | None = None,
109123
where_document: dict | None = None,
110-
include: "Include" = ["metadatas"],
124+
include: list[str] = ["metadatas"],
111125
**kwargs,
112126
) -> "QueryResult":
113127
return await run_sync_in_worker_thread(
@@ -124,8 +138,8 @@ async def query(
124138
async def count(self) -> int:
125139
return await run_sync_in_worker_thread(self.collection.count)
126140

127-
async def upsert(self, documents: list[Document]):
128-
documents = list(get_distinct_documents(documents))
141+
async def upsert(self, documents: list[RaggyDocument]) -> list[ChromaDocument]:
142+
documents = list(distinct(documents, key=lambda doc: hash(doc.text)))
129143
kwargs = dict(
130144
ids=[document.id for document in documents],
131145
documents=[document.text for document in documents],
@@ -143,7 +157,7 @@ async def upsert(self, documents: list[Document]):
143157
self.collection.get, ids=kwargs["ids"]
144158
)
145159

146-
return get_result.get("documents")
160+
return get_result.get("documents") or []
147161

148162
async def reset_collection(self):
149163
client = get_client(self.client_type)
@@ -160,7 +174,7 @@ async def reset_collection(self):
160174

161175
def ok(self) -> bool:
162176
try:
163-
version = self.client.get_version()
177+
version = get_client(self.client_type).get_version()
164178
except Exception as e:
165179
self.logger.error_kv("Connection error", f"Cannot connect to Chroma: {e}")
166180
if re.match(r"^\d+\.\d+\.\d+$", version):
@@ -177,6 +191,7 @@ async def query_collection(
177191
where: dict | None = None,
178192
where_document: dict | None = None,
179193
max_tokens: int = 500,
194+
client_type: ChromaClientType = "base",
180195
) -> str:
181196
"""Query a Chroma collection.
182197
@@ -194,7 +209,9 @@ async def query_collection(
194209
print(await query_collection("How to create a flow in Prefect?"))
195210
```
196211
"""
197-
async with Chroma(collection_name=collection_name) as chroma:
212+
async with Chroma(
213+
collection_name=collection_name, client_type=client_type
214+
) as chroma:
198215
query_embedding = query_embedding or await create_openai_embeddings(query_text)
199216

200217
query_result = await chroma.query(
@@ -205,8 +222,8 @@ async def query_collection(
205222
include=["documents"],
206223
)
207224

208-
concatenated_result = "\n".join(
209-
doc for doc in query_result.get("documents", [])
210-
)
225+
assert (
226+
result := query_result.get("documents")
227+
) is not None, "No documents found"
211228

212-
return slice_tokens(concatenated_result, max_tokens)
229+
return slice_tokens("\n".join(result[0]), max_tokens)

0 commit comments

Comments
 (0)