Skip to content

Commit e4e5b7a

Browse files
authored
Merge pull request #8 from zzstoatzz/more-concurrent
async utils
2 parents e2b8c0b + 949b51c commit e4e5b7a

File tree

7 files changed

+120
-63
lines changed

7 files changed

+120
-63
lines changed

examples/refresh_tpuf/refresh_namespace.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# /// script
22
# dependencies = [
33
# "prefect",
4-
# "raggy[tpuf]@git+https://github.com/zzstoatzz/raggy@improve-ingest",
4+
# "raggy[tpuf]@git+https://github.com/zzstoatzz/raggy",
55
# "trafilatura",
66
# ]
77
# ///
@@ -100,11 +100,12 @@ async def refresh_tpuf_namespace(
100100
namespace_loaders: list[Loader],
101101
reset: bool = False,
102102
batch_size: int = 100,
103+
max_concurrent: int = 8,
103104
):
104105
"""Flow updating vectorstore with info from the Prefect community."""
105106
documents: list[Document] = [
106107
doc
107-
for future in run_loader.map(quote(namespace_loaders)) # type: ignore
108+
for future in run_loader.map(quote(namespace_loaders))
108109
for doc in future.result()
109110
]
110111

@@ -115,20 +116,32 @@ async def refresh_tpuf_namespace(
115116
await task(tpuf.reset)()
116117
print(f"RESETTING: Deleted all documents from tpuf ns {namespace!r}.")
117118

118-
await task(tpuf.upsert_batched)(documents=documents, batch_size=batch_size)
119+
await task(tpuf.upsert_batched)(
120+
documents=documents, batch_size=batch_size, max_concurrent=max_concurrent
121+
)
119122

120123
print(f"Updated tpuf ns {namespace!r} with {len(documents)} documents.")
121124

122125

123126
@flow(name="Refresh Namespaces", log_prints=True)
124-
async def refresh_tpuf(reset: bool = False, batch_size: int = 100):
127+
async def refresh_tpuf(
128+
reset: bool = False, batch_size: int = 100, test_mode: bool = False
129+
):
125130
for namespace, namespace_loaders in loaders.items():
131+
if test_mode:
132+
namespace = f"TESTING-{namespace}"
126133
await refresh_tpuf_namespace(
127134
namespace, namespace_loaders, reset=reset, batch_size=batch_size
128135
)
129136

130137

131138
if __name__ == "__main__":
132139
import asyncio
140+
import sys
141+
142+
if len(sys.argv) > 1:
143+
test_mode = sys.argv[1] != "prod"
144+
else:
145+
test_mode = True
133146

134-
asyncio.run(refresh_tpuf(reset=True))
147+
asyncio.run(refresh_tpuf(reset=True, test_mode=test_mode))

src/raggy/documents.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Annotated
55

66
from jinja2 import Environment, Template
7-
from pydantic import BaseModel, ConfigDict, Field, model_validator
7+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
88

99
from raggy.utilities.ids import generate_prefixed_uuid
1010
from raggy.utilities.text import count_tokens, extract_keywords, hash_text, split_text
@@ -32,11 +32,18 @@ class Document(BaseModel):
3232
text: str = Field(..., description="Document text content.")
3333

3434
embedding: list[float] | None = Field(default=None)
35-
metadata: DocumentMetadata = Field(default_factory=DocumentMetadata)
35+
metadata: DocumentMetadata | dict = Field(default_factory=DocumentMetadata)
3636

3737
tokens: int | None = Field(default=None)
3838
keywords: list[str] = Field(default_factory=list)
3939

40+
@field_validator("metadata", mode="before")
41+
@classmethod
42+
def ensure_metadata(cls, v):
43+
if isinstance(v, dict):
44+
return DocumentMetadata(**v)
45+
return v
46+
4047
@model_validator(mode="after")
4148
def ensure_tokens(self):
4249
if self.tokens is None:

src/raggy/loaders/web.py

+31-29
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import re
32
from typing import Callable, Self
43
from urllib.parse import urljoin
@@ -11,12 +10,11 @@
1110
import raggy
1211
from raggy.documents import Document, document_to_excerpts
1312
from raggy.loaders.base import Loader, MultiLoader
13+
from raggy.utilities.asyncutils import run_concurrent_tasks
1414
from raggy.utilities.collections import batched
1515

1616
user_agent = UserAgent()
1717

18-
URL_CONCURRENCY = asyncio.Semaphore(30)
19-
2018

2119
def ensure_http(url):
2220
if not url.startswith(("http://", "https://")):
@@ -30,7 +28,6 @@ async def sitemap_search(sitemap_url) -> list[str]:
3028
response.raise_for_status()
3129

3230
soup = BeautifulSoup(response.content, "xml")
33-
3431
return [loc.text for loc in soup.find_all("loc")]
3532

3633

@@ -51,28 +48,33 @@ class URLLoader(WebLoader):
5148
"""
5249

5350
source_type: str = "url"
54-
5551
urls: list[str] = Field(default_factory=list)
5652

5753
async def load(self) -> list[Document]:
5854
headers = await self.get_headers()
5955
async with AsyncClient(
6056
headers=headers, timeout=30, follow_redirects=True
6157
) as client:
62-
documents = await asyncio.gather(
63-
*[self.load_url(u, client) for u in self.urls], return_exceptions=True
58+
59+
async def load_url_task(url):
60+
try:
61+
return await self.load_url(url, client)
62+
except Exception as e:
63+
self.logger.error(e)
64+
return None
65+
66+
documents = await run_concurrent_tasks(
67+
[lambda u=url: load_url_task(u) for url in self.urls], max_concurrent=30
6468
)
69+
6570
final_documents = []
6671
for d in documents:
67-
if isinstance(d, Exception):
68-
self.logger.error(d)
69-
elif d is not None:
70-
final_documents.extend(await document_to_excerpts(d)) # type: ignore
72+
if d is not None:
73+
final_documents.extend(await document_to_excerpts(d))
7174
return final_documents
7275

7376
async def load_url(self, url, client) -> Document | None:
74-
async with URL_CONCURRENCY:
75-
response = await client.get(url, follow_redirects=True)
77+
response = await client.get(url, follow_redirects=True)
7678

7779
if not response.status_code == 200:
7880
self.logger.warning_style(
@@ -84,16 +86,17 @@ async def load_url(self, url, client) -> Document | None:
8486
meta_refresh = soup.find(
8587
"meta", attrs={"http-equiv": re.compile(r"refresh", re.I)}
8688
)
87-
if meta_refresh:
88-
refresh_content = meta_refresh.get("content")
89-
redirect_url_match = re.search(r"url=([\S]+)", refresh_content, re.I)
90-
if redirect_url_match:
91-
redirect_url = redirect_url_match.group(1)
92-
# join base url with relative url
93-
redirect_url = urljoin(str(response.url), redirect_url)
94-
# Now ensure the URL includes the protocol
95-
redirect_url = ensure_http(redirect_url)
96-
response = await client.get(redirect_url, follow_redirects=True)
89+
if meta_refresh and isinstance(meta_refresh, BeautifulSoup.Tag):
90+
content = meta_refresh.get("content", "")
91+
if isinstance(content, str):
92+
redirect_url_match = re.search(r"url=([\S]+)", content, re.I)
93+
if redirect_url_match:
94+
redirect_url = redirect_url_match.group(1)
95+
# join base url with relative url
96+
redirect_url = urljoin(str(response.url), redirect_url)
97+
# Now ensure the URL includes the protocol
98+
redirect_url = ensure_http(redirect_url)
99+
response = await client.get(redirect_url, follow_redirects=True)
97100

98101
document = await self.response_to_document(response)
99102
if document:
@@ -103,6 +106,7 @@ async def load_url(self, url, client) -> Document | None:
103106
return document
104107

105108
async def response_to_document(self, response: Response) -> Document:
109+
"""Convert an HTTP response to a Document."""
106110
return Document(
107111
text=await self.get_document_text(response),
108112
metadata=dict(
@@ -128,17 +132,15 @@ async def get_document_text(self, response: Response) -> str:
128132

129133
class SitemapLoader(URLLoader):
130134
"""A loader that loads URLs from a sitemap.
131-
132135
Attributes:
133136
include: A list of strings or regular expressions. Only URLs that match one of these will be included.
134137
exclude: A list of strings or regular expressions. URLs that match one of these will be excluded.
135138
url_loader: The loader to use for loading the URLs.
136-
137139
Examples:
138140
Load all URLs from a sitemap:
139141
```python
140142
from raggy.loaders.web import SitemapLoader
141-
loader = SitemapLoader(urls=["https://askmarvin.ai/sitemap.xml"])
143+
loader = SitemapLoader(urls=["https://controlflow.ai/sitemap.xml"])
142144
documents = await loader.load()
143145
print(documents)
144146
```
@@ -147,11 +149,12 @@ class SitemapLoader(URLLoader):
147149
include: list[str | re.Pattern] = Field(default_factory=list)
148150
exclude: list[str | re.Pattern] = Field(default_factory=list)
149151
url_loader: URLLoader = Field(default_factory=HTMLLoader)
150-
151152
url_processor: Callable[[str], str] = lambda x: x # noqa: E731
152153

153154
async def _get_loader(self: Self) -> MultiLoader:
154-
urls = await asyncio.gather(*[self.load_sitemap(url) for url in self.urls])
155+
urls = await run_concurrent_tasks(
156+
[lambda u=url: self.load_sitemap(u) for url in self.urls], max_concurrent=5
157+
)
155158
return MultiLoader(
156159
loaders=[
157160
type(self.url_loader)(urls=url_batch, headers=await self.get_headers()) # type: ignore
@@ -169,7 +172,6 @@ async def load_sitemap(self, url: str) -> list[str]:
169172
def is_included(url: str) -> bool:
170173
if not self.include:
171174
return True
172-
173175
return any(
174176
(isinstance(i, str) and i in url)
175177
or (isinstance(i, re.Pattern) and re.search(i, url))

src/raggy/settings.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ class Settings(BaseSettings):
6868
extra="allow",
6969
validate_assignment=True,
7070
)
71-
71+
max_concurrent_tasks: int = Field(
72+
default=50, gt=3, description="The maximum number of concurrent tasks to run."
73+
)
7274
html_parser: Callable[[str], str] = default_html_parser
7375

7476
log_level: str = Field(

src/raggy/utilities/asyncutils.py

+28-15
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from typing import Any, Callable, TypeVar
33

44
import anyio
5+
from anyio import create_task_group, to_thread
6+
7+
from raggy import settings
58

69
T = TypeVar("T")
710

@@ -21,23 +24,33 @@ async def run_sync_in_worker_thread(
2124
__fn: Callable[..., T], *args: Any, **kwargs: Any
2225
) -> T:
2326
"""Runs a sync function in a new worker thread so that the main thread's event loop
24-
is not blocked
27+
is not blocked."""
28+
call = partial(__fn, *args, **kwargs)
29+
return await to_thread.run_sync(
30+
call, cancellable=True, limiter=get_thread_limiter()
31+
)
2532

26-
Unlike the anyio function, this defaults to a cancellable thread and does not allow
27-
passing arguments to the anyio function so users can pass kwargs to their function.
2833

29-
Note that cancellation of threads will not result in interrupted computation, the
30-
thread may continue running — the outcome will just be ignored.
34+
async def run_concurrent_tasks(
35+
tasks: list[Callable],
36+
max_concurrent: int = settings.max_concurrent_tasks,
37+
):
38+
"""Run multiple tasks concurrently with a limit on concurrent execution.
3139
3240
Args:
33-
__fn: The function to run in a worker thread
34-
*args: Positional arguments to pass to the function
35-
**kwargs: Keyword arguments to pass to the function
36-
37-
Returns:
38-
The result of the function
41+
tasks: List of async callables to execute
42+
max_concurrent: Maximum number of tasks to run concurrently
3943
"""
40-
call = partial(__fn, *args, **kwargs)
41-
return await anyio.to_thread.run_sync(
42-
call, cancellable=True, limiter=get_thread_limiter()
43-
)
44+
semaphore = anyio.Semaphore(max_concurrent)
45+
results = []
46+
47+
async def _run_task(task: Callable):
48+
async with semaphore:
49+
result = await task()
50+
results.append(result)
51+
52+
async with create_task_group() as tg:
53+
for task in tasks:
54+
tg.start_soon(_run_task, task)
55+
56+
return results

src/raggy/vectorstores/chroma.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919

2020
from raggy.documents import Document as RaggyDocument
21+
from raggy.documents import DocumentMetadata
2122
from raggy.settings import settings
2223
from raggy.utilities.asyncutils import run_sync_in_worker_thread
2324
from raggy.utilities.embeddings import create_openai_embeddings
@@ -90,7 +91,10 @@ async def add(self, documents: list[RaggyDocument]) -> list[ChromaDocument]:
9091
ids = [doc.id for doc in unique_documents]
9192
texts = [doc.text for doc in unique_documents]
9293
metadatas = [
93-
doc.metadata.model_dump(exclude_none=True) for doc in unique_documents
94+
doc.metadata.model_dump(exclude_none=True)
95+
if isinstance(doc.metadata, DocumentMetadata)
96+
else None
97+
for doc in unique_documents
9498
]
9599

96100
embeddings = await create_openai_embeddings(texts)
@@ -145,7 +149,9 @@ async def upsert(self, documents: list[RaggyDocument]) -> list[ChromaDocument]:
145149
ids=[document.id for document in documents],
146150
documents=[document.text for document in documents],
147151
metadatas=[
148-
document.metadata.model_dump(exclude_none=True) or None
152+
document.metadata.model_dump(exclude_none=True)
153+
if isinstance(document.metadata, DocumentMetadata)
154+
else None
149155
for document in documents
150156
],
151157
embeddings=await create_openai_embeddings(

src/raggy/vectorstores/tpuf.py

+23-9
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from turbopuffer.vectors import VectorResult
1212

1313
from raggy.documents import Document
14-
from raggy.utilities.asyncutils import run_sync_in_worker_thread
14+
from raggy.utilities.asyncutils import run_concurrent_tasks, run_sync_in_worker_thread
1515
from raggy.utilities.embeddings import create_openai_embeddings
1616
from raggy.utilities.text import slice_tokens
1717
from raggy.vectorstores.base import Vectorstore
@@ -27,8 +27,10 @@ class TurboPufferSettings(BaseSettings):
2727
extra="ignore",
2828
)
2929

30-
api_key: SecretStr
31-
default_namespace: str = "raggy"
30+
api_key: SecretStr = Field(
31+
default=..., description="The API key for the TurboPuffer instance."
32+
)
33+
default_namespace: str = Field(default="raggy")
3234

3335
@model_validator(mode="after")
3436
def set_api_key(self):
@@ -151,20 +153,32 @@ async def upsert_batched(
151153
self,
152154
documents: Iterable[Document],
153155
batch_size: int = 100,
156+
max_concurrent: int = 25,
154157
):
155-
"""Upsert documents in batches to avoid memory issues with large datasets.
158+
"""Upsert documents in batches concurrently.
156159
157160
Args:
158161
documents: Iterable of documents to upsert
159-
batch_size: Maximum number of documents to upsert in each batch
162+
batch_size: Maximum number of documents per batch
163+
max_concurrent: Maximum number of concurrent upsert operations
160164
"""
161165
document_list = list(documents)
162-
total_docs = len(document_list)
166+
batches = [
167+
document_list[i : i + batch_size]
168+
for i in range(0, len(document_list), batch_size)
169+
]
163170

164-
for i in range(0, total_docs, batch_size):
165-
batch = document_list[i : i + batch_size]
171+
async def process_batch(batch: list[Document], batch_num: int):
166172
await self.upsert(documents=batch)
167-
print(f"Upserted batch {i//batch_size + 1} ({len(batch)} documents)")
173+
print(
174+
f"Upserted batch {batch_num + 1}/{len(batches)} ({len(batch)} documents)"
175+
)
176+
177+
tasks = [
178+
lambda b=batch, i=i: process_batch(b, i) for i, batch in enumerate(batches)
179+
]
180+
181+
await run_concurrent_tasks(tasks, max_concurrent=max_concurrent)
168182

169183

170184
async def query_namespace(

0 commit comments

Comments
 (0)