Skip to content

Commit d8c3e56

Browse files
feat(app): optimize boards queries
Use SQL instead of python to retrieve image count, asset count and board cover image. This reduces the number of SQL queries needed to list all boards. Previously, we did `1 + 2 * board_count` queries:: - 1 query to get the list of board records - 1 query per board to get its total count - 1 query per board to get its cover image Then, on the frontend, we made two additional network requests to get each board's counts: - 1 request (== 1 SQL query) for image count - 1 request (== 1 SQL query) for asset count All of this information is now retrieved in a single SQL query, and provided via single network request. As part of this change, `BoardRecord` now includes `image_count`, `asset_count` and `cover_image_name`. This makes `BoardDTO` redundant, but removing it is a deeper change...
1 parent 5303f48 commit d8c3e56

File tree

5 files changed

+157
-123
lines changed

5 files changed

+157
-123
lines changed

Diff for: invokeai/app/services/board_records/board_records_common.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime
2-
from typing import Optional, Union
2+
from typing import Any, Optional, Union
33

44
from pydantic import BaseModel, Field
55

@@ -26,21 +26,25 @@ class BoardRecord(BaseModelExcludeNull):
2626
"""Whether or not the board is archived."""
2727
is_private: Optional[bool] = Field(default=None, description="Whether the board is private.")
2828
"""Whether the board is private."""
29+
image_count: int = Field(description="The number of images in the board.")
30+
asset_count: int = Field(description="The number of assets in the board.")
2931

3032

31-
def deserialize_board_record(board_dict: dict) -> BoardRecord:
33+
def deserialize_board_record(board_dict: dict[str, Any]) -> BoardRecord:
3234
"""Deserializes a board record."""
3335

3436
# Retrieve all the values, setting "reasonable" defaults if they are not present.
3537

3638
board_id = board_dict.get("board_id", "unknown")
3739
board_name = board_dict.get("board_name", "unknown")
38-
cover_image_name = board_dict.get("cover_image_name", "unknown")
40+
cover_image_name = board_dict.get("cover_image_name", None)
3941
created_at = board_dict.get("created_at", get_iso_timestamp())
4042
updated_at = board_dict.get("updated_at", get_iso_timestamp())
4143
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
4244
archived = board_dict.get("archived", False)
4345
is_private = board_dict.get("is_private", False)
46+
image_count = board_dict.get("image_count", 0)
47+
asset_count = board_dict.get("asset_count", 0)
4448

4549
return BoardRecord(
4650
board_id=board_id,
@@ -51,6 +55,8 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
5155
deleted_at=deleted_at,
5256
archived=archived,
5357
is_private=is_private,
58+
image_count=image_count,
59+
asset_count=asset_count,
5460
)
5561

5662

@@ -63,21 +69,21 @@ class BoardChanges(BaseModel, extra="forbid"):
6369
class BoardRecordNotFoundException(Exception):
6470
"""Raised when an board record is not found."""
6571

66-
def __init__(self, message="Board record not found"):
72+
def __init__(self, message: str = "Board record not found"):
6773
super().__init__(message)
6874

6975

7076
class BoardRecordSaveException(Exception):
7177
"""Raised when an board record cannot be saved."""
7278

73-
def __init__(self, message="Board record not saved"):
79+
def __init__(self, message: str = "Board record not saved"):
7480
super().__init__(message)
7581

7682

7783
class BoardRecordDeleteException(Exception):
7884
"""Raised when an board record cannot be deleted."""
7985

80-
def __init__(self, message="Board record not deleted"):
86+
def __init__(self, message: str = "Board record not deleted"):
8187
super().__init__(message)
8288

8389

Diff for: invokeai/app/services/board_records/board_records_sqlite.py

+116-58
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,114 @@
1616
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
1717
from invokeai.app.util.misc import uuid_string
1818

