Skip to content

Commit eb41b0d

Browse files
authored
feat: Implementing client side statements in dbapi (starting with commit) (#1037)
* Implementing client side statement in dbapi starting with commit * Fixing comments * Adding dependency on "deprecated" package * Fix in setup.py * Fixing tests * Lint issue fix * Resolving comments * Fixing formatting issue
1 parent 07fbc45 commit eb41b0d

File tree

9 files changed

+292
-91
lines changed

9 files changed

+292
-91
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2023 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from google.cloud.spanner_dbapi.parsed_statement import (
15+
ParsedStatement,
16+
ClientSideStatementType,
17+
)
18+
19+
20+
def execute(connection, parsed_statement: ParsedStatement):
21+
"""Executes the client side statements by calling the relevant method.
22+
23+
It is an internal method that can make backwards-incompatible changes.
24+
25+
:type parsed_statement: ParsedStatement
26+
:param parsed_statement: parsed_statement based on the sql query
27+
"""
28+
if parsed_statement.client_side_statement_type == ClientSideStatementType.COMMIT:
29+
return connection.commit()
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2023 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import re
16+
17+
from google.cloud.spanner_dbapi.parsed_statement import (
18+
ParsedStatement,
19+
StatementType,
20+
ClientSideStatementType,
21+
)
22+
23+
RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE)
24+
25+
26+
def parse_stmt(query):
27+
"""Parses the sql query to check if it matches with any of the client side
28+
statement regex.
29+
30+
It is an internal method that can make backwards-incompatible changes.
31+
32+
:type query: str
33+
:param query: sql query
34+
35+
:rtype: ParsedStatement
36+
:returns: ParsedStatement object.
37+
"""
38+
if RE_COMMIT.match(query):
39+
return ParsedStatement(
40+
StatementType.CLIENT_SIDE, query, ClientSideStatementType.COMMIT
41+
)
42+
return None

google/cloud/spanner_dbapi/cursor.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,14 @@
3232
from google.cloud.spanner_dbapi.exceptions import OperationalError
3333
from google.cloud.spanner_dbapi.exceptions import ProgrammingError
3434

35-
from google.cloud.spanner_dbapi import _helpers
35+
from google.cloud.spanner_dbapi import _helpers, client_side_statement_executor
3636
from google.cloud.spanner_dbapi._helpers import ColumnInfo
3737
from google.cloud.spanner_dbapi._helpers import CODE_TO_DISPLAY_SIZE
3838

3939
from google.cloud.spanner_dbapi import parse_utils
4040
from google.cloud.spanner_dbapi.parse_utils import get_param_types
4141
from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner
42+
from google.cloud.spanner_dbapi.parsed_statement import StatementType
4243
from google.cloud.spanner_dbapi.utils import PeekIterator
4344
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets
4445

@@ -210,7 +211,10 @@ def _batch_DDLs(self, sql):
210211
for ddl in sqlparse.split(sql):
211212
if ddl:
212213
ddl = ddl.rstrip(";")
213-
if parse_utils.classify_stmt(ddl) != parse_utils.STMT_DDL:
214+
if (
215+
parse_utils.classify_statement(ddl).statement_type
216+
!= StatementType.DDL
217+
):
214218
raise ValueError("Only DDL statements may be batched.")
215219

216220
statements.append(ddl)
@@ -239,8 +243,12 @@ def execute(self, sql, args=None):
239243
self._handle_DQL(sql, args or None)
240244
return
241245

242-
class_ = parse_utils.classify_stmt(sql)
243-
if class_ == parse_utils.STMT_DDL:
246+
parsed_statement = parse_utils.classify_statement(sql)
247+
if parsed_statement.statement_type == StatementType.CLIENT_SIDE:
248+
return client_side_statement_executor.execute(
249+
self.connection, parsed_statement
250+
)
251+
if parsed_statement.statement_type == StatementType.DDL:
244252
self._batch_DDLs(sql)
245253
if self.connection.autocommit:
246254
self.connection.run_prior_DDL_statements()
@@ -251,7 +259,7 @@ def execute(self, sql, args=None):
251259
# self._run_prior_DDL_statements()
252260
self.connection.run_prior_DDL_statements()
253261

254-
if class_ == parse_utils.STMT_UPDATING:
262+
if parsed_statement.statement_type == StatementType.UPDATE:
255263
sql = parse_utils.ensure_where_clause(sql)
256264

