1
1
import re
2
+ from typing import Optional , Any
2
3
4
+ from mysql_mimic .session import Session
3
5
from mysql_mimic .charset import CharacterSet
4
6
from mysql_mimic .errors import MysqlError , ErrorCode
5
7
from mysql_mimic .results import ResultColumn , ResultSet , Column
6
8
from mysql_mimic .types import ColumnType
9
+ from mysql_mimic .variables import SystemVariables
7
10
8
11
9
12
class Admin :
10
13
"""
11
14
Administration Statements (plus some other things)
12
15
https://dev.mysql.com/doc/refman/8.0/en/sql-server-administration-statements.html
13
-
14
- Args:
15
- connection_id (int): connection ID
16
- session (mysql_mimic.session.Session): session
17
- variables (mysql_mimic.variables.SystemVariables): system variables
18
16
"""
19
17
20
18
# Regex for finding "information functions" to replace in SQL statements
@@ -139,18 +137,20 @@ class Admin:
139
137
re .IGNORECASE | re .VERBOSE ,
140
138
)
141
139
142
- def __init__ (self , connection_id , session , variables ):
140
+ def __init__ (
141
+ self , connection_id : int , session : Session , variables : SystemVariables
142
+ ):
143
143
self .connection_id = connection_id
144
144
self .session = session
145
- self .database = None
146
- self .username = None
145
+ self .database : Optional [ str ] = None
146
+ self .username : Optional [ str ] = None
147
147
self .vars = variables
148
148
149
- def replace_variables (self , sql ) :
149
+ def replace_variables (self , sql : str ) -> str :
150
150
sql = self .REGEX_INFO_FUNC .sub (self ._replace_info_func , sql )
151
151
return self .REGEX_SESSION_VAR .sub (self ._replace_session_var , sql )
152
152
153
- def _replace_info_func (self , matchobj ) :
153
+ def _replace_info_func (self , matchobj : re . Match ) -> str :
154
154
func = (matchobj .group ("func" ) or matchobj .group ("current_user" )).lower ()
155
155
if func == "connection_id" :
156
156
return str (self .connection_id )
@@ -164,13 +164,13 @@ def _replace_info_func(self, matchobj):
164
164
return f"'{ self .database } '" if self .database else "NULL"
165
165
raise MysqlError (f"Failed to parse system information function: { func } " )
166
166
167
- def _replace_session_var (self , matchobj ) :
167
+ def _replace_session_var (self , matchobj : re . Match ) -> str :
168
168
var = matchobj .group ("var" ).lower ()
169
169
if var in self .vars :
170
170
return f"'{ self .vars [var ]} '"
171
171
raise MysqlError (f"Unknown variable: { var } " , ErrorCode .UNKNOWN_SYSTEM_VARIABLE )
172
172
173
- async def parse (self , sql ) :
173
+ async def parse (self , sql : str ) -> Optional [ ResultSet ] :
174
174
m = self .REGEX_CMD .match (sql )
175
175
if not m :
176
176
return None
@@ -187,7 +187,7 @@ async def parse(self, sql):
187
187
return ResultSet ([], [])
188
188
raise MysqlError ("Failed to parse command" , ErrorCode .PARSE_ERROR )
189
189
190
- async def _parse_set (self , this ) :
190
+ async def _parse_set (self , this : str ) -> Optional [ ResultSet ] :
191
191
m = self .REGEX_SET_NAMES .match (this )
192
192
if m :
193
193
return await self ._parse_set_names (
@@ -209,11 +209,13 @@ async def _parse_set(self, this):
209
209
210
210
raise MysqlError ("Unsupported SET command" , ErrorCode .NOT_SUPPORTED_YET )
211
211
212
- async def _set_variable (self , key , val ) :
212
+ async def _set_variable (self , key : str , val : Any ) -> None :
213
213
self .vars [key ] = val
214
214
await self .session .set (** {key : val })
215
215
216
- async def _parse_set_names (self , charset_name , collation_name ):
216
+ async def _parse_set_names (
217
+ self , charset_name : str , collation_name : str
218
+ ) -> Optional [ResultSet ]:
217
219
await self ._set_variable ("character_set_client" , charset_name )
218
220
await self ._set_variable ("character_set_connection" , charset_name )
219
221
await self ._set_variable ("character_set_results" , charset_name )
@@ -224,14 +226,25 @@ async def _parse_set_names(self, charset_name, collation_name):
224
226
return ResultSet ([], [])
225
227
226
228
async def _parse_set_variables (
227
- self , global_ , persist , persist_only , session , user , var_name , value
228
- ): # pylint: disable=unused-argument
229
+ self ,
230
+ global_ : Optional [str ],
231
+ persist : Optional [str ],
232
+ persist_only : Optional [str ],
233
+ session : Optional [str ],
234
+ user : Optional [str ],
235
+ var_name : Optional [str ],
236
+ value : Optional [str ],
237
+ ) -> Optional [ResultSet ]: # pylint: disable=unused-argument
229
238
if global_ or persist or persist_only or user :
230
239
raise MysqlError (
231
240
"Only setting session variables is supported" ,
232
241
ErrorCode .NOT_SUPPORTED_YET ,
233
242
)
234
243
244
+ assert value is not None
245
+ assert var_name is not None
246
+ result : Any
247
+
235
248
value_lower = value .lower ()
236
249
if value_lower in {"true" , "on" }:
237
250
result = True
@@ -256,7 +269,7 @@ async def _parse_set_variables(
256
269
await self ._set_variable (var_name , result )
257
270
return ResultSet ([], [])
258
271
259
- async def _parse_show (self , this ) :
272
+ async def _parse_show (self , this : str ) -> Optional [ ResultSet ] :
260
273
m = self .REGEX_SHOW_VARS .match (this )
261
274
if m :
262
275
return await self ._parse_show_variables (
@@ -284,22 +297,29 @@ async def _parse_show(self, this):
284
297
285
298
raise MysqlError ("Unsupported SHOW command" , ErrorCode .NOT_SUPPORTED_YET )
286
299
287
- def _like_to_regex (self , like ) :
300
+ def _like_to_regex (self , like : Optional [ str ]) -> re . Pattern :
288
301
if like is None :
289
302
return re .compile (r".*" )
290
303
like = like .replace ("%" , ".*" )
291
304
like = like .replace ("_" , "." )
292
305
return re .compile (like )
293
306
294
307
async def _parse_show_columns (
295
- self , extended , full , db_name , tbl_name , like
296
- ): # pylint: disable=unused-argument
308
+ self ,
309
+ extended : bool ,
310
+ full : bool ,
311
+ db_name : Optional [str ],
312
+ tbl_name : Optional [str ],
313
+ like : re .Pattern ,
314
+ ) -> Optional [ResultSet ]: # pylint: disable=unused-argument
297
315
db_name = db_name or self .database
298
316
if not db_name :
299
317
raise MysqlError ("No database selected" , ErrorCode .NO_DB_ERROR )
300
318
301
- columns = await self .session .show_columns (db_name , tbl_name )
302
- columns = [c if isinstance (c , Column ) else Column (** c ) for c in columns ]
319
+ columns = [
320
+ c if isinstance (c , Column ) else Column (** c )
321
+ for c in await self .session .show_columns (db_name , tbl_name or "" )
322
+ ]
303
323
columns = [c for c in columns if like .match (c .name )]
304
324
305
325
result_columns = [
@@ -333,8 +353,8 @@ async def _parse_show_columns(
333
353
return ResultSet (rows = rows , columns = result_columns )
334
354
335
355
async def _parse_show_index (
336
- self , extended , db_name , tbl_name
337
- ): # pylint: disable=unused-argument
356
+ self , extended : bool , db_name : Optional [ str ] , tbl_name : Optional [ str ]
357
+ ) -> ResultSet : # pylint: disable=unused-argument
338
358
result_columns = [
339
359
ResultColumn ("Table" , ColumnType .STRING ),
340
360
ResultColumn ("Non_unique" , ColumnType .TINY ),
@@ -355,8 +375,8 @@ async def _parse_show_index(
355
375
return ResultSet (rows = [], columns = result_columns )
356
376
357
377
async def _parse_show_variables (
358
- self , global_ , session , like
359
- ): # pylint: disable=unused-argument
378
+ self , global_ : bool , session : bool , like : re . Pattern
379
+ ) -> ResultSet : # pylint: disable=unused-argument
360
380
rows = list (self .vars .items ())
361
381
rows = [(k , v ) for k , v in rows if like .match (k )]
362
382
return ResultSet (
0 commit comments