Skip to content

feat(spanner): add support for txn changstream exclusion #1152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ def _check_state(self):
raise ValueError("Batch already committed")

def commit(
self, return_commit_stats=False, request_options=None, max_commit_delay=None
self,
return_commit_stats=False,
request_options=None,
max_commit_delay=None,
exclude_txn_from_change_streams=False,
):
"""Commit mutations to the database.

Expand Down Expand Up @@ -178,7 +182,10 @@ def commit(
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
txn_options = TransactionOptions(
read_write=TransactionOptions.ReadWrite(),
exclude_txn_from_change_streams=exclude_txn_from_change_streams,
)
trace_attributes = {"num_mutations": len(self._mutations)}

if request_options is None:
Expand Down Expand Up @@ -270,7 +277,7 @@ def group(self):
self._mutation_groups.append(mutation_group)
return MutationGroup(self._session, mutation_group.mutations)

def batch_write(self, request_options=None):
def batch_write(self, request_options=None, exclude_txn_from_change_streams=False):
"""Executes batch_write.

:type request_options:
Expand All @@ -280,6 +287,13 @@ def batch_write(self, request_options=None):
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.

:type exclude_txn_from_change_streams: bool
:param exclude_txn_from_change_streams:
(Optional) If true, instructs the transaction to be excluded from being recorded in change streams
with the DDL option `allow_txn_exclusion=true`. This does not exclude the transaction from
being recorded in the change streams with the DDL option `allow_txn_exclusion` being false or
unset.

:rtype: :class:`Iterable[google.cloud.spanner_v1.types.BatchWriteResponse]`
:returns: a sequence of responses for each batch.
"""
Expand All @@ -302,6 +316,7 @@ def batch_write(self, request_options=None):
session=self._session.name,
mutation_groups=self._mutation_groups,
request_options=request_options,
exclude_txn_from_change_streams=exclude_txn_from_change_streams,
)
with trace_call("CloudSpanner.BatchWrite", self._session, trace_attributes):
method = functools.partial(
Expand Down
43 changes: 39 additions & 4 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ def execute_partitioned_dml(
param_types=None,
query_options=None,
request_options=None,
exclude_txn_from_change_streams=False,
):
"""Execute a partitionable DML statement.

Expand Down Expand Up @@ -651,6 +652,13 @@ def execute_partitioned_dml(
Please note, the `transactionTag` setting will be ignored as it is
not supported for partitioned DML.

:type exclude_txn_from_change_streams: bool
:param exclude_txn_from_change_streams:
(Optional) If true, instructs the transaction to be excluded from being recorded in change streams
with the DDL option `allow_txn_exclusion=true`. This does not exclude the transaction from
being recorded in the change streams with the DDL option `allow_txn_exclusion` being false or
unset.

:rtype: int
:returns: Count of rows affected by the DML statement.
"""
Expand All @@ -673,7 +681,8 @@ def execute_partitioned_dml(
api = self.spanner_api

txn_options = TransactionOptions(
partitioned_dml=TransactionOptions.PartitionedDml()
partitioned_dml=TransactionOptions.PartitionedDml(),
exclude_txn_from_change_streams=exclude_txn_from_change_streams,
)

metadata = _metadata_with_prefix(self.name)
Expand Down Expand Up @@ -752,7 +761,12 @@ def snapshot(self, **kw):
"""
return SnapshotCheckout(self, **kw)

def batch(self, request_options=None, max_commit_delay=None):
def batch(
self,
request_options=None,
max_commit_delay=None,
exclude_txn_from_change_streams=False,
):
"""Return an object which wraps a batch.

The wrapper *must* be used as a context manager, with the batch
Expand All @@ -771,10 +785,19 @@ def batch(self, request_options=None, max_commit_delay=None):
in order to improve throughput. Value must be between 0ms and
500ms.

:type exclude_txn_from_change_streams: bool
:param exclude_txn_from_change_streams:
(Optional) If true, instructs the transaction to be excluded from being recorded in change streams
with the DDL option `allow_txn_exclusion=true`. This does not exclude the transaction from
being recorded in the change streams with the DDL option `allow_txn_exclusion` being false or
unset.

:rtype: :class:`~google.cloud.spanner_v1.database.BatchCheckout`
:returns: new wrapper
"""
return BatchCheckout(self, request_options, max_commit_delay)
return BatchCheckout(
self, request_options, max_commit_delay, exclude_txn_from_change_streams
)

def mutation_groups(self):
"""Return an object which wraps a mutation_group.
Expand Down Expand Up @@ -840,6 +863,10 @@ def run_in_transaction(self, func, *args, **kw):
"max_commit_delay" will be removed and used to set the
max_commit_delay for the request. Value must be between
0ms and 500ms.
"exclude_txn_from_change_streams" if true, instructs the transaction to be excluded
from being recorded in change streams with the DDL option `allow_txn_exclusion=true`.
This does not exclude the transaction from being recorded in the change streams with
the DDL option `allow_txn_exclusion` being false or unset.

:rtype: Any
:returns: The return value of ``func``.
Expand Down Expand Up @@ -1103,7 +1130,13 @@ class BatchCheckout(object):
in order to improve throughput.
"""

def __init__(self, database, request_options=None, max_commit_delay=None):
def __init__(
self,
database,
request_options=None,
max_commit_delay=None,
exclude_txn_from_change_streams=False,
):
self._database = database
self._session = self._batch = None
if request_options is None:
Expand All @@ -1113,6 +1146,7 @@ def __init__(self, database, request_options=None, max_commit_delay=None):
else:
self._request_options = request_options
self._max_commit_delay = max_commit_delay
self._exclude_txn_from_change_streams = exclude_txn_from_change_streams

def __enter__(self):
"""Begin ``with`` block."""
Expand All @@ -1130,6 +1164,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return_commit_stats=self._database.log_commit_stats,
request_options=self._request_options,
max_commit_delay=self._max_commit_delay,
exclude_txn_from_change_streams=self._exclude_txn_from_change_streams,
)
finally:
if self._database.log_commit_stats and self._batch.commit_stats:
Expand Down
8 changes: 8 additions & 0 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,10 @@ def run_in_transaction(self, func, *args, **kw):
request options for the commit request.
"max_commit_delay" will be removed and used to set the max commit delay for the request.
"transaction_tag" will be removed and used to set the transaction tag for the request.
"exclude_txn_from_change_streams" if true, instructs the transaction to be excluded
from being recorded in change streams with the DDL option `allow_txn_exclusion=true`.
This does not exclude the transaction from being recorded in the change streams with
the DDL option `allow_txn_exclusion` being false or unset.

