Skip to content

Commit f6df2fd

Browse files
CaiofcasCaio Fontes
authored and
Caio Fontes
committed
refactor: add type hints to query.py + type_checking in CI
chore: fix linting
1 parent 3e880fd commit f6df2fd

File tree

3 files changed

+96
-26
lines changed

3 files changed

+96
-26
lines changed

Diff for: elasticsearch_dsl/query.py

+60-22
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,60 @@
1616
# under the License.
1717

1818
import collections.abc
19+
from copy import deepcopy
1920
from itertools import chain
21+
from typing import Any, Callable, ClassVar, Optional, Protocol, TypeVar, Union, overload
2022

2123
# 'SF' looks unused but the test suite assumes it's available
2224
# from this module so others are liable to do so as well.
2325
from .function import SF # noqa: F401
2426
from .function import ScoreFunction
2527
from .utils import DslBase
2628

29+
_T = TypeVar("_T")
30+
_M = TypeVar("_M", bound=collections.abc.Mapping[str, Any])
2731

28-
def Q(name_or_query="match_all", **params):
32+
33+
class QProxiedProtocol(Protocol[_T]):
34+
_proxied: _T
35+
36+
37+
@overload
38+
def Q(name_or_query: collections.abc.MutableMapping[str, _M]) -> "Query": ...
39+
40+
41+
@overload
42+
def Q(name_or_query: "Query") -> "Query": ...
43+
44+
45+
@overload
46+
def Q(name_or_query: QProxiedProtocol[_T]) -> _T: ...
47+
48+
49+
@overload
50+
def Q(name_or_query: str, **params: Any) -> "Query": ...
51+
52+
53+
def Q(
54+
name_or_query: Union[
55+
str,
56+
"Query",
57+
QProxiedProtocol[_T],
58+
collections.abc.MutableMapping[str, _M],
59+
] = "match_all",
60+
**params: Any,
61+
) -> Union["Query", _T]:
2962
# {"match": {"title": "python"}}
30-
if isinstance(name_or_query, collections.abc.Mapping):
63+
if isinstance(name_or_query, collections.abc.MutableMapping):
3164
if params:
3265
raise ValueError("Q() cannot accept parameters when passing in a dict.")
3366
if len(name_or_query) != 1:
3467
raise ValueError(
3568
'Q() can only accept dict with a single query ({"match": {...}}). '
3669
"Instead it got (%r)" % name_or_query
3770
)
38-
name, params = name_or_query.copy().popitem()
39-
return Query.get_dsl_class(name)(_expand__to_dot=False, **params)
71+
name, q_params = deepcopy(name_or_query).popitem()
72+
return Query.get_dsl_class(name)(_expand__to_dot=False, **q_params)
4073

4174
# MatchAll()
4275
if isinstance(name_or_query, Query):
@@ -57,26 +90,31 @@ def Q(name_or_query="match_all", **params):
5790
class Query(DslBase):
5891
_type_name = "query"
5992
_type_shortcut = staticmethod(Q)
60-
name = None
93+
name: ClassVar[Optional[str]] = None
94+
95+
# Add type annotations for methods not defined in every subclass
96+
__ror__: ClassVar[Callable[["Query", "Query"], "Query"]]
97+
__radd__: ClassVar[Callable[["Query", "Query"], "Query"]]
98+
__rand__: ClassVar[Callable[["Query", "Query"], "Query"]]
6199

62-
def __add__(self, other):
100+
def __add__(self, other: "Query") -> "Query":
63101
# make sure we give queries that know how to combine themselves
64102
# preference
65103
if hasattr(other, "__radd__"):
66104
return other.__radd__(self)
67105
return Bool(must=[self, other])
68106

69-
def __invert__(self):
107+
def __invert__(self) -> "Query":
70108
return Bool(must_not=[self])
71109

72-
def __or__(self, other):
110+
def __or__(self, other: "Query") -> "Query":
73111
# make sure we give queries that know how to combine themselves
74112
# preference
75113
if hasattr(other, "__ror__"):
76114
return other.__ror__(self)
77115
return Bool(should=[self, other])
78116

79-
def __and__(self, other):
117+
def __and__(self, other: "Query") -> "Query":
80118
# make sure we give queries that know how to combine themselves
81119
# preference
82120
if hasattr(other, "__rand__"):
@@ -87,17 +125,17 @@ def __and__(self, other):
87125
class MatchAll(Query):
88126
name = "match_all"
89127

90-
def __add__(self, other):
128+
def __add__(self, other: "Query") -> "Query":
91129
return other._clone()
92130

93131
__and__ = __rand__ = __radd__ = __add__
94132

95-
def __or__(self, other):
133+
def __or__(self, other: "Query") -> "MatchAll":
96134
return self
97135

98136
__ror__ = __or__
99137

100-
def __invert__(self):
138+
def __invert__(self) -> "MatchNone":
101139
return MatchNone()
102140

103141

@@ -107,17 +145,17 @@ def __invert__(self):
107145
class MatchNone(Query):
108146
name = "match_none"
109147

110-
def __add__(self, other):
148+
def __add__(self, other: "Query") -> "MatchNone":
111149
return self
112150

113151
__and__ = __rand__ = __radd__ = __add__
114152

115-
def __or__(self, other):
153+
def __or__(self, other: "Query") -> "Query":
116154
return other._clone()
117155

118156
__ror__ = __or__
119157

