Skip to content

Commit 514e675

Browse files
committed
feat: add support for restore a database with CMEK
1 parent 023b2e5 commit 514e675

File tree

4 files changed

+129
-35
lines changed

4 files changed

+129
-35
lines changed

google/cloud/spanner_v1/database.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
)
4848
from google.cloud.spanner_admin_database_v1 import CreateDatabaseRequest
4949
from google.cloud.spanner_admin_database_v1 import EncryptionConfig
50+
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseEncryptionConfig
51+
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest
5052
from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest
5153
from google.cloud.spanner_v1 import ExecuteSqlRequest
5254
from google.cloud.spanner_v1 import (
@@ -102,8 +104,9 @@ class Database(object):
102104
or :class:`dict`
103105
:param encryption_config:
104106
(Optional) Encryption information about the database.
105-
If a dict is provided, it must be of the same form as the protobuf
106-
message :class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig`
107+
If a dict is provided, it must be of the same form as either of the protobuf
108+
messages :class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig`
109+
or :class:`~google.cloud.spanner_admin_database_v1.types.RestoreDatabaseEncryptionConfig`
107110
"""
108111

109112
_spanner_api = None
@@ -123,11 +126,7 @@ def __init__(
123126
self._state = None
124127
self._create_time = None
125128
self._restore_info = None
126-
127-
if type(encryption_config) == dict:
128-
self._encryption_config = EncryptionConfig(**encryption_config)
129-
else:
130-
self._encryption_config = encryption_config
129+
self._encryption_config = encryption_config
131130

132131
if pool is None:
133132
pool = BurstyPool()
@@ -297,6 +296,8 @@ def create(self):
297296
db_name = self.database_id
298297
if "-" in db_name:
299298
db_name = "`%s`" % (db_name,)
299+
if type(self._encryption_config) == dict:
300+
self._encryption_config = EncryptionConfig(**self._encryption_config)
300301

301302
request = CreateDatabaseRequest(
302303
parent=self._instance.name,
@@ -560,8 +561,8 @@ def run_in_transaction(self, func, *args, **kw):
560561
def restore(self, source):
561562
"""Restore from a backup to this database.
562563
563-
:type backup: :class:`~google.cloud.spanner_v1.backup.Backup`
564-
:param backup: the path of the backup being restored from.
564+
:type source: :class:`~google.cloud.spanner_v1.backup.Backup`
565+
:param source: the path of the source being restored from.
565566
566567
:rtype: :class:`~google.api_core.operation.Operation`
567568
:returns: a future used to poll the status of the create request
@@ -575,10 +576,16 @@ def restore(self, source):
575576
raise ValueError("Restore source not specified")
576577
api = self._instance._client.database_admin_api
577578
metadata = _metadata_with_prefix(self.name)
578-
future = api.restore_database(
579+
if type(self._encryption_config) == dict:
580+
self._encryption_config = RestoreDatabaseEncryptionConfig(**self._encryption_config)
581+
request = RestoreDatabaseRequest(
579582
parent=self._instance.name,
580583
database_id=self.database_id,
581584
backup=source.name,
585+
encryption_config=self._encryption_config
586+
)
587+
future = api.restore_database(
588+
request=request,
582589
metadata=metadata,
583590
)
584591
return future

google/cloud/spanner_v1/instance.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,12 +373,14 @@ def database(
373373
:param pool: (Optional) session pool to be used by database.
374374
375375
:type encryption_config:
376-
:class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig`
376+
:class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig` or
377+
:class:`~google.cloud.spanner_admin_database_v1.types.RestoreDatabaseEncryptionConfig`
377378
or :class:`dict`
378379
:param encryption_config:
379380
(Optional) Encryption information about the database.
380-
If a dict is provided, it must be of the same form as the protobuf
381-
message :class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig
381+
If a dict is provided, it must be of the same form as either of the protobuf
382+
messages :class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig`
383+
or :class:`~google.cloud.spanner_admin_database_v1.types.RestoreDatabaseEncryptionConfig`
382384
383385
:rtype: :class:`~google.cloud.spanner_v1.database.Database`
384386
:returns: a database owned by this instance.
@@ -444,7 +446,13 @@ def backup(self, backup_id, database="", expire_time=None, encryption_config=Non
444446
backup_id, self, database=database.name, expire_time=expire_time, encryption_config=encryption_config
445447
)
446448
except AttributeError:
447-
return Backup(backup_id, self, database=database, expire_time=expire_time)
449+
return Backup(
450+
backup_id,
451+
self,
452+
database=database,
453+
expire_time=expire_time,
454+
encryption_config=encryption_config
455+
)
448456

449457
def list_backups(self, filter_="", page_size=None):
450458
"""List backups for the instance.

tests/unit/test_database.py

Lines changed: 99 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -157,20 +157,6 @@ def test_ctor_w_encryption_config(self):
157157
self.assertIs(database._instance, instance)
158158
self.assertEqual(database._encryption_config, encryption_config)
159159

160-
def test_ctor_w_encryption_config_dict(self):
161-
from google.cloud.spanner_admin_database_v1 import EncryptionConfig
162-
163-
instance = _Instance(self.INSTANCE_NAME)
164-
encryption_config_dict = {"kms_key_name": "kms_key"}
165-
encryption_config = EncryptionConfig(kms_key_name="kms_key")
166-
database = self._make_one(
167-
self.DATABASE_ID, instance, encryption_config=encryption_config_dict
168-
)
169-
self.assertEqual(database.database_id, self.DATABASE_ID)
170-
self.assertIs(database._instance, instance)
171-
self.assertEqual(database._encryption_config, encryption_config)
172-
173-
174160
def test_from_pb_bad_database_name(self):
175161
from google.cloud.spanner_admin_database_v1 import Database
176162

@@ -487,15 +473,17 @@ def test_create_instance_not_found(self):
487473
def test_create_success(self):
488474
from tests._fixtures import DDL_STATEMENTS
489475
from google.cloud.spanner_admin_database_v1 import CreateDatabaseRequest
476+
from google.cloud.spanner_admin_database_v1 import EncryptionConfig
490477

491478
op_future = object()
492479
client = _Client()
493480
api = client.database_admin_api = self._make_database_admin_api()
494481
api.create_database.return_value = op_future
495482
instance = _Instance(self.INSTANCE_NAME, client=client)
496483
pool = _Pool()
484+
encryption_config = EncryptionConfig(kms_key_name="kms_key_name")
497485
database = self._make_one(
498-
self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS, pool=pool
486+
self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS, pool=pool, encryption_config=encryption_config
499487
)
500488

501489
future = database.create()
@@ -506,7 +494,40 @@ def test_create_success(self):
506494
parent=self.INSTANCE_NAME,
507495
create_statement="CREATE DATABASE {}".format(self.DATABASE_ID),
508496
extra_statements=DDL_STATEMENTS,
509-
encryption_config=None,
497+
encryption_config=encryption_config,
498+
)
499+
500+
api.create_database.assert_called_once_with(
501+
request=expected_request,
502+
metadata=[("google-cloud-resource-prefix", database.name)],
503+
)
504+
505+
def test_create_success_w_encryption_config_dict(self):
506+
from tests._fixtures import DDL_STATEMENTS
507+
from google.cloud.spanner_admin_database_v1 import CreateDatabaseRequest
508+
from google.cloud.spanner_admin_database_v1 import EncryptionConfig
509+
510+
op_future = object()
511+
client = _Client()
512+
api = client.database_admin_api = self._make_database_admin_api()
513+
api.create_database.return_value = op_future
514+
instance = _Instance(self.INSTANCE_NAME, client=client)
515+
pool = _Pool()
516+
encryption_config = {"kms_key_name": "kms_key_name"}
517+
database = self._make_one(
518+
self.DATABASE_ID, instance, ddl_statements=DDL_STATEMENTS, pool=pool, encryption_config=encryption_config
519+
)
520+
521+
future = database.create()
522+
523+
self.assertIs(future, op_future)
524+
525+
expected_encryption_config = EncryptionConfig(**encryption_config)
526+
expected_request = CreateDatabaseRequest(
527+
parent=self.INSTANCE_NAME,
528+
create_statement="CREATE DATABASE {}".format(self.DATABASE_ID),
529+
extra_statements=DDL_STATEMENTS,
530+
encryption_config=expected_encryption_config,
510531
)
511532

512533
api.create_database.assert_called_once_with(
@@ -1123,6 +1144,7 @@ def test_restore_backup_unspecified(self):
11231144

11241145
def test_restore_grpc_error(self):
11251146
from google.api_core.exceptions import Unknown
1147+
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest
11261148

11271149
client = _Client()
11281150
api = client.database_admin_api = self._make_database_admin_api()
@@ -1135,15 +1157,20 @@ def test_restore_grpc_error(self):
11351157
with self.assertRaises(Unknown):
11361158
database.restore(backup)
11371159

1138-
api.restore_database.assert_called_once_with(
1160+
expected_request = RestoreDatabaseRequest(
11391161
parent=self.INSTANCE_NAME,
11401162
database_id=self.DATABASE_ID,
11411163
backup=self.BACKUP_NAME,
1164+
)
1165+
1166+
api.restore_database.assert_called_once_with(
1167+
request=expected_request,
11421168
metadata=[("google-cloud-resource-prefix", database.name)],
11431169
)
11441170

11451171
def test_restore_not_found(self):
11461172
from google.api_core.exceptions import NotFound
1173+
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest
11471174

11481175
client = _Client()
11491176
api = client.database_admin_api = self._make_database_admin_api()
@@ -1156,31 +1183,84 @@ def test_restore_not_found(self):
11561183
with self.assertRaises(NotFound):
11571184
database.restore(backup)
11581185

1159-
api.restore_database.assert_called_once_with(
1186+
expected_request = RestoreDatabaseRequest(
11601187
parent=self.INSTANCE_NAME,
11611188
database_id=self.DATABASE_ID,
11621189
backup=self.BACKUP_NAME,
1190+
)
1191+
1192+
api.restore_database.assert_called_once_with(
1193+
request=expected_request,
11631194
metadata=[("google-cloud-resource-prefix", database.name)],
11641195
)
11651196

11661197
def test_restore_success(self):
1198+
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseEncryptionConfig
1199+
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest
1200+
11671201
op_future = object()
11681202
client = _Client()
11691203
api = client.database_admin_api = self._make_database_admin_api()
11701204
api.restore_database.return_value = op_future
11711205
instance = _Instance(self.INSTANCE_NAME, client=client)
11721206
pool = _Pool()
1173-
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
1207+
encryption_config = RestoreDatabaseEncryptionConfig(
1208+
encryption_type=RestoreDatabaseEncryptionConfig.EncryptionType.CUSTOMER_MANAGED_ENCRYPTION,
1209+
kms_key_name="kms_key_name"
1210+
)
1211+
database = self._make_one(self.DATABASE_ID, instance, pool=pool, encryption_config=encryption_config)
11741212
backup = _Backup(self.BACKUP_NAME)
11751213

11761214
future = database.restore(backup)
11771215

11781216
self.assertIs(future, op_future)
11791217

1218+
expected_request = RestoreDatabaseRequest(
1219+
parent=self.INSTANCE_NAME,
1220+
database_id=self.DATABASE_ID,
1221+
backup=self.BACKUP_NAME,
1222+
encryption_config=encryption_config
1223+
)
1224+
11801225
api.restore_database.assert_called_once_with(
1226+
request=expected_request,
1227+
metadata=[("google-cloud-resource-prefix", database.name)],
1228+
)
1229+
1230+
def test_restore_success_w_encryption_config_dict(self):
1231+
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseEncryptionConfig
1232+
from google.cloud.spanner_admin_database_v1 import RestoreDatabaseRequest
1233+
1234+
op_future = object()
1235+
client = _Client()
1236+
api = client.database_admin_api = self._make_database_admin_api()
1237+
api.restore_database.return_value = op_future
1238+
instance = _Instance(self.INSTANCE_NAME, client=client)
1239+
pool = _Pool()
1240+
encryption_config = {
1241+
'encryption_type': RestoreDatabaseEncryptionConfig.EncryptionType.CUSTOMER_MANAGED_ENCRYPTION,
1242+
'kms_key_name': 'kms_key_name'
1243+
}
1244+
database = self._make_one(self.DATABASE_ID, instance, pool=pool, encryption_config=encryption_config)
1245+
backup = _Backup(self.BACKUP_NAME)
1246+
1247+
future = database.restore(backup)
1248+
1249+
self.assertIs(future, op_future)
1250+
1251+
expected_encryption_config = RestoreDatabaseEncryptionConfig(
1252+
encryption_type=RestoreDatabaseEncryptionConfig.EncryptionType.CUSTOMER_MANAGED_ENCRYPTION,
1253+
kms_key_name="kms_key_name"
1254+
)
1255+
expected_request = RestoreDatabaseRequest(
11811256
parent=self.INSTANCE_NAME,
11821257
database_id=self.DATABASE_ID,
11831258
backup=self.BACKUP_NAME,
1259+
encryption_config=expected_encryption_config
1260+
)
1261+
1262+
api.restore_database.assert_called_once_with(
1263+
request=expected_request,
11841264
metadata=[("google-cloud-resource-prefix", database.name)],
11851265
)
11861266

tests/unit/test_instance.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,15 +488,14 @@ def test_database_factory_defaults(self):
488488
self.assertIs(pool._database, database)
489489

490490
def test_database_factory_explicit(self):
491-
from google.cloud.spanner_admin_database_v1 import EncryptionConfig
492491
from google.cloud.spanner_v1.database import Database
493492
from tests._fixtures import DDL_STATEMENTS
494493

495494
client = _Client(self.PROJECT)
496495
instance = self._make_one(self.INSTANCE_ID, client, self.CONFIG_NAME)
497496
DATABASE_ID = "database-id"
498497
pool = _Pool()
499-
encryption_config = EncryptionConfig(kms_key_name="kms_key")
498+
encryption_config = {"kms_key_name": "kms_key_name"}
500499

501500
database = instance.database(
502501
DATABASE_ID,

0 commit comments

Comments
 (0)