Skip to content

Commit 52b1a0a

Browse files
authored
feat: pass custom Client object to dbapi (#911)
1 parent 520d6d7 commit 52b1a0a

File tree

2 files changed

+66
-13
lines changed

2 files changed

+66
-13
lines changed

google/cloud/spanner_dbapi/connection.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ def connect(
497497
credentials=None,
498498
pool=None,
499499
user_agent=None,
500+
client=None,
500501
):
501502
"""Creates a connection to a Google Cloud Spanner database.
502503
@@ -529,25 +530,31 @@ def connect(
529530
:param user_agent: (Optional) User agent to be used with this connection's
530531
requests.
531532
533+
:type client: Concrete subclass of
534+
:class:`~google.cloud.spanner_v1.Client`.
535+
:param client: (Optional) Custom user provided Client Object
536+
532537
:rtype: :class:`google.cloud.spanner_dbapi.connection.Connection`
533538
:returns: Connection object associated with the given Google Cloud Spanner
534539
resource.
535540
"""
536-
537-
client_info = ClientInfo(
538-
user_agent=user_agent or DEFAULT_USER_AGENT,
539-
python_version=PY_VERSION,
540-
client_library_version=spanner.__version__,
541-
)
542-
543-
if isinstance(credentials, str):
544-
client = spanner.Client.from_service_account_json(
545-
credentials, project=project, client_info=client_info
541+
if client is None:
542+
client_info = ClientInfo(
543+
user_agent=user_agent or DEFAULT_USER_AGENT,
544+
python_version=PY_VERSION,
545+
client_library_version=spanner.__version__,
546546
)
547+
if isinstance(credentials, str):
548+
client = spanner.Client.from_service_account_json(
549+
credentials, project=project, client_info=client_info
550+
)
551+
else:
552+
client = spanner.Client(
553+
project=project, credentials=credentials, client_info=client_info
554+
)
547555
else:
548-
client = spanner.Client(
549-
project=project, credentials=credentials, client_info=client_info
550-
)
556+
if project is not None and client.project != project:
557+
raise ValueError("project in url does not match client object project")
551558

552559
instance = client.instance(instance_id)
553560
conn = Connection(instance, instance.database(database_id, pool=pool))

tests/unit/spanner_dbapi/test_connection.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import mock
1919
import unittest
2020
import warnings
21+
import pytest
2122

2223
PROJECT = "test-project"
2324
INSTANCE = "test-instance"
@@ -915,7 +916,52 @@ def test_request_priority(self):
915916
sql, params, param_types=param_types, request_options=None
916917
)
917918

919+
@mock.patch("google.cloud.spanner_v1.Client")
920+
def test_custom_client_connection(self, mock_client):
921+
from google.cloud.spanner_dbapi import connect
922+
923+
client = _Client()
924+
connection = connect("test-instance", "test-database", client=client)
925+
self.assertTrue(connection.instance._client == client)
926+
927+
@mock.patch("google.cloud.spanner_v1.Client")
928+
def test_invalid_custom_client_connection(self, mock_client):
929+
from google.cloud.spanner_dbapi import connect
930+
931+
client = _Client()
932+
with pytest.raises(ValueError):
933+
connect(
934+
"test-instance",
935+
"test-database",
936+
project="invalid_project",
937+
client=client,
938+
)
939+
918940

919941
def exit_ctx_func(self, exc_type, exc_value, traceback):
920942
"""Context __exit__ method mock."""
921943
pass
944+
945+
946+
class _Client(object):
947+
def __init__(self, project="project_id"):
948+
self.project = project
949+
self.project_name = "projects/" + self.project
950+
951+
def instance(self, instance_id="instance_id"):
952+
return _Instance(name=instance_id, client=self)
953+
954+
955+
class _Instance(object):
956+
def __init__(self, name="instance_id", client=None):
957+
self.name = name
958+
self._client = client
959+
960+
def database(self, database_id="database_id", pool=None):
961+
return _Database(database_id, pool)
962+
963+
964+
class _Database(object):
965+
def __init__(self, database_id="database_id", pool=None):
966+
self.name = database_id
967+
self.pool = pool

0 commit comments

Comments
 (0)