diff --git a/google/cloud/spanner_v1/backup.py b/google/cloud/spanner_v1/backup.py index 4938aa7403..9068816705 100644 --- a/google/cloud/spanner_v1/backup.py +++ b/google/cloud/spanner_v1/backup.py @@ -19,6 +19,8 @@ from google.cloud.exceptions import NotFound from google.cloud.spanner_admin_database_v1 import Backup as BackupPB +from google.cloud.spanner_admin_database_v1 import CreateBackupEncryptionConfig +from google.cloud.spanner_admin_database_v1 import CreateBackupRequest from google.cloud.spanner_v1._helpers import _metadata_with_prefix _BACKUP_NAME_RE = re.compile( @@ -57,10 +59,24 @@ class Backup(object): the externally consistent copy of the database. If not present, it is the same as the `create_time` of the backup. + + :type encryption_config: + :class:`~google.cloud.spanner_admin_database_v1.types.CreateBackupEncryptionConfig` + or :class:`dict` + :param encryption_config: + (Optional) Encryption configuration for the backup. + If a dict is provided, it must be of the same form as the protobuf + message :class:`~google.cloud.spanner_admin_database_v1.types.CreateBackupEncryptionConfig` """ def __init__( - self, backup_id, instance, database="", expire_time=None, version_time=None + self, + backup_id, + instance, + database="", + expire_time=None, + version_time=None, + encryption_config=None, ): self.backup_id = backup_id self._instance = instance @@ -71,6 +87,11 @@ def __init__( self._size_bytes = None self._state = None self._referencing_databases = None + self._encryption_info = None + if type(encryption_config) == dict: + self._encryption_config = CreateBackupEncryptionConfig(**encryption_config) + else: + self._encryption_config = encryption_config @property def name(self): @@ -156,6 +177,22 @@ def referencing_databases(self): """ return self._referencing_databases + @property + def encryption_info(self): + """Encryption info for this backup. + :rtype: :class:`~google.clod.spanner_admin_database_v1.types.EncryptionInfo` + :returns: a class representing the encryption info + """ + return self._encryption_info + + @property + def encryption_config(self): + """Encryption config for this database. + :rtype: :class:`~google.cloud.spanner_admin_instance_v1.types.CreateBackupEncryptionConfig` + :returns: an object representing the encryption config for this database + """ + return self._encryption_config + @classmethod def from_pb(cls, backup_pb, instance): """Create an instance of this class from a protobuf message. @@ -207,6 +244,13 @@ def create(self): raise ValueError("expire_time not set") if not self._database: raise ValueError("database not set") + if ( + self.encryption_config + and self.encryption_config.kms_key_name + and self.encryption_config.encryption_type + != CreateBackupEncryptionConfig.EncryptionType.CUSTOMER_MANAGED_ENCRYPTION + ): + raise ValueError("kms_key_name only used with CUSTOMER_MANAGED_ENCRYPTION") api = self._instance._client.database_admin_api metadata = _metadata_with_prefix(self.name) backup = BackupPB( @@ -215,12 +259,14 @@ def create(self): version_time=self.version_time, ) - future = api.create_backup( + request = CreateBackupRequest( parent=self._instance.name, backup_id=self.backup_id, backup=backup, - metadata=metadata, + encryption_config=self._encryption_config, ) + + future = api.create_backup(request=request, metadata=metadata,) return future def exists(self): @@ -255,6 +301,7 @@ def reload(self): self._size_bytes = pb.size_bytes self._state = BackupPB.State(pb.state) self._referencing_databases = pb.referencing_databases + self._encryption_info = pb.encryption_info def update_expire_time(self, new_expire_time): """Update the expire time of this backup. diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index db34d095c7..3b367445e9 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -47,6 +47,9 @@ SpannerGrpcTransport, ) from google.cloud.spanner_admin_database_v1 import CreateDatabaseRequest +from google.cloud.spanner_admin_database_v1 import EncryptionConfig +from google.cloud.spanner_admin_database_v1 import RestoreDatabaseEncryptionConfig +from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest from google.cloud.spanner_v1 import ( ExecuteSqlRequest, @@ -108,12 +111,27 @@ class Database(object): is `True` to log commit statistics. If not passed, a logger will be created when needed that will log the commit statistics to stdout. + :type encryption_config: + :class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig` + or :class:`~google.cloud.spanner_admin_database_v1.types.RestoreDatabaseEncryptionConfig` + or :class:`dict` + :param encryption_config: + (Optional) Encryption configuration for the database. + If a dict is provided, it must be of the same form as either of the protobuf + messages :class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig` + or :class:`~google.cloud.spanner_admin_database_v1.types.RestoreDatabaseEncryptionConfig` """ _spanner_api = None def __init__( - self, database_id, instance, ddl_statements=(), pool=None, logger=None + self, + database_id, + instance, + ddl_statements=(), + pool=None, + logger=None, + encryption_config=None, ): self.database_id = database_id self._instance = instance @@ -126,6 +144,7 @@ def __init__( self._earliest_version_time = None self.log_commit_stats = False self._logger = logger + self._encryption_config = encryption_config if pool is None: pool = BurstyPool() @@ -242,6 +261,14 @@ def earliest_version_time(self): """ return self._earliest_version_time + @property + def encryption_config(self): + """Encryption config for this database. + :rtype: :class:`~google.cloud.spanner_admin_instance_v1.types.EncryptionConfig` + :returns: an object representing the encryption config for this database + """ + return self._encryption_config + @property def ddl_statements(self): """DDL Statements used to define database schema. @@ -325,11 +352,14 @@ def create(self): db_name = self.database_id if "-" in db_name: db_name = "`%s`" % (db_name,) + if type(self._encryption_config) == dict: + self._encryption_config = EncryptionConfig(**self._encryption_config) request = CreateDatabaseRequest( parent=self._instance.name, create_statement="CREATE DATABASE %s" % (db_name,), extra_statements=list(self._ddl_statements), + encryption_config=self._encryption_config, ) future = api.create_database(request=request, metadata=metadata) return future @@ -372,6 +402,7 @@ def reload(self): self._restore_info = response.restore_info self._version_retention_period = response.version_retention_period self._earliest_version_time = response.earliest_version_time + self._encryption_config = response.encryption_config def update_ddl(self, ddl_statements, operation_id=""): """Update DDL for this database. @@ -588,8 +619,8 @@ def run_in_transaction(self, func, *args, **kw): def restore(self, source): """Restore from a backup to this database. - :type backup: :class:`~google.cloud.spanner_v1.backup.Backup` - :param backup: the path of the backup being restored from. + :type source: :class:`~google.cloud.spanner_v1.backup.Backup` + :param source: the path of the source being restored from. :rtype: :class:`~google.api_core.operation.Operation` :returns: a future used to poll the status of the create request @@ -601,14 +632,26 @@ def restore(self, source): """ if source is None: raise ValueError("Restore source not specified") + if type(self._encryption_config) == dict: + self._encryption_config = RestoreDatabaseEncryptionConfig( + **self._encryption_config + ) + if ( + self.encryption_config + and self.encryption_config.kms_key_name + and self.encryption_config.encryption_type + != RestoreDatabaseEncryptionConfig.EncryptionType.CUSTOMER_MANAGED_ENCRYPTION + ): + raise ValueError("kms_key_name only used with CUSTOMER_MANAGED_ENCRYPTION") api = self._instance._client.database_admin_api metadata = _metadata_with_prefix(self.name) - future = api.restore_database( + request = RestoreDatabaseRequest( parent=self._instance.name, database_id=self.database_id, backup=source.name, - metadata=metadata, + encryption_config=self._encryption_config, ) + future = api.restore_database(request=request, metadata=metadata,) return future def is_ready(self): diff --git a/google/cloud/spanner_v1/instance.py b/google/cloud/spanner_v1/instance.py index 5ea297734c..5a9cf95f5a 100644 --- a/google/cloud/spanner_v1/instance.py +++ b/google/cloud/spanner_v1/instance.py @@ -357,7 +357,14 @@ def delete(self): api.delete_instance(name=self.name, metadata=metadata) - def database(self, database_id, ddl_statements=(), pool=None, logger=None): + def database( + self, + database_id, + ddl_statements=(), + pool=None, + logger=None, + encryption_config=None, + ): """Factory to create a database within this instance. :type database_id: str @@ -377,11 +384,26 @@ def database(self, database_id, ddl_statements=(), pool=None, logger=None): will be created when needed that will log the commit statistics to stdout. + :type encryption_config: + :class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig` + or :class:`~google.cloud.spanner_admin_database_v1.types.RestoreDatabaseEncryptionConfig` + or :class:`dict` + :param encryption_config: + (Optional) Encryption configuration for the database. + If a dict is provided, it must be of the same form as either of the protobuf + messages :class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig` + or :class:`~google.cloud.spanner_admin_database_v1.types.RestoreDatabaseEncryptionConfig` + :rtype: :class:`~google.cloud.spanner_v1.database.Database` :returns: a database owned by this instance. """ return Database( - database_id, self, ddl_statements=ddl_statements, pool=pool, logger=logger + database_id, + self, + ddl_statements=ddl_statements, + pool=pool, + logger=logger, + encryption_config=encryption_config, ) def list_databases(self, page_size=None): @@ -408,7 +430,14 @@ def list_databases(self, page_size=None): ) return page_iter - def backup(self, backup_id, database="", expire_time=None, version_time=None): + def backup( + self, + backup_id, + database="", + expire_time=None, + version_time=None, + encryption_config=None, + ): """Factory to create a backup within this instance. :type backup_id: str @@ -430,6 +459,14 @@ def backup(self, backup_id, database="", expire_time=None, version_time=None): consistent copy of the database. If not present, it is the same as the `create_time` of the backup. + :type encryption_config: + :class:`~google.cloud.spanner_admin_database_v1.types.CreateBackupEncryptionConfig` + or :class:`dict` + :param encryption_config: + (Optional) Encryption configuration for the backup. + If a dict is provided, it must be of the same form as the protobuf + message :class:`~google.cloud.spanner_admin_database_v1.types.CreateBackupEncryptionConfig` + :rtype: :class:`~google.cloud.spanner_v1.backup.Backup` :returns: a backup owned by this instance. """ @@ -440,6 +477,7 @@ def backup(self, backup_id, database="", expire_time=None, version_time=None): database=database.name, expire_time=expire_time, version_time=version_time, + encryption_config=encryption_config, ) except AttributeError: return Backup( @@ -448,6 +486,7 @@ def backup(self, backup_id, database="", expire_time=None, version_time=None): database=database, expire_time=expire_time, version_time=version_time, + encryption_config=encryption_config, ) def list_backups(self, filter_="", page_size=None): diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 575f79746e..8be207ef06 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -738,6 +738,11 @@ def test_create_invalid(self): op.result() def test_backup_workflow(self): + from google.cloud.spanner_admin_database_v1 import ( + CreateBackupEncryptionConfig, + EncryptionConfig, + RestoreDatabaseEncryptionConfig, + ) from datetime import datetime from datetime import timedelta from pytz import UTC @@ -746,6 +751,9 @@ def test_backup_workflow(self): backup_id = "backup_id" + unique_resource_id("_") expire_time = datetime.utcnow() + timedelta(days=3) expire_time = expire_time.replace(tzinfo=UTC) + encryption_config = CreateBackupEncryptionConfig( + encryption_type=CreateBackupEncryptionConfig.EncryptionType.GOOGLE_DEFAULT_ENCRYPTION, + ) # Create backup. backup = instance.backup( @@ -753,6 +761,7 @@ def test_backup_workflow(self): database=self._db, expire_time=expire_time, version_time=self.database_version_time, + encryption_config=encryption_config, ) operation = backup.create() self.to_delete.append(backup) @@ -771,6 +780,7 @@ def test_backup_workflow(self): self.assertEqual(self.database_version_time, backup.version_time) self.assertIsNotNone(backup.size_bytes) self.assertIsNotNone(backup.state) + self.assertEqual(encryption_config, backup.encryption_config) # Update with valid argument. valid_expire_time = datetime.utcnow() + timedelta(days=7) @@ -780,7 +790,10 @@ def test_backup_workflow(self): # Restore database to same instance. restored_id = "restored_db" + unique_resource_id("_") - database = instance.database(restored_id) + encryption_config = RestoreDatabaseEncryptionConfig( + encryption_type=RestoreDatabaseEncryptionConfig.EncryptionType.GOOGLE_DEFAULT_ENCRYPTION, + ) + database = instance.database(restored_id, encryption_config=encryption_config) self.to_drop.append(database) operation = database.restore(source=backup) restored_db = operation.result() @@ -791,6 +804,9 @@ def test_backup_workflow(self): metadata = operation.metadata self.assertEqual(self.database_version_time, metadata.backup_info.version_time) + database.reload() + expected_encryption_config = EncryptionConfig() + self.assertEqual(expected_encryption_config, database.encryption_config) database.drop() backup.delete() diff --git a/tests/unit/test_backup.py b/tests/unit/test_backup.py index bf6ce68a84..335ccb564b 100644 --- a/tests/unit/test_backup.py +++ b/tests/unit/test_backup.py @@ -62,18 +62,52 @@ def test_ctor_defaults(self): self.assertIsNone(backup._expire_time) def test_ctor_non_defaults(self): + from google.cloud.spanner_admin_database_v1 import CreateBackupEncryptionConfig + instance = _Instance(self.INSTANCE_NAME) timestamp = self._make_timestamp() + encryption_config = CreateBackupEncryptionConfig( + encryption_type=CreateBackupEncryptionConfig.EncryptionType.CUSTOMER_MANAGED_ENCRYPTION, + kms_key_name="key_name", + ) backup = self._make_one( - self.BACKUP_ID, instance, database=self.DATABASE_NAME, expire_time=timestamp + self.BACKUP_ID, + instance, + database=self.DATABASE_NAME, + expire_time=timestamp, + encryption_config=encryption_config, + ) + + self.assertEqual(backup.backup_id, self.BACKUP_ID) + self.assertIs(backup._instance, instance) + self.assertEqual(backup._database, self.DATABASE_NAME) + self.assertIsNotNone(backup._expire_time) + self.assertIs(backup._expire_time, timestamp) + self.assertEqual(backup.encryption_config, encryption_config) + + def test_ctor_w_encryption_config_dict(self): + from google.cloud.spanner_admin_database_v1 import CreateBackupEncryptionConfig + + instance = _Instance(self.INSTANCE_NAME) + timestamp = self._make_timestamp() + + encryption_config = {"encryption_type": 3, "kms_key_name": "key_name"} + backup = self._make_one( + self.BACKUP_ID, + instance, + database=self.DATABASE_NAME, + expire_time=timestamp, + encryption_config=encryption_config, ) + expected_encryption_config = CreateBackupEncryptionConfig(**encryption_config) self.assertEqual(backup.backup_id, self.BACKUP_ID) self.assertIs(backup._instance, instance) self.assertEqual(backup._database, self.DATABASE_NAME) self.assertIsNotNone(backup._expire_time) self.assertIs(backup._expire_time, timestamp) + self.assertEqual(backup.encryption_config, expected_encryption_config) def test_from_pb_project_mismatch(self): from google.cloud.spanner_admin_database_v1 import Backup @@ -170,10 +204,32 @@ def test_referencing_databases_property(self): expected = backup._referencing_databases = [self.DATABASE_NAME] self.assertEqual(backup.referencing_databases, expected) + def test_encrpytion_info_property(self): + from google.cloud.spanner_admin_database_v1 import EncryptionInfo + + instance = _Instance(self.INSTANCE_NAME) + backup = self._make_one(self.BACKUP_ID, instance) + expected = backup._encryption_info = EncryptionInfo( + kms_key_version="kms_key_version" + ) + self.assertEqual(backup.encryption_info, expected) + + def test_encryption_config_property(self): + from google.cloud.spanner_admin_database_v1 import CreateBackupEncryptionConfig + + instance = _Instance(self.INSTANCE_NAME) + backup = self._make_one(self.BACKUP_ID, instance) + expected = backup._encryption_config = CreateBackupEncryptionConfig( + encryption_type=CreateBackupEncryptionConfig.EncryptionType.CUSTOMER_MANAGED_ENCRYPTION, + kms_key_name="kms_key_name", + ) + self.assertEqual(backup.encryption_config, expected) + def test_create_grpc_error(self): from google.api_core.exceptions import GoogleAPICallError from google.api_core.exceptions import Unknown from google.cloud.spanner_admin_database_v1 import Backup + from google.cloud.spanner_admin_database_v1 import CreateBackupRequest client = _Client() api = client.database_admin_api = self._make_database_admin_api() @@ -190,16 +246,18 @@ def test_create_grpc_error(self): with self.assertRaises(GoogleAPICallError): backup.create() + request = CreateBackupRequest( + parent=self.INSTANCE_NAME, backup_id=self.BACKUP_ID, backup=backup_pb, + ) + api.create_backup.assert_called_once_with( - parent=self.INSTANCE_NAME, - backup_id=self.BACKUP_ID, - backup=backup_pb, - metadata=[("google-cloud-resource-prefix", backup.name)], + request=request, metadata=[("google-cloud-resource-prefix", backup.name)], ) def test_create_already_exists(self): from google.cloud.exceptions import Conflict from google.cloud.spanner_admin_database_v1 import Backup + from google.cloud.spanner_admin_database_v1 import CreateBackupRequest client = _Client() api = client.database_admin_api = self._make_database_admin_api() @@ -216,16 +274,18 @@ def test_create_already_exists(self): with self.assertRaises(Conflict): backup.create() + request = CreateBackupRequest( + parent=self.INSTANCE_NAME, backup_id=self.BACKUP_ID, backup=backup_pb, + ) + api.create_backup.assert_called_once_with( - parent=self.INSTANCE_NAME, - backup_id=self.BACKUP_ID, - backup=backup_pb, - metadata=[("google-cloud-resource-prefix", backup.name)], + request=request, metadata=[("google-cloud-resource-prefix", backup.name)], ) def test_create_instance_not_found(self): from google.cloud.exceptions import NotFound from google.cloud.spanner_admin_database_v1 import Backup + from google.cloud.spanner_admin_database_v1 import CreateBackupRequest client = _Client() api = client.database_admin_api = self._make_database_admin_api() @@ -242,11 +302,12 @@ def test_create_instance_not_found(self): with self.assertRaises(NotFound): backup.create() + request = CreateBackupRequest( + parent=self.INSTANCE_NAME, backup_id=self.BACKUP_ID, backup=backup_pb, + ) + api.create_backup.assert_called_once_with( - parent=self.INSTANCE_NAME, - backup_id=self.BACKUP_ID, - backup=backup_pb, - metadata=[("google-cloud-resource-prefix", backup.name)], + request=request, metadata=[("google-cloud-resource-prefix", backup.name)], ) def test_create_expire_time_not_set(self): @@ -266,6 +327,8 @@ def test_create_database_not_set(self): def test_create_success(self): from google.cloud.spanner_admin_database_v1 import Backup + from google.cloud.spanner_admin_database_v1 import CreateBackupRequest + from google.cloud.spanner_admin_database_v1 import CreateBackupEncryptionConfig from datetime import datetime from datetime import timedelta from pytz import UTC @@ -279,12 +342,14 @@ def test_create_success(self): version_timestamp = datetime.utcnow() - timedelta(minutes=5) version_timestamp = version_timestamp.replace(tzinfo=UTC) expire_timestamp = self._make_timestamp() + encryption_config = {"encryption_type": 3, "kms_key_name": "key_name"} backup = self._make_one( self.BACKUP_ID, instance, database=self.DATABASE_NAME, expire_time=expire_timestamp, version_time=version_timestamp, + encryption_config=encryption_config, ) backup_pb = Backup( @@ -296,13 +361,39 @@ def test_create_success(self): future = backup.create() self.assertIs(future, op_future) - api.create_backup.assert_called_once_with( + expected_encryption_config = CreateBackupEncryptionConfig(**encryption_config) + request = CreateBackupRequest( parent=self.INSTANCE_NAME, backup_id=self.BACKUP_ID, backup=backup_pb, - metadata=[("google-cloud-resource-prefix", backup.name)], + encryption_config=expected_encryption_config, ) + api.create_backup.assert_called_once_with( + request=request, metadata=[("google-cloud-resource-prefix", backup.name)], + ) + + def test_create_w_invalid_encryption_config(self): + from google.cloud.spanner_admin_database_v1 import CreateBackupEncryptionConfig + + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + expire_timestamp = self._make_timestamp() + encryption_config = { + "encryption_type": CreateBackupEncryptionConfig.EncryptionType.GOOGLE_DEFAULT_ENCRYPTION, + "kms_key_name": "key_name", + } + backup = self._make_one( + self.BACKUP_ID, + instance, + database=self.DATABASE_NAME, + expire_time=expire_timestamp, + encryption_config=encryption_config, + ) + + with self.assertRaises(ValueError): + backup.create() + def test_exists_grpc_error(self): from google.api_core.exceptions import Unknown @@ -442,8 +533,10 @@ def test_reload_not_found(self): def test_reload_success(self): from google.cloud.spanner_admin_database_v1 import Backup + from google.cloud.spanner_admin_database_v1 import EncryptionInfo timestamp = self._make_timestamp() + encryption_info = EncryptionInfo(kms_key_version="kms_key_version") client = _Client() backup_pb = Backup( @@ -455,6 +548,7 @@ def test_reload_success(self): size_bytes=10, state=1, referencing_databases=[], + encryption_info=encryption_info, ) api = client.database_admin_api = self._make_database_admin_api() api.get_backup.return_value = backup_pb @@ -470,6 +564,7 @@ def test_reload_success(self): self.assertEqual(backup.size_bytes, 10) self.assertEqual(backup.state, Backup.State.CREATING) self.assertEqual(backup.referencing_databases, []) + self.assertEqual(backup.encryption_info, encryption_info) api.get_backup.assert_called_once_with( name=self.BACKUP_NAME, diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 148bb79b0e..4bd7f7659e 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -159,6 +159,18 @@ def test_ctor_w_explicit_logger(self): self.assertFalse(database.log_commit_stats) self.assertEqual(database._logger, logger) + def test_ctor_w_encryption_config(self): + from google.cloud.spanner_admin_database_v1 import EncryptionConfig + + instance = _Instance(self.INSTANCE_NAME) + encryption_config = EncryptionConfig(kms_key_name="kms_key") + database = self._make_one( + self.DATABASE_ID, instance, encryption_config=encryption_config + ) + self.assertEqual(database.database_id, self.DATABASE_ID) + self.assertIs(database._instance, instance) + self.assertEqual(database._encryption_config, encryption_config) + def test_from_pb_bad_database_name(self): from google.cloud.spanner_admin_database_v1 import Database @@ -295,6 +307,17 @@ def test_logger_property_custom(self): logger = database._logger = mock.create_autospec(logging.Logger, instance=True) self.assertEqual(database.logger, logger) + def test_encryption_config(self): + from google.cloud.spanner_admin_database_v1 import EncryptionConfig + + instance = _Instance(self.INSTANCE_NAME) + pool = _Pool() + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + encryption_config = database._encryption_config = mock.create_autospec( + EncryptionConfig, instance=True + ) + self.assertEqual(database.encryption_config, encryption_config) + def test_spanner_api_property_w_scopeless_creds(self): client = _Client() @@ -432,6 +455,7 @@ def test_create_grpc_error(self): parent=self.INSTANCE_NAME, create_statement="CREATE DATABASE {}".format(self.DATABASE_ID), extra_statements=[], + encryption_config=None, ) api.create_database.assert_called_once_with( @@ -458,6 +482,7 @@ def test_create_already_exists(self): parent=self.INSTANCE_NAME, create_statement="CREATE DATABASE `{}`".format(DATABASE_ID_HYPHEN), extra_statements=[], + encryption_config=None, ) api.create_database.assert_called_once_with( @@ -483,6 +508,7 @@ def test_create_instance_not_found(self): parent=self.INSTANCE_NAME, create_statement="CREATE DATABASE {}".format(self.DATABASE_ID), extra_statements=[], + encryption_config=None, ) api.create_database.assert_called_once_with( @@ -493,6 +519,7 @@ def test_create_instance_not_found(self): def test_create_success(self): from tests._fixtures import DDL_STATEMENTS from google.cloud.spanner_admin_database_v1 import CreateDatabaseRequest + from google.cloud.spanner_admin_database_v1 import EncryptionConfig op_future = object() client = _Client() @@ -500,8 +527,13 @@ def test_create_success(self): api.create_database.return_value = op_future instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() + encryption_config = EncryptionConfig(kms_key_name="kms_key_name") database = self._make_one( - self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS, pool=pool + self.DATABASE_ID, + instance, + ddl_statements=DDL_STATEMENTS, + pool=pool, + encryption_config=encryption_config, ) future = database.create() @@ -512,6 +544,44 @@ def test_create_success(self): parent=self.INSTANCE_NAME, create_statement="CREATE DATABASE {}".format(self.DATABASE_ID), extra_statements=DDL_STATEMENTS, + encryption_config=encryption_config, + ) + + api.create_database.assert_called_once_with( + request=expected_request, + metadata=[("google-cloud-resource-prefix", database.name)], + ) + + def test_create_success_w_encryption_config_dict(self): + from tests._fixtures import DDL_STATEMENTS + from google.cloud.spanner_admin_database_v1 import CreateDatabaseRequest + from google.cloud.spanner_admin_database_v1 import EncryptionConfig + + op_future = object() + client = _Client() + api = client.database_admin_api = self._make_database_admin_api() + api.create_database.return_value = op_future + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + encryption_config = {"kms_key_name": "kms_key_name"} + database = self._make_one( + self.DATABASE_ID, + instance, + ddl_statements=DDL_STATEMENTS, + pool=pool, + encryption_config=encryption_config, + ) + + future = database.create() + + self.assertIs(future, op_future) + + expected_encryption_config = EncryptionConfig(**encryption_config) + expected_request = CreateDatabaseRequest( + parent=self.INSTANCE_NAME, + create_statement="CREATE DATABASE {}".format(self.DATABASE_ID), + extra_statements=DDL_STATEMENTS, + encryption_config=expected_encryption_config, ) api.create_database.assert_called_once_with( @@ -611,6 +681,7 @@ def test_reload_not_found(self): def test_reload_success(self): from google.cloud.spanner_admin_database_v1 import Database + from google.cloud.spanner_admin_database_v1 import EncryptionConfig from google.cloud.spanner_admin_database_v1 import GetDatabaseDdlResponse from google.cloud.spanner_admin_database_v1 import RestoreInfo from google.cloud._helpers import _datetime_to_pb_timestamp @@ -621,6 +692,7 @@ def test_reload_success(self): client = _Client() ddl_pb = GetDatabaseDdlResponse(statements=DDL_STATEMENTS) + encryption_config = EncryptionConfig(kms_key_name="kms_key") api = client.database_admin_api = self._make_database_admin_api() api.get_database_ddl.return_value = ddl_pb db_pb = Database( @@ -629,6 +701,7 @@ def test_reload_success(self): restore_info=restore_info, version_retention_period="1d", earliest_version_time=_datetime_to_pb_timestamp(timestamp), + encryption_config=encryption_config, ) api.get_database.return_value = db_pb instance = _Instance(self.INSTANCE_NAME, client=client) @@ -642,6 +715,7 @@ def test_reload_success(self): self.assertEqual(database._version_retention_period, "1d") self.assertEqual(database._earliest_version_time, timestamp) self.assertEqual(database._ddl_statements, tuple(DDL_STATEMENTS)) + self.assertEqual(database._encryption_config, encryption_config) api.get_database_ddl.assert_called_once_with( database=self.DATABASE_NAME, @@ -1128,6 +1202,7 @@ def test_restore_backup_unspecified(self): def test_restore_grpc_error(self): from google.api_core.exceptions import Unknown + from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest client = _Client() api = client.database_admin_api = self._make_database_admin_api() @@ -1140,15 +1215,20 @@ def test_restore_grpc_error(self): with self.assertRaises(Unknown): database.restore(backup) - api.restore_database.assert_called_once_with( + expected_request = RestoreDatabaseRequest( parent=self.INSTANCE_NAME, database_id=self.DATABASE_ID, backup=self.BACKUP_NAME, + ) + + api.restore_database.assert_called_once_with( + request=expected_request, metadata=[("google-cloud-resource-prefix", database.name)], ) def test_restore_not_found(self): from google.api_core.exceptions import NotFound + from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest client = _Client() api = client.database_admin_api = self._make_database_admin_api() @@ -1161,34 +1241,115 @@ def test_restore_not_found(self): with self.assertRaises(NotFound): database.restore(backup) - api.restore_database.assert_called_once_with( + expected_request = RestoreDatabaseRequest( parent=self.INSTANCE_NAME, database_id=self.DATABASE_ID, backup=self.BACKUP_NAME, + ) + + api.restore_database.assert_called_once_with( + request=expected_request, metadata=[("google-cloud-resource-prefix", database.name)], ) def test_restore_success(self): + from google.cloud.spanner_admin_database_v1 import ( + RestoreDatabaseEncryptionConfig, + ) + from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest + op_future = object() client = _Client() api = client.database_admin_api = self._make_database_admin_api() api.restore_database.return_value = op_future instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() - database = self._make_one(self.DATABASE_ID, instance, pool=pool) + encryption_config = RestoreDatabaseEncryptionConfig( + encryption_type=RestoreDatabaseEncryptionConfig.EncryptionType.CUSTOMER_MANAGED_ENCRYPTION, + kms_key_name="kms_key_name", + ) + database = self._make_one( + self.DATABASE_ID, instance, pool=pool, encryption_config=encryption_config + ) backup = _Backup(self.BACKUP_NAME) future = database.restore(backup) self.assertIs(future, op_future) + expected_request = RestoreDatabaseRequest( + parent=self.INSTANCE_NAME, + database_id=self.DATABASE_ID, + backup=self.BACKUP_NAME, + encryption_config=encryption_config, + ) + api.restore_database.assert_called_once_with( + request=expected_request, + metadata=[("google-cloud-resource-prefix", database.name)], + ) + + def test_restore_success_w_encryption_config_dict(self): + from google.cloud.spanner_admin_database_v1 import ( + RestoreDatabaseEncryptionConfig, + ) + from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest + + op_future = object() + client = _Client() + api = client.database_admin_api = self._make_database_admin_api() + api.restore_database.return_value = op_future + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + encryption_config = { + "encryption_type": RestoreDatabaseEncryptionConfig.EncryptionType.CUSTOMER_MANAGED_ENCRYPTION, + "kms_key_name": "kms_key_name", + } + database = self._make_one( + self.DATABASE_ID, instance, pool=pool, encryption_config=encryption_config + ) + backup = _Backup(self.BACKUP_NAME) + + future = database.restore(backup) + + self.assertIs(future, op_future) + + expected_encryption_config = RestoreDatabaseEncryptionConfig( + encryption_type=RestoreDatabaseEncryptionConfig.EncryptionType.CUSTOMER_MANAGED_ENCRYPTION, + kms_key_name="kms_key_name", + ) + expected_request = RestoreDatabaseRequest( parent=self.INSTANCE_NAME, database_id=self.DATABASE_ID, backup=self.BACKUP_NAME, + encryption_config=expected_encryption_config, + ) + + api.restore_database.assert_called_once_with( + request=expected_request, metadata=[("google-cloud-resource-prefix", database.name)], ) + def test_restore_w_invalid_encryption_config_dict(self): + from google.cloud.spanner_admin_database_v1 import ( + RestoreDatabaseEncryptionConfig, + ) + + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + encryption_config = { + "encryption_type": RestoreDatabaseEncryptionConfig.EncryptionType.GOOGLE_DEFAULT_ENCRYPTION, + "kms_key_name": "kms_key_name", + } + database = self._make_one( + self.DATABASE_ID, instance, pool=pool, encryption_config=encryption_config + ) + backup = _Backup(self.BACKUP_NAME) + + with self.assertRaises(ValueError): + database.restore(backup) + def test_is_ready(self): from google.cloud.spanner_admin_database_v1 import Database diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index c1d02c5728..2ed777b25b 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -498,9 +498,14 @@ def test_database_factory_explicit(self): DATABASE_ID = "database-id" pool = _Pool() logger = mock.create_autospec(Logger, instance=True) + encryption_config = {"kms_key_name": "kms_key_name"} database = instance.database( - DATABASE_ID, ddl_statements=DDL_STATEMENTS, pool=pool, logger=logger + DATABASE_ID, + ddl_statements=DDL_STATEMENTS, + pool=pool, + logger=logger, + encryption_config=encryption_config, ) self.assertIsInstance(database, Database) @@ -510,6 +515,7 @@ def test_database_factory_explicit(self): self.assertIs(database._pool, pool) self.assertIs(database._logger, logger) self.assertIs(pool._bound, database) + self.assertIs(database._encryption_config, encryption_config) def test_list_databases(self): from google.cloud.spanner_admin_database_v1 import Database as DatabasePB @@ -603,15 +609,23 @@ def test_backup_factory_explicit(self): import datetime from google.cloud._helpers import UTC from google.cloud.spanner_v1.backup import Backup + from google.cloud.spanner_admin_database_v1 import CreateBackupEncryptionConfig client = _Client(self.PROJECT) instance = self._make_one(self.INSTANCE_ID, client, self.CONFIG_NAME) BACKUP_ID = "backup-id" DATABASE_NAME = "database-name" timestamp = datetime.datetime.utcnow().replace(tzinfo=UTC) + encryption_config = CreateBackupEncryptionConfig( + encryption_type=CreateBackupEncryptionConfig.EncryptionType.CUSTOMER_MANAGED_ENCRYPTION, + kms_key_name="kms_key_name", + ) backup = instance.backup( - BACKUP_ID, database=DATABASE_NAME, expire_time=timestamp + BACKUP_ID, + database=DATABASE_NAME, + expire_time=timestamp, + encryption_config=encryption_config, ) self.assertIsInstance(backup, Backup) @@ -619,6 +633,7 @@ def test_backup_factory_explicit(self): self.assertIs(backup._instance, instance) self.assertEqual(backup._database, DATABASE_NAME) self.assertIs(backup._expire_time, timestamp) + self.assertEqual(backup._encryption_config, encryption_config) def test_list_backups_defaults(self): from google.cloud.spanner_admin_database_v1 import Backup as BackupPB