19+
_BASE_BOARD_RECORD_QUERY = """
20+
-- This query retrieves board records, joining with the board_images and images tables to get image counts and cover image names.
21+
-- It is not a complete query, as it is missing a GROUP BY or WHERE clause (and is unterminated).
22+
SELECT b.board_id,
23+
b.board_name,
24+
b.created_at,
25+
b.updated_at,
26+
b.archived,
27+
-- Count the number of images in the board, alias image_count
28+
COUNT(
29+
CASE
30+
WHEN i.image_category in ('general') -- "Images" are images in the 'general' category
31+
AND i.is_intermediate = 0 THEN 1 -- Intermediates are not counted
32+
END
33+
) AS image_count,
34+
-- Count the number of assets in the board, alias asset_count
35+
COUNT(
36+
CASE
37+
WHEN i.image_category in ('control', 'mask', 'user', 'other') -- "Assets" are images in any of the other categories ('control', 'mask', 'user', 'other')
38+
AND i.is_intermediate = 0 THEN 1 -- Intermediates are not counted
39+
END
40+
) AS asset_count,
41+
-- Get the name of the the most recent image in the board, alias cover_image_name
42+
(
43+
SELECT bi.image_name
44+
FROM board_images bi
45+
JOIN images i ON bi.image_name = i.image_name
46+
WHERE bi.board_id = b.board_id
47+
AND i.is_intermediate = 0 -- Intermediates cannot be cover images
48+
ORDER BY i.created_at DESC -- Sort by created_at to get the most recent image
49+
LIMIT 1
50+
) AS cover_image_name
51+
FROM boards b
52+
LEFT JOIN board_images bi ON b.board_id = bi.board_id
53+
LEFT JOIN images i ON bi.image_name = i.image_name
54+
"""
55+
56+
57+
def get_paginated_list_board_records_queries(include_archived: bool) -> str:
58+
"""Gets a query to retrieve a paginated list of board records. The query has placeholders for limit and offset.
59+
60+
Args:
61+
include_archived: Whether to include archived board records in the results.
62+
63+
Returns:
64+
A query to retrieve a paginated list of board records.
65+
"""
66+
67+
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
68+
69+
# The GROUP BY must be added _after_ the WHERE clause!
70+
query = f"""
71+
{_BASE_BOARD_RECORD_QUERY}
72+
{archived_condition}
73+
GROUP BY b.board_id,
74+
b.board_name,
75+
b.created_at,
76+
b.updated_at
77+
ORDER BY b.created_at DESC
78+
LIMIT ? OFFSET ?;
79+
"""
80+
81+
return query
82+
83+
84+
def get_total_boards_count_query(include_archived: bool) -> str:
85+
"""Gets a query to retrieve the total count of board records.
86+
87+
Args:
88+
include_archived: Whether to include archived board records in the count.
89+
90+
Returns:
91+
A query to retrieve the total count of board records.
92+
"""
93+
94+
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
95+
96+
return f"SELECT COUNT(*) FROM boards {archived_condition};"
97+
98+
99+
def get_list_all_board_records_query(include_archived: bool) -> str:
100+
"""Gets a query to retrieve all board records.
101+
102+
Args:
103+
include_archived: Whether to include archived board records in the results.
104+
105+
Returns:
106+
A query to retrieve all board records.
107+
"""
108+
109+
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
110+
111+
return f"""
112+
{_BASE_BOARD_RECORD_QUERY}
113+
{archived_condition}
114+
GROUP BY b.board_id,
115+
b.board_name,
116+
b.created_at,
117+
b.updated_at
118+
ORDER BY b.created_at DESC;
119+
"""
120+
121+
122+
def get_board_record_query() -> str:
123+
"""Gets a query to retrieve a board record. The query has a placeholder for the board_id."""
124+
125+
return f"{_BASE_BOARD_RECORD_QUERY} WHERE b.board_id = ?;"
126+
19127