257265
sql, args = sql_pyformat_args_to_spanner(sql, args or None)
@@ -276,7 +284,7 @@ def execute(self, sql, args=None):
276284
self.connection.retry_transaction()
277285
return
278286

279-
if class_ == parse_utils.STMT_NON_UPDATING:
287+
if parsed_statement.statement_type == StatementType.QUERY:
280288
self._handle_DQL(sql, args or None)
281289
else:
282290
self.connection.database.run_in_transaction(
@@ -309,19 +317,29 @@ def executemany(self, operation, seq_of_params):
309317
self._result_set = None
310318
self._row_count = _UNSET_COUNT
311319

312-
class_ = parse_utils.classify_stmt(operation)
313-
if class_ == parse_utils.STMT_DDL:
320+
parsed_statement = parse_utils.classify_statement(operation)
321+
if parsed_statement.statement_type == StatementType.DDL:
314322
raise ProgrammingError(
315323
"Executing DDL statements with executemany() method is not allowed."
316324
)
317325

326+
if parsed_statement.statement_type == StatementType.CLIENT_SIDE:
327+
raise ProgrammingError(
328+
"Executing the following operation: "
329+
+ operation
330+
+ ", with executemany() method is not allowed."
331+
)
332+
318333
# For every operation, we've got to ensure that any prior DDL
319334
# statements were run.
320335
self.connection.run_prior_DDL_statements()
321336

322337
many_result_set = StreamedManyResultSets()
323338

324-
if class_ in (parse_utils.STMT_INSERT, parse_utils.STMT_UPDATING):
339+
if parsed_statement.statement_type in (
340+
StatementType.INSERT,
341+
StatementType.UPDATE,
342+
):
325343
statements = []
326344

327345
for params in seq_of_params:

google/cloud/spanner_dbapi/parse_utils.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@
2121
import sqlparse
2222
from google.cloud import spanner_v1 as spanner
2323
from google.cloud.spanner_v1 import JsonObject
24+
from . import client_side_statement_parser
25+
from deprecated import deprecated
2426

2527
from .exceptions import Error
28+
from .parsed_statement import ParsedStatement, StatementType
2629
from .types import DateStr, TimestampStr
2730
from .utils import sanitize_literals_for_upload
2831

@@ -174,12 +177,11 @@
174177
RE_PYFORMAT = re.compile(r"(%s|%\([^\(\)]+\)s)+", re.DOTALL)
175178

176179

180+
@deprecated(reason="This method is deprecated. Use _classify_stmt method")
177181
def classify_stmt(query):
178182
"""Determine SQL query type.
179-
180183
:type query: str
181184
:param query: A SQL query.
182-
183185
:rtype: str
184186
:returns: The query type name.
185187
"""
@@ -203,6 +205,39 @@ def classify_stmt(query):
203205
return STMT_UPDATING
204206

205207

208+
def classify_statement(query):
209+
"""Determine SQL query type.
210+
211+
It is an internal method that can make backwards-incompatible changes.
212+
213+
:type query: str
214+
:param query: A SQL query.
215+
216+
:rtype: ParsedStatement
217+
:returns: parsed statement attributes.
218+
"""
219+
# sqlparse will strip Cloud Spanner comments,
220+
# still, special commenting styles, like
221+
# PostgreSQL dollar quoted comments are not
222+
# supported and will not be stripped.
223+
query = sqlparse.format(query, strip_comments=True).strip()
224+
parsed_statement = client_side_statement_parser.parse_stmt(query)
225+
if parsed_statement is not None:
226+
return parsed_statement
227+
if RE_DDL.match(query):
228+
return ParsedStatement(StatementType.DDL, query)
229+
230+
if RE_IS_INSERT.match(query):
231+
return ParsedStatement(StatementType.INSERT, query)
232+
233+
if RE_NON_UPDATE.match(query) or RE_WITH.match(query):
234+
# As of 13-March-2020, Cloud Spanner only supports WITH for DQL
235+
# statements and doesn't yet support WITH for DML statements.
236+
return ParsedStatement(StatementType.QUERY, query)
237+
238+
return ParsedStatement(StatementType.UPDATE, query)
239+
240+
206241
def sql_pyformat_args_to_spanner(sql, params):
207242
"""
208243
Transform pyformat set SQL to named arguments for Cloud Spanner.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 20203 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass
16+
from enum import Enum
17+
18+
19+
class StatementType(Enum):
20+
CLIENT_SIDE = 1
21+
DDL = 2
22+
QUERY = 3
23+
UPDATE = 4
24+
INSERT = 5
25+
26+
27+
class ClientSideStatementType(Enum):
28+
COMMIT = 1
29+
BEGIN = 2
30+
31+
32+
@dataclass
33+
class ParsedStatement:
34+
statement_type: StatementType
35+
query: str
36+
client_side_statement_type: ClientSideStatementType = None

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
"proto-plus >= 1.22.0, <2.0.0dev",
4343
"sqlparse >= 0.4.4",
4444
"protobuf>=3.19.5,<5.0.0dev,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5",
45+
"deprecated >= 1.2.14",
4546
]
4647
extras = {
4748
"tracing": [

tests/system/test_dbapi.py

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
from google.cloud import spanner_v1
2222
from google.cloud._helpers import UTC
23+
24+
from google.cloud.spanner_dbapi import Cursor
2325
from google.cloud.spanner_dbapi.connection import connect
2426
from google.cloud.spanner_dbapi.connection import Connection
2527
from google.cloud.spanner_dbapi.exceptions import ProgrammingError
@@ -72,37 +74,11 @@ def dbapi_database(raw_database):
7274

7375
def test_commit(shared_instance, dbapi_database):
7476
"""Test committing a transaction with several statements."""
75-
want_row = (
76-
1,
77-
"updated-first-name",
78-
"last-name",
79-
80-
)
8177
# connect to the test database
8278
conn = Connection(shared_instance, dbapi_database)
8379
cursor = conn.cursor()
8480

85-
# execute several DML statements within one transaction
86-
cursor.execute(
87-
"""
88-
INSERT INTO contacts (contact_id, first_name, last_name, email)
89-
VALUES (1, 'first-name', 'last-name', '[email protected]')
90-
"""
91-
)
92-
cursor.execute(
93-
"""
94-
UPDATE contacts
95-
SET first_name = 'updated-first-name'
96-
WHERE first_name = 'first-name'
97-
"""
98-
)
99-
cursor.execute(
100-
"""
101-
UPDATE contacts
102-
SET email = '[email protected]'
103-
WHERE email = '[email protected]'
104-
"""
105-
)
81+
want_row = _execute_common_precommit_statements(cursor)
10682
conn.commit()
10783

10884
# read the resulting data from the database
@@ -116,6 +92,25 @@ def test_commit(shared_instance, dbapi_database):
11692
conn.close()
11793

11894

95+
def test_commit_client_side(shared_instance, dbapi_database):
96+
"""Test committing a transaction with several statements."""
97+
# connect to the test database
98+
conn = Connection(shared_instance, dbapi_database)
99+
cursor = conn.cursor()
100+
101+
want_row = _execute_common_precommit_statements(cursor)
102+
cursor.execute("""COMMIT""")
103+
104+
# read the resulting data from the database
105+
cursor.execute("SELECT * FROM contacts")
106+
got_rows = cursor.fetchall()
107+
conn.commit()
108+
cursor.close()
109+
conn.close()
110+
111+
assert got_rows == [want_row]
112+
113+
119114
def test_rollback(shared_instance, dbapi_database):
120115
"""Test rollbacking a transaction with several statements."""
121116
want_row = (2, "first-name", "last-name", "[email protected]")
@@ -810,3 +805,33 @@ def test_dml_returning_delete(shared_instance, dbapi_database, autocommit):
810805
assert cur.fetchone() == (1, "first-name")
811806
assert cur.rowcount == 1
812807
conn.commit()
808+
809+
810+
def _execute_common_precommit_statements(cursor: Cursor):
811+
# execute several DML statements within one transaction
812+
cursor.execute(
813+
"""
814+
INSERT INTO contacts (contact_id, first_name, last_name, email)
815+
VALUES (1, 'first-name', 'last-name', '[email protected]')
816+
"""
817+
)
818+
cursor.execute(
819+
"""
820+
UPDATE contacts
821+
SET first_name = 'updated-first-name'
822+
WHERE first_name = 'first-name'
823+
"""
824+
)
825+
cursor.execute(
826+
"""
827+
UPDATE contacts
828+
SET email = '[email protected]'
829+
WHERE email = '[email protected]'
830+
"""
831+
)
832+
return (
833+
1,
834+
"updated-first-name",
835+
"last-name",
836+
837+
)

0 commit comments

Comments
 (0)