Skip to content

feat: Proto Columns Feature #909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,11 @@ class CreateDatabaseRequest(proto.Message):
database_dialect (google.cloud.spanner_admin_database_v1.types.DatabaseDialect):
Optional. The dialect of the Cloud Spanner
Database.
proto_descriptors (bytes):
Proto descriptors used by CREATE/ALTER PROTO BUNDLE
statements in 'extra_statements' above. Contains a
protobuf-serialized
`google.protobuf.FileDescriptorSet <https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/descriptor.proto>`__.
"""

parent: str = proto.Field(
Expand All @@ -355,6 +360,10 @@ class CreateDatabaseRequest(proto.Message):
number=5,
enum=common.DatabaseDialect,
)
proto_descriptors: bytes = proto.Field(
proto.BYTES,
number=6,
)


class CreateDatabaseMetadata(proto.Message):
Expand Down Expand Up @@ -435,6 +444,10 @@ class UpdateDatabaseDdlRequest(proto.Message):
underscore. If the named operation already exists,
[UpdateDatabaseDdl][google.spanner.admin.database.v1.DatabaseAdmin.UpdateDatabaseDdl]
returns ``ALREADY_EXISTS``.
proto_descriptors (bytes):
Proto descriptors used by CREATE/ALTER PROTO BUNDLE
statements. Contains a protobuf-serialized
`google.protobuf.FileDescriptorSet <https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/descriptor.proto>`__.
"""

database: str = proto.Field(
Expand All @@ -449,6 +462,10 @@ class UpdateDatabaseDdlRequest(proto.Message):
proto.STRING,
number=3,
)
proto_descriptors: bytes = proto.Field(
proto.BYTES,
number=4,
)


class UpdateDatabaseDdlMetadata(proto.Message):
Expand Down Expand Up @@ -549,12 +566,20 @@ class GetDatabaseDdlResponse(proto.Message):
A list of formatted DDL statements defining
the schema of the database specified in the
request.
proto_descriptors (bytes):
Proto descriptors stored in the database. Contains a
protobuf-serialized
`google.protobuf.FileDescriptorSet <https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/descriptor.proto>`__.
"""

statements: MutableSequence[str] = proto.RepeatedField(
proto.STRING,
number=1,
)
proto_descriptors: bytes = proto.Field(
proto.BYTES,
number=2,
)


class ListDatabaseOperationsRequest(proto.Message):
Expand Down
43 changes: 39 additions & 4 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
import datetime
import decimal
import math
import base64

from google.protobuf.struct_pb2 import ListValue
from google.protobuf.struct_pb2 import Value
from google.protobuf.message import Message
from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper

from google.api_core import datetime_helpers
from google.cloud._helpers import _date_from_iso8601_date
Expand Down Expand Up @@ -170,6 +173,12 @@ def _make_value_pb(value):
return Value(null_value="NULL_VALUE")
else:
return Value(string_value=value)
if isinstance(value, Message):
value = value.SerializeToString()
if value is None:
return Value(null_value="NULL_VALUE")
else:
return Value(string_value=base64.b64encode(value))

raise ValueError("Unknown type: %s" % (value,))

Expand Down Expand Up @@ -198,7 +207,7 @@ def _make_list_value_pbs(values):
return [_make_list_value_pb(row) for row in values]


def _parse_value_pb(value_pb, field_type):
def _parse_value_pb(value_pb, field_type, field_name, column_info=None):
"""Convert a Value protobuf to cell data.

:type value_pb: :class:`~google.protobuf.struct_pb2.Value`
Expand All @@ -207,6 +216,12 @@ def _parse_value_pb(value_pb, field_type):
:type field_type: :class:`~google.cloud.spanner_v1.types.Type`
:param field_type: type code for the value

:type field_name: str
:param field_name: column name

:type column_info: dict
:param column_info: (Optional) dict of column name and column information

