Skip to content

feat: support partitioned dml in dbapi #1103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
)
if statement_type == ClientSideStatementType.RUN_PARTITIONED_QUERY:
return connection.run_partitioned_query(parsed_statement)
if statement_type == ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE:
return connection._set_autocommit_dml_mode(parsed_statement)


def _get_streamed_result_set(column_name, type_code, column_values):
Expand Down
7 changes: 7 additions & 0 deletions google/cloud/spanner_dbapi/client_side_statement_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
RE_RUN_PARTITIONED_QUERY = re.compile(
r"^\s*(RUN)\s+(PARTITIONED)\s+(QUERY)\s+(.+)", re.IGNORECASE
)
RE_SET_AUTOCOMMIT_DML_MODE = re.compile(
r"^\s*(SET)\s+(AUTOCOMMIT_DML_MODE)\s+(=)\s+(.+)", re.IGNORECASE
)


def parse_stmt(query):
Expand Down Expand Up @@ -82,6 +85,10 @@ def parse_stmt(query):
match = re.search(RE_RUN_PARTITION, query)
client_side_statement_params.append(match.group(3))
client_side_statement_type = ClientSideStatementType.RUN_PARTITION
elif RE_SET_AUTOCOMMIT_DML_MODE.match(query):
match = re.search(RE_SET_AUTOCOMMIT_DML_MODE, query)
client_side_statement_params.append(match.group(4))
client_side_statement_type = ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE
if client_side_statement_type is not None:
return ParsedStatement(
StatementType.CLIENT_SIDE,
Expand Down
50 changes: 50 additions & 0 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from google.cloud.spanner_dbapi.parse_utils import _get_statement_type
from google.cloud.spanner_dbapi.parsed_statement import (
StatementType,
AutocommitDmlMode,
)
from google.cloud.spanner_dbapi.partition_helper import PartitionId
from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement
Expand Down Expand Up @@ -116,6 +117,7 @@ def __init__(self, instance, database=None, read_only=False):
self._batch_mode = BatchMode.NONE
self._batch_dml_executor: BatchDmlExecutor = None
self._transaction_helper = TransactionRetryHelper(self)
self._autocommit_dml_mode: AutocommitDmlMode = AutocommitDmlMode.TRANSACTIONAL

@property
def spanner_client(self):
Expand Down Expand Up @@ -167,6 +169,23 @@ def database(self):
"""
return self._database

@property
def autocommit_dml_mode(self):
"""Modes for executing DML statements in autocommit mode for this connection.

The DML autocommit modes are:
1) TRANSACTIONAL - DML statements are executed as single read-write transaction.
After successful execution, the DML statement is guaranteed to have been applied
exactly once to the database.

2) PARTITIONED_NON_ATOMIC - DML statements are executed as partitioned DML transactions.
If an error occurs during the execution of the DML statement, it is possible that the
statement has been applied to some but not all of the rows specified in the statement.

:rtype: :class:`~google.cloud.spanner_dbapi.parsed_statement.AutocommitDmlMode`
"""
return self._autocommit_dml_mode

@property
@deprecated(
reason="This method is deprecated. Use _spanner_transaction_started field"
Expand Down Expand Up @@ -577,6 +596,37 @@ def run_partitioned_query(
partitioned_query, statement.params, statement.param_types
)

@check_not_closed
def _set_autocommit_dml_mode(
self,
parsed_statement: ParsedStatement,
):
autocommit_dml_mode_str = parsed_statement.client_side_statement_params[0]
autocommit_dml_mode = AutocommitDmlMode[autocommit_dml_mode_str.upper()]
self.set_autocommit_dml_mode(autocommit_dml_mode)

def set_autocommit_dml_mode(
self,
autocommit_dml_mode,
):
"""
Sets the mode for executing DML statements in autocommit mode for this connection.
This mode is only used when the connection is in autocommit mode, and may only
be set while the transaction is in autocommit mode and not in a temporary transaction.
"""

if self._client_transaction_started is True:
raise ProgrammingError(
"Cannot set autocommit DML mode while not in autocommit mode or while a transaction is active."
)
if self.read_only is True:
raise ProgrammingError(
"Cannot set autocommit DML mode for a read-only connection."
)
if self._batch_mode is not BatchMode.NONE:
raise ProgrammingError("Cannot set autocommit DML mode while in a batch.")
self._autocommit_dml_mode = autocommit_dml_mode

def _partitioned_query_validation(self, partitioned_query, statement):
if _get_statement_type(Statement(partitioned_query)) is not StatementType.QUERY:
raise ProgrammingError(
Expand Down
12 changes: 12 additions & 0 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
StatementType,
Statement,
ParsedStatement,
AutocommitDmlMode,
)
from google.cloud.spanner_dbapi.transaction_helper import CursorStatementType
from google.cloud.spanner_dbapi.utils import PeekIterator
Expand Down Expand Up @@ -272,6 +273,17 @@ def _execute(self, sql, args=None, call_from_execute_many=False):
self._batch_DDLs(sql)
if not self.connection._client_transaction_started:
self.connection.run_prior_DDL_statements()
elif (
self.connection.autocommit_dml_mode
is AutocommitDmlMode.PARTITIONED_NON_ATOMIC
):
self._row_count = self.connection.database.execute_partitioned_dml(
sql,
params=args,
param_types=self._parsed_statement.statement.param_types,
request_options=self.connection.request_options,
)
self._result_set = None
else:
self._execute_in_rw_transaction()

Expand Down
6 changes: 6 additions & 0 deletions google/cloud/spanner_dbapi/parsed_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ class ClientSideStatementType(Enum):
PARTITION_QUERY = 9
RUN_PARTITION = 10
RUN_PARTITIONED_QUERY = 11
SET_AUTOCOMMIT_DML_MODE = 12


class AutocommitDmlMode(Enum):
TRANSACTIONAL = 1
PARTITIONED_NON_ATOMIC = 2


@dataclass
Expand Down
22 changes: 22 additions & 0 deletions tests/system/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
OperationalError,
RetryAborted,
)
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
from google.cloud.spanner_v1 import JsonObject
from google.cloud.spanner_v1 import gapic_version as package_version
from google.api_core.datetime_helpers import DatetimeWithNanoseconds
Expand Down Expand Up @@ -669,6 +670,27 @@ def test_run_partitioned_query(self):
assert len(rows) == 10
self._conn.commit()

def test_partitioned_dml_query(self):
"""Test partitioned_dml query works in autocommit mode."""
self._cursor.execute("start batch dml")
for i in range(1, 11):
self._insert_row(i)
self._cursor.execute("run batch")
self._conn.commit()

self._conn.autocommit = True
self._cursor.execute("set autocommit_dml_mode = PARTITIONED_NON_ATOMIC")
self._cursor.execute("DELETE FROM contacts WHERE contact_id > 3")
assert self._cursor.rowcount == 7

self._cursor.execute("set autocommit_dml_mode = TRANSACTIONAL")
assert self._conn.autocommit_dml_mode == AutocommitDmlMode.TRANSACTIONAL

self._conn.autocommit = False
# Test changing autocommit_dml_mode is not allowed when connection is in autocommit mode
with pytest.raises(ProgrammingError):
self._cursor.execute("set autocommit_dml_mode = PARTITIONED_NON_ATOMIC")

def _insert_row(self, i):
self._cursor.execute(
f"""
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/spanner_dbapi/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
ParsedStatement,
StatementType,
Statement,
ClientSideStatementType,
AutocommitDmlMode,
)

