Skip to content

Commit e68585b

Browse files
Add Type hints to query.py (#1821)
* refactor: add type hints to query.py + type_checking in CI chore: fix linting * fix: fix typing for older versions of python * refactor: add typing to query tests * chore: add type_check to CI * fix: fix typing for older python versions * fix: fix python version for ci --------- Co-authored-by: Miguel Grinberg <[email protected]>
1 parent b5435a8 commit e68585b

File tree

7 files changed

+206
-95
lines changed

7 files changed

+206
-95
lines changed

Diff for: .github/workflows/ci.yml

+15
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,21 @@ jobs:
3939
- name: Lint the code
4040
run: nox -s lint
4141

42+
type_check:
43+
runs-on: ubuntu-latest
44+
steps:
45+
- name: Checkout Repository
46+
uses: actions/checkout@v3
47+
- name: Set up Python
48+
uses: actions/setup-python@v4
49+
with:
50+
python-version: "3.8"
51+
- name: Install dependencies
52+
run: |
53+
python3 -m pip install nox
54+
- name: Lint the code
55+
run: nox -s type_check
56+
4257
docs:
4358
runs-on: ubuntu-latest
4459
steps:

Diff for: elasticsearch_dsl/function.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
# under the License.
1717

1818
import collections.abc
19+
from typing import Dict
1920

2021
from .utils import DslBase
2122

2223

23-
def SF(name_or_sf, **params):
24+
# Incomplete annotation to not break query.py tests
25+
def SF(name_or_sf, **params) -> "ScoreFunction":
2426
# {"script_score": {"script": "_score"}, "filter": {}}
2527
if isinstance(name_or_sf, collections.abc.Mapping):
2628
if params:
@@ -86,7 +88,7 @@ class ScriptScore(ScoreFunction):
8688
class BoostFactor(ScoreFunction):
8789
name = "boost_factor"
8890

89-
def to_dict(self):
91+
def to_dict(self) -> Dict[str, int]:
9092
d = super().to_dict()
9193
if "value" in d[self.name]:
9294
d[self.name] = d[self.name].pop("value")

Diff for: elasticsearch_dsl/query.py

+74-23
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,73 @@
1616
# under the License.
1717

1818
import collections.abc
19+
from copy import deepcopy
1920
from itertools import chain
21+
from typing import (
22+
Any,
23+
Callable,
24+
ClassVar,
25+
List,
26+
Mapping,
27+
MutableMapping,
28+
Optional,
29+
Protocol,
30+
TypeVar,
31+
Union,
32+
cast,
33+
overload,
34+
)
2035

2136
# 'SF' looks unused but the test suite assumes it's available
2237
# from this module so others are liable to do so as well.
2338
from .function import SF # noqa: F401
2439
from .function import ScoreFunction
2540
from .utils import DslBase
2641

42+
_T = TypeVar("_T")
43+
_M = TypeVar("_M", bound=Mapping[str, Any])
2744

28-
def Q(name_or_query="match_all", **params):
45+
46+
class QProxiedProtocol(Protocol[_T]):
47+
_proxied: _T
48+
49+
50+
@overload
51+
def Q(name_or_query: MutableMapping[str, _M]) -> "Query": ...
52+
53+
54+
@overload
55+
def Q(name_or_query: "Query") -> "Query": ...
56+
57+
58+
@overload
59+
def Q(name_or_query: QProxiedProtocol[_T]) -> _T: ...
60+
61+
62+
@overload
63+
def Q(name_or_query: str = "match_all", **params: Any) -> "Query": ...
64+
65+
66+
def Q(
67+
name_or_query: Union[
68+
str,
69+
"Query",
70+
QProxiedProtocol[_T],
71+
MutableMapping[str, _M],
72+
] = "match_all",
73+
**params: Any,
74+
) -> Union["Query", _T]:
2975
# {"match": {"title": "python"}}
30-
if isinstance(name_or_query, collections.abc.Mapping):
76+
if isinstance(name_or_query, collections.abc.MutableMapping):
3177
if params:
3278
raise ValueError("Q() cannot accept parameters when passing in a dict.")
3379
if len(name_or_query) != 1:
3480
raise ValueError(
3581
'Q() can only accept dict with a single query ({"match": {...}}). '
3682
"Instead it got (%r)" % name_or_query
3783
)
38-
name, params = name_or_query.copy().popitem()
39-
return Query.get_dsl_class(name)(_expand__to_dot=False, **params)
84+
name, q_params = deepcopy(name_or_query).popitem()
85+
return Query.get_dsl_class(name)(_expand__to_dot=False, **q_params)
4086

4187
# MatchAll()
4288
if isinstance(name_or_query, Query):
@@ -48,7 +94,7 @@ def Q(name_or_query="match_all", **params):
4894

4995
# s.query = Q('filtered', query=s.query)
5096
if hasattr(name_or_query, "_proxied"):
51-
return name_or_query._proxied
97+
return cast(QProxiedProtocol[_T], name_or_query)._proxied
5298

5399
# "match", title="python"
54100
return Query.get_dsl_class(name_or_query)(**params)
@@ -57,26 +103,31 @@ def Q(name_or_query="match_all", **params):
57103
class Query(DslBase):
58104
_type_name = "query"
59105
_type_shortcut = staticmethod(Q)
60-
name = None
106+
name: ClassVar[Optional[str]] = None
107+
108+
# Add type annotations for methods not defined in every subclass
109+
__ror__: ClassVar[Callable[["Query", "Query"], "Query"]]
110+
__radd__: ClassVar[Callable[["Query", "Query"], "Query"]]
111+
__rand__: ClassVar[Callable[["Query", "Query"], "Query"]]
61112

62-
def __add__(self, other):
113+
def __add__(self, other: "Query") -> "Query":
63114
# make sure we give queries that know how to combine themselves
64115
# preference
65116
if hasattr(other, "__radd__"):
66117
return other.__radd__(self)
67118
return Bool(must=[self, other])
68119

69-
def __invert__(self):
120+
def __invert__(self) -> "Query":
70121
return Bool(must_not=[self])
71122

72-
def __or__(self, other):
123+
def __or__(self, other: "Query") -> "Query":
73124
# make sure we give queries that know how to combine themselves
74125
# preference
75126
if hasattr(other, "__ror__"):
76127
return other.__ror__(self)
77128
return Bool(should=[self, other])
78129

79-
def __and__(self, other):
130+
def __and__(self, other: "Query") -> "Query":
80131
# make sure we give queries that know how to combine themselves
81132
# preference
82133
if hasattr(other, "__rand__"):
@@ -87,17 +138,17 @@ def __and__(self, other):
87138
class MatchAll(Query):
88139
name = "match_all"
89140

90-
def __add__(self, other):
141+
def __add__(self, other: "Query") -> "Query":
91142
return other._clone()
92143

93144
__and__ = __rand__ = __radd__ = __add__
94145

95-
def __or__(self, other):
146+
def __or__(self, other: "Query") -> "MatchAll":
96147
return self
97148

98149
__ror__ = __or__
99150

100-
def __invert__(self):
151+
def __invert__(self) -> "MatchNone":
101152
return MatchNone()
102153

103154

@@ -107,17 +158,17 @@ def __invert__(self):
107158
class MatchNone(Query):
108159
name = "match_none"
109160

110-
def __add__(self, other):
161+
def __add__(self, other: "Query") -> "MatchNone":
111162
return self
112163

113164
__and__ = __rand__ = __radd__ = __add__
114165

115-
def __or__(self, other):
166+
def __or__(self, other: "Query") -> "Query":
116167
return other._clone()
117168

118169
__ror__ = __or__
119170

120-
def __invert__(self):
171+
def __invert__(self) -> MatchAll:
121172
return MatchAll()
122173

123174

@@ -130,7 +181,7 @@ class Bool(Query):
130181
"filter": {"type": "query", "multi": True},
131182
}
132183

