Skip to content

Commit d8453c7

Browse files
authored
fix: DatabaseWrapper method impl and potential bugfix (#545)
1 parent 7083f1d commit d8453c7

File tree

3 files changed

+138
-12
lines changed

3 files changed

+138
-12
lines changed

django_spanner/base.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
# license that can be found in the LICENSE file or at
55
# https://developers.google.com/open-source/licenses/bsd
66

7+
import google.cloud.spanner_v1 as spanner
8+
79
from django.db.backends.base.base import BaseDatabaseWrapper
8-
from google.cloud import spanner_dbapi as Database, spanner_v1 as spanner
10+
from google.cloud import spanner_dbapi
911

1012
from .client import DatabaseClient
1113
from .creation import DatabaseCreation
@@ -81,6 +83,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
8183
# special characters for REGEXP_CONTAINS operators (e.g. \, *, _) must be
8284
# escaped on database side.
8385
pattern_esc = r'REPLACE(REPLACE(REPLACE({}, "\\", "\\\\"), "%%", r"\%%"), "_", r"\_")'
86+
8487
# These are all no-ops in favor of using REGEXP_CONTAINS in the customized
8588
# lookups.
8689
pattern_ops = {
@@ -92,7 +95,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
9295
"iendswith": "",
9396
}
9497

95-
Database = Database
98+
Database = spanner_dbapi
9699
SchemaEditorClass = DatabaseSchemaEditor
97100
creation_class = DatabaseCreation
98101
features_class = DatabaseFeatures
@@ -131,7 +134,7 @@ def get_connection_params(self):
131134
**self.settings_dict["OPTIONS"],
132135
}
133136

134-
def get_new_connection(self, **conn_params):
137+
def get_new_connection(self, conn_params):
135138
"""Create a new connection with corresponding connection parameters.
136139
137140
:type conn_params: list
@@ -145,11 +148,13 @@ def get_new_connection(self, **conn_params):
145148
:raises: :class:`ValueError` in case the given instance/database
146149
doesn't exist.
147150
"""
148-
return Database.connect(**conn_params)
151+
return self.Database.connect(**conn_params)
149152

150153
def init_connection_state(self):
151154
"""Initialize the state of the existing connection."""
152-
pass
155+
self.connection.close()
156+
database = self.connection.database
157+
self.connection.__init__(self.instance, database)
153158

154159
def create_cursor(self, name=None):
155160
"""Create a new Database cursor.
@@ -177,12 +182,13 @@ def is_usable(self):
177182
:rtype: bool
178183
:returns: True if the connection is open, otherwise False.
179184
"""
180-
if self.connection is None:
185+
if self.connection is None or self.connection.is_closed:
181186
return False
187+
182188
try:
183189
# Use a cursor directly, bypassing Django's utilities.
184190
self.connection.cursor().execute("SELECT 1")
185-
except Database.Error:
191+
except self.Database.Error:
186192
return False
187-
else:
188-
return True
193+
194+
return True

noxfile.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010
from __future__ import absolute_import
1111

12-
import nox
1312
import os
1413
import shutil
1514

15+
import nox
1616

