Skip to content

Commit ebaf944

Browse files
authored
feat: type annotations, consolidate on pytest (#13)
1 parent 899831b commit ebaf944

26 files changed

+1129
-881
lines changed

.github/workflows/tests.yml

+4-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ jobs:
77
runs-on: ubuntu-latest
88
strategy:
99
matrix:
10-
python-version: ["3.8", "3.9", "3.10"]
10+
python-version: ["3.7", "3.8", "3.9", "3.10"]
1111
steps:
1212
- uses: actions/checkout@v2
1313
- name: Set up Python ${{ matrix.python-version }}
@@ -29,3 +29,6 @@ jobs:
2929
- name: Format
3030
run: |
3131
make format-check
32+
- name: Type annotations
33+
run: |
34+
make types

Makefile

+5
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,16 @@ run:
1717
lint:
1818
python -m pylint mysql_mimic/ tests/
1919

20+
types:
21+
python -m mypy -p mysql_mimic -p tests
22+
2023
test:
2124
coverage run --source=mysql_mimic -m pytest
2225
coverage report
2326
coverage html
2427

28+
check: format-check lint types test
29+
2530
build: clean
2631
python setup.py sdist bdist_wheel
2732

mysql_mimic/admin.py

+47-27
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
11
import re
2+
from typing import Optional, Any
23

4+
from mysql_mimic.session import Session
35
from mysql_mimic.charset import CharacterSet
46
from mysql_mimic.errors import MysqlError, ErrorCode
57
from mysql_mimic.results import ResultColumn, ResultSet, Column
68
from mysql_mimic.types import ColumnType
9+
from mysql_mimic.variables import SystemVariables
710

811

912
class Admin:
1013
"""
1114
Administration Statements (plus some other things)
1215
https://dev.mysql.com/doc/refman/8.0/en/sql-server-administration-statements.html
13-
14-
Args:
15-
connection_id (int): connection ID
16-
session (mysql_mimic.session.Session): session
17-
variables (mysql_mimic.variables.SystemVariables): system variables
1816
"""
1917

2018
# Regex for finding "information functions" to replace in SQL statements
@@ -139,18 +137,20 @@ class Admin:
139137
re.IGNORECASE | re.VERBOSE,
140138
)
141139

142-
def __init__(self, connection_id, session, variables):
140+
def __init__(
141+
self, connection_id: int, session: Session, variables: SystemVariables
142+
):
143143
self.connection_id = connection_id
144144
self.session = session
145-
self.database = None
146-
self.username = None
145+
self.database: Optional[str] = None
146+
self.username: Optional[str] = None
147147
self.vars = variables
148148

149-
def replace_variables(self, sql):
149+
def replace_variables(self, sql: str) -> str:
150150
sql = self.REGEX_INFO_FUNC.sub(self._replace_info_func, sql)
151151
return self.REGEX_SESSION_VAR.sub(self._replace_session_var, sql)
152152

153-
def _replace_info_func(self, matchobj):
153+
def _replace_info_func(self, matchobj: re.Match) -> str:
154154
func = (matchobj.group("func") or matchobj.group("current_user")).lower()
155155
if func == "connection_id":
156156
return str(self.connection_id)
@@ -164,13 +164,13 @@ def _replace_info_func(self, matchobj):
164164
return f"'{self.database}'" if self.database else "NULL"
165165
raise MysqlError(f"Failed to parse system information function: {func}")
166166

167-
def _replace_session_var(self, matchobj):
167+
def _replace_session_var(self, matchobj: re.Match) -> str:
168168
var = matchobj.group("var").lower()
169169
if var in self.vars:
170170
return f"'{self.vars[var]}'"
171171
raise MysqlError(f"Unknown variable: {var}", ErrorCode.UNKNOWN_SYSTEM_VARIABLE)
172172

173-
async def parse(self, sql):
173+
async def parse(self, sql: str) -> Optional[ResultSet]:
174174
m = self.REGEX_CMD.match(sql)
175175
if not m:
176176
return None
@@ -187,7 +187,7 @@ async def parse(self, sql):
187187
return ResultSet([], [])
188188
raise MysqlError("Failed to parse command", ErrorCode.PARSE_ERROR)
189189