:rtype: varies on field_type
:returns: value extracted from value_pb
:raises ValueError: if unknown type is passed
Expand Down Expand Up @@ -234,18 +249,38 @@ def _parse_value_pb(value_pb, field_type):
return DatetimeWithNanoseconds.from_rfc3339(value_pb.string_value)
elif type_code == TypeCode.ARRAY:
return [
_parse_value_pb(item_pb, field_type.array_element_type)
_parse_value_pb(
item_pb, field_type.array_element_type, field_name, column_info
)
for item_pb in value_pb.list_value.values
]
elif type_code == TypeCode.STRUCT:
return [
_parse_value_pb(item_pb, field_type.struct_type.fields[i].type_)
_parse_value_pb(
item_pb, field_type.struct_type.fields[i].type_, field_name, column_info
)
for (i, item_pb) in enumerate(value_pb.list_value.values)
]
elif type_code == TypeCode.NUMERIC:
return decimal.Decimal(value_pb.string_value)
elif type_code == TypeCode.JSON:
return JsonObject.from_str(value_pb.string_value)
elif type_code == TypeCode.PROTO:
bytes_value = base64.b64decode(value_pb.string_value)
if column_info is not None and column_info.get(field_name) is not None:
proto_message = column_info.get(field_name)
if isinstance(proto_message, Message):
proto_message = proto_message.__deepcopy__()
proto_message.ParseFromString(bytes_value)
return proto_message
return bytes_value
elif type_code == TypeCode.ENUM:
int_value = int(value_pb.string_value)
if column_info is not None and column_info.get(field_name) is not None:
proto_enum = column_info.get(field_name)
if isinstance(proto_enum, EnumTypeWrapper):
return proto_enum.Name(int_value)
return int_value
else:
raise ValueError("Unknown type: %s" % (field_type,))

Expand All @@ -266,7 +301,7 @@ def _parse_list_value_pbs(rows, row_type):
for row in rows:
row_data = []
for value_pb, field in zip(row.values, row_type.fields):
row_data.append(_parse_value_pb(value_pb, field.type_))
row_data.append(_parse_value_pb(value_pb, field.type_, field.name))
result.append(row_data)
return result

Expand Down
110 changes: 110 additions & 0 deletions google/cloud/spanner_v1/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
"""Custom data types for spanner."""

import json
import types

from google.protobuf.message import Message
from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper


class JsonObject(dict):
Expand Down Expand Up @@ -71,3 +75,109 @@ def serialize(self):
return json.dumps(self._array_value, sort_keys=True, separators=(",", ":"))

return json.dumps(self, sort_keys=True, separators=(",", ":"))


def _proto_message(bytes_val, proto_message_object):
"""Helper for :func:`get_proto_message`.
parses serialized protocol buffer bytes data into proto message.

Args:
bytes_val (bytes): bytes object.
proto_message_object (Message): Message object for parsing

Returns:
Message: parses serialized protocol buffer data into this message.

Raises:
ValueError: if the input proto_message_object is not of type Message
"""
if isinstance(bytes_val, types.NoneType):
return None

if not isinstance(bytes_val, bytes):
raise ValueError("Expected input bytes_val to be a string")

proto_message = proto_message_object.__deepcopy__()
proto_message.ParseFromString(bytes_val)
return proto_message


def _proto_enum(int_val, proto_enum_object):
"""Helper for :func:`get_proto_enum`.
parses int value into string containing the name of an enum value.

Args:
int_val (int): integer value.
proto_enum_object (EnumTypeWrapper): Enum object.

Returns:
str: string containing the name of an enum value.

Raises:
ValueError: if the input proto_enum_object is not of type EnumTypeWrapper
"""
if isinstance(int_val, types.NoneType):
return None

if not isinstance(int_val, int):
raise ValueError("Expected input int_val to be a integer")

return proto_enum_object.Name(int_val)


def get_proto_message(bytes_string, proto_message_object):
"""parses serialized protocol buffer bytes' data or its list into proto message or list of proto message.

Args:
bytes_string (bytes or list[bytes]): bytes object.
proto_message_object (Message): Message object for parsing

Returns:
Message or list[Message]: parses serialized protocol buffer data into this message.

Raises:
ValueError: if the input proto_message_object is not of type Message
"""
if isinstance(bytes_string, types.NoneType):
return None

if not isinstance(proto_message_object, Message):
raise ValueError("Input proto_message_object should be of type Message")

if not isinstance(bytes_string, (bytes, list)):
raise ValueError(
"Expected input bytes_string to be a string or list of strings"
)

if isinstance(bytes_string, list):
return [_proto_message(item, proto_message_object) for item in bytes_string]

return _proto_message(bytes_string, proto_message_object)


def get_proto_enum(int_value, proto_enum_object):
"""parses int or list of int values into enum or list of enum values.

Args:
int_value (int or list[int]): list of integer value.
proto_enum_object (EnumTypeWrapper): Enum object.

Returns:
str or list[str]: list of strings containing the name of enum value.

Raises:
ValueError: if the input int_list is not of type list
"""
if isinstance(int_value, types.NoneType):
return None

