16
16
# under the License.
17
17
18
18
import collections .abc
19
+ from copy import deepcopy
19
20
from itertools import chain
21
+ from typing import Any , Callable , ClassVar , Optional , Protocol , TypeVar , Union , overload
20
22
21
23
# 'SF' looks unused but the test suite assumes it's available
22
24
# from this module so others are liable to do so as well.
23
25
from .function import SF # noqa: F401
24
26
from .function import ScoreFunction
25
27
from .utils import DslBase
26
28
29
+ _T = TypeVar ("_T" )
30
+ _M = TypeVar ("_M" , bound = collections .abc .Mapping [str , Any ])
27
31
28
- def Q (name_or_query = "match_all" , ** params ):
32
+
33
+ class QProxiedProtocol (Protocol [_T ]):
34
+ _proxied : _T
35
+
36
+
37
+ @overload
38
+ def Q (name_or_query : collections .abc .MutableMapping [str , _M ]) -> "Query" : ...
39
+
40
+
41
+ @overload
42
+ def Q (name_or_query : "Query" ) -> "Query" : ...
43
+
44
+
45
+ @overload
46
+ def Q (name_or_query : QProxiedProtocol [_T ]) -> _T : ...
47
+
48
+
49
+ @overload
50
+ def Q (name_or_query : str , ** params : Any ) -> "Query" : ...
51
+
52
+
53
+ def Q (
54
+ name_or_query : Union [
55
+ str ,
56
+ "Query" ,
57
+ QProxiedProtocol [_T ],
58
+ collections .abc .MutableMapping [str , _M ],
59
+ ] = "match_all" ,
60
+ ** params : Any ,
61
+ ) -> Union ["Query" , _T ]:
29
62
# {"match": {"title": "python"}}
30
- if isinstance (name_or_query , collections .abc .Mapping ):
63
+ if isinstance (name_or_query , collections .abc .MutableMapping ):
31
64
if params :
32
65
raise ValueError ("Q() cannot accept parameters when passing in a dict." )
33
66
if len (name_or_query ) != 1 :
34
67
raise ValueError (
35
68
'Q() can only accept dict with a single query ({"match": {...}}). '
36
69
"Instead it got (%r)" % name_or_query
37
70
)
38
- name , params = name_or_query . copy ( ).popitem ()
39
- return Query .get_dsl_class (name )(_expand__to_dot = False , ** params )
71
+ name , q_params = deepcopy ( name_or_query ).popitem ()
72
+ return Query .get_dsl_class (name )(_expand__to_dot = False , ** q_params )
40
73
41
74
# MatchAll()
42
75
if isinstance (name_or_query , Query ):
@@ -57,26 +90,31 @@ def Q(name_or_query="match_all", **params):
57
90
class Query (DslBase ):
58
91
_type_name = "query"
59
92
_type_shortcut = staticmethod (Q )
60
- name = None
93
+ name : ClassVar [Optional [str ]] = None
94
+
95
+ # Add type annotations for methods not defined in every subclass
96
+ __ror__ : ClassVar [Callable [["Query" , "Query" ], "Query" ]]
97
+ __radd__ : ClassVar [Callable [["Query" , "Query" ], "Query" ]]
98
+ __rand__ : ClassVar [Callable [["Query" , "Query" ], "Query" ]]
61
99
62
- def __add__ (self , other ) :
100
+ def __add__ (self , other : "Query" ) -> "Query" :
63
101
# make sure we give queries that know how to combine themselves
64
102
# preference
65
103
if hasattr (other , "__radd__" ):
66
104
return other .__radd__ (self )
67
105
return Bool (must = [self , other ])
68
106
69
- def __invert__ (self ):
107
+ def __invert__ (self ) -> "Query" :
70
108
return Bool (must_not = [self ])
71
109
72
- def __or__ (self , other ) :
110
+ def __or__ (self , other : "Query" ) -> "Query" :
73
111
# make sure we give queries that know how to combine themselves
74
112
# preference
75
113
if hasattr (other , "__ror__" ):
76
114
return other .__ror__ (self )
77
115
return Bool (should = [self , other ])
78
116
79
- def __and__ (self , other ) :
117
+ def __and__ (self , other : "Query" ) -> "Query" :
80
118
# make sure we give queries that know how to combine themselves
81
119
# preference
82
120
if hasattr (other , "__rand__" ):
@@ -87,17 +125,17 @@ def __and__(self, other):
87
125
class MatchAll (Query ):
88
126
name = "match_all"
89
127
90
- def __add__ (self , other ) :
128
+ def __add__ (self , other : "Query" ) -> "Query" :
91
129
return other ._clone ()
92
130
93
131
__and__ = __rand__ = __radd__ = __add__
94
132
95
- def __or__ (self , other ) :
133
+ def __or__ (self , other : "Query" ) -> "MatchAll" :
96
134
return self
97
135
98
136
__ror__ = __or__
99
137
100
- def __invert__ (self ):
138
+ def __invert__ (self ) -> "MatchNone" :
101
139
return MatchNone ()
102
140
103
141
@@ -107,17 +145,17 @@ def __invert__(self):
107
145
class MatchNone (Query ):
108
146
name = "match_none"
109
147
110
- def __add__ (self , other ) :
148
+ def __add__ (self , other : "Query" ) -> "MatchNone" :
111
149
return self
112
150
113
151
__and__ = __rand__ = __radd__ = __add__
114
152
115
- def __or__ (self , other ) :
153
+ def __or__ (self , other : "Query" ) -> "Query" :
116
154
return other ._clone ()
117
155
118
156
__ror__ = __or__
119
157
120
- def __invert__ (self ):
158
+ def __invert__ (self ) -> MatchAll :
121
159
return MatchAll ()
122
160
123
161
@@ -130,7 +168,7 @@ class Bool(Query):
130
168
"filter" : {"type" : "query" , "multi" : True },
131
169
}
132
170
133
- def __add__ (self , other ) :
171
+ def __add__ (self , other : Query ) -> "Bool" :
134
172
q = self ._clone ()
135
173
if isinstance (other , Bool ):
136
174
q .must += other .must
@@ -143,7 +181,7 @@ def __add__(self, other):
143
181
144
182
__radd__ = __add__
145
183
146
- def __or__ (self , other ) :
184
+ def __or__ (self , other : Query ) -> Query :
147
185
for q in (self , other ):
148
186
if isinstance (q , Bool ) and not any (
149
187
(q .must , q .must_not , q .filter , getattr (q , "minimum_should_match" , None ))
@@ -168,20 +206,20 @@ def __or__(self, other):
168
206
__ror__ = __or__
169
207
170
208
@property
171
- def _min_should_match (self ):
209
+ def _min_should_match (self ) -> int :
172
210
return getattr (
173
211
self ,
174
212
"minimum_should_match" ,
175
213
0 if not self .should or (self .must or self .filter ) else 1 ,
176
214
)
177
215
178
- def __invert__ (self ):
216
+ def __invert__ (self ) -> Query :
179
217
# Because an empty Bool query is treated like
180
218
# MatchAll the inverse should be MatchNone
181
219
if not any (chain (self .must , self .filter , self .should , self .must_not )):
182
220
return MatchNone ()
183
221
184
- negations = []
222
+ negations : list [ Query ] = []
185
223
for q in chain (self .must , self .filter ):
186
224
negations .append (~ q )
187
225
@@ -195,7 +233,7 @@ def __invert__(self):
195
233
return negations [0 ]
196
234
return Bool (should = negations )
197
235
198
- def __and__ (self , other ) :
236
+ def __and__ (self , other : Query ) -> Query :
199
237
q = self ._clone ()
200
238
if isinstance (other , Bool ):
201
239
q .must += other .must
@@ -247,7 +285,7 @@ class FunctionScore(Query):
247
285
"functions" : {"type" : "score_function" , "multi" : True },
248
286
}
249
287
250
- def __init__ (self , ** kwargs ):
288
+ def __init__ (self , ** kwargs : Any ):
251
289
if "functions" in kwargs :
252
290
pass
253
291
else :
0 commit comments