190-
async def _parse_set(self, this):
190+
async def _parse_set(self, this: str) -> Optional[ResultSet]:
191191
m = self.REGEX_SET_NAMES.match(this)
192192
if m:
193193
return await self._parse_set_names(
@@ -209,11 +209,13 @@ async def _parse_set(self, this):
209209

210210
raise MysqlError("Unsupported SET command", ErrorCode.NOT_SUPPORTED_YET)
211211

212-
async def _set_variable(self, key, val):
212+
async def _set_variable(self, key: str, val: Any) -> None:
213213
self.vars[key] = val
214214
await self.session.set(**{key: val})
215215

216-
async def _parse_set_names(self, charset_name, collation_name):
216+
async def _parse_set_names(
217+
self, charset_name: str, collation_name: str
218+
) -> Optional[ResultSet]:
217219
await self._set_variable("character_set_client", charset_name)
218220
await self._set_variable("character_set_connection", charset_name)
219221
await self._set_variable("character_set_results", charset_name)
@@ -224,14 +226,25 @@ async def _parse_set_names(self, charset_name, collation_name):
224226
return ResultSet([], [])
225227

226228
async def _parse_set_variables(
227-
self, global_, persist, persist_only, session, user, var_name, value
228-
): # pylint: disable=unused-argument
229+
self,
230+
global_: Optional[str],
231+
persist: Optional[str],
232+
persist_only: Optional[str],
233+
session: Optional[str],
234+
user: Optional[str],
235+
var_name: Optional[str],
236+
value: Optional[str],
237+
) -> Optional[ResultSet]: # pylint: disable=unused-argument
229238
if global_ or persist or persist_only or user:
230239
raise MysqlError(
231240
"Only setting session variables is supported",
232241
ErrorCode.NOT_SUPPORTED_YET,
233242
)
234243

