16
16
# under the License.
17
17
18
18
import collections .abc
19
+ from copy import deepcopy
19
20
from itertools import chain
21
+ from typing import (
22
+ Any ,
23
+ Callable ,
24
+ ClassVar ,
25
+ List ,
26
+ Mapping ,
27
+ MutableMapping ,
28
+ Optional ,
29
+ Protocol ,
30
+ TypeVar ,
31
+ Union ,
32
+ cast ,
33
+ overload ,
34
+ )
20
35
21
36
# 'SF' looks unused but the test suite assumes it's available
22
37
# from this module so others are liable to do so as well.
23
38
from .function import SF # noqa: F401
24
39
from .function import ScoreFunction
25
40
from .utils import DslBase
26
41
42
+ _T = TypeVar ("_T" )
43
+ _M = TypeVar ("_M" , bound = Mapping [str , Any ])
27
44
28
- def Q (name_or_query = "match_all" , ** params ):
45
+
46
+ class QProxiedProtocol (Protocol [_T ]):
47
+ _proxied : _T
48
+
49
+
50
+ @overload
51
+ def Q (name_or_query : MutableMapping [str , _M ]) -> "Query" : ...
52
+
53
+
54
+ @overload
55
+ def Q (name_or_query : "Query" ) -> "Query" : ...
56
+
57
+
58
+ @overload
59
+ def Q (name_or_query : QProxiedProtocol [_T ]) -> _T : ...
60
+
61
+
62
+ @overload
63
+ def Q (name_or_query : str = "match_all" , ** params : Any ) -> "Query" : ...
64
+
65
+
66
+ def Q (
67
+ name_or_query : Union [
68
+ str ,
69
+ "Query" ,
70
+ QProxiedProtocol [_T ],
71
+ MutableMapping [str , _M ],
72
+ ] = "match_all" ,
73
+ ** params : Any ,
74
+ ) -> Union ["Query" , _T ]:
29
75
# {"match": {"title": "python"}}
30
- if isinstance (name_or_query , collections .abc .Mapping ):
76
+ if isinstance (name_or_query , collections .abc .MutableMapping ):
31
77
if params :
32
78
raise ValueError ("Q() cannot accept parameters when passing in a dict." )
33
79
if len (name_or_query ) != 1 :
34
80
raise ValueError (
35
81
'Q() can only accept dict with a single query ({"match": {...}}). '
36
82
"Instead it got (%r)" % name_or_query
37
83
)
38
- name , params = name_or_query . copy ( ).popitem ()
39
- return Query .get_dsl_class (name )(_expand__to_dot = False , ** params )
84
+ name , q_params = deepcopy ( name_or_query ).popitem ()
85
+ return Query .get_dsl_class (name )(_expand__to_dot = False , ** q_params )
40
86
41
87
# MatchAll()
42
88
if isinstance (name_or_query , Query ):
@@ -48,7 +94,7 @@ def Q(name_or_query="match_all", **params):
48
94
49
95
# s.query = Q('filtered', query=s.query)
50
96
if hasattr (name_or_query , "_proxied" ):
51
- return name_or_query ._proxied
97
+ return cast ( QProxiedProtocol [ _T ], name_or_query ) ._proxied
52
98
53
99
# "match", title="python"
54
100
return Query .get_dsl_class (name_or_query )(** params )
@@ -57,26 +103,31 @@ def Q(name_or_query="match_all", **params):
57
103
class Query (DslBase ):
58
104
_type_name = "query"
59
105
_type_shortcut = staticmethod (Q )
60
- name = None
106
+ name : ClassVar [Optional [str ]] = None
107
+
108
+ # Add type annotations for methods not defined in every subclass
109
+ __ror__ : ClassVar [Callable [["Query" , "Query" ], "Query" ]]
110
+ __radd__ : ClassVar [Callable [["Query" , "Query" ], "Query" ]]
111
+ __rand__ : ClassVar [Callable [["Query" , "Query" ], "Query" ]]
61
112
62
- def __add__ (self , other ) :
113
+ def __add__ (self , other : "Query" ) -> "Query" :
63
114
# make sure we give queries that know how to combine themselves
64
115
# preference
65
116
if hasattr (other , "__radd__" ):
66
117
return other .__radd__ (self )
67
118
return Bool (must = [self , other ])
68
119
69
- def __invert__ (self ):
120
+ def __invert__ (self ) -> "Query" :
70
121
return Bool (must_not = [self ])
71
122
72
- def __or__ (self , other ) :
123
+ def __or__ (self , other : "Query" ) -> "Query" :
73
124
# make sure we give queries that know how to combine themselves
74
125
# preference
75
126
if hasattr (other , "__ror__" ):
76
127
return other .__ror__ (self )
77
128
return Bool (should = [self , other ])
78
129
79
- def __and__ (self , other ) :
130
+ def __and__ (self , other : "Query" ) -> "Query" :
80
131
# make sure we give queries that know how to combine themselves
81
132
# preference
82
133
if hasattr (other , "__rand__" ):
@@ -87,17 +138,17 @@ def __and__(self, other):
87
138
class MatchAll (Query ):
88
139
name = "match_all"
89
140
90
- def __add__ (self , other ) :
141
+ def __add__ (self , other : "Query" ) -> "Query" :
91
142
return other ._clone ()
92
143
93
144
__and__ = __rand__ = __radd__ = __add__
94
145
95
- def __or__ (self , other ) :
146
+ def __or__ (self , other : "Query" ) -> "MatchAll" :
96
147
return self
97
148
98
149
__ror__ = __or__
99
150
100
- def __invert__ (self ):
151
+ def __invert__ (self ) -> "MatchNone" :
101
152
return MatchNone ()
102
153
103
154
@@ -107,17 +158,17 @@ def __invert__(self):
107
158
class MatchNone (Query ):
108
159
name = "match_none"
109
160
110
- def __add__ (self , other ) :
161
+ def __add__ (self , other : "Query" ) -> "MatchNone" :
111
162
return self
112
163
113
164
__and__ = __rand__ = __radd__ = __add__
114
165
115
- def __or__ (self , other ) :
166
+ def __or__ (self , other : "Query" ) -> "Query" :
116
167
return other ._clone ()
117
168
118
169
__ror__ = __or__
119
170
120
- def __invert__ (self ):
171
+ def __invert__ (self ) -> MatchAll :
121
172
return MatchAll ()
122
173
123
174
@@ -130,7 +181,7 @@ class Bool(Query):
130
181
"filter" : {"type" : "query" , "multi" : True },
131
182
}
132
183
133
- def __add__ (self , other ) :
184
+ def __add__ (self , other : Query ) -> "Bool" :
134
185
q = self ._clone ()
135
186
if isinstance (other , Bool ):
136
187
q .must += other .must
@@ -143,7 +194,7 @@ def __add__(self, other):
143
194
144
195
__radd__ = __add__
145
196
146
- def __or__ (self , other ) :
197
+ def __or__ (self , other : Query ) -> Query :
147
198
for q in (self , other ):
148
199
if isinstance (q , Bool ) and not any (
149
200
(q .must , q .must_not , q .filter , getattr (q , "minimum_should_match" , None ))
@@ -168,20 +219,20 @@ def __or__(self, other):
168
219
__ror__ = __or__
169
220
170
221
@property
171
- def _min_should_match (self ):
222
+ def _min_should_match (self ) -> int :
172
223
return getattr (
173
224
self ,
174
225
"minimum_should_match" ,
175
226
0 if not self .should or (self .must or self .filter ) else 1 ,
176
227
)
177
228
178
- def __invert__ (self ):
229
+ def __invert__ (self ) -> Query :
179
230
# Because an empty Bool query is treated like
180
231
# MatchAll the inverse should be MatchNone
181
232
if not any (chain (self .must , self .filter , self .should , self .must_not )):
182
233
return MatchNone ()
183
234
184
- negations = []
235
+ negations : List [ Query ] = []
185
236
for q in chain (self .must , self .filter ):
186
237
negations .append (~ q )
187
238
@@ -195,7 +246,7 @@ def __invert__(self):
195
246
return negations [0 ]
196
247
return Bool (should = negations )
197
248
198
- def __and__ (self , other ) :
249
+ def __and__ (self , other : Query ) -> Query :
199
250
q = self ._clone ()
200
251
if isinstance (other , Bool ):
201
252
q .must += other .must
@@ -247,7 +298,7 @@ class FunctionScore(Query):
247
298
"functions" : {"type" : "score_function" , "multi" : True },
248
299
}
249
300
250
- def __init__ (self , ** kwargs ):
301
+ def __init__ (self , ** kwargs : Any ):
251
302
if "functions" in kwargs :
252
303
pass
253
304
else :
0 commit comments