120-
def __invert__(self):
158+
def __invert__(self) -> MatchAll:
121159
return MatchAll()
122160

123161

@@ -130,7 +168,7 @@ class Bool(Query):
130168
"filter": {"type": "query", "multi": True},
131169
}
132170

133-
def __add__(self, other):
171+
def __add__(self, other: Query) -> "Bool":
134172
q = self._clone()
135173
if isinstance(other, Bool):
136174
q.must += other.must
@@ -143,7 +181,7 @@ def __add__(self, other):
143181

144182
__radd__ = __add__
145183

146-
def __or__(self, other):
184+
def __or__(self, other: Query) -> Query:
147185
for q in (self, other):
148186
if isinstance(q, Bool) and not any(
149187
(q.must, q.must_not, q.filter, getattr(q, "minimum_should_match", None))
@@ -168,20 +206,20 @@ def __or__(self, other):
168206
__ror__ = __or__
169207

170208
@property
171-
def _min_should_match(self):
209+
def _min_should_match(self) -> int:
172210
return getattr(
173211
self,
174212
"minimum_should_match",
175213
0 if not self.should or (self.must or self.filter) else 1,
176214
)
177215

178-
def __invert__(self):
216+
def __invert__(self) -> Query:
179217
# Because an empty Bool query is treated like
180218
# MatchAll the inverse should be MatchNone
181219
if not any(chain(self.must, self.filter, self.should, self.must_not)):
182220
return MatchNone()
183221

184-
negations = []
222+
negations: list[Query] = []
185223
for q in chain(self.must, self.filter):
186224
negations.append(~q)
187225

@@ -195,7 +233,7 @@ def __invert__(self):
195233
return negations[0]
196234
return Bool(should=negations)
197235

198-
def __and__(self, other):
236+
def __and__(self, other: Query) -> Query:
199237
q = self._clone()
200238
if isinstance(other, Bool):
201239
q.must += other.must
@@ -247,7 +285,7 @@ class FunctionScore(Query):
247285
"functions": {"type": "score_function", "multi": True},
248286
}
249287

250-
def __init__(self, **kwargs):
288+
def __init__(self, **kwargs: Any):
251289
if "functions" in kwargs:
252290
pass
253291
else:

Diff for: elasticsearch_dsl/utils.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import collections.abc
2020
from copy import copy
21+
from typing import Any, Optional, Self
2122

2223
from .exceptions import UnknownDslObject, ValidationException
2324

@@ -251,7 +252,9 @@ class DslBase(metaclass=DslMeta):
251252
_param_defs = {}
252253

253254
@classmethod
254-
def get_dsl_class(cls, name, default=None):
255+
def get_dsl_class(
256+
cls: type[Self], name: str, default: Optional[str] = None
257+
) -> type[Self]:
255258
try:
256259
return cls._classes[name]
257260
except KeyError:
@@ -261,7 +264,7 @@ def get_dsl_class(cls, name, default=None):
261264
f"DSL class `{name}` does not exist in {cls._type_name}."
262265
)
263266

264-
def __init__(self, _expand__to_dot=None, **params):
267+
def __init__(self, _expand__to_dot: Optional[bool] = None, **params: Any) -> None:
265268
if _expand__to_dot is None:
266269
_expand__to_dot = EXPAND__TO_DOT
267270
self._params = {}
@@ -390,7 +393,7 @@ def to_dict(self):
390393
d[pname] = value
391394
return {self.name: d}
392395

393-
def _clone(self):
396+
def _clone(self) -> Self:
394397
c = self.__class__()
395398
for attr in self._params:
396399
c._params[attr] = copy(self._params[attr])

Diff for: noxfile.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import subprocess
19+
1820
import nox
1921

2022
SOURCE_FILES = (
@@ -27,6 +29,8 @@
2729
"utils/",
2830
)
2931

32+
TYPED_FILES = ("elasticsearch_dsl/query.py",)
33+
3034

3135
@nox.session(
3236
python=[
@@ -72,10 +76,35 @@ def lint(session):
7276
session.run("black", "--check", "--target-version=py38", *SOURCE_FILES)
7377
session.run("isort", "--check", *SOURCE_FILES)
7478
session.run("python", "utils/run-unasync.py", "--check")
75-
session.run("flake8", "--ignore=E501,E741,W503", *SOURCE_FILES)
79+
session.run("flake8", "--ignore=E501,E741,W503,E704", *SOURCE_FILES)
7680
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
7781

7882

83+
@nox.session(python="3.12")
84+
def type_check(session):
85+
session.install("mypy", ".[develop]")
86+
errors = []
87+
popen = subprocess.Popen(
88+
"mypy --strict elasticsearch_dsl",
89+
env=session.env,
90+
shell=True,
91+
stdout=subprocess.PIPE,
92+
stderr=subprocess.STDOUT,
93+
)
94+
95+
mypy_output = ""
96+
while popen.poll() is None:
97+
mypy_output += popen.stdout.read(8192).decode()
98+
mypy_output += popen.stdout.read().decode()
99+
100+
for line in mypy_output.split("\n"):
101+
filepath = line.partition(":")[0]
102+
if filepath in TYPED_FILES:
103+
errors.append(line)
104+
if errors:
105+
session.error("\n" + "\n".join(sorted(set(errors))))
106+
107+
79108
@nox.session()
80109
def docs(session):
81110
session.install(".[develop]")

0 commit comments

Comments
 (0)