diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 1d8425aa48..bf1094180e 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -268,9 +268,9 @@ def _parse_value_pb(value_pb, field_type, field_name, column_info=None): elif type_code == TypeCode.PROTO: bytes_value = base64.b64decode(value_pb.string_value) if column_info is not None and column_info.get(field_name) is not None: - proto_message = column_info.get(field_name) - if isinstance(proto_message, Message): - proto_message = proto_message.__deepcopy__() + default_proto_message = column_info.get(field_name) + if isinstance(default_proto_message, Message): + proto_message = type(default_proto_message)() proto_message.ParseFromString(bytes_value) return proto_message return bytes_value diff --git a/samples/samples/conftest.py b/samples/samples/conftest.py index 6747199022..674d61099d 100644 --- a/samples/samples/conftest.py +++ b/samples/samples/conftest.py @@ -116,7 +116,13 @@ def multi_region_instance_config(spanner_client): @pytest.fixture(scope="module") def proto_descriptor_file(): - return open("../../samples/samples/testdata/descriptors.pb", 'rb').read() + import os + + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "testdata/descriptors.pb") + file = open(filename, "rb") + yield file.read() + file.close() @pytest.fixture(scope="module") @@ -213,8 +219,7 @@ def sample_database( sample_instance, database_id, database_ddl, - database_dialect, - proto_descriptor_file): + database_dialect): if database_dialect == DatabaseDialect.POSTGRESQL: sample_database = sample_instance.database( database_id, @@ -242,7 +247,6 @@ def sample_database( sample_database = sample_instance.database( database_id, ddl_statements=database_ddl, - proto_descriptors=proto_descriptor_file ) if not sample_database.exists(): @@ -254,6 +258,31 @@ def sample_database( sample_database.drop() +@pytest.fixture(scope="module") +def sample_database_for_proto_columns( + spanner_client, + sample_instance, + database_id, + database_ddl_for_proto_columns, + database_dialect, + proto_descriptor_file, +): + if database_dialect == DatabaseDialect.GOOGLE_STANDARD_SQL: + sample_database = sample_instance.database( + database_id, + ddl_statements=database_ddl_for_proto_columns, + proto_descriptors=proto_descriptor_file, + ) + + if not sample_database.exists(): + operation = sample_database.create() + operation.result(OPERATION_TIMEOUT_SECONDS) + + yield sample_database + + sample_database.drop() + + @pytest.fixture(scope="module") def kms_key_name(spanner_client): return "projects/{}/locations/{}/keyRings/{}/cryptoKeys/{}".format( diff --git a/samples/samples/snippets.py b/samples/samples/snippets.py index 0542031f96..855a78949b 100644 --- a/samples/samples/snippets.py +++ b/samples/samples/snippets.py @@ -31,7 +31,11 @@ from google.cloud import spanner from google.cloud.spanner_admin_instance_v1.types import spanner_instance_admin from google.cloud.spanner_v1 import param_types -from google.cloud.spanner_v1.data_types import JsonObject, get_proto_message, get_proto_enum +from google.cloud.spanner_v1.data_types import ( + JsonObject, + get_proto_message, + get_proto_enum, +) from google.iam.v1 import policy_pb2 from google.protobuf import field_mask_pb2 # type: ignore from google.type import expr_pb2 @@ -280,14 +284,20 @@ def create_database_with_default_leader(instance_id, database_id, default_leader # [END spanner_create_database_with_default_leader] -# [START spanner_create_database_with_proto_descriptors] -def create_database_with_proto_descriptors(instance_id, database_id): +# [START spanner_create_database_with_proto_descriptor] +def create_database_with_proto_descriptor(instance_id, database_id): """Creates a database with proto descriptors and tables with proto columns for sample data.""" + import os + + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "testdata/descriptors.pb") + spanner_client = spanner.Client() instance = spanner_client.instance(instance_id) # reads proto descriptor file as bytes - proto_descriptor_file = open("testdata/descriptors.pb", 'rb').read() + proto_descriptor_file = open(filename, "rb") + proto_descriptor = proto_descriptor_file.read() database = instance.database( database_id, @@ -296,7 +306,7 @@ def create_database_with_proto_descriptors(instance_id, database_id): spanner.examples.music.SingerInfo, spanner.examples.music.Genre, )""", - """CREATE TABLE SingersProto ( + """CREATE TABLE Singers ( SingerId INT64 NOT NULL, FirstName STRING(1024), LastName STRING(1024), @@ -306,18 +316,23 @@ def create_database_with_proto_descriptors(instance_id, database_id): SingerGenreArray ARRAY, ) PRIMARY KEY (SingerId)""", ], - proto_descriptors=proto_descriptor_file + proto_descriptors=proto_descriptor, ) operation = database.create() print("Waiting for operation to complete...") operation.result(OPERATION_TIMEOUT_SECONDS) + proto_descriptor_file.close() - print("Created database {} with proto descriptors on instance {}".format(database_id, instance_id)) + print( + "Created database {} with proto descriptors on instance {}".format( + database_id, instance_id + ) + ) -# [END spanner_create_database_with_proto_descriptors] +# [END spanner_create_database_with_proto_descriptor] # [START spanner_update_database_with_default_leader] @@ -348,14 +363,20 @@ def update_database_with_default_leader(instance_id, database_id, default_leader # [END spanner_update_database_with_default_leader] -# [START spanner_update_database_with_proto_descriptors] -def update_database_with_proto_descriptors(instance_id, database_id): +# [START spanner_update_database_with_proto_descriptor] +def update_database_with_proto_descriptor(instance_id, database_id): """Updates a database with tables with a default leader.""" + import os + + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "testdata/descriptors.pb") + spanner_client = spanner.Client() instance = spanner_client.instance(instance_id) database = instance.database(database_id) - proto_descriptor_file = open("testdata/descriptors.pb", 'rb').read() + proto_descriptor_file = open(filename, "rb") + proto_descriptor = proto_descriptor_file.read() operation = database.update_ddl( [ @@ -363,29 +384,28 @@ def update_database_with_proto_descriptors(instance_id, database_id): spanner.examples.music.SingerInfo, spanner.examples.music.Genre, )""", - """CREATE TABLE SingersProto ( + """CREATE TABLE Singers ( SingerId INT64 NOT NULL, FirstName STRING(1024), LastName STRING(1024), SingerInfo spanner.examples.music.SingerInfo, SingerGenre spanner.examples.music.Genre, + SingerInfoArray ARRAY, + SingerGenreArray ARRAY, ) PRIMARY KEY (SingerId)""", ], - proto_descriptors=proto_descriptor_file + proto_descriptors=proto_descriptor, ) print("Waiting for operation to complete...") operation.result(OPERATION_TIMEOUT_SECONDS) + proto_descriptor_file.close() database.reload() - print( - "Database {} updated with proto descriptors".format( - database.name - ) - ) + print("Database {} updated with proto descriptors".format(database.name)) -# [END spanner_update_database_with_proto_descriptors] +# [END spanner_update_database_with_proto_descriptor] # [START spanner_get_database_ddl] @@ -398,7 +418,6 @@ def get_database_ddl(instance_id, database_id): print("Retrieved database DDL for {}".format(database_id)) for statement in ddl.statements: print(statement) - print(ddl.proto_descriptors) # [END spanner_get_database_ddl] @@ -2511,8 +2530,60 @@ def enable_fine_grained_access( # [END spanner_enable_fine_grained_access] -# [START spanner_insert_proto_columns_data_with_dml] -def insert_proto_columns_data_with_dml(instance_id, database_id): +# [START spanner_insert_proto_columns_data] +def insert_proto_columns_data(instance_id, database_id): + """Inserts sample proto column data into the given database. + + The database and table must already exist and can be created using + `create_database`. + """ + spanner_client = spanner.Client(client_options={'api_endpoint':'staging-wrenchworks.sandbox.googleapis.com'}) + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + singer_info = singer_pb2.SingerInfo() + singer_info.singer_id = 2 + singer_info.birth_date = "February" + singer_info.nationality = "Country2" + singer_info.genre = singer_pb2.Genre.FOLK + + singer_info_array = [singer_info] + singer_genre_array = [singer_pb2.Genre.FOLK] + + with database.batch() as batch: + batch.insert( + table="Singers", + columns=( + "SingerId", + "FirstName", + "LastName", + "SingerInfo", + "SingerGenre", + "SingerInfoArray", + "SingerGenreArray", + ), + values=[ + ( + 2, + "Marc", + "Richards", + singer_info, + singer_pb2.Genre.ROCK, + singer_info_array, + singer_genre_array, + ), + (3, "Catalina", "Smith", None, None, None, None), + ], + ) + + print("Inserted data.") + + +# [END spanner_insert_proto_columns_data] + + +# [START spanner_insert_proto_columns_data_using_dml] +def insert_proto_columns_data_using_dml(instance_id, database_id): """Inserts sample proto column data into the given database using a DML statement.""" spanner_client = spanner.Client() instance = spanner_client.instance(instance_id) @@ -2529,21 +2600,25 @@ def insert_proto_columns_data_with_dml(instance_id, database_id): def insert_singers_with_proto_column(transaction): row_ct = transaction.execute_update( - "INSERT INTO SingersProto (SingerId, FirstName, LastName, SingerInfo, SingerGenre, SingerInfoArray," + "INSERT INTO Singers (SingerId, FirstName, LastName, SingerInfo, SingerGenre, SingerInfoArray," " SingerGenreArray) " " VALUES (1, 'Virginia', 'Watson', @singerInfo, @singerGenre, @singerInfoArray, @singerGenreArray)", params={ "singerInfo": singer_info, "singerGenre": singer_pb2.Genre.ROCK, "singerInfoArray": singer_info_array, - "singerGenreArray": singer_genre_array + "singerGenreArray": singer_genre_array, }, param_types={ "singerInfo": param_types.ProtoMessage(singer_info), "singerGenre": param_types.ProtoEnum(singer_pb2.Genre), - "singerInfoArray": param_types.Array(param_types.ProtoMessage(singer_info)), - "singerGenreArray": param_types.Array(param_types.ProtoEnum(singer_pb2.Genre)) - } + "singerInfoArray": param_types.Array( + param_types.ProtoMessage(singer_info) + ), + "singerGenreArray": param_types.Array( + param_types.ProtoEnum(singer_pb2.Genre) + ), + }, ) print("{} record(s) inserted.".format(row_ct)) @@ -2551,44 +2626,7 @@ def insert_singers_with_proto_column(transaction): database.run_in_transaction(insert_singers_with_proto_column) -# [END spanner_insert_proto_columns_data_with_dml] - - -# [START spanner_insert_proto_columns_data] -def insert_proto_columns_data(instance_id, database_id): - """Inserts sample proto column data into the given database. - - The database and table must already exist and can be created using - `create_database`. - """ - spanner_client = spanner.Client() - instance = spanner_client.instance(instance_id) - database = instance.database(database_id) - - singer_info = singer_pb2.SingerInfo() - singer_info.singer_id = 2 - singer_info.birth_date = "February" - singer_info.nationality = "Country2" - singer_info.genre = singer_pb2.Genre.FOLK - - singer_info_array = [singer_info] - singer_genre_array = [singer_pb2.Genre.FOLK] - - with database.batch() as batch: - batch.insert( - table="SingersProto", - columns=("SingerId", "FirstName", "LastName", "SingerInfo", "SingerGenre", "SingerInfoArray", - "SingerGenreArray"), - values=[ - (2, "Marc", "Richards", singer_info, singer_pb2.Genre.ROCK, singer_info_array, singer_genre_array), - (3, "Catalina", "Smith", None, None, None, None), - ], - ) - - print("Inserted data.") - - -# [END spanner_insert_proto_columns_data] +# [END spanner_insert_proto_columns_data_using_dml] # [START spanner_read_proto_columns_data] @@ -2601,73 +2639,102 @@ def read_proto_columns_data(instance_id, database_id): with database.snapshot() as snapshot: keyset = spanner.KeySet(all_=True) results = snapshot.read( - table="SingersProto", - columns=("SingerId", "FirstName", "LastName", "SingerInfo", "SingerGenre", "SingerInfoArray", "SingerGenreArray"), + table="Singers", + columns=( + "SingerId", + "FirstName", + "LastName", + "SingerInfo", + "SingerGenre", + "SingerInfoArray", + "SingerGenreArray", + ), keyset=keyset, - column_info={"SingerInfo": singer_pb2.SingerInfo(), - "SingerGenre": singer_pb2.Genre, - "SingerInfoArray": singer_pb2.SingerInfo(), - "SingerGenreArray": singer_pb2.Genre}, + column_info={ + "SingerInfo": singer_pb2.SingerInfo(), + "SingerGenre": singer_pb2.Genre, + "SingerInfoArray": singer_pb2.SingerInfo(), + "SingerGenreArray": singer_pb2.Genre, + }, ) for row in results: - print("SingerId: {}, FirstName: {}, LastName: {}, SingerInfo: {}, SingerGenre: {}, , SingerInfoArray: {}, " - "SingerGenreArray: {}".format(*row)) + print( + "SingerId: {}, FirstName: {}, LastName: {}, SingerInfo: {}, SingerGenre: {}, , SingerInfoArray: {}, " + "SingerGenreArray: {}".format(*row) + ) # [END spanner_read_proto_columns_data] -# [START spanner_read_proto_columns_data_using_helper_method] -def read_proto_columns_data_using_helper_method(instance_id, database_id): - """Reads sample proto column data from the database.""" +# [START spanner_read_proto_columns_data_using_dql] +def read_proto_columns_data_using_dql(instance_id, database_id): + """Queries sample proto column data from the database using SQL.""" spanner_client = spanner.Client() instance = spanner_client.instance(instance_id) database = instance.database(database_id) with database.snapshot() as snapshot: - keyset = spanner.KeySet(all_=True) - results = snapshot.read( - table="SingersProto", - columns=("SingerId", "FirstName", "LastName", "SingerInfo", "SingerGenre", "SingerInfoArray", "SingerGenreArray"), - keyset=keyset, + results = snapshot.execute_sql( + "SELECT SingerId, FirstName, LastName, SingerInfo, SingerGenre, SingerInfoArray, SingerGenreArray FROM Singers", + column_info={ + "SingerInfo": singer_pb2.SingerInfo(), + "SingerGenre": singer_pb2.Genre, + "SingerInfoArray": singer_pb2.SingerInfo(), + "SingerGenreArray": singer_pb2.Genre, + }, ) for row in results: - singer_info_proto_msg = get_proto_message(row[3], singer_pb2.SingerInfo()) - singer_genre_proto_enum = get_proto_enum(row[4], singer_pb2.Genre) - singer_info_list = get_proto_message(row[5], singer_pb2.SingerInfo()) - singer_genre_list = get_proto_enum(row[6], singer_pb2.Genre) - print("SingerId: {}, FirstName: {}, LastName: {}, SingerInfo: {}, SingerGenre: {}, SingerInfoArray: {}, " - "SingerGenreArray: {}".format(row[0], row[1], row[2], singer_info_proto_msg, singer_genre_proto_enum, - singer_info_list, singer_genre_list)) + print( + "SingerId: {}, FirstName: {}, LastName: {}, SingerInfo: {}, SingerGenre: {}, , SingerInfoArray: {}, " + "SingerGenreArray: {}".format(*row) + ) -# [END spanner_read_proto_columns_data_using_helper_method] +# [END spanner_read_proto_columns_data_using_dql] -# [START spanner_query_proto_columns_data] -def query_proto_columns_data(instance_id, database_id): - """Queries sample proto column data from the database using SQL.""" +def read_proto_columns_data_using_helper_method(instance_id, database_id): + """Reads sample proto column data from the database.""" spanner_client = spanner.Client() instance = spanner_client.instance(instance_id) database = instance.database(database_id) with database.snapshot() as snapshot: - results = snapshot.execute_sql( - "SELECT SingerId, FirstName, LastName, SingerInfo, SingerGenre, SingerInfoArray, SingerGenreArray FROM SingersProto", - column_info={"SingerInfo": singer_pb2.SingerInfo(), - "SingerGenre": singer_pb2.Genre, - "SingerInfoArray": singer_pb2.SingerInfo(), - "SingerGenreArray": singer_pb2.Genre}, + keyset = spanner.KeySet(all_=True) + results = snapshot.read( + table="Singers", + columns=( + "SingerId", + "FirstName", + "LastName", + "SingerInfo", + "SingerGenre", + "SingerInfoArray", + "SingerGenreArray", + ), + keyset=keyset, ) for row in results: - print("SingerId: {}, FirstName: {}, LastName: {}, SingerInfo: {}, SingerGenre: {}, , SingerInfoArray: {}, " - "SingerGenreArray: {}".format(*row)) - - -# [END spanner_query_proto_columns_data] + singer_info_proto_msg = get_proto_message(row[3], singer_pb2.SingerInfo()) + singer_genre_proto_enum = get_proto_enum(row[4], singer_pb2.Genre) + singer_info_list = get_proto_message(row[5], singer_pb2.SingerInfo()) + singer_genre_list = get_proto_enum(row[6], singer_pb2.Genre) + print( + "SingerId: {}, FirstName: {}, LastName: {}, SingerInfo: {}, SingerGenre: {}, SingerInfoArray: {}, " + "SingerGenreArray: {}".format( + row[0], + row[1], + row[2], + singer_info_proto_msg, + singer_genre_proto_enum, + singer_info_list, + singer_genre_list, + ) + ) if __name__ == "__main__": # noqa: C901 @@ -2682,7 +2749,6 @@ def query_proto_columns_data(instance_id, database_id): subparsers = parser.add_subparsers(dest="command") subparsers.add_parser("create_instance", help=create_instance.__doc__) subparsers.add_parser("create_database", help=create_database.__doc__) - subparsers.add_parser("create_database_with_proto_descriptors", help=create_database_with_proto_descriptors.__doc__) subparsers.add_parser("get_database_ddl", help=get_database_ddl.__doc__) subparsers.add_parser("insert_data", help=insert_data.__doc__) subparsers.add_parser("delete_data", help=delete_data.__doc__) @@ -2788,13 +2854,28 @@ def query_proto_columns_data(instance_id, database_id): "read_data_with_database_role", help=read_data_with_database_role.__doc__ ) subparsers.add_parser("list_database_roles", help=list_database_roles.__doc__) - subparsers.add_parser("insert_proto_columns_data_with_dml", help=insert_proto_columns_data_with_dml.__doc__) - subparsers.add_parser("insert_proto_columns_data", help=insert_proto_columns_data.__doc__) - subparsers.add_parser("read_proto_columns_data", help=read_proto_columns_data.__doc__) subparsers.add_parser( - "read_proto_columns_data_using_helper_method", help=read_proto_columns_data_using_helper_method.__doc__ + "create_database_with_proto_descriptor", + help=create_database_with_proto_descriptor.__doc__, + ) + subparsers.add_parser( + "insert_proto_columns_data_using_dml", + help=insert_proto_columns_data_using_dml.__doc__, + ) + subparsers.add_parser( + "insert_proto_columns_data", help=insert_proto_columns_data.__doc__ + ) + subparsers.add_parser( + "read_proto_columns_data", help=read_proto_columns_data.__doc__ + ) + subparsers.add_parser( + "read_proto_columns_data_using_helper_method", + help=read_proto_columns_data_using_helper_method.__doc__, + ) + subparsers.add_parser( + "read_proto_columns_data_using_dql", + help=read_proto_columns_data_using_dql.__doc__, ) - subparsers.add_parser("query_proto_columns_data", help=query_proto_columns_data.__doc__) enable_fine_grained_access_parser = subparsers.add_parser( "enable_fine_grained_access", help=enable_fine_grained_access.__doc__ ) @@ -2812,8 +2893,6 @@ def query_proto_columns_data(instance_id, database_id): create_instance(args.instance_id) elif args.command == "create_database": create_database(args.instance_id, args.database_id) - elif args.command == "create_database_with_proto_descriptors": - create_database_with_proto_descriptors(args.instance_id, args.database_id) elif args.command == "get_database_ddl": get_database_ddl(args.instance_id, args.database_id) elif args.command == "insert_data": @@ -2938,13 +3017,15 @@ def query_proto_columns_data(instance_id, database_id): args.database_role, args.title, ) - elif args.command == "insert_proto_columns_data_with_dml": - insert_proto_columns_data_with_dml(args.instance_id, args.database_id) + elif args.command == "create_database_with_proto_descriptor": + create_database_with_proto_descriptor(args.instance_id, args.database_id) + elif args.command == "insert_proto_columns_data_using_dml": + insert_proto_columns_data_using_dml(args.instance_id, args.database_id) elif args.command == "insert_proto_columns_data": insert_proto_columns_data(args.instance_id, args.database_id) elif args.command == "read_proto_columns_data": read_proto_columns_data(args.instance_id, args.database_id) elif args.command == "read_proto_columns_data_using_helper_method": read_proto_columns_data_using_helper_method(args.instance_id, args.database_id) - elif args.command == "query_proto_columns_data": - query_proto_columns_data(args.instance_id, args.database_id) + elif args.command == "read_proto_columns_data_using_dql": + read_proto_columns_data_using_dql(args.instance_id, args.database_id) diff --git a/samples/samples/snippets_test.py b/samples/samples/snippets_test.py index af80e0b535..90834d5339 100644 --- a/samples/samples/snippets_test.py +++ b/samples/samples/snippets_test.py @@ -45,13 +45,13 @@ """ CREATE_TABLE_SINGERS_PROTO = """\ -CREATE TABLE SingersProto ( -SingerId INT64 NOT NULL, -FirstName STRING(1024), -LastName STRING(1024), -SingerInfo spanner.examples.music.SingerInfo, -SingerGenre spanner.examples.music.Genre, -SingerInfoArray ARRAY, +CREATE TABLE Singers ( +SingerId INT64 NOT NULL, +FirstName STRING(1024), +LastName STRING(1024), +SingerInfo spanner.examples.music.SingerInfo, +SingerGenre spanner.examples.music.Genre, +SingerInfoArray ARRAY, SingerGenreArray ARRAY, ) PRIMARY KEY (SingerId) """ @@ -119,7 +119,16 @@ def database_ddl(): Sample testcase modules can override as needed. """ - return [CREATE_TABLE_SINGERS, CREATE_TABLE_ALBUMS, CREATE_PROTO_BUNDLE, CREATE_TABLE_SINGERS_PROTO] + return [CREATE_TABLE_SINGERS, CREATE_TABLE_ALBUMS] + + +@pytest.fixture(scope="module") +def database_ddl_for_proto_columns(): + """Sequence of DDL statements used to set up the database for proto columns. + + Sample testcase modules can override as needed. + """ + return [CREATE_PROTO_BUNDLE, CREATE_TABLE_SINGERS_PROTO] @pytest.fixture(scope="module") @@ -184,8 +193,8 @@ def test_create_database_with_encryption_config( assert kms_key_name in out -def test_create_database_with_proto_descriptors(capsys, instance_id, database_id): - snippets.create_database_with_proto_descriptors(instance_id, database_id) +def test_create_database_with_proto_descriptor(capsys, instance_id, database_id): + snippets.create_database_with_proto_descriptor(instance_id, database_id) out, _ = capsys.readouterr() assert database_id in out assert instance_id in out @@ -809,23 +818,37 @@ def test_list_database_roles(capsys, instance_id, sample_database): assert "new_parent" in out +def test_update_database_with_proto_descriptor(capsys, sample_instance, create_database_id): + # We have to create a new database here as proto samples also have Singers table and this will clash. + sample_instance.database(create_database_id).create().result(240) + snippets.update_database_with_proto_descriptor(sample_instance.instance_id, create_database_id) + out, _ = capsys.readouterr() + assert "updated with proto descriptors" in out + database = sample_instance.database(create_database_id) + database.drop() + + @pytest.mark.dependency(name="insert_proto_columns_data_dml") -def test_insert_proto_columns_data_with_dml(capsys, instance_id, sample_database): - snippets.insert_proto_columns_data_with_dml(instance_id, sample_database.database_id) +def test_insert_proto_columns_data_using_dml(capsys, instance_id, sample_database_for_proto_columns): + snippets.insert_proto_columns_data_using_dml( + instance_id, sample_database_for_proto_columns.database_id + ) out, _ = capsys.readouterr() assert "record(s) inserted" in out @pytest.mark.dependency(name="insert_proto_columns_data") -def test_insert_proto_columns_data(capsys, instance_id, sample_database): - snippets.insert_proto_columns_data(instance_id, sample_database.database_id) +def test_insert_proto_columns_data(capsys, instance_id, sample_database_for_proto_columns): + snippets.insert_proto_columns_data(instance_id, sample_database_for_proto_columns.database_id) out, _ = capsys.readouterr() assert "Inserted data" in out -@pytest.mark.dependency(depends=["insert_proto_columns_data_dml, insert_proto_columns_data"]) -def test_query_proto_columns_data(capsys, instance_id, sample_database): - snippets.query_proto_columns_data(instance_id, sample_database.database_id) +@pytest.mark.dependency( + depends=["insert_proto_columns_data_dml, insert_proto_columns_data"] +) +def test_read_proto_columns_data_using_dql(capsys, instance_id, sample_database_for_proto_columns): + snippets.read_proto_columns_data_using_dql(instance_id, sample_database_for_proto_columns.database_id) out, _ = capsys.readouterr() assert "SingerId: 1, FirstName: Virginia, LastName: Watson" in out @@ -833,9 +856,11 @@ def test_query_proto_columns_data(capsys, instance_id, sample_database): assert "SingerId: 3, FirstName: Catalina, LastName: Smith" in out -@pytest.mark.dependency(depends=["insert_proto_columns_data_dml, insert_proto_columns_data"]) -def test_read_proto_columns_data(capsys, instance_id, sample_database): - snippets.read_proto_columns_data(instance_id, sample_database.database_id) +@pytest.mark.dependency( + depends=["insert_proto_columns_data_dml, insert_proto_columns_data"] +) +def test_read_proto_columns_data(capsys, instance_id, sample_database_for_proto_columns): + snippets.read_proto_columns_data(instance_id, sample_database_for_proto_columns.database_id) out, _ = capsys.readouterr() assert "SingerId: 1, FirstName: Virginia, LastName: Watson" in out @@ -843,9 +868,15 @@ def test_read_proto_columns_data(capsys, instance_id, sample_database): assert "SingerId: 3, FirstName: Catalina, LastName: Smith" in out -@pytest.mark.dependency(depends=["insert_proto_columns_data_dml, insert_proto_columns_data"]) -def test_read_proto_columns_data_using_helper_method(capsys, instance_id, sample_database): - snippets.read_proto_columns_data_using_helper_method(instance_id, sample_database.database_id) +@pytest.mark.dependency( + depends=["insert_proto_columns_data_dml, insert_proto_columns_data"] +) +def test_read_proto_columns_data_using_helper_method( + capsys, instance_id, sample_database_for_proto_columns +): + snippets.read_proto_columns_data_using_helper_method( + instance_id, sample_database_for_proto_columns.database_id + ) out, _ = capsys.readouterr() assert "SingerId: 1, FirstName: Virginia, LastName: Watson" in out diff --git a/setup.py b/setup.py index 86f2203d20..650b452838 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ "grpc-google-iam-v1 >= 0.12.4, <1.0.0dev", "proto-plus >= 1.22.0, <2.0.0dev", "sqlparse >= 0.3.0", - "protobuf>=3.19.5,<5.0.0dev,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", + "protobuf>=3.20.2,<5.0.0dev,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", ] extras = { "tracing": [ diff --git a/testing/constraints-3.7.txt b/testing/constraints-3.7.txt index e061a1eadf..cd64ca21f9 100644 --- a/testing/constraints-3.7.txt +++ b/testing/constraints-3.7.txt @@ -13,4 +13,4 @@ sqlparse==0.3.0 opentelemetry-api==1.1.0 opentelemetry-sdk==1.1.0 opentelemetry-instrumentation==0.20b0 -protobuf==3.19.5 +protobuf==3.20.2 diff --git a/tests/_fixtures.py b/tests/_fixtures.py index 0bd8fe163a..62616c6969 100644 --- a/tests/_fixtures.py +++ b/tests/_fixtures.py @@ -28,6 +28,10 @@ phone_number STRING(1024) ) PRIMARY KEY (contact_id, phone_type), INTERLEAVE IN PARENT contacts ON DELETE CASCADE; +CREATE PROTO BUNDLE ( + spanner.examples.music.SingerInfo, + spanner.examples.music.Genre, + ); CREATE TABLE all_types ( pkey INT64 NOT NULL, int_value INT64, @@ -48,6 +52,10 @@ numeric_array ARRAY, json_value JSON, json_array ARRAY, + proto_message_value spanner.examples.music.SingerInfo, + proto_message_array ARRAY, + proto_enum_value spanner.examples.music.Genre, + proto_enum_array ARRAY, ) PRIMARY KEY (pkey); CREATE TABLE counters ( @@ -159,8 +167,22 @@ CREATE INDEX name ON contacts(first_name, last_name); """ +PROTO_COLUMNS_DDL = """\ +CREATE TABLE singers ( + singer_id INT64 NOT NULL, + first_name STRING(1024), + last_name STRING(1024), + singer_info spanner.examples.music.SingerInfo, + singer_genre spanner.examples.music.Genre, ) + PRIMARY KEY (singer_id); +CREATE INDEX SingerByGenre ON singers(singer_genre) STORING (first_name, last_name); +""" + DDL_STATEMENTS = [stmt.strip() for stmt in DDL.split(";") if stmt.strip()] EMULATOR_DDL_STATEMENTS = [ stmt.strip() for stmt in EMULATOR_DDL.split(";") if stmt.strip() ] PG_DDL_STATEMENTS = [stmt.strip() for stmt in PG_DDL.split(";") if stmt.strip()] +PROTO_COLUMNS_DDL_STATEMENTS = [ + stmt.strip() for stmt in PROTO_COLUMNS_DDL.split(";") if stmt.strip() +] diff --git a/tests/system/_helpers.py b/tests/system/_helpers.py index 60926b216e..b62d453512 100644 --- a/tests/system/_helpers.py +++ b/tests/system/_helpers.py @@ -65,6 +65,8 @@ ) ) +PROTO_COLUMNS_DDL_STATEMENTS = _fixtures.PROTO_COLUMNS_DDL_STATEMENTS + retry_true = retry.RetryResult(operator.truth) retry_false = retry.RetryResult(operator.not_) diff --git a/tests/system/_sample_data.py b/tests/system/_sample_data.py index a7f3b80a86..f7f23fc5d2 100644 --- a/tests/system/_sample_data.py +++ b/tests/system/_sample_data.py @@ -18,7 +18,7 @@ from google.api_core import datetime_helpers from google.cloud._helpers import UTC from google.cloud import spanner_v1 - +from samples.samples.testdata import singer_pb2 TABLE = "contacts" COLUMNS = ("contact_id", "first_name", "last_name", "email") @@ -33,6 +33,31 @@ COUNTERS_TABLE = "counters" COUNTERS_COLUMNS = ("name", "value") +SINGERS_PROTO_TABLE = "singers" +SINGERS_PROTO_COLUMNS = ( + "singer_id", + "first_name", + "last_name", + "singer_info", + "singer_genre", +) +SINGER_INFO_1 = singer_pb2.SingerInfo() +SINGER_GENRE_1 = singer_pb2.Genre.ROCK +SINGER_INFO_1.singer_id = 1 +SINGER_INFO_1.birth_date = "January" +SINGER_INFO_1.nationality = "Country1" +SINGER_INFO_1.genre = SINGER_GENRE_1 +SINGER_INFO_2 = singer_pb2.SingerInfo() +SINGER_GENRE_2 = singer_pb2.Genre.FOLK +SINGER_INFO_2.singer_id = 2 +SINGER_INFO_2.birth_date = "February" +SINGER_INFO_2.nationality = "Country2" +SINGER_INFO_2.genre = SINGER_GENRE_2 +SINGERS_PROTO_ROW_DATA = ( + (1, "Singer1", "Singer1", SINGER_INFO_1, SINGER_GENRE_1), + (2, "Singer2", "Singer2", SINGER_INFO_2, SINGER_GENRE_2), +) + def _assert_timestamp(value, nano_value): assert isinstance(value, datetime.datetime) diff --git a/tests/system/conftest.py b/tests/system/conftest.py index fdeab14c8f..62b06019f5 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -74,6 +74,17 @@ def database_dialect(): ) +@pytest.fixture(scope="session") +def proto_descriptor_file(): + import os + + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "testdata/descriptors.pb") + file = open(filename, "rb") + yield file.read() + file.close() + + @pytest.fixture(scope="session") def spanner_client(): if _helpers.USE_EMULATOR: @@ -177,7 +188,9 @@ def shared_instance( @pytest.fixture(scope="session") -def shared_database(shared_instance, database_operation_timeout, database_dialect): +def shared_database( + shared_instance, database_operation_timeout, database_dialect, proto_descriptor_file +): database_name = _helpers.unique_id("test_database") pool = spanner_v1.BurstyPool(labels={"testcase": "database_api"}) if database_dialect == DatabaseDialect.POSTGRESQL: @@ -198,6 +211,7 @@ def shared_database(shared_instance, database_operation_timeout, database_dialec ddl_statements=_helpers.DDL_STATEMENTS, pool=pool, database_dialect=database_dialect, + proto_descriptors=proto_descriptor_file, ) operation = database.create() operation.result(database_operation_timeout) # raises on failure / timeout. diff --git a/tests/system/test_backup_api.py b/tests/system/test_backup_api.py index dc80653786..6ffc74283e 100644 --- a/tests/system/test_backup_api.py +++ b/tests/system/test_backup_api.py @@ -94,7 +94,9 @@ def database_version_time(shared_database): @pytest.fixture(scope="session") -def second_database(shared_instance, database_operation_timeout, database_dialect): +def second_database( + shared_instance, database_operation_timeout, database_dialect, proto_descriptor_file +): database_name = _helpers.unique_id("test_database2") pool = spanner_v1.BurstyPool(labels={"testcase": "database_api"}) if database_dialect == DatabaseDialect.POSTGRESQL: @@ -115,6 +117,7 @@ def second_database(shared_instance, database_operation_timeout, database_dialec ddl_statements=_helpers.DDL_STATEMENTS, pool=pool, database_dialect=database_dialect, + proto_descriptors=proto_descriptor_file, ) operation = database.create() operation.result(database_operation_timeout) # raises on failure / timeout. diff --git a/tests/system/test_database_api.py b/tests/system/test_database_api.py index 364c159da5..2108667c7e 100644 --- a/tests/system/test_database_api.py +++ b/tests/system/test_database_api.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import time import uuid @@ -75,7 +74,11 @@ def test_create_database(shared_instance, databases_to_delete, database_dialect) def test_database_binding_of_fixed_size_pool( - not_emulator, shared_instance, databases_to_delete, not_postgres + not_emulator, + shared_instance, + databases_to_delete, + not_postgres, + proto_descriptor_file, ): temp_db_id = _helpers.unique_id("fixed_size_db", separator="_") temp_db = shared_instance.database(temp_db_id) @@ -89,7 +92,9 @@ def test_database_binding_of_fixed_size_pool( "CREATE ROLE parent", "GRANT SELECT ON TABLE contacts TO ROLE parent", ] - operation = temp_db.update_ddl(ddl_statements) + operation = temp_db.update_ddl( + ddl_statements, proto_descriptors=proto_descriptor_file + ) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. pool = FixedSizePool( @@ -102,7 +107,11 @@ def test_database_binding_of_fixed_size_pool( def test_database_binding_of_pinging_pool( - not_emulator, shared_instance, databases_to_delete, not_postgres + not_emulator, + shared_instance, + databases_to_delete, + not_postgres, + proto_descriptor_file, ): temp_db_id = _helpers.unique_id("binding_db", separator="_") temp_db = shared_instance.database(temp_db_id) @@ -116,7 +125,9 @@ def test_database_binding_of_pinging_pool( "CREATE ROLE parent", "GRANT SELECT ON TABLE contacts TO ROLE parent", ] - operation = temp_db.update_ddl(ddl_statements) + operation = temp_db.update_ddl( + ddl_statements, proto_descriptors=proto_descriptor_file + ) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. pool = PingingPool( @@ -291,7 +302,7 @@ def test_table_not_found(shared_instance): def test_update_ddl_w_operation_id( - shared_instance, databases_to_delete, database_dialect + shared_instance, databases_to_delete, database_dialect, proto_descriptor_file ): # We used to have: # @pytest.mark.skip( @@ -309,7 +320,11 @@ def test_update_ddl_w_operation_id( # random but shortish always start with letter operation_id = f"a{str(uuid.uuid4())[:8]}" - operation = temp_db.update_ddl(_helpers.DDL_STATEMENTS, operation_id=operation_id) + operation = temp_db.update_ddl( + _helpers.DDL_STATEMENTS, + operation_id=operation_id, + proto_descriptors=proto_descriptor_file, + ) assert operation_id == operation.operation.name.split("/")[-1] @@ -325,6 +340,7 @@ def test_update_ddl_w_pitr_invalid( not_postgres, shared_instance, databases_to_delete, + proto_descriptor_file, ): pool = spanner_v1.BurstyPool(labels={"testcase": "update_database_ddl_pitr"}) temp_db_id = _helpers.unique_id("pitr_upd_ddl_inv", separator="_") @@ -342,7 +358,7 @@ def test_update_ddl_w_pitr_invalid( f" SET OPTIONS (version_retention_period = '{retention_period}')" ] with pytest.raises(exceptions.InvalidArgument): - temp_db.update_ddl(ddl_statements) + temp_db.update_ddl(ddl_statements, proto_descriptors=proto_descriptor_file) def test_update_ddl_w_pitr_success( @@ -350,6 +366,7 @@ def test_update_ddl_w_pitr_success( not_postgres, shared_instance, databases_to_delete, + proto_descriptor_file, ): pool = spanner_v1.BurstyPool(labels={"testcase": "update_database_ddl_pitr"}) temp_db_id = _helpers.unique_id("pitr_upd_ddl_inv", separator="_") @@ -366,7 +383,9 @@ def test_update_ddl_w_pitr_success( f"ALTER DATABASE {temp_db_id}" f" SET OPTIONS (version_retention_period = '{retention_period}')" ] - operation = temp_db.update_ddl(ddl_statements) + operation = temp_db.update_ddl( + ddl_statements, proto_descriptors=proto_descriptor_file + ) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. temp_db.reload() @@ -379,6 +398,7 @@ def test_update_ddl_w_default_leader_success( not_postgres, multiregion_instance, databases_to_delete, + proto_descriptor_file, ): pool = spanner_v1.BurstyPool( labels={"testcase": "update_database_ddl_default_leader"}, @@ -398,7 +418,9 @@ def test_update_ddl_w_default_leader_success( f"ALTER DATABASE {temp_db_id}" f" SET OPTIONS (default_leader = '{default_leader}')" ] - operation = temp_db.update_ddl(ddl_statements) + operation = temp_db.update_ddl( + ddl_statements, proto_descriptors=proto_descriptor_file + ) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. temp_db.reload() @@ -411,6 +433,7 @@ def test_create_role_grant_access_success( shared_instance, databases_to_delete, not_postgres, + proto_descriptor_file, ): creator_role_parent = _helpers.unique_id("role_parent", separator="_") creator_role_orphan = _helpers.unique_id("role_orphan", separator="_") @@ -428,7 +451,9 @@ def test_create_role_grant_access_success( f"CREATE ROLE {creator_role_orphan}", f"GRANT SELECT ON TABLE contacts TO ROLE {creator_role_parent}", ] - operation = temp_db.update_ddl(ddl_statements) + operation = temp_db.update_ddl( + ddl_statements, proto_descriptors=proto_descriptor_file + ) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. # Perform select with orphan role on table contacts. @@ -460,6 +485,7 @@ def test_list_database_role_success( shared_instance, databases_to_delete, not_postgres, + proto_descriptor_file, ): creator_role_parent = _helpers.unique_id("role_parent", separator="_") creator_role_orphan = _helpers.unique_id("role_orphan", separator="_") @@ -476,7 +502,9 @@ def test_list_database_role_success( f"CREATE ROLE {creator_role_parent}", f"CREATE ROLE {creator_role_orphan}", ] - operation = temp_db.update_ddl(ddl_statements) + operation = temp_db.update_ddl( + ddl_statements, proto_descriptors=proto_descriptor_file + ) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. # List database roles. @@ -562,3 +590,30 @@ def _unit_of_work(transaction, name): rows = list(after.read(sd.COUNTERS_TABLE, sd.COUNTERS_COLUMNS, sd.ALL)) assert len(rows) == 2 + + +def test_create_table_with_proto_columns( + not_emulator, + not_postgres, + shared_instance, + databases_to_delete, + proto_descriptor_file, +): + proto_cols_db_id = _helpers.unique_id("proto-columns") + extra_ddl = [ + "CREATE PROTO BUNDLE (spanner.examples.music.SingerInfo, spanner.examples.music.Genre,)" + ] + + proto_cols_database = shared_instance.database( + proto_cols_db_id, + ddl_statements=extra_ddl + _helpers.PROTO_COLUMNS_DDL_STATEMENTS, + proto_descriptors=proto_descriptor_file, + ) + operation = proto_cols_database.create() + operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + databases_to_delete.append(proto_cols_database) + + proto_cols_database.reload() + assert proto_cols_database.proto_descriptors is not None + assert any("PROTO BUNDLE" in stmt for stmt in proto_cols_database.ddl_statements) diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 6b7afbe525..8b00073567 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import base64 import collections import datetime import decimal @@ -29,6 +29,7 @@ from google.cloud.spanner_admin_database_v1 import DatabaseDialect from google.cloud._helpers import UTC from google.cloud.spanner_v1.data_types import JsonObject +from samples.samples.testdata import singer_pb2 from tests import _helpers as ot_helpers from . import _helpers from . import _sample_data @@ -57,6 +58,8 @@ JSON_2 = JsonObject( {"sample_object": {"name": "Anamika", "id": 2635}}, ) +SINGER_INFO = _sample_data.SINGER_INFO_1 +SINGER_GENRE = _sample_data.SINGER_GENRE_1 COUNTERS_TABLE = "counters" COUNTERS_COLUMNS = ("name", "value") @@ -81,6 +84,10 @@ "numeric_array", "json_value", "json_array", + "proto_message_value", + "proto_message_array", + "proto_enum_value", + "proto_enum_array", ) EMULATOR_ALL_TYPES_COLUMNS = LIVE_ALL_TYPES_COLUMNS[:-4] @@ -120,6 +127,8 @@ AllTypesRowData(pkey=109, numeric_value=NUMERIC_1), AllTypesRowData(pkey=110, json_value=JSON_1), AllTypesRowData(pkey=111, json_value=JsonObject([JSON_1, JSON_2])), + AllTypesRowData(pkey=112, proto_message_value=SINGER_INFO), + AllTypesRowData(pkey=113, proto_enum_value=SINGER_GENRE), # empty array values AllTypesRowData(pkey=201, int_array=[]), AllTypesRowData(pkey=202, bool_array=[]), @@ -130,6 +139,8 @@ AllTypesRowData(pkey=207, timestamp_array=[]), AllTypesRowData(pkey=208, numeric_array=[]), AllTypesRowData(pkey=209, json_array=[]), + AllTypesRowData(pkey=210, proto_message_array=[]), + AllTypesRowData(pkey=211, proto_enum_array=[]), # non-empty array values, including nulls AllTypesRowData(pkey=301, int_array=[123, 456, None]), AllTypesRowData(pkey=302, bool_array=[True, False, None]), @@ -142,6 +153,8 @@ AllTypesRowData(pkey=307, timestamp_array=[SOME_TIME, NANO_TIME, None]), AllTypesRowData(pkey=308, numeric_array=[NUMERIC_1, NUMERIC_2, None]), AllTypesRowData(pkey=309, json_array=[JSON_1, JSON_2, None]), + AllTypesRowData(pkey=310, proto_message_array=[SINGER_INFO, None]), + AllTypesRowData(pkey=311, proto_enum_array=[SINGER_GENRE, None]), ) EMULATOR_ALL_TYPES_ROWDATA = ( # all nulls @@ -221,9 +234,16 @@ ALL_TYPES_COLUMNS = LIVE_ALL_TYPES_COLUMNS ALL_TYPES_ROWDATA = LIVE_ALL_TYPES_ROWDATA +COLUMN_INFO = { + "proto_message_value": singer_pb2.SingerInfo(), + "proto_message_array": singer_pb2.SingerInfo(), +} + @pytest.fixture(scope="session") -def sessions_database(shared_instance, database_operation_timeout, database_dialect): +def sessions_database( + shared_instance, database_operation_timeout, database_dialect, proto_descriptor_file +): database_name = _helpers.unique_id("test_sessions", separator="_") pool = spanner_v1.BurstyPool(labels={"testcase": "session_api"}) @@ -245,6 +265,7 @@ def sessions_database(shared_instance, database_operation_timeout, database_dial database_name, ddl_statements=_helpers.DDL_STATEMENTS, pool=pool, + proto_descriptors=proto_descriptor_file, ) operation = sessions_database.create() @@ -459,7 +480,11 @@ def test_batch_insert_then_read_all_datatypes(sessions_database): batch.insert(ALL_TYPES_TABLE, ALL_TYPES_COLUMNS, ALL_TYPES_ROWDATA) with sessions_database.snapshot(read_timestamp=batch.committed) as snapshot: - rows = list(snapshot.read(ALL_TYPES_TABLE, ALL_TYPES_COLUMNS, sd.ALL)) + rows = list( + snapshot.read( + ALL_TYPES_TABLE, ALL_TYPES_COLUMNS, sd.ALL, column_info=COLUMN_INFO + ) + ) sd._check_rows_data(rows, expected=ALL_TYPES_ROWDATA) @@ -1315,6 +1340,21 @@ def _unit_of_work(transaction): return committed +def _set_up_proto_table(database): + + sd = _sample_data + + def _unit_of_work(transaction): + transaction.delete(sd.SINGERS_PROTO_TABLE, sd.ALL) + transaction.insert( + sd.SINGERS_PROTO_TABLE, sd.SINGERS_PROTO_COLUMNS, sd.SINGERS_PROTO_ROW_DATA + ) + + committed = database.run_in_transaction(_unit_of_work) + + return committed + + def test_read_with_single_keys_index(sessions_database): # [START spanner_test_single_key_index_read] sd = _sample_data @@ -1464,7 +1504,11 @@ def test_multiuse_snapshot_read_isolation_exact_staleness(sessions_database): def test_read_w_index( - shared_instance, database_operation_timeout, databases_to_delete, database_dialect + shared_instance, + database_operation_timeout, + databases_to_delete, + database_dialect, + proto_descriptor_file, ): # Indexed reads cannot return non-indexed columns sd = _sample_data @@ -1492,9 +1536,12 @@ def test_read_w_index( else: temp_db = shared_instance.database( _helpers.unique_id("test_read", separator="_"), - ddl_statements=_helpers.DDL_STATEMENTS + extra_ddl, + ddl_statements=_helpers.DDL_STATEMENTS + + extra_ddl + + _helpers.PROTO_COLUMNS_DDL_STATEMENTS, pool=pool, database_dialect=database_dialect, + proto_descriptors=proto_descriptor_file, ) operation = temp_db.create() operation.result(database_operation_timeout) # raises on failure / timeout. @@ -1510,6 +1557,28 @@ def test_read_w_index( expected = list(reversed([(row[0], row[2]) for row in _row_data(row_count)])) sd._check_rows_data(rows, expected) + # Test indexes on proto column types + if database_dialect == DatabaseDialect.GOOGLE_STANDARD_SQL: + # Indexed reads cannot return non-indexed columns + my_columns = ( + sd.SINGERS_PROTO_COLUMNS[0], + sd.SINGERS_PROTO_COLUMNS[1], + sd.SINGERS_PROTO_COLUMNS[4], + ) + committed = _set_up_proto_table(temp_db) + with temp_db.snapshot(read_timestamp=committed) as snapshot: + rows = list( + snapshot.read( + sd.SINGERS_PROTO_TABLE, + my_columns, + spanner_v1.KeySet(keys=[[singer_pb2.Genre.ROCK]]), + index="SingerByGenre", + ) + ) + row = sd.SINGERS_PROTO_ROW_DATA[0] + expected = list([(row[0], row[1], row[4])]) + sd._check_rows_data(rows, expected) + def test_read_w_single_key(sessions_database): # [START spanner_test_single_key_read] @@ -1922,12 +1991,17 @@ def _check_sql_results( expected, order=True, recurse_into_lists=True, + column_info=None, ): if order and "ORDER" not in sql: sql += " ORDER BY pkey" with database.snapshot() as snapshot: - rows = list(snapshot.execute_sql(sql, params=params, param_types=param_types)) + rows = list( + snapshot.execute_sql( + sql, params=params, param_types=param_types, column_info=column_info + ) + ) _sample_data._check_rows_data( rows, expected=expected, recurse_into_lists=recurse_into_lists @@ -2023,32 +2097,39 @@ def _bind_test_helper( array_value, expected_array_value=None, recurse_into_lists=True, + column_info=None, + expected_single_value=None, ): database.snapshot(multi_use=True) key = "p1" if database_dialect == DatabaseDialect.POSTGRESQL else "v" placeholder = "$1" if database_dialect == DatabaseDialect.POSTGRESQL else f"@{key}" + if expected_single_value is None: + expected_single_value = single_value + # Bind a non-null _check_sql_results( database, - sql=f"SELECT {placeholder}", + sql=f"SELECT {placeholder} as column", params={key: single_value}, param_types={key: param_type}, - expected=[(single_value,)], + expected=[(expected_single_value,)], order=False, recurse_into_lists=recurse_into_lists, + column_info=column_info, ) # Bind a null _check_sql_results( database, - sql=f"SELECT {placeholder}", + sql=f"SELECT {placeholder} as column", params={key: None}, param_types={key: param_type}, expected=[(None,)], order=False, recurse_into_lists=recurse_into_lists, + column_info=column_info, ) # Bind an array of @@ -2062,34 +2143,37 @@ def _bind_test_helper( _check_sql_results( database, - sql=f"SELECT {placeholder}", + sql=f"SELECT {placeholder} as column", params={key: array_value}, param_types={key: array_type}, expected=[(expected_array_value,)], order=False, recurse_into_lists=recurse_into_lists, + column_info=column_info, ) # Bind an empty array of _check_sql_results( database, - sql=f"SELECT {placeholder}", + sql=f"SELECT {placeholder} as column", params={key: []}, param_types={key: array_type}, expected=[([],)], order=False, recurse_into_lists=recurse_into_lists, + column_info=column_info, ) # Bind a null array of _check_sql_results( database, - sql=f"SELECT {placeholder}", + sql=f"SELECT {placeholder} as column", params={key: None}, param_types={key: array_type}, expected=[(None,)], order=False, recurse_into_lists=recurse_into_lists, + column_info=column_info, ) @@ -2457,6 +2541,80 @@ def test_execute_sql_w_query_param_struct(sessions_database, not_postgres): ) +def test_execute_sql_w_proto_message_bindings( + not_emulator, not_postgres, sessions_database, database_dialect +): + singer_info = _sample_data.SINGER_INFO_1 + singer_info_bytes = base64.b64encode(singer_info.SerializeToString()) + + _bind_test_helper( + sessions_database, + database_dialect, + spanner_v1.param_types.ProtoMessage(singer_info), + singer_info, + [singer_info, None], + column_info={"column": singer_pb2.SingerInfo()}, + ) + + # Tests compatibility between proto message and bytes column types + _bind_test_helper( + sessions_database, + database_dialect, + spanner_v1.param_types.ProtoMessage(singer_info), + singer_info_bytes, + [singer_info_bytes, None], + expected_single_value=singer_info, + expected_array_value=[singer_info, None], + column_info={"column": singer_pb2.SingerInfo()}, + ) + + # Tests compatibility between proto message and bytes column types + _bind_test_helper( + sessions_database, + database_dialect, + spanner_v1.param_types.BYTES, + singer_info, + [singer_info, None], + expected_single_value=singer_info_bytes, + expected_array_value=[singer_info_bytes, None], + ) + + +def test_execute_sql_w_proto_enum_bindings( + not_emulator, not_postgres, sessions_database, database_dialect +): + singer_genre = _sample_data.SINGER_GENRE_1 + + _bind_test_helper( + sessions_database, + database_dialect, + spanner_v1.param_types.ProtoEnum(singer_pb2.Genre), + singer_genre, + [singer_genre, None], + ) + + # Tests compatibility between proto enum and int64 column types + _bind_test_helper( + sessions_database, + database_dialect, + spanner_v1.param_types.ProtoEnum(singer_pb2.Genre), + 3, + [3, None], + expected_single_value="ROCK", + expected_array_value=["ROCK", None], + column_info={"column": singer_pb2.Genre}, + ) + + # Tests compatibility between proto enum and int64 column types + _bind_test_helper( + sessions_database, + database_dialect, + spanner_v1.param_types.INT64, + singer_genre, + [singer_genre, None], + ) + + def test_execute_sql_returning_transfinite_floats(sessions_database, not_postgres): with sessions_database.snapshot(multi_use=True) as snapshot: diff --git a/tests/system/testdata/descriptors.pb b/tests/system/testdata/descriptors.pb new file mode 100644 index 0000000000..3ebb79420b Binary files /dev/null and b/tests/system/testdata/descriptors.pb differ diff --git a/tests/unit/test__helpers.py b/tests/unit/test__helpers.py index 21434da191..b695f42564 100644 --- a/tests/unit/test__helpers.py +++ b/tests/unit/test__helpers.py @@ -313,6 +313,25 @@ def test_w_json_None(self): value_pb = self._callFUT(value) self.assertTrue(value_pb.HasField("null_value")) + def test_w_proto_message(self): + from google.protobuf.struct_pb2 import Value + import base64 + from samples.samples.testdata import singer_pb2 + + singer_info = singer_pb2.SingerInfo() + expected = Value(string_value=base64.b64encode(singer_info.SerializeToString())) + value_pb = self._callFUT(singer_info) + self.assertIsInstance(value_pb, Value) + self.assertEqual(value_pb, expected) + + def test_w_proto_enum(self): + from google.protobuf.struct_pb2 import Value + from samples.samples.testdata import singer_pb2 + + value_pb = self._callFUT(singer_pb2.Genre.ROCK) + self.assertIsInstance(value_pb, Value) + self.assertEqual(value_pb.string_value, "3") + class Test_make_list_value_pb(unittest.TestCase): def _callFUT(self, *args, **kw): @@ -394,9 +413,10 @@ def test_w_null(self): from google.cloud.spanner_v1 import TypeCode field_type = Type(code=TypeCode.STRING) + field_name = "null_column" value_pb = Value(null_value=NULL_VALUE) - self.assertEqual(self._callFUT(value_pb, field_type), None) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), None) def test_w_string(self): from google.protobuf.struct_pb2 import Value @@ -405,9 +425,10 @@ def test_w_string(self): VALUE = "Value" field_type = Type(code=TypeCode.STRING) + field_name = "string_column" value_pb = Value(string_value=VALUE) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_bytes(self): from google.protobuf.struct_pb2 import Value @@ -416,9 +437,10 @@ def test_w_bytes(self): VALUE = b"Value" field_type = Type(code=TypeCode.BYTES) + field_name = "bytes_column" value_pb = Value(string_value=VALUE) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_bool(self): from google.protobuf.struct_pb2 import Value @@ -427,9 +449,10 @@ def test_w_bool(self): VALUE = True field_type = Type(code=TypeCode.BOOL) + field_name = "bool_column" value_pb = Value(bool_value=VALUE) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_int(self): from google.protobuf.struct_pb2 import Value @@ -438,9 +461,10 @@ def test_w_int(self): VALUE = 12345 field_type = Type(code=TypeCode.INT64) + field_name = "int_column" value_pb = Value(string_value=str(VALUE)) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_float(self): from google.protobuf.struct_pb2 import Value @@ -449,9 +473,10 @@ def test_w_float(self): VALUE = 3.14159 field_type = Type(code=TypeCode.FLOAT64) + field_name = "float_column" value_pb = Value(number_value=VALUE) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_float_str(self): from google.protobuf.struct_pb2 import Value @@ -460,10 +485,13 @@ def test_w_float_str(self): VALUE = "3.14159" field_type = Type(code=TypeCode.FLOAT64) + field_name = "float_str_column" value_pb = Value(string_value=VALUE) expected_value = 3.14159 - self.assertEqual(self._callFUT(value_pb, field_type), expected_value) + self.assertEqual( + self._callFUT(value_pb, field_type, field_name), expected_value + ) def test_w_date(self): import datetime @@ -473,9 +501,10 @@ def test_w_date(self): VALUE = datetime.date.today() field_type = Type(code=TypeCode.DATE) + field_name = "date_column" value_pb = Value(string_value=VALUE.isoformat()) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_timestamp_wo_nanos(self): import datetime @@ -488,9 +517,10 @@ def test_w_timestamp_wo_nanos(self): 2016, 12, 20, 21, 13, 47, microsecond=123456, tzinfo=datetime.timezone.utc ) field_type = Type(code=TypeCode.TIMESTAMP) + field_name = "nanos_column" value_pb = Value(string_value=datetime_helpers.to_rfc3339(value)) - parsed = self._callFUT(value_pb, field_type) + parsed = self._callFUT(value_pb, field_type, field_name) self.assertIsInstance(parsed, datetime_helpers.DatetimeWithNanoseconds) self.assertEqual(parsed, value) @@ -505,9 +535,10 @@ def test_w_timestamp_w_nanos(self): 2016, 12, 20, 21, 13, 47, nanosecond=123456789, tzinfo=datetime.timezone.utc ) field_type = Type(code=TypeCode.TIMESTAMP) + field_name = "timestamp_column" value_pb = Value(string_value=datetime_helpers.to_rfc3339(value)) - parsed = self._callFUT(value_pb, field_type) + parsed = self._callFUT(value_pb, field_type, field_name) self.assertIsInstance(parsed, datetime_helpers.DatetimeWithNanoseconds) self.assertEqual(parsed, value) @@ -519,9 +550,10 @@ def test_w_array_empty(self): field_type = Type( code=TypeCode.ARRAY, array_element_type=Type(code=TypeCode.INT64) ) + field_name = "array_empty_column" value_pb = Value(list_value=ListValue(values=[])) - self.assertEqual(self._callFUT(value_pb, field_type), []) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), []) def test_w_array_non_empty(self): from google.protobuf.struct_pb2 import Value, ListValue @@ -531,13 +563,14 @@ def test_w_array_non_empty(self): field_type = Type( code=TypeCode.ARRAY, array_element_type=Type(code=TypeCode.INT64) ) + field_name = "array_non_empty_column" VALUES = [32, 19, 5] values_pb = ListValue( values=[Value(string_value=str(value)) for value in VALUES] ) value_pb = Value(list_value=values_pb) - self.assertEqual(self._callFUT(value_pb, field_type), VALUES) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUES) def test_w_struct(self): from google.protobuf.struct_pb2 import Value @@ -554,9 +587,10 @@ def test_w_struct(self): ] ) field_type = Type(code=TypeCode.STRUCT, struct_type=struct_type_pb) + field_name = "struct_column" value_pb = Value(list_value=_make_list_value_pb(VALUES)) - self.assertEqual(self._callFUT(value_pb, field_type), VALUES) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUES) def test_w_numeric(self): import decimal @@ -566,9 +600,10 @@ def test_w_numeric(self): VALUE = decimal.Decimal("99999999999999999999999999999.999999999") field_type = Type(code=TypeCode.NUMERIC) + field_name = "numeric_column" value_pb = Value(string_value=str(VALUE)) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_json(self): import json @@ -580,9 +615,10 @@ def test_w_json(self): str_repr = json.dumps(VALUE, sort_keys=True, separators=(",", ":")) field_type = Type(code=TypeCode.JSON) + field_name = "json_column" value_pb = Value(string_value=str_repr) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) VALUE = None str_repr = json.dumps(VALUE, sort_keys=True, separators=(",", ":")) @@ -590,7 +626,7 @@ def test_w_json(self): field_type = Type(code=TypeCode.JSON) value_pb = Value(string_value=str_repr) - self.assertEqual(self._callFUT(value_pb, field_type), {}) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), {}) def test_w_unknown_type(self): from google.protobuf.struct_pb2 import Value @@ -598,10 +634,44 @@ def test_w_unknown_type(self): from google.cloud.spanner_v1 import TypeCode field_type = Type(code=TypeCode.TYPE_CODE_UNSPECIFIED) + field_name = "unknown_column" value_pb = Value(string_value="Borked") with self.assertRaises(ValueError): - self._callFUT(value_pb, field_type) + self._callFUT(value_pb, field_type, field_name) + + def test_w_proto_message(self): + from google.protobuf.struct_pb2 import Value + from google.cloud.spanner_v1 import Type + from google.cloud.spanner_v1 import TypeCode + import base64 + from samples.samples.testdata import singer_pb2 + + VALUE = singer_pb2.SingerInfo() + field_type = Type(code=TypeCode.PROTO) + field_name = "proto_message_column" + value_pb = Value(string_value=base64.b64encode(VALUE.SerializeToString())) + column_info = {"proto_message_column": singer_pb2.SingerInfo()} + + self.assertEqual( + self._callFUT(value_pb, field_type, field_name, column_info), VALUE + ) + + def test_w_proto_enum(self): + from google.protobuf.struct_pb2 import Value + from google.cloud.spanner_v1 import Type + from google.cloud.spanner_v1 import TypeCode + from samples.samples.testdata import singer_pb2 + + VALUE = "ROCK" + field_type = Type(code=TypeCode.ENUM) + field_name = "proto_enum_column" + value_pb = Value(string_value=str(singer_pb2.Genre.ROCK)) + column_info = {"proto_enum_column": singer_pb2.Genre} + + self.assertEqual( + self._callFUT(value_pb, field_type, field_name, column_info), VALUE + ) class Test_parse_list_value_pbs(unittest.TestCase): diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index bff89320c7..dbff6c5107 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -186,6 +186,14 @@ def test_ctor_w_encryption_config(self): self.assertIs(database._instance, instance) self.assertEqual(database._encryption_config, encryption_config) + def test_ctor_w_proto_descriptors(self): + + instance = _Instance(self.INSTANCE_NAME) + database = self._make_one(self.DATABASE_ID, instance, proto_descriptors=b"") + self.assertEqual(database.database_id, self.DATABASE_ID) + self.assertIs(database._instance, instance) + self.assertEqual(database._proto_descriptors, b"") + def test_from_pb_bad_database_name(self): from google.cloud.spanner_admin_database_v1 import Database @@ -351,6 +359,15 @@ def test_default_leader(self): default_leader = database._default_leader = "us-east4" self.assertEqual(database.default_leader, default_leader) + def test_proto_descriptors(self): + + instance = _Instance(self.INSTANCE_NAME) + pool = _Pool() + database = self._make_one( + self.DATABASE_ID, instance, pool=pool, proto_descriptors=b"" + ) + self.assertEqual(database.proto_descriptors, b"") + def test_spanner_api_property_w_scopeless_creds(self): client = _Client() @@ -622,6 +639,41 @@ def test_create_success_w_encryption_config_dict(self): metadata=[("google-cloud-resource-prefix", database.name)], ) + def test_create_success_w_proto_descriptors(self): + from tests._fixtures import DDL_STATEMENTS + from google.cloud.spanner_admin_database_v1 import CreateDatabaseRequest + + op_future = object() + client = _Client() + api = client.database_admin_api = self._make_database_admin_api() + api.create_database.return_value = op_future + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + proto_descriptors = b"" + database = self._make_one( + self.DATABASE_ID, + instance, + ddl_statements=DDL_STATEMENTS, + pool=pool, + proto_descriptors=proto_descriptors, + ) + + future = database.create() + + self.assertIs(future, op_future) + + expected_request = CreateDatabaseRequest( + parent=self.INSTANCE_NAME, + create_statement="CREATE DATABASE {}".format(self.DATABASE_ID), + extra_statements=DDL_STATEMENTS, + proto_descriptors=proto_descriptors, + ) + + api.create_database.assert_called_once_with( + request=expected_request, + metadata=[("google-cloud-resource-prefix", database.name)], + ) + def test_exists_grpc_error(self): from google.api_core.exceptions import Unknown @@ -877,6 +929,34 @@ def test_update_ddl_w_operation_id(self): metadata=[("google-cloud-resource-prefix", database.name)], ) + def test_update_ddl_w_proto_descriptors(self): + from tests._fixtures import DDL_STATEMENTS + from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest + + op_future = object() + client = _Client() + api = client.database_admin_api = self._make_database_admin_api() + api.update_database_ddl.return_value = op_future + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + + future = database.update_ddl(DDL_STATEMENTS, proto_descriptors=b"") + + self.assertIs(future, op_future) + + expected_request = UpdateDatabaseDdlRequest( + database=self.DATABASE_NAME, + statements=DDL_STATEMENTS, + operation_id="", + proto_descriptors=b"", + ) + + api.update_database_ddl.assert_called_once_with( + request=expected_request, + metadata=[("google-cloud-resource-prefix", database.name)], + ) + def test_drop_grpc_error(self): from google.api_core.exceptions import Unknown diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index e0a0f663cf..e45b3f051c 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -557,6 +557,7 @@ def test_database_factory_explicit(self): pool = _Pool() logger = mock.create_autospec(Logger, instance=True) encryption_config = {"kms_key_name": "kms_key_name"} + proto_descriptors = b"" database = instance.database( DATABASE_ID, @@ -565,6 +566,7 @@ def test_database_factory_explicit(self): logger=logger, encryption_config=encryption_config, database_role=DATABASE_ROLE, + proto_descriptors=proto_descriptors, ) self.assertIsInstance(database, Database) @@ -576,6 +578,7 @@ def test_database_factory_explicit(self): self.assertIs(pool._bound, database) self.assertIs(database._encryption_config, encryption_config) self.assertIs(database.database_role, DATABASE_ROLE) + self.assertIs(database._proto_descriptors, proto_descriptors) def test_list_databases(self): from google.cloud.spanner_admin_database_v1 import Database as DatabasePB diff --git a/tests/unit/test_param_types.py b/tests/unit/test_param_types.py index 02f41c1f25..fad171c918 100644 --- a/tests/unit/test_param_types.py +++ b/tests/unit/test_param_types.py @@ -71,3 +71,37 @@ def test_it(self): found = param_types.PG_JSONB self.assertEqual(found, expected) + + +class Test_ProtoMessageParamType(unittest.TestCase): + def test_it(self): + from google.cloud.spanner_v1 import Type + from google.cloud.spanner_v1 import TypeCode + from google.cloud.spanner_v1 import param_types + from samples.samples.testdata import singer_pb2 + + singer_info = singer_pb2.SingerInfo() + expected = Type( + code=TypeCode.PROTO, proto_type_fqn=singer_info.DESCRIPTOR.full_name + ) + + found = param_types.ProtoMessage(singer_info) + + self.assertEqual(found, expected) + + +class Test_ProtoEnumParamType(unittest.TestCase): + def test_it(self): + from google.cloud.spanner_v1 import Type + from google.cloud.spanner_v1 import TypeCode + from google.cloud.spanner_v1 import param_types + from samples.samples.testdata import singer_pb2 + + singer_genre = singer_pb2.Genre + expected = Type( + code=TypeCode.ENUM, proto_type_fqn=singer_genre.DESCRIPTOR.full_name + ) + + found = param_types.ProtoEnum(singer_genre) + + self.assertEqual(found, expected) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index edad4ce777..ce9b205264 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -614,7 +614,12 @@ def test_read(self): self.assertIs(found, snapshot().read.return_value) snapshot().read.assert_called_once_with( - TABLE_NAME, COLUMNS, KEYSET, INDEX, LIMIT + TABLE_NAME, + COLUMNS, + KEYSET, + INDEX, + LIMIT, + column_info=None, ) def test_execute_sql_not_created(self): @@ -645,6 +650,7 @@ def test_execute_sql_defaults(self): request_options=None, timeout=google.api_core.gapic_v1.method.DEFAULT, retry=google.api_core.gapic_v1.method.DEFAULT, + column_info=None, ) def test_execute_sql_non_default_retry(self): @@ -675,6 +681,7 @@ def test_execute_sql_non_default_retry(self): request_options=None, timeout=None, retry=None, + column_info=None, ) def test_execute_sql_explicit(self): @@ -703,6 +710,7 @@ def test_execute_sql_explicit(self): request_options=None, timeout=google.api_core.gapic_v1.method.DEFAULT, retry=google.api_core.gapic_v1.method.DEFAULT, + column_info=None, ) def test_batch_not_created(self):