PROJECT = "test-project"
Expand Down Expand Up @@ -433,6 +435,62 @@ def test_abort_dml_batch(self, mock_batch_dml_executor):
self.assertEqual(self._under_test._batch_mode, BatchMode.NONE)
self.assertEqual(self._under_test._batch_dml_executor, None)

def test_set_autocommit_dml_mode_with_autocommit_false(self):
self._under_test.autocommit = False
parsed_statement = ParsedStatement(
StatementType.CLIENT_SIDE,
Statement("sql"),
ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE,
["PARTITIONED_NON_ATOMIC"],
)

with self.assertRaises(ProgrammingError):
self._under_test._set_autocommit_dml_mode(parsed_statement)

def test_set_autocommit_dml_mode_with_readonly(self):
self._under_test.autocommit = True
self._under_test.read_only = True
parsed_statement = ParsedStatement(
StatementType.CLIENT_SIDE,
Statement("sql"),
ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE,
["PARTITIONED_NON_ATOMIC"],
)

with self.assertRaises(ProgrammingError):
self._under_test._set_autocommit_dml_mode(parsed_statement)

def test_set_autocommit_dml_mode_with_batch_mode(self):
self._under_test.autocommit = True
parsed_statement = ParsedStatement(
StatementType.CLIENT_SIDE,
Statement("sql"),
ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE,
["PARTITIONED_NON_ATOMIC"],
)

self._under_test._set_autocommit_dml_mode(parsed_statement)

assert (
self._under_test.autocommit_dml_mode
== AutocommitDmlMode.PARTITIONED_NON_ATOMIC
)

def test_set_autocommit_dml_mode(self):
self._under_test.autocommit = True
parsed_statement = ParsedStatement(
StatementType.CLIENT_SIDE,
Statement("sql"),
ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE,
["PARTITIONED_NON_ATOMIC"],
)

self._under_test._set_autocommit_dml_mode(parsed_statement)
assert (
self._under_test.autocommit_dml_mode
== AutocommitDmlMode.PARTITIONED_NON_ATOMIC
)

@mock.patch("google.cloud.spanner_v1.database.Database", autospec=True)
def test_run_prior_DDL_statements(self, mock_database):
from google.cloud.spanner_dbapi import Connection, InterfaceError
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/spanner_dbapi/test_parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,20 @@ def test_run_partitioned_query_classify_stmt(self):
),
)

def test_set_autocommit_dml_mode_stmt(self):
parsed_statement = classify_statement(
" set autocommit_dml_mode = PARTITIONED_NON_ATOMIC "
)
self.assertEqual(
parsed_statement,
ParsedStatement(
StatementType.CLIENT_SIDE,
Statement("set autocommit_dml_mode = PARTITIONED_NON_ATOMIC"),
ClientSideStatementType.SET_AUTOCOMMIT_DML_MODE,
["PARTITIONED_NON_ATOMIC"],
),
)

@unittest.skipIf(skip_condition, skip_message)
def test_sql_pyformat_args_to_spanner(self):
from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner
Expand Down