1717
BLACK_VERSION = "black==19.10b0"
1818
BLACK_PATHS = [
@@ -60,12 +60,24 @@ def lint_setup_py(session):
6060

6161
def default(session):
6262
# Install all test dependencies, then install this package in-place.
63-
session.install("mock", "pytest", "pytest-cov")
63+
session.install(
64+
"django~=2.2", "mock", "mock-import", "pytest", "pytest-cov"
65+
)
6466
session.install("-e", ".")
6567

6668
# Run py.test against the unit tests.
6769
session.run(
68-
"py.test", "--quiet", os.path.join("tests", "unit"), *session.posargs
70+
"py.test",
71+
"--quiet",
72+
"--cov=django_spanner",
73+
"--cov=google.cloud",
74+
"--cov=tests.unit",
75+
"--cov-append",
76+
"--cov-config=.coveragerc",
77+
"--cov-report=",
78+
"--cov-fail-under=60",
79+
os.path.join("tests", "unit"),
80+
*session.posargs
6981
)
7082

7183

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Use of this source code is governed by a BSD-style
4+
# license that can be found in the LICENSE file or at
5+
# https://developers.google.com/open-source/licenses/bsd
6+
7+
import sys
8+
import unittest
9+
10+
from mock_import import mock_import
11+
from unittest import mock
12+
13+
14+
@mock_import()
15+
@unittest.skipIf(sys.version_info < (3, 6), reason="Skipping Python 3.5")
16+
class TestBase(unittest.TestCase):
17+
PROJECT = "project"
18+
INSTANCE_ID = "instance_id"
19+
DATABASE_ID = "database_id"
20+
USER_AGENT = "django_spanner/2.2.0a1"
21+
OPTIONS = {"option": "dummy"}
22+
23+
settings_dict = {
24+
"PROJECT": PROJECT,
25+
"INSTANCE": INSTANCE_ID,
26+
"NAME": DATABASE_ID,
27+
"user_agent": USER_AGENT,
28+
"OPTIONS": OPTIONS,
29+
}
30+
31+
def _get_target_class(self):
32+
from django_spanner.base import DatabaseWrapper
33+
34+
return DatabaseWrapper
35+
36+
def _make_one(self, *args, **kwargs):
37+
return self._get_target_class()(*args, **kwargs)
38+
39+
def test_property_instance(self):
40+
settings_dict = {"INSTANCE": "instance"}
41+
db_wrapper = self._make_one(settings_dict=settings_dict)
42+
43+
with mock.patch("django_spanner.base.spanner") as mock_spanner:
44+
mock_spanner.Client = mock_client = mock.MagicMock()
45+
mock_client().instance = mock_instance = mock.MagicMock()
46+
_ = db_wrapper.instance
47+
mock_instance.assert_called_once_with(settings_dict["INSTANCE"])
48+
49+
def test_property__nodb_connection(self):
50+
db_wrapper = self._make_one(None)
51+
with self.assertRaises(NotImplementedError):
52+
db_wrapper._nodb_connection()
53+
54+
def test_get_connection_params(self):
55+
db_wrapper = self._make_one(self.settings_dict)
56+
params = db_wrapper.get_connection_params()
57+
58+
self.assertEqual(params["project"], self.PROJECT)
59+
self.assertEqual(params["instance_id"], self.INSTANCE_ID)
60+
self.assertEqual(params["database_id"], self.DATABASE_ID)
61+
self.assertEqual(params["user_agent"], self.USER_AGENT)
62+
self.assertEqual(params["option"], self.OPTIONS["option"])
63+
64+
def test_get_new_connection(self):
65+
db_wrapper = self._make_one(self.settings_dict)
66+
db_wrapper.Database = mock_database = mock.MagicMock()
67+
mock_database.connect = mock_connect = mock.MagicMock()
68+
conn_params = {"test_param": "dummy"}
69+
db_wrapper.get_new_connection(conn_params)
70+
mock_connect.assert_called_once_with(**conn_params)
71+
72+
def test_init_connection_state(self):
73+
db_wrapper = self._make_one(self.settings_dict)
74+
db_wrapper.connection = mock_connection = mock.MagicMock()
75+
mock_connection.close = mock_close = mock.MagicMock()
76+
db_wrapper.init_connection_state()
77+
mock_close.assert_called_once_with()
78+
79+
def test_create_cursor(self):
80+
db_wrapper = self._make_one(self.settings_dict)
81+
db_wrapper.connection = mock_connection = mock.MagicMock()
82+
mock_connection.cursor = mock_cursor = mock.MagicMock()
83+
db_wrapper.create_cursor()
84+
mock_cursor.assert_called_once_with()
85+
86+
def test__set_autocommit(self):
87+
db_wrapper = self._make_one(self.settings_dict)
88+
db_wrapper.connection = mock_connection = mock.MagicMock()
89+
mock_connection.autocommit = False
90+
db_wrapper._set_autocommit(True)
91+
self.assertEqual(mock_connection.autocommit, True)
92+
93+
def test_is_usable(self):
94+
from google.cloud.spanner_dbapi.exceptions import Error
95+
96+
db_wrapper = self._make_one(self.settings_dict)
97+
db_wrapper.connection = None
98+
self.assertFalse(db_wrapper.is_usable())
99+
100+
db_wrapper.connection = mock_connection = mock.MagicMock()
101+
mock_connection.is_closed = True
102+
self.assertFalse(db_wrapper.is_usable())
103+
104+
mock_connection.is_closed = False
105+
self.assertTrue(db_wrapper.is_usable())
106+
107+
mock_connection.cursor = mock.MagicMock(side_effect=Error)
108+
self.assertFalse(db_wrapper.is_usable())

0 commit comments

Comments
 (0)