13
13
# limitations under the License.
14
14
#
15
15
from __future__ import annotations
16
- from typing import TYPE_CHECKING
16
+ from typing import TYPE_CHECKING , Any
17
+ from .row_response import row_key
18
+ from dataclasses import dataclass
19
+ from google .cloud .bigtable .row_filters import RowFilter
17
20
18
21
if TYPE_CHECKING :
19
- from google .cloud .bigtable .row_filters import RowFilter
20
22
from google .cloud .bigtable import RowKeySamples
21
23
22
24
25
+ @dataclass
26
+ class _RangePoint :
27
+ """Model class for a point in a row range"""
28
+
29
+ key : row_key
30
+ is_inclusive : bool
31
+
32
+
33
+ @dataclass
34
+ class RowRange :
35
+ start : _RangePoint | None
36
+ end : _RangePoint | None
37
+
38
+ def __init__ (
39
+ self ,
40
+ start_key : str | bytes | None = None ,
41
+ end_key : str | bytes | None = None ,
42
+ start_is_inclusive : bool | None = None ,
43
+ end_is_inclusive : bool | None = None ,
44
+ ):
45
+ # check for invalid combinations of arguments
46
+ if start_is_inclusive is None :
47
+ start_is_inclusive = True
48
+ elif start_key is None :
49
+ raise ValueError ("start_is_inclusive must be set with start_key" )
50
+ if end_is_inclusive is None :
51
+ end_is_inclusive = False
52
+ elif end_key is None :
53
+ raise ValueError ("end_is_inclusive must be set with end_key" )
54
+ # ensure that start_key and end_key are bytes
55
+ if isinstance (start_key , str ):
56
+ start_key = start_key .encode ()
57
+ elif start_key is not None and not isinstance (start_key , bytes ):
58
+ raise ValueError ("start_key must be a string or bytes" )
59
+ if isinstance (end_key , str ):
60
+ end_key = end_key .encode ()
61
+ elif end_key is not None and not isinstance (end_key , bytes ):
62
+ raise ValueError ("end_key must be a string or bytes" )
63
+
64
+ self .start = (
65
+ _RangePoint (start_key , start_is_inclusive )
66
+ if start_key is not None
67
+ else None
68
+ )
69
+ self .end = (
70
+ _RangePoint (end_key , end_is_inclusive ) if end_key is not None else None
71
+ )
72
+
73
+ def _to_dict (self ) -> dict [str , bytes ]:
74
+ """Converts this object to a dictionary"""
75
+ output = {}
76
+ if self .start is not None :
77
+ key = "start_key_closed" if self .start .is_inclusive else "start_key_open"
78
+ output [key ] = self .start .key
79
+ if self .end is not None :
80
+ key = "end_key_closed" if self .end .is_inclusive else "end_key_open"
81
+ output [key ] = self .end .key
82
+ return output
83
+
84
+
23
85
class ReadRowsQuery :
24
86
"""
25
87
Class to encapsulate details of a read row request
26
88
"""
27
89
28
90
def __init__ (
29
- self , row_keys : list [str | bytes ] | str | bytes | None = None , limit = None
91
+ self ,
92
+ row_keys : list [str | bytes ] | str | bytes | None = None ,
93
+ row_ranges : list [RowRange ] | RowRange | None = None ,
94
+ limit : int | None = None ,
95
+ row_filter : RowFilter | None = None ,
30
96
):
31
- pass
97
+ """
98
+ Create a new ReadRowsQuery
32
99
33
- def set_limit (self , limit : int ) -> ReadRowsQuery :
34
- raise NotImplementedError
100
+ Args:
101
+ - row_keys: row keys to include in the query
102
+ a query can contain multiple keys, but ranges should be preferred
103
+ - row_ranges: ranges of rows to include in the query
104
+ - limit: the maximum number of rows to return. None or 0 means no limit
105
+ default: None (no limit)
106
+ - row_filter: a RowFilter to apply to the query
107
+ """
108
+ self .row_keys : set [bytes ] = set ()
109
+ self .row_ranges : list [RowRange | dict [str , bytes ]] = []
110
+ if row_ranges :
111
+ if isinstance (row_ranges , RowRange ):
112
+ row_ranges = [row_ranges ]
113
+ for r in row_ranges :
114
+ self .add_range (r )
115
+ if row_keys :
116
+ if not isinstance (row_keys , list ):
117
+ row_keys = [row_keys ]
118
+ for k in row_keys :
119
+ self .add_key (k )
120
+ self .limit : int | None = limit
121
+ self .filter : RowFilter | dict [str , Any ] | None = row_filter
35
122
36
- def set_filter (self , filter : "RowFilter" ) -> ReadRowsQuery :
37
- raise NotImplementedError
123
+ @property
124
+ def limit (self ) -> int | None :
125
+ return self ._limit
38
126
39
- def add_rows (self , row_id_list : list [str ]) -> ReadRowsQuery :
40
- raise NotImplementedError
127
+ @limit .setter
128
+ def limit (self , new_limit : int | None ):
129
+ """
130
+ Set the maximum number of rows to return by this query.
131
+
132
+ None or 0 means no limit
133
+
134
+ Args:
135
+ - new_limit: the new limit to apply to this query
136
+ Returns:
137
+ - a reference to this query for chaining
138
+ Raises:
139
+ - ValueError if new_limit is < 0
140
+ """
141
+ if new_limit is not None and new_limit < 0 :
142
+ raise ValueError ("limit must be >= 0" )
143
+ self ._limit = new_limit
144
+
145
+ @property
146
+ def filter (self ) -> RowFilter | dict [str , Any ] | None :
147
+ return self ._filter
148
+
149
+ @filter .setter
150
+ def filter (self , row_filter : RowFilter | dict [str , Any ] | None ):
151
+ """
152
+ Set a RowFilter to apply to this query
153
+
154
+ Args:
155
+ - row_filter: a RowFilter to apply to this query
156
+ Can be a RowFilter object or a dict representation
157
+ Returns:
158
+ - a reference to this query for chaining
159
+ """
160
+ if not (
161
+ isinstance (row_filter , dict )
162
+ or isinstance (row_filter , RowFilter )
163
+ or row_filter is None
164
+ ):
165
+ raise ValueError ("row_filter must be a RowFilter or dict" )
166
+ self ._filter = row_filter
167
+
168
+ def add_key (self , row_key : str | bytes ):
169
+ """
170
+ Add a row key to this query
171
+
172
+ A query can contain multiple keys, but ranges should be preferred
173
+
174
+ Args:
175
+ - row_key: a key to add to this query
176
+ Returns:
177
+ - a reference to this query for chaining
178
+ Raises:
179
+ - ValueError if an input is not a string or bytes
180
+ """
181
+ if isinstance (row_key , str ):
182
+ row_key = row_key .encode ()
183
+ elif not isinstance (row_key , bytes ):
184
+ raise ValueError ("row_key must be string or bytes" )
185
+ self .row_keys .add (row_key )
41
186
42
187
def add_range (
43
- self , start_key : str | bytes | None = None , end_key : str | bytes | None = None
44
- ) -> ReadRowsQuery :
45
- raise NotImplementedError
188
+ self ,
189
+ row_range : RowRange | dict [str , bytes ],
190
+ ):
191
+ """
192
+ Add a range of row keys to this query.
193
+
194
+ Args:
195
+ - row_range: a range of row keys to add to this query
196
+ Can be a RowRange object or a dict representation in
197
+ RowRange proto format
198
+ """
199
+ if not (isinstance (row_range , dict ) or isinstance (row_range , RowRange )):
200
+ raise ValueError ("row_range must be a RowRange or dict" )
201
+ self .row_ranges .append (row_range )
46
202
47
203
def shard (self , shard_keys : "RowKeySamples" | None = None ) -> list [ReadRowsQuery ]:
48
204
"""
@@ -54,3 +210,27 @@ def shard(self, shard_keys: "RowKeySamples" | None = None) -> list[ReadRowsQuery
54
210
query (if possible)
55
211
"""
56
212
raise NotImplementedError
213
+
214
+ def _to_dict (self ) -> dict [str , Any ]:
215
+ """
216
+ Convert this query into a dictionary that can be used to construct a
217
+ ReadRowsRequest protobuf
218
+ """
219
+ row_ranges = []
220
+ for r in self .row_ranges :
221
+ dict_range = r ._to_dict () if isinstance (r , RowRange ) else r
222
+ row_ranges .append (dict_range )
223
+ row_keys = list (self .row_keys )
224
+ row_keys .sort ()
225
+ row_set = {"row_keys" : row_keys , "row_ranges" : row_ranges }
226
+ final_dict : dict [str , Any ] = {
227
+ "rows" : row_set ,
228
+ }
229
+ dict_filter = (
230
+ self .filter .to_dict () if isinstance (self .filter , RowFilter ) else self .filter
231
+ )
232
+ if dict_filter :
233
+ final_dict ["filter" ] = dict_filter
234
+ if self .limit is not None :
235
+ final_dict ["rows_limit" ] = self .limit
236
+ return final_dict
0 commit comments