133-
def __add__(self, other):
184+
def __add__(self, other: Query) -> "Bool":
134185
q = self._clone()
135186
if isinstance(other, Bool):
136187
q.must += other.must
@@ -143,7 +194,7 @@ def __add__(self, other):
143194

144195
__radd__ = __add__
145196

146-
def __or__(self, other):
197+
def __or__(self, other: Query) -> Query:
147198
for q in (self, other):
148199
if isinstance(q, Bool) and not any(
149200
(q.must, q.must_not, q.filter, getattr(q, "minimum_should_match", None))
@@ -168,20 +219,20 @@ def __or__(self, other):
168219
__ror__ = __or__
169220

170221
@property
171-
def _min_should_match(self):
222+
def _min_should_match(self) -> int:
172223
return getattr(
173224
self,
174225
"minimum_should_match",
175226
0 if not self.should or (self.must or self.filter) else 1,
176227
)
177228

178-
def __invert__(self):
229+
def __invert__(self) -> Query:
179230
# Because an empty Bool query is treated like
180231
# MatchAll the inverse should be MatchNone
181232
if not any(chain(self.must, self.filter, self.should, self.must_not)):
182233
return MatchNone()
183234

184-
negations = []
235+
negations: List[Query] = []
185236
for q in chain(self.must, self.filter):
186237
negations.append(~q)
187238

@@ -195,7 +246,7 @@ def __invert__(self):
195246
return negations[0]
196247
return Bool(should=negations)
197248

198-
def __and__(self, other):
249+
def __and__(self, other: Query) -> Query:
199250
q = self._clone()
200251
if isinstance(other, Bool):
201252
q.must += other.must
@@ -247,7 +298,7 @@ class FunctionScore(Query):
247298
"functions": {"type": "score_function", "multi": True},
248299
}
249300

