Skip to content

Commit dc8c312

Browse files
author
cecily_carver
committed
Reverting to at-use instantiation of VariableProcessor.
1 parent 91427cd commit dc8c312

File tree

2 files changed

+15
-16
lines changed

2 files changed

+15
-16
lines changed

mysql_mimic/session.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ class Session(BaseSession):
187187
dialect: Type[Dialect] = MySQL
188188

189189
def __init__(self, variables: Variables | None = None):
190-
self._variable_processor: VariableProcessor
191190
self.variables = variables or SessionVariables(GlobalVariables())
192191

193192
# Query middlewares.
@@ -278,9 +277,6 @@ async def close(self) -> None:
278277

279278
async def handle_query(self, sql: str, attrs: Dict[str, str]) -> AllowedResult:
280279
self.timestamp = datetime.now(tz=self.timezone())
281-
self._variable_processor = VariableProcessor(
282-
mysql_function_mapping(self), self.variables
283-
)
284280
result = None
285281
for expression in self._parse(sql):
286282
if not expression:
@@ -306,7 +302,9 @@ async def _query_info_schema(self, expression: exp.Expression) -> AllowedResult:
306302

307303
async def _set_var_middleware(self, q: Query) -> AllowedResult:
308304
"""Handles SET_VAR hints and replaces functions defined in the _functions mapping with their mapped values."""
309-
with self._variable_processor.set_variables(q.expression):
305+
with VariableProcessor(
306+
mysql_function_mapping(self), self.variables, q.expression
307+
).set_variables():
310308
return await q.next()
311309

312310
async def _use_middleware(self, q: Query) -> AllowedResult:

mysql_mimic/variable_processor.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -50,35 +50,36 @@ class VariableProcessor:
5050
original values.
5151
"""
5252

53-
def __init__(self, functions: Mapping, variables: Variables):
53+
def __init__(
54+
self, functions: Mapping, variables: Variables, expression: exp.Expression
55+
):
5456
self._functions = functions
5557
self._variables = variables
58+
self._expression = expression
5659

5760
# Stores the original system variable values.
5861
self._orig: Dict[str, str] = {}
5962

6063
@contextmanager
61-
def set_variables(
62-
self, expression: exp.Expression
63-
) -> Generator[exp.Expression, None, None]:
64-
assignments = _get_var_assignments(expression)
64+
def set_variables(self) -> Generator[exp.Expression, None, None]:
65+
assignments = _get_var_assignments(self._expression)
6566
self._orig = {k: self._variables.get(k) for k in assignments}
6667
for k, v in assignments.items():
6768
self._variables.set(k, v)
6869

69-
self._replace_variables(expression)
70+
self._replace_variables()
7071

71-
yield expression
72+
yield self._expression
7273

7374
for k, v in self._orig.items():
7475
self._variables.set(k, v)
7576

76-
def _replace_variables(self, expression: exp.Expression) -> None:
77+
def _replace_variables(self) -> None:
7778
"""Replaces certain functions in the query with literals provided from the mapping in _functions,
7879
and session parameters with the values of the session variables.
7980
"""
80-
if isinstance(expression, exp.Set):
81-
for setitem in expression.expressions:
81+
if isinstance(self._expression, exp.Set):
82+
for setitem in self._expression.expressions:
8283
if isinstance(setitem.this, exp.Binary):
8384
# In the case of statements like: SET @@foo = @@bar
8485
# We only want to replace variables on the right
@@ -87,7 +88,7 @@ def _replace_variables(self, expression: exp.Expression) -> None:
8788
setitem.this.expression.transform(self._transform, copy=True),
8889
)
8990
else:
90-
expression.transform(self._transform, copy=False)
91+
self._expression.transform(self._transform, copy=False)
9192

9293
def _transform(self, node: exp.Expression) -> exp.Expression:
9394
new_node = None

0 commit comments

Comments
 (0)