Skip to content

Commit 982b998

Browse files
committed
added new_transaction to query and aggregation classes
1 parent d63cb90 commit 982b998

File tree

5 files changed

+100
-7
lines changed

5 files changed

+100
-7
lines changed

google/cloud/datastore/aggregation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,13 +442,17 @@ def _next_page(self):
442442
return None
443443

444444
query_pb = self._build_protobuf()
445+
new_transaction_options = None
445446
transaction = self.client.current_transaction
446447
if transaction is None:
447448
transaction_id = None
448449
else:
449450
transaction_id = transaction.id
451+
if transation._begin_later and transaction._status == transaction._INITIAL:
452+
# if transaction hasn't been initialized, initialize it as part of this request
453+
new_transaction_options = transaction._options
450454
read_options = helpers.get_read_options(
451-
self._eventual, transaction_id, self._read_time
455+
self._eventual, transaction_id, self._read_time, new_transaction_options
452456
)
453457

454458
partition_id = entity_pb2.PartitionId(

google/cloud/datastore/client.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,7 @@ def _extended_lookup(
203203
new_transaction_options = None
204204
if transaction is not None:
205205
transaction_id = transaction.id
206-
if (
207-
transaction_id is None
208-
and transaction._begin_later
209-
and transaction._status == transaction._INITIAL
210-
):
206+
if and transaction._begin_later and transaction._status == transaction._INITIAL:
211207
# if transaction hasn't been initialized, initialize it as part of this request
212208
new_transaction_options = transaction._options
213209

google/cloud/datastore/query.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -778,13 +778,16 @@ def _next_page(self):
778778
return None
779779

780780
query_pb = self._build_protobuf()
781+
new_transaction_options = None
781782
transaction = self.client.current_transaction
782783
if transaction is None:
783784
transaction_id = None
784785
else:
785786
transaction_id = transaction.id
787+
# if transaction hasn't been initialized, initialize it as part of this request
788+
new_transaction_options = transaction._options
786789
read_options = helpers.get_read_options(
787-
self._eventual, transaction_id, self._read_time
790+
self._eventual, transaction_id, self._read_time, new_transaction_options
788791
)
789792

790793
partition_id = entity_pb2.PartitionId(

tests/unit/test_aggregation.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,54 @@ def test_transaction_id_populated(database_id, aggregation_type, aggregation_arg
612612
assert read_options.transaction == client.current_transaction.id
613613

614614

615+
@pytest.mark.parametrize("database_id", [None, "somedb"], indirect=True)
616+
@pytest.mark.parametrize(
617+
"aggregation_type,aggregation_args",
618+
[
619+
("count", ()),
620+
(
621+
"sum",
622+
("appearances",),
623+
),
624+
("avg", ("appearances",)),
625+
],
626+
)
627+
def test_transaction_begin_later(database_id, aggregation_type, aggregation_args):
628+
"""
629+
When an aggregation is run in the context of a transaction with begin_later=True,
630+
the new_transaction field should be populated in the request read_options.
631+
"""
632+
import mock
633+
634+
# make a fake begin_later transaction
635+
transaction = mock.Mock()
636+
transaction.id = None
637+
transaction._begin_later = True
638+
transaction._state = transaction._INITIAL
639+
mock_datastore_api = mock.Mock()
640+
mock_gapic = mock_datastore_api.run_aggregation_query
641+
mock_gapic.return_value = _make_aggregation_query_response([])
642+
client = _Client(
643+
None,
644+
datastore_api=mock_datastore_api,
645+
database=database_id,
646+
transaction=transaction,
647+
)
648+
649+
query = _make_query(client)
650+
aggregation_query = _make_aggregation_query(client=client, query=query)
651+
652+
# initiate requested aggregation (ex count, sum, avg)
653+
getattr(aggregation_query, aggregation_type)(*aggregation_args)
654+
# run mock query
655+
list(aggregation_query.fetch())
656+
assert mock_gapic.call_count == 1
657+
request = mock_gapic.call_args[1]["request"]
658+
read_options = request["read_options"]
659+
# ensure new_transaction is populated
660+
assert read_options.transaction is None
661+
assert read_options.new_transaction == transaction._options
662+
615663
class _Client(object):
616664
def __init__(
617665
self,

tests/unit/test_query.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,48 @@ def test_transaction_id_populated(database_id):
698698
assert read_options.transaction == client.current_transaction.id
699699

700700

701+
@pytest.mark.parametrize("database_id", [None, "somedb"])
702+
def test_transaction_begin_later(database_id):
703+
"""
704+
When an aggregation is run in the context of a transaction with begin_later=True,
705+
the new_transaction field should be populated in the request read_options.
706+
"""
707+
import mock
708+
709+
710+
# make a fake begin_later transaction
711+
transaction = mock.Mock()
712+
transaction.id = None
713+
transaction._begin_later = True
714+
transaction._state = transaction._INITIAL
715+
716+
mock_datastore_api = mock.Mock()
717+
mock_gapic = mock_datastore_api.run_query
718+
719+
more_results_enum = 3 # NO_MORE_RESULTS
720+
response_pb = _make_query_response([], b"", more_results_enum, 0)
721+
mock_gapic.return_value = response_pb
722+
723+
client = _Client(
724+
None,
725+
datastore_api=mock_datastore_api,
726+
database=database_id,
727+
transaction=transaction,
728+
)
729+
730+
query = _make_query(client)
731+
# run mock query
732+
list(query.fetch())
733+
assert mock_gapic.call_count == 1
734+
request = mock_gapic.call_args[1]["request"]
735+
read_options = request["read_options"]
736+
# ensure new_transaction is populated
737+
assert read_options.transaction is None
738+
assert read_options.new_transaction == transaction._options
739+
740+
741+
742+
701743
def test_iterator_constructor_defaults():
702744
query = object()
703745
client = object()

0 commit comments

Comments
 (0)