Skip to content

Commit 15f9c59

Browse files
authored
Add paginate and async_paginate method (#197)
This PR adds `paginate` and `async_paginate` methods that let you iterate over a paginated list of resources. **sync** ```python import replicate for page in replicate.paginate(replicate.collections.list): for collection in page: print(collection.name) ``` **async** ```python import replicate async for page in replicate.async_paginate(replicate.collections.async_list): for collection in page: print(collection.name) ``` --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 45e3020 commit 15f9c59

File tree

8 files changed

+96
-11
lines changed

8 files changed

+96
-11
lines changed

README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,21 +223,29 @@ You can the models you've created:
223223
replicate.models.list()
224224
```
225225

226-
Lists of models are paginated. You can get the next page of models by passing the `next` property as an argument to the `list` method. Here's how you can get all the models you've created:
226+
Lists of models are paginated. You can get the next page of models by passing the `next` property as an argument to the `list` method, or you can use the `paginate` method to fetch pages automatically.
227227

228228
```python
229+
# Automatic pagination using `replicate.paginate` (recommended)
229230
models = []
230-
page = replicate.models.list()
231+
for page in replicate.paginate(replicate.models.list):
232+
models.extend(page.results)
233+
if len(models) > 100:
234+
break
231235

236+
# Manual pagination using `next` cursors
237+
page = replicate.models.list()
232238
while page:
233239
models.extend(page.results)
240+
if len(models) > 100:
241+
break
234242
page = replicate.models.list(page.next) if page.next else None
235243
```
236244

237245
You can also find collections of featured models on Replicate:
238246

239247
```python
240-
>>> collections = replicate.collections.list()
248+
>>> collections = [collection for page in replicate.paginate(replicate.collections.list) for collection in page]
241249
>>> collections[0].slug
242250
"vision-models"
243251
>>> collections[0].description

replicate/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
from replicate.client import Client
2+
from replicate.pagination import async_paginate as _async_paginate
3+
from replicate.pagination import paginate as _paginate
24

35
default_client = Client()
46

57
run = default_client.run
68
async_run = default_client.async_run
79

10+
paginate = _paginate
11+
async_paginate = _async_paginate
12+
813
collections = default_client.collections
914
hardware = default_client.hardware
1015
deployments = default_client.deployments

replicate/collection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class Collections(Namespace):
5555

5656
def list(
5757
self,
58-
cursor: Union[str, "ellipsis"] = ..., # noqa: F821
58+
cursor: Union[str, "ellipsis", None] = ..., # noqa: F821
5959
) -> Page[Collection]:
6060
"""
6161
List collections of models.
@@ -82,7 +82,7 @@ def list(
8282

8383
async def async_list(
8484
self,
85-
cursor: Union[str, "ellipsis"] = ..., # noqa: F821
85+
cursor: Union[str, "ellipsis", None] = ..., # noqa: F821
8686
) -> Page[Collection]:
8787
"""
8888
List collections of models.

replicate/model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class Models(Namespace):
140140

141141
model = Model
142142

143-
def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Model]: # noqa: F821
143+
def list(self, cursor: Union[str, "ellipsis", None] = ...) -> Page[Model]: # noqa: F821
144144
"""
145145
List all public models.
146146
@@ -164,7 +164,10 @@ def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Model]: # noqa: F8
164164

165165
return Page[Model](**obj)
166166

167-
async def async_list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Model]: # noqa: F821
167+
async def async_list(
168+
self,
169+
cursor: Union[str, "ellipsis", None] = ..., # noqa: F821
170+
) -> Page[Model]:
168171
"""
169172
List all public models.
170173

replicate/pagination.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
from typing import (
22
TYPE_CHECKING,
3+
AsyncGenerator,
4+
Awaitable,
5+
Callable,
6+
Generator,
37
Generic,
48
List,
59
Optional,
610
TypeVar,
11+
Union,
712
)
813

914
try:
@@ -41,3 +46,35 @@ def __getitem__(self, index: int) -> T:
4146

4247
def __len__(self) -> int:
4348
return len(self.results)
49+
50+
51+
def paginate(
52+
list_method: Callable[[Union[str, "ellipsis", None]], Page[T]], # noqa: F821
53+
) -> Generator[Page[T], None, None]:
54+
"""
55+
Iterate over all items using the provided list method.
56+
57+
Args:
58+
list_method: A method that takes a cursor argument and returns a Page of items.
59+
"""
60+
cursor: Union[str, "ellipsis", None] = ... # noqa: F821
61+
while cursor is not None:
62+
page = list_method(cursor)
63+
yield page
64+
cursor = page.next
65+
66+
67+
async def async_paginate(
68+
list_method: Callable[[Union[str, "ellipsis", None]], Awaitable[Page[T]]], # noqa: F821
69+
) -> AsyncGenerator[Page[T], None]:
70+
"""
71+
Asynchronously iterate over all items using the provided list method.
72+
73+
Args:
74+
list_method: An async method that takes a cursor argument and returns a Page of items.
75+
"""
76+
cursor: Union[str, "ellipsis", None] = ... # noqa: F821
77+
while cursor is not None:
78+
page = await list_method(cursor)
79+
yield page
80+
cursor = page.next

replicate/prediction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ class Predictions(Namespace):
172172
Namespace for operations related to predictions.
173173
"""
174174

175-
def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Prediction]: # noqa: F821
175+
def list(self, cursor: Union[str, "ellipsis", None] = ...) -> Page[Prediction]: # noqa: F821
176176
"""
177177
List your predictions.
178178
@@ -200,7 +200,7 @@ def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Prediction]: # noq
200200

201201
async def async_list(
202202
self,
203-
cursor: Union[str, "ellipsis"] = ..., # noqa: F821
203+
cursor: Union[str, "ellipsis", None] = ..., # noqa: F821
204204
) -> Page[Prediction]:
205205
"""
206206
List your predictions.

replicate/training.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class Trainings(Namespace):
101101
Namespace for operations related to trainings.
102102
"""
103103

104-
def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Training]: # noqa: F821
104+
def list(self, cursor: Union[str, "ellipsis", None] = ...) -> Page[Training]: # noqa: F821
105105
"""
106106
List your trainings.
107107
@@ -127,7 +127,10 @@ def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Training]: # noqa:
127127

128128
return Page[Training](**obj)
129129

130-
async def async_list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Training]: # noqa: F821
130+
async def async_list(
131+
self,
132+
cursor: Union[str, "ellipsis", None] = ..., # noqa: F821
133+
) -> Page[Training]:
131134
"""
132135
List your trainings.
133136

tests/test_pagination.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,32 @@
77
async def test_paginate_with_none_cursor(mock_replicate_api_token):
88
with pytest.raises(ValueError):
99
replicate.models.list(None)
10+
11+
12+
@pytest.mark.vcr("collections-list.yaml")
13+
@pytest.mark.asyncio
14+
@pytest.mark.parametrize("async_flag", [True, False])
15+
async def test_paginate(async_flag):
16+
found = False
17+
18+
if async_flag:
19+
async for page in replicate.async_paginate(replicate.collections.async_list):
20+
assert page.next is None
21+
assert page.previous is None
22+
23+
for collection in page:
24+
if collection.slug == "text-to-image":
25+
found = True
26+
break
27+
28+
else:
29+
for page in replicate.paginate(replicate.collections.list):
30+
assert page.next is None
31+
assert page.previous is None
32+
33+
for collection in page:
34+
if collection.slug == "text-to-image":
35+
found = True
36+
break
37+
38+
assert found

0 commit comments

Comments
 (0)