Skip to content

Commit 71b0312

Browse files
feat: read rows query model class (#752)
1 parent 507da99 commit 71b0312

File tree

3 files changed

+554
-13
lines changed

3 files changed

+554
-13
lines changed

google/cloud/bigtable/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from google.cloud.bigtable.client import Table
2323

2424
from google.cloud.bigtable.read_rows_query import ReadRowsQuery
25+
from google.cloud.bigtable.read_rows_query import RowRange
2526
from google.cloud.bigtable.row_response import RowResponse
2627
from google.cloud.bigtable.row_response import CellResponse
2728

@@ -43,6 +44,7 @@
4344
"Table",
4445
"RowKeySamples",
4546
"ReadRowsQuery",
47+
"RowRange",
4648
"MutationsBatcher",
4749
"Mutation",
4850
"BulkMutationsEntry",

google/cloud/bigtable/read_rows_query.py

Lines changed: 193 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,36 +13,192 @@
1313
# limitations under the License.
1414
#
1515
from __future__ import annotations
16-
from typing import TYPE_CHECKING
16+
from typing import TYPE_CHECKING, Any
17+
from .row_response import row_key
18+
from dataclasses import dataclass
19+
from google.cloud.bigtable.row_filters import RowFilter
1720

1821
if TYPE_CHECKING:
19-
from google.cloud.bigtable.row_filters import RowFilter
2022
from google.cloud.bigtable import RowKeySamples
2123

2224

25+
@dataclass
26+
class _RangePoint:
27+
"""Model class for a point in a row range"""
28+
29+
key: row_key
30+
is_inclusive: bool
31+
32+
33+
@dataclass
34+
class RowRange:
35+
start: _RangePoint | None
36+
end: _RangePoint | None
37+
38+
def __init__(
39+
self,
40+
start_key: str | bytes | None = None,
41+
end_key: str | bytes | None = None,
42+
start_is_inclusive: bool | None = None,
43+
end_is_inclusive: bool | None = None,
44+
):
45+
# check for invalid combinations of arguments
46+
if start_is_inclusive is None:
47+
start_is_inclusive = True
48+
elif start_key is None:
49+
raise ValueError("start_is_inclusive must be set with start_key")
50+
if end_is_inclusive is None:
51+
end_is_inclusive = False
52+
elif end_key is None:
53+
raise ValueError("end_is_inclusive must be set with end_key")
54+
# ensure that start_key and end_key are bytes
55+
if isinstance(start_key, str):
56+
start_key = start_key.encode()
57+
elif start_key is not None and not isinstance(start_key, bytes):
58+
raise ValueError("start_key must be a string or bytes")
59+
if isinstance(end_key, str):
60+
end_key = end_key.encode()
61+
elif end_key is not None and not isinstance(end_key, bytes):
62+
raise ValueError("end_key must be a string or bytes")
63+
64+
self.start = (
65+
_RangePoint(start_key, start_is_inclusive)
66+
if start_key is not None
67+
else None
68+
)
69+
self.end = (
70+
_RangePoint(end_key, end_is_inclusive) if end_key is not None else None
71+
)
72+
73+
def _to_dict(self) -> dict[str, bytes]:
74+
"""Converts this object to a dictionary"""
75+
output = {}
76+
if self.start is not None:
77+
key = "start_key_closed" if self.start.is_inclusive else "start_key_open"
78+
output[key] = self.start.key
79+
if self.end is not None:
80+
key = "end_key_closed" if self.end.is_inclusive else "end_key_open"
81+
output[key] = self.end.key
82+
return output
83+
84+
2385
class ReadRowsQuery:
2486
"""
2587
Class to encapsulate details of a read row request
2688
"""
2789

2890
def __init__(
29-
self, row_keys: list[str | bytes] | str | bytes | None = None, limit=None
91+
self,
92+
row_keys: list[str | bytes] | str | bytes | None = None,
93+
row_ranges: list[RowRange] | RowRange | None = None,
94+
limit: int | None = None,
95+
row_filter: RowFilter | None = None,
3096
):
31-
pass
97+
"""
98+
Create a new ReadRowsQuery
3299
33-
def set_limit(self, limit: int) -> ReadRowsQuery:
34-
raise NotImplementedError
100+
Args:
101+
- row_keys: row keys to include in the query
102+
a query can contain multiple keys, but ranges should be preferred
103+
- row_ranges: ranges of rows to include in the query
104+
- limit: the maximum number of rows to return. None or 0 means no limit
105+
default: None (no limit)
106+
- row_filter: a RowFilter to apply to the query
107+
"""
108+
self.row_keys: set[bytes] = set()
109+
self.row_ranges: list[RowRange | dict[str, bytes]] = []
110+
if row_ranges:
111+
if isinstance(row_ranges, RowRange):
112+
row_ranges = [row_ranges]
113+
for r in row_ranges:
114+
self.add_range(r)
115+
if row_keys:
116+
if not isinstance(row_keys, list):
117+
row_keys = [row_keys]
118+
for k in row_keys:
119+
self.add_key(k)
120+
self.limit: int | None = limit
121+
self.filter: RowFilter | dict[str, Any] | None = row_filter
35122

36-
def set_filter(self, filter: "RowFilter") -> ReadRowsQuery:
37-
raise NotImplementedError
123+
@property
124+
def limit(self) -> int | None:
125+
return self._limit
38126

39-
def add_rows(self, row_id_list: list[str]) -> ReadRowsQuery:
40-
raise NotImplementedError
127+
@limit.setter
128+
def limit(self, new_limit: int | None):
129+
"""
130+
Set the maximum number of rows to return by this query.
131+
132+
None or 0 means no limit
133+
134+
Args:
135+
- new_limit: the new limit to apply to this query
136+
Returns:
137+
- a reference to this query for chaining
138+
Raises:
139+
- ValueError if new_limit is < 0
140+
"""
141+
if new_limit is not None and new_limit < 0:
142+
raise ValueError("limit must be >= 0")
143+
self._limit = new_limit
144+
145+
@property
146+
def filter(self) -> RowFilter | dict[str, Any] | None:
147+
return self._filter
148+
149+
@filter.setter
150+
def filter(self, row_filter: RowFilter | dict[str, Any] | None):
151+
"""
152+
Set a RowFilter to apply to this query
153+
154+
Args:
155+
- row_filter: a RowFilter to apply to this query
156+
Can be a RowFilter object or a dict representation
157+
Returns:
158+
- a reference to this query for chaining
159+
"""
160+
if not (
161+
isinstance(row_filter, dict)
162+
or isinstance(row_filter, RowFilter)
163+
or row_filter is None
164+
):
165+
raise ValueError("row_filter must be a RowFilter or dict")
166+
self._filter = row_filter
167+
168+
def add_key(self, row_key: str | bytes):
169+
"""
170+
Add a row key to this query
171+
172+
A query can contain multiple keys, but ranges should be preferred
173+
174+
Args:
175+
- row_key: a key to add to this query
176+
Returns:
177+
- a reference to this query for chaining
178+
Raises:
179+
- ValueError if an input is not a string or bytes
180+
"""
181+
if isinstance(row_key, str):
182+
row_key = row_key.encode()
183+
elif not isinstance(row_key, bytes):
184+
raise ValueError("row_key must be string or bytes")
185+
self.row_keys.add(row_key)
41186

42187
def add_range(
43-
self, start_key: str | bytes | None = None, end_key: str | bytes | None = None
44-
) -> ReadRowsQuery:
45-
raise NotImplementedError
188+
self,
189+
row_range: RowRange | dict[str, bytes],
190+
):
191+
"""
192+
Add a range of row keys to this query.
193+
194+
Args:
195+
- row_range: a range of row keys to add to this query
196+
Can be a RowRange object or a dict representation in
197+
RowRange proto format
198+
"""
199+
if not (isinstance(row_range, dict) or isinstance(row_range, RowRange)):
200+
raise ValueError("row_range must be a RowRange or dict")
201+
self.row_ranges.append(row_range)
46202