:rtype: Any
:returns: The return value of ``func``.
Expand All @@ -398,12 +402,16 @@ def run_in_transaction(self, func, *args, **kw):
commit_request_options = kw.pop("commit_request_options", None)
max_commit_delay = kw.pop("max_commit_delay", None)
transaction_tag = kw.pop("transaction_tag", None)
exclude_txn_from_change_streams = kw.pop(
"exclude_txn_from_change_streams", None
)
attempts = 0

while True:
if self._transaction is None:
txn = self.transaction()
txn.transaction_tag = transaction_tag
txn.exclude_txn_from_change_streams = exclude_txn_from_change_streams
else:
txn = self._transaction

Expand Down
11 changes: 9 additions & 2 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class Transaction(_SnapshotBase, _BatchBase):
_execute_sql_count = 0
_lock = threading.Lock()
_read_only = False
exclude_txn_from_change_streams = False

def __init__(self, session):
if session._transaction is not None:
Expand Down Expand Up @@ -86,7 +87,10 @@ def _make_txn_selector(self):

if self._transaction_id is None:
return TransactionSelector(
begin=TransactionOptions(read_write=TransactionOptions.ReadWrite())
begin=TransactionOptions(
read_write=TransactionOptions.ReadWrite(),
exclude_txn_from_change_streams=self.exclude_txn_from_change_streams,
)
)
else:
return TransactionSelector(id=self._transaction_id)
Expand Down Expand Up @@ -137,7 +141,10 @@ def begin(self):
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
txn_options = TransactionOptions(
read_write=TransactionOptions.ReadWrite(),
exclude_txn_from_change_streams=self.exclude_txn_from_change_streams,
)
with trace_call("CloudSpanner.BeginTransaction", self._session):
method = functools.partial(
api.begin_transaction,
Expand Down
42 changes: 38 additions & 4 deletions tests/unit/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,12 @@ def test_commit_ok(self):
"CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1)
)