244+
assert value is not None
245+
assert var_name is not None
246+
result: Any
247+
235248
value_lower = value.lower()
236249
if value_lower in {"true", "on"}:
237250
result = True
@@ -256,7 +269,7 @@ async def _parse_set_variables(
256269
await self._set_variable(var_name, result)
257270
return ResultSet([], [])
258271

259-
async def _parse_show(self, this):
272+
async def _parse_show(self, this: str) -> Optional[ResultSet]:
260273
m = self.REGEX_SHOW_VARS.match(this)
261274
if m:
262275
return await self._parse_show_variables(
@@ -284,22 +297,29 @@ async def _parse_show(self, this):
284297

285298
raise MysqlError("Unsupported SHOW command", ErrorCode.NOT_SUPPORTED_YET)
286299

287-
def _like_to_regex(self, like):
300+
def _like_to_regex(self, like: Optional[str]) -> re.Pattern:
288301
if like is None:
289302
return re.compile(r".*")
290303
like = like.replace("%", ".*")
291304
like = like.replace("_", ".")
292305
return re.compile(like)
293306

294307
async def _parse_show_columns(
295-
self, extended, full, db_name, tbl_name, like
296-
): # pylint: disable=unused-argument
308+
self,
309+
extended: bool,
310+
full: bool,
311+
db_name: Optional[str],
312+
tbl_name: Optional[str],
313+
like: re.Pattern,
314+
) -> Optional[ResultSet]: # pylint: disable=unused-argument
297315
db_name = db_name or self.database
298316
if not db_name:
299317
raise MysqlError("No database selected", ErrorCode.NO_DB_ERROR)
300318

301-
columns = await self.session.show_columns(db_name, tbl_name)
302-
columns = [c if isinstance(c, Column) else Column(**c) for c in columns]
319+
columns = [
320+
c if isinstance(c, Column) else Column(**c)
321+
for c in await self.session.show_columns(db_name, tbl_name or "")
322+
]
303323
columns = [c for c in columns if like.match(c.name)]
304324

305325
result_columns = [
@@ -333,8 +353,8 @@ async def _parse_show_columns(
333353
return ResultSet(rows=rows, columns=result_columns)
334354

335355
async def _parse_show_index(
336-
self, extended, db_name, tbl_name
337-
): # pylint: disable=unused-argument
356+
self, extended: bool, db_name: Optional[str], tbl_name: Optional[str]
357+
) -> ResultSet: # pylint: disable=unused-argument
338358
result_columns = [
339359
ResultColumn("Table", ColumnType.STRING),
340360
ResultColumn("Non_unique", ColumnType.TINY),
@@ -355,8 +375,8 @@ async def _parse_show_index(
355375
return ResultSet(rows=[], columns=result_columns)
356376

357377
async def _parse_show_variables(
358-
self, global_, session, like
359-
): # pylint: disable=unused-argument
378+
self, global_: bool, session: bool, like: re.Pattern
379+
) -> ResultSet: # pylint: disable=unused-argument
360380
rows = list(self.vars.items())
361381
rows = [(k, v) for k, v in rows if like.match(k)]
362382
return ResultSet(

mysql_mimic/auth.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from __future__ import annotations
2+
13
import io
24
from copy import copy
35
from hashlib import sha1
46
import logging
57
import secrets
68
from dataclasses import dataclass
7-
from typing import Optional, Dict
9+
from typing import Optional, Dict, AsyncGenerator, Union, Tuple
810

911
from mysql_mimic.types import read_str_null
1012
from mysql_mimic.utils import xor
@@ -35,21 +37,25 @@ class AuthInfo:
3537
data: bytes
3638
user: User
3739
connect_attrs: Dict[str, str]
38-
client_plugin_name: str
39-
handshake_auth_data: bytes
40+
client_plugin_name: Optional[str]
41+
handshake_auth_data: Optional[bytes]
4042
handshake_plugin_name: str
4143

42-
def copy(self, data):
44+
def copy(self, data: bytes) -> AuthInfo:
4345
new = copy(self)
4446
new.data = data
4547
return new
4648

4749

50+
Decision = Union[Success, Forbidden, bytes]
51+
AuthState = AsyncGenerator[Decision, AuthInfo]
52+
53+
4854
class AuthPlugin:
4955
name = ""
50-
client_plugin_name = None # None means any
56+
client_plugin_name: Optional[str] = None # None means any
5157

52-
async def auth(self, auth_info=None):
58+
async def auth(self, auth_info: Optional[AuthInfo] = None) -> AuthState:
5359
"""
5460
Begin the authentication lifecycle.
5561
@@ -58,7 +64,9 @@ async def auth(self, auth_info=None):
5864
"""
5965
yield Forbidden()
6066

61-
async def start(self, auth_info=None):
67+
async def start(
68+
self, auth_info: Optional[AuthInfo] = None
69+
) -> Tuple[Decision, AuthState]:
6270
state = self.auth(auth_info)
6371
data = await state.__anext__()
6472
return data, state
@@ -67,7 +75,7 @@ async def start(self, auth_info=None):
6775
class GullibleAuthPlugin(AuthPlugin):
6876
name = "mysql_mimic_gullible"
6977

70-
async def auth(self, auth_info=None):
78+
async def auth(self, auth_info: Optional[AuthInfo] = None) -> AuthState:
7179
if not auth_info:
7280
auth_info = yield b"\x00" * 20 # 20 bytes of filler to be ignored
7381
yield Success(authenticated_as=auth_info.username)
@@ -77,27 +85,27 @@ class AbstractMysqlClearPasswordAuthPlugin(AuthPlugin):
7785
name = "abstract_mysql_clear_password"
7886
client_plugin_name = "mysql_clear_password"
7987

80-
async def auth(self, auth_info=None):
88+
async def auth(self, auth_info: Optional[AuthInfo] = None) -> AuthState:
8189
if not auth_info:
8290
auth_info = yield b"\x00" * 20 # 20 bytes of filler to be ignored
8391

8492
r = io.BytesIO(auth_info.data)
8593
password = read_str_null(r).decode()
8694
authenticated_as = await self.check(auth_info.username, password)
87-
if authenticated_as:
95+
if authenticated_as is not None:
8896
yield Success(authenticated_as)
8997
else:
9098
yield Forbidden()
9199

92-
async def check(self, username, password):
100+
async def check(self, username: str, password: str) -> Optional[str]:
93101
return username
94102

95103

96104
class MysqlNativePasswordAuthPlugin(AuthPlugin):
97105
name = "mysql_native_password"
98106
client_plugin_name = "mysql_native_password"
99107

100-
async def auth(self, auth_info=None):
108+
async def auth(self, auth_info: Optional[AuthInfo] = None) -> AuthState:
101109
if (
102110
auth_info
103111
and auth_info.handshake_plugin_name == self.name
@@ -138,5 +146,5 @@ async def auth(self, auth_info=None):
138146
yield Forbidden()
139147

140148

141-
def get_mysql_native_password_auth_string(password):
149+
def get_mysql_native_password_auth_string(password: str) -> str:
142150
return sha1(sha1(password.encode("utf-8")).digest()).hexdigest()

mysql_mimic/charset.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import annotations
12
from enum import IntEnum
23

34

@@ -45,19 +46,19 @@ class CharacterSet(IntEnum):
4546
utf8mb4 = 255
4647

4748
@property
48-
def codec(self):
49+
def codec(self) -> str:
4950
if self.name == "utf8mb4":
5051
return "utf8"
5152
return self.name
5253

5354
@property
54-
def default_collation(self):
55+
def default_collation(self) -> Collation:
5556
return DEFAULT_COLLATIONS[self]
5657

57-
def decode(self, b):
58+
def decode(self, b: bytes) -> str:
5859
return b.decode(self.codec)
5960

60-
def encode(self, s):
61+
def encode(self, s: str) -> bytes:
6162
return s.encode(self.codec)
6263

6364

@@ -287,11 +288,11 @@ class Collation(IntEnum):
287288
utf8mb4_0900_ai_ci = 255
288289

289290
@property
290-
def codec(self):
291+
def codec(self) -> str:
291292
return self.charset.codec
292293

293294
@property
294-
def charset(self):
295+
def charset(self) -> CharacterSet:
295296
return DEFAULT_CHARACTER_SETS[self]
296297

297298

0 commit comments

Comments
 (0)