@@ -50,35 +50,36 @@ class VariableProcessor:
50
50
original values.
51
51
"""
52
52
53
- def __init__ (self , functions : Mapping , variables : Variables ):
53
+ def __init__ (
54
+ self , functions : Mapping , variables : Variables , expression : exp .Expression
55
+ ):
54
56
self ._functions = functions
55
57
self ._variables = variables
58
+ self ._expression = expression
56
59
57
60
# Stores the original system variable values.
58
61
self ._orig : Dict [str , str ] = {}
59
62
60
63
@contextmanager
61
- def set_variables (
62
- self , expression : exp .Expression
63
- ) -> Generator [exp .Expression , None , None ]:
64
- assignments = _get_var_assignments (expression )
64
+ def set_variables (self ) -> Generator [exp .Expression , None , None ]:
65
+ assignments = _get_var_assignments (self ._expression )
65
66
self ._orig = {k : self ._variables .get (k ) for k in assignments }
66
67
for k , v in assignments .items ():
67
68
self ._variables .set (k , v )
68
69
69
- self ._replace_variables (expression )
70
+ self ._replace_variables ()
70
71
71
- yield expression
72
+ yield self . _expression
72
73
73
74
for k , v in self ._orig .items ():
74
75
self ._variables .set (k , v )
75
76
76
- def _replace_variables (self , expression : exp . Expression ) -> None :
77
+ def _replace_variables (self ) -> None :
77
78
"""Replaces certain functions in the query with literals provided from the mapping in _functions,
78
79
and session parameters with the values of the session variables.
79
80
"""
80
- if isinstance (expression , exp .Set ):
81
- for setitem in expression .expressions :
81
+ if isinstance (self . _expression , exp .Set ):
82
+ for setitem in self . _expression .expressions :
82
83
if isinstance (setitem .this , exp .Binary ):
83
84
# In the case of statements like: SET @@foo = @@bar
84
85
# We only want to replace variables on the right
@@ -87,7 +88,7 @@ def _replace_variables(self, expression: exp.Expression) -> None:
87
88
setitem .this .expression .transform (self ._transform , copy = True ),
88
89
)
89
90
else :
90
- expression .transform (self ._transform , copy = False )
91
+ self . _expression .transform (self ._transform , copy = False )
91
92
92
93
def _transform (self , node : exp .Expression ) -> exp .Expression :
93
94
new_node = None
0 commit comments