250-
def __init__(self, **kwargs):
301+
def __init__(self, **kwargs: Any):
251302
if "functions" in kwargs:
252303
pass
253304
else:

Diff for: elasticsearch_dsl/utils.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818

1919
import collections.abc
2020
from copy import copy
21+
from typing import Any, Dict, Optional, Type
22+
23+
from typing_extensions import Self
2124

2225
from .exceptions import UnknownDslObject, ValidationException
2326

@@ -251,7 +254,9 @@ class DslBase(metaclass=DslMeta):
251254
_param_defs = {}
252255

253256
@classmethod
254-
def get_dsl_class(cls, name, default=None):
257+
def get_dsl_class(
258+
cls: Type[Self], name: str, default: Optional[str] = None
259+
) -> Type[Self]:
255260
try:
256261
return cls._classes[name]
257262
except KeyError:
@@ -261,7 +266,7 @@ def get_dsl_class(cls, name, default=None):
261266
f"DSL class `{name}` does not exist in {cls._type_name}."
262267
)
263268

264-
def __init__(self, _expand__to_dot=None, **params):
269+
def __init__(self, _expand__to_dot: Optional[bool] = None, **params: Any) -> None:
265270
if _expand__to_dot is None:
266271
_expand__to_dot = EXPAND__TO_DOT
267272
self._params = {}
@@ -351,7 +356,8 @@ def __getattr__(self, name):
351356
return AttrDict(value)
352357
return value
353358

354-
def to_dict(self):
359+
# TODO: This type annotation can probably be made tighter
360+
def to_dict(self) -> Dict[str, Dict[str, Any]]:
355361
"""
356362
Serialize the DSL object to plain dict
357363
"""
@@ -390,7 +396,7 @@ def to_dict(self):
390396
d[pname] = value
391397
return {self.name: d}
392398

393-
def _clone(self):
399+
def _clone(self) -> Self:
394400
c = self.__class__()
395401
for attr in self._params:
396402
c._params[attr] = copy(self._params[attr])

Diff for: mypy.ini

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[mypy-elasticsearch_dsl.query]
2+
# Allow reexport of SF for tests
3+
implicit_reexport = True

Diff for: noxfile.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import subprocess
19+
1820
import nox
1921

2022
SOURCE_FILES = (
@@ -27,6 +29,11 @@
2729
"utils/",
2830
)
2931

32+
TYPED_FILES = (
33+
"elasticsearch_dsl/query.py",
34+
"tests/test_query.py",
35+
)
36+
3037

3138
@nox.session(
3239
python=[
@@ -72,10 +79,35 @@ def lint(session):
7279
session.run("black", "--check", "--target-version=py38", *SOURCE_FILES)
7380
session.run("isort", "--check", *SOURCE_FILES)
7481
session.run("python", "utils/run-unasync.py", "--check")
75-
session.run("flake8", "--ignore=E501,E741,W503", *SOURCE_FILES)
82+
session.run("flake8", "--ignore=E501,E741,W503,E704", *SOURCE_FILES)
7683
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
7784

7885

86+
@nox.session(python="3.8")
87+
def type_check(session):
88+
session.install("mypy", ".[develop]")
89+
errors = []
90+
popen = subprocess.Popen(
91+
"mypy --strict elasticsearch_dsl tests",
92+
env=session.env,
93+
shell=True,
94+
stdout=subprocess.PIPE,
95+
stderr=subprocess.STDOUT,
96+
)
97+
98+
mypy_output = ""
99+
while popen.poll() is None:
100+
mypy_output += popen.stdout.read(8192).decode()
101+
mypy_output += popen.stdout.read().decode()
102+
103+
for line in mypy_output.split("\n"):
104+
filepath = line.partition(":")[0]
105+
if filepath in TYPED_FILES:
106+
errors.append(line)
107+
if errors:
108+
session.error("\n" + "\n".join(sorted(set(errors))))
109+
110+
79111
@nox.session()
80112
def docs(session):
81113
session.install(".[develop]")

0 commit comments

Comments
 (0)