47203
def shard(self, shard_keys: "RowKeySamples" | None = None) -> list[ReadRowsQuery]:
48204
"""
@@ -54,3 +210,27 @@ def shard(self, shard_keys: "RowKeySamples" | None = None) -> list[ReadRowsQuery
54210
query (if possible)
55211
"""
56212
raise NotImplementedError
213+
214+
def _to_dict(self) -> dict[str, Any]:
215+
"""
216+
Convert this query into a dictionary that can be used to construct a
217+
ReadRowsRequest protobuf
218+
"""
219+
row_ranges = []
220+
for r in self.row_ranges:
221+
dict_range = r._to_dict() if isinstance(r, RowRange) else r
222+
row_ranges.append(dict_range)
223+
row_keys = list(self.row_keys)
224+
row_keys.sort()
225+
row_set = {"row_keys": row_keys, "row_ranges": row_ranges}
226+
final_dict: dict[str, Any] = {
227+
"rows": row_set,
228+
}
229+
dict_filter = (
230+
self.filter.to_dict() if isinstance(self.filter, RowFilter) else self.filter
231+
)
232+
if dict_filter:
233+
final_dict["filter"] = dict_filter
234+
if self.limit is not None:
235+
final_dict["rows_limit"] = self.limit
236+
return final_dict

0 commit comments

Comments
 (0)