20128
class SqliteBoardRecordStorage(BoardRecordStorageBase):
21129
_conn: sqlite3.Connection
@@ -77,11 +185,7 @@ def get(
77185
try:
78186
self._lock.acquire()
79187
self._cursor.execute(
80-
"""--sql
81-
SELECT *
82-
FROM boards
83-
WHERE board_id = ?;
84-
""",
188+
get_board_record_query(),
85189
(board_id,),
86190
)
87191

@@ -93,7 +197,7 @@ def get(
93197
self._lock.release()
94198
if result is None:
95199
raise BoardRecordNotFoundException
96-
return BoardRecord(**dict(result))
200+
return deserialize_board_record(dict(result))
97201

98202
def update(
99203
self,
@@ -150,45 +254,15 @@ def get_many(
150254
try:
151255
self._lock.acquire()
152256

153-
# Build base query
154-
base_query = """
155-
SELECT *
156-
FROM boards
157-
{archived_filter}
158-
ORDER BY created_at DESC
159-
LIMIT ? OFFSET ?;
160-
"""
161-
162-
# Determine archived filter condition
163-
if include_archived:
164-
archived_filter = ""
165-
else:
166-
archived_filter = "WHERE archived = 0"
257+
main_query = get_paginated_list_board_records_queries(include_archived=include_archived)
167258

168-
final_query = base_query.format(archived_filter=archived_filter)
169-
170-
# Execute query to fetch boards
171-
self._cursor.execute(final_query, (limit, offset))
259+
self._cursor.execute(main_query, (limit, offset))
172260

173261
result = cast(list[sqlite3.Row], self._cursor.fetchall())
174262
boards = [deserialize_board_record(dict(r)) for r in result]
175263

176-
# Determine count query
177-
if include_archived:
178-
count_query = """
179-
SELECT COUNT(*)
180-
FROM boards;
181-
"""
182-
else:
183-
count_query = """
184-
SELECT COUNT(*)
185-
FROM boards
186-
WHERE archived = 0;
187-
"""
188-
189-
# Execute count query
190-
self._cursor.execute(count_query)
191-
264+
total_query = get_total_boards_count_query(include_archived=include_archived)
265+
self._cursor.execute(total_query)
192266
count = cast(int, self._cursor.fetchone()[0])
193267

194268
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
@@ -202,26 +276,10 @@ def get_many(
202276
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
203277
try:
204278
self._lock.acquire()
205-
206-
base_query = """
207-
SELECT *
208-
FROM boards
209-
{archived_filter}
210-
ORDER BY created_at DESC
211-
"""
212-
213-
if include_archived:
214-
archived_filter = ""
215-
else:
216-
archived_filter = "WHERE archived = 0"
217-
218-
final_query = base_query.format(archived_filter=archived_filter)
219-
220-
self._cursor.execute(final_query)
221-
279+
query = get_list_all_board_records_query(include_archived=include_archived)
280+
self._cursor.execute(query)
222281
result = cast(list[sqlite3.Row], self._cursor.fetchall())
223282
boards = [deserialize_board_record(dict(r)) for r in result]
224-
225283
return boards
226284

227285
except sqlite3.Error as e:

Diff for: invokeai/app/services/boards/boards_common.py

+3-18
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,8 @@
1-
from typing import Optional
2-
3-
from pydantic import Field
4-
51
from invokeai.app.services.board_records.board_records_common import BoardRecord
62

73

4+
# TODO(psyche): BoardDTO is now identical to BoardRecord. We should consider removing it.
85
class BoardDTO(BoardRecord):
9-
"""Deserialized board record with cover image URL and image count."""
10-
11-
cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
12-
"""The URL of the thumbnail of the most recent image in the board."""
13-
image_count: int = Field(description="The number of images in the board.")
14-
"""The number of images in the board."""
15-
6+
"""Deserialized board record."""
167

17-
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
18-
"""Converts a board record to a board DTO."""
19-
return BoardDTO(
20-
**board_record.model_dump(exclude={"cover_image_name"}),
21-
cover_image_name=cover_image_name,
22-
image_count=image_count,
23-
)
8+
pass

Diff for: invokeai/app/services/boards/boards_default.py

+6-39
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from invokeai.app.services.board_records.board_records_common import BoardChanges
22
from invokeai.app.services.boards.boards_base import BoardServiceABC
3-
from invokeai.app.services.boards.boards_common import BoardDTO, board_record_to_dto
3+
from invokeai.app.services.boards.boards_common import BoardDTO
44
from invokeai.app.services.invoker import Invoker
55
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
66

@@ -16,32 +16,19 @@ def create(
1616
board_name: str,
1717
) -> BoardDTO:
1818
board_record = self.__invoker.services.board_records.save(board_name)
19-
return board_record_to_dto(board_record, None, 0)
19+
return BoardDTO.model_validate(board_record.model_dump())
2020

2121
def get_dto(self, board_id: str) -> BoardDTO:
2222
board_record = self.__invoker.services.board_records.get(board_id)
23-
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
24-
if cover_image:
25-
cover_image_name = cover_image.image_name
26-
else:
27-
cover_image_name = None
28-
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
29-
return board_record_to_dto(board_record, cover_image_name, image_count)
23+
return BoardDTO.model_validate(board_record.model_dump())
3024

3125
def update(
3226
self,
3327
board_id: str,
3428
changes: BoardChanges,
3529
) -> BoardDTO:
3630
board_record = self.__invoker.services.board_records.update(board_id, changes)
37-
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
38-
if cover_image:
39-
cover_image_name = cover_image.image_name
40-
else:
41-
cover_image_name = None
42-
43-
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
44-
return board_record_to_dto(board_record, cover_image_name, image_count)
31+
return BoardDTO.model_validate(board_record.model_dump())
4532

4633
def delete(self, board_id: str) -> None:
4734
self.__invoker.services.board_records.delete(board_id)
@@ -50,30 +37,10 @@ def get_many(
5037
self, offset: int = 0, limit: int = 10, include_archived: bool = False
5138
) -> OffsetPaginatedResults[BoardDTO]:
5239
board_records = self.__invoker.services.board_records.get_many(offset, limit, include_archived)
53-
board_dtos = []
54-
for r in board_records.items:
55-
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
56-
if cover_image:
57-
cover_image_name = cover_image.image_name
58-
else:
59-
cover_image_name = None
60-
61-
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
62-
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
63-
40+
board_dtos = [BoardDTO.model_validate(r.model_dump()) for r in board_records.items]
6441
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
6542

6643
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
6744
board_records = self.__invoker.services.board_records.get_all(include_archived)
68-
board_dtos = []
69-
for r in board_records:
70-
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
71-
if cover_image:
72-
cover_image_name = cover_image.image_name
73-
else:
74-
cover_image_name = None
75-
76-
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
77-
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
78-
45+
board_dtos = [BoardDTO.model_validate(r.model_dump()) for r in board_records]
7946
return board_dtos

0 commit comments

Comments
 (0)