def _test_commit_with_options(self, request_options=None, max_commit_delay_in=None):
def _test_commit_with_options(
self,
request_options=None,
max_commit_delay_in=None,
exclude_txn_from_change_streams=False,
):
import datetime
from google.cloud.spanner_v1 import CommitResponse
from google.cloud.spanner_v1 import TransactionOptions
Expand All @@ -276,7 +281,9 @@ def _test_commit_with_options(self, request_options=None, max_commit_delay_in=No
batch.transaction_tag = self.TRANSACTION_TAG
batch.insert(TABLE_NAME, COLUMNS, VALUES)
committed = batch.commit(
request_options=request_options, max_commit_delay=max_commit_delay_in
request_options=request_options,
max_commit_delay=max_commit_delay_in,
exclude_txn_from_change_streams=exclude_txn_from_change_streams,
)

self.assertEqual(committed, now)
Expand All @@ -301,6 +308,10 @@ def _test_commit_with_options(self, request_options=None, max_commit_delay_in=No
self.assertEqual(mutations, batch._mutations)
self.assertIsInstance(single_use_txn, TransactionOptions)
self.assertTrue(type(single_use_txn).pb(single_use_txn).HasField("read_write"))
self.assertEqual(
single_use_txn.exclude_txn_from_change_streams,
exclude_txn_from_change_streams,
)
self.assertEqual(
metadata,
[
Expand Down Expand Up @@ -355,6 +366,14 @@ def test_commit_w_max_commit_delay(self):
max_commit_delay_in=datetime.timedelta(milliseconds=100),
)

def test_commit_w_exclude_txn_from_change_streams(self):
request_options = RequestOptions(
request_tag="tag-1",
)
self._test_commit_with_options(
request_options=request_options, exclude_txn_from_change_streams=True
)

def test_context_mgr_already_committed(self):
import datetime
from google.cloud._helpers import UTC
Expand Down Expand Up @@ -499,7 +518,9 @@ def test_batch_write_grpc_error(self):
attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1),
)

def _test_batch_write_with_request_options(self, request_options=None):
def _test_batch_write_with_request_options(
self, request_options=None, exclude_txn_from_change_streams=False
):
import datetime
from google.cloud.spanner_v1 import BatchWriteResponse
from google.cloud._helpers import UTC
Expand All @@ -519,7 +540,10 @@ def _test_batch_write_with_request_options(self, request_options=None):
group = groups.group()
group.insert(TABLE_NAME, COLUMNS, VALUES)

response_iter = groups.batch_write(request_options)
response_iter = groups.batch_write(
request_options,
exclude_txn_from_change_streams=exclude_txn_from_change_streams,
)
self.assertEqual(len(response_iter), 1)
self.assertEqual(response_iter[0], response)

Expand All @@ -528,6 +552,7 @@ def _test_batch_write_with_request_options(self, request_options=None):
mutation_groups,
actual_request_options,
metadata,
request_exclude_txn_from_change_streams,
) = api._batch_request
self.assertEqual(session, self.SESSION_NAME)
self.assertEqual(mutation_groups, groups._mutation_groups)
Expand All @@ -545,6 +570,9 @@ def _test_batch_write_with_request_options(self, request_options=None):
else:
expected_request_options = request_options
self.assertEqual(actual_request_options, expected_request_options)
self.assertEqual(
request_exclude_txn_from_change_streams, exclude_txn_from_change_streams
)

self.assertSpanAttributes(
"CloudSpanner.BatchWrite",
Expand All @@ -567,6 +595,11 @@ def test_batch_write_w_incorrect_tag_dictionary_error(self):
with self.assertRaises(ValueError):
self._test_batch_write_with_request_options({"incorrect_tag": "tag-1-1"})

def test_batch_write_w_exclude_txn_from_change_streams(self):
self._test_batch_write_with_request_options(
exclude_txn_from_change_streams=True
)


class _Session(object):
def __init__(self, database=None, name=TestBatch.SESSION_NAME):
Expand Down Expand Up @@ -625,6 +658,7 @@ def batch_write(
request.mutation_groups,
request.request_options,
metadata,
request.exclude_txn_from_change_streams,
)
if self._rpc_error:
raise Unknown("error")
Expand Down
16 changes: 14 additions & 2 deletions tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,7 @@ def _execute_partitioned_dml_helper(
query_options=None,
request_options=None,
retried=False,
exclude_txn_from_change_streams=False,
):
from google.api_core.exceptions import Aborted
from google.api_core.retry import Retry
Expand Down Expand Up @@ -1129,13 +1130,19 @@ def _execute_partitioned_dml_helper(
api.execute_streaming_sql.return_value = iterator

row_count = database.execute_partitioned_dml(
dml, params, param_types, query_options, request_options
dml,
params,
param_types,
query_options,
request_options,
exclude_txn_from_change_streams,
)

self.assertEqual(row_count, 2)

txn_options = TransactionOptions(
partitioned_dml=TransactionOptions.PartitionedDml()
partitioned_dml=TransactionOptions.PartitionedDml(),
exclude_txn_from_change_streams=exclude_txn_from_change_streams,
)

api.begin_transaction.assert_called_with(
Expand Down Expand Up @@ -1250,6 +1257,11 @@ def test_execute_partitioned_dml_w_req_tag_used(self):
def test_execute_partitioned_dml_wo_params_retry_aborted(self):
self._execute_partitioned_dml_helper(dml=DML_WO_PARAM, retried=True)

def test_execute_partitioned_dml_w_exclude_txn_from_change_streams(self):
self._execute_partitioned_dml_helper(
dml=DML_WO_PARAM, exclude_txn_from_change_streams=True
)

def test_session_factory_defaults(self):
from google.cloud.spanner_v1.session import Session

Expand Down
Loading