if not isinstance(proto_enum_object, EnumTypeWrapper):
raise ValueError("Input proto_enum_object should be of type EnumTypeWrapper")

if not isinstance(int_value, (int, list)):
raise ValueError("Expected input int_value to be a integer or list of integers")

if isinstance(int_value, list):
return [_proto_enum(item, proto_enum_object) for item in int_value]

return _proto_enum(int_value, proto_enum_object)
20 changes: 19 additions & 1 deletion google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ class Database(object):
(Optional) database dialect for the database
:type database_role: str or None
:param database_role: (Optional) user-assigned database_role for the session.
:type proto_descriptors: bytes
:param proto_descriptors: (Optional) Proto descriptors used by CREATE/ALTER PROTO BUNDLE
statements in 'ddl_statements' above.
"""

_spanner_api = None
Expand All @@ -138,6 +141,7 @@ def __init__(
encryption_config=None,
database_dialect=DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED,
database_role=None,
proto_descriptors=None,
):
self.database_id = database_id
self._instance = instance
Expand All @@ -155,6 +159,7 @@ def __init__(
self._encryption_config = encryption_config
self._database_dialect = database_dialect
self._database_role = database_role
self._proto_descriptors = proto_descriptors

if pool is None:
pool = BurstyPool(database_role=database_role)
Expand Down Expand Up @@ -328,6 +333,14 @@ def database_role(self):
"""
return self._database_role

@property
def proto_descriptors(self):
"""Proto Descriptors for this database.
:rtype: bytes
:returns: bytes representing the proto descriptors for this database
"""
return self._proto_descriptors

@property
def logger(self):
"""Logger used by the database.
Expand Down Expand Up @@ -411,6 +424,7 @@ def create(self):
extra_statements=list(self._ddl_statements),
encryption_config=self._encryption_config,
database_dialect=self._database_dialect,
proto_descriptors=self._proto_descriptors,
)
future = api.create_database(request=request, metadata=metadata)
return future
Expand Down Expand Up @@ -447,6 +461,7 @@ def reload(self):
metadata = _metadata_with_prefix(self.name)
response = api.get_database_ddl(database=self.name, metadata=metadata)
self._ddl_statements = tuple(response.statements)
self._proto_descriptors = response.proto_descriptors
response = api.get_database(name=self.name, metadata=metadata)
self._state = DatabasePB.State(response.state)
self._create_time = response.create_time
Expand All @@ -458,7 +473,7 @@ def reload(self):
self._default_leader = response.default_leader
self._database_dialect = response.database_dialect

def update_ddl(self, ddl_statements, operation_id=""):
def update_ddl(self, ddl_statements, operation_id="", proto_descriptors=None):
"""Update DDL for this database.

Apply any configured schema from :attr:`ddl_statements`.
Expand All @@ -470,6 +485,8 @@ def update_ddl(self, ddl_statements, operation_id=""):
:param ddl_statements: a list of DDL statements to use on this database
:type operation_id: str
:param operation_id: (optional) a string ID for the long-running operation
:type proto_descriptors: bytes
:param proto_descriptors: (optional) Proto descriptors used by CREATE/ALTER PROTO BUNDLE statements

:rtype: :class:`google.api_core.operation.Operation`
:returns: an operation instance
Expand All @@ -483,6 +500,7 @@ def update_ddl(self, ddl_statements, operation_id=""):
database=self.name,
statements=ddl_statements,
operation_id=operation_id,
proto_descriptors=proto_descriptors,
)

future = api.update_database_ddl(request=request, metadata=metadata)
Expand Down
6 changes: 6 additions & 0 deletions google/cloud/spanner_v1/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def database(
encryption_config=None,
database_dialect=DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED,
database_role=None,
proto_descriptors=None,
):
"""Factory to create a database within this instance.

Expand Down Expand Up @@ -467,6 +468,10 @@ def database(
:param database_dialect:
(Optional) database dialect for the database

:type proto_descriptors: bytes
:param proto_descriptors: (Optional) Proto descriptors used by CREATE/ALTER PROTO BUNDLE
statements in 'ddl_statements' above.

:rtype: :class:`~google.cloud.spanner_v1.database.Database`
:returns: a database owned by this instance.
"""
Expand All @@ -479,6 +484,7 @@ def database(
encryption_config=encryption_config,
database_dialect=database_dialect,
database_role=database_role,
proto_descriptors=proto_descriptors,
)

def list_databases(self, page_size=None):
Expand Down
Loading