Skip to content

Commit 066c14e

Browse files
feat: improve key for connector cache (#1172)
1 parent a348659 commit 066c14e

File tree

3 files changed

+69
-40
lines changed

3 files changed

+69
-40
lines changed

google/cloud/sql/connector/connector.py

+14-15
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import logging
2222
from threading import Thread
2323
from types import TracebackType
24-
from typing import Any, Dict, Optional, Type, Union
24+
from typing import Any, Dict, Optional, Tuple, Type, Union
2525

2626
import google.auth
2727
from google.auth.credentials import Credentials
@@ -131,7 +131,11 @@ def __init__(
131131
asyncio.run_coroutine_threadsafe(generate_keys(), self._loop),
132132
loop=self._loop,
133133
)
134-
self._cache: Dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {}
134+
# initialize dict to store caches, key is a tuple consisting of instance
135+
# connection name string and enable_iam_auth boolean flag
136+
self._cache: Dict[
137+
Tuple[str, bool], Union[RefreshAheadCache, LazyRefreshCache]
138+
] = {}
135139
self._client: Optional[CloudSQLClient] = None
136140

137141
# initialize credentials
@@ -262,15 +266,8 @@ async def connect_async(
262266
driver=driver,
263267
)
264268
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
265-
if instance_connection_string in self._cache:
266-
cache = self._cache[instance_connection_string]
267-
if enable_iam_auth != cache._enable_iam_auth:
268-
raise ValueError(
269-
f"connect() called with 'enable_iam_auth={enable_iam_auth}', "
270-
f"but previously used 'enable_iam_auth={cache._enable_iam_auth}'. "
271-
"If you require both for your use case, please use a new "
272-
"connector.Connector object."
273-
)
269+
if (instance_connection_string, enable_iam_auth) in self._cache:
270+
cache = self._cache[(instance_connection_string, enable_iam_auth)]
274271
else:
275272
if self._refresh_strategy == RefreshStrategy.LAZY:
276273
logger.debug(
@@ -297,7 +294,7 @@ async def connect_async(
297294
logger.debug(
298295
f"['{instance_connection_string}']: Connection info added to cache"
299296
)
300-
self._cache[instance_connection_string] = cache
297+
self._cache[(instance_connection_string, enable_iam_auth)] = cache
301298

302299
connect_func = {
303300
"pymysql": pymysql.connect,
@@ -333,7 +330,7 @@ async def connect_async(
333330
except Exception:
334331
# with an error from Cloud SQL Admin API call or IP type, invalidate
335332
# the cache and re-raise the error
336-
await self._remove_cached(instance_connection_string)
333+
await self._remove_cached(instance_connection_string, enable_iam_auth)
337334
raise
338335
logger.debug(
339336
f"['{instance_connection_string}']: Connecting to {ip_address}:3307"
@@ -370,15 +367,17 @@ async def connect_async(
370367
await cache.force_refresh()
371368
raise
372369

373-
async def _remove_cached(self, instance_connection_string: str) -> None:
370+
async def _remove_cached(
371+
self, instance_connection_string: str, enable_iam_auth: bool
372+
) -> None:
374373
"""Stops all background refreshes and deletes the connection
375374
info cache from the map of caches.
376375
"""
377376
logger.debug(
378377
f"['{instance_connection_string}']: Removing connection info from cache"
379378
)
380379
# remove cache from stored caches and close it
381-
cache = self._cache.pop(instance_connection_string)
380+
cache = self._cache.pop((instance_connection_string, enable_iam_auth))
382381
await cache.close()
383382

384383
def __enter__(self) -> Any:

tests/system/test_connector_object.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,21 @@ def test_multiple_connectors() -> None:
7979
conn.execute(sqlalchemy.text("SELECT 1"))
8080

8181
instance_connection_string = os.environ["MYSQL_CONNECTION_NAME"]
82-
assert instance_connection_string in first_connector._cache
83-
assert instance_connection_string in second_connector._cache
8482
assert (
85-
first_connector._cache[instance_connection_string]
86-
!= second_connector._cache[instance_connection_string]
83+
instance_connection_string,
84+
first_connector._enable_iam_auth,
85+
) in first_connector._cache
86+
assert (
87+
instance_connection_string,
88+
second_connector._enable_iam_auth,
89+
) in second_connector._cache
90+
assert (
91+
first_connector._cache[
92+
(instance_connection_string, first_connector._enable_iam_auth)
93+
]
94+
!= second_connector._cache[
95+
(instance_connection_string, second_connector._enable_iam_auth)
96+
]
8797
)
8898
except Exception:
8999
raise

tests/unit/test_connector.py

+41-21
Original file line numberDiff line numberDiff line change
@@ -31,26 +31,46 @@
3131
from google.cloud.sql.connector.instance import RefreshAheadCache
3232

3333

34-
def test_connect_enable_iam_auth_error(
35-
fake_credentials: Credentials, cache: RefreshAheadCache
34+
@pytest.mark.asyncio
35+
async def test_connect_enable_iam_auth_error(
36+
fake_credentials: Credentials, fake_client: CloudSQLClient
3637
) -> None:
3738
"""Test that calling connect() with different enable_iam_auth
38-
argument values throws error."""
39+
argument values creates two cache entries."""
3940
connect_string = "test-project:test-region:test-instance"
40-
connector = Connector(credentials=fake_credentials)
41-
# set cache
42-
connector._cache[connect_string] = cache
43-
# try to connect using enable_iam_auth=True, should raise error
44-
with pytest.raises(ValueError) as exc_info:
45-
connector.connect(connect_string, "pg8000", enable_iam_auth=True)
46-
assert (
47-
exc_info.value.args[0] == "connect() called with 'enable_iam_auth=True', "
48-
"but previously used 'enable_iam_auth=False'. "
49-
"If you require both for your use case, please use a new "
50-
"connector.Connector object."
51-
)
52-
# remove cache entry to avoid destructor warnings
53-
connector._cache = {}
41+
async with Connector(
42+
credentials=fake_credentials, loop=asyncio.get_running_loop()
43+
) as connector:
44+
connector._client = fake_client
45+
# patch db connection creation
46+
with patch("google.cloud.sql.connector.asyncpg.connect") as mock_connect:
47+
mock_connect.return_value = True
48+
# connect with enable_iam_auth False
49+
connection = await connector.connect_async(
50+
connect_string,
51+
"asyncpg",
52+
user="my-user",
53+
password="my-pass",
54+
db="my-db",
55+
enable_iam_auth=False,
56+
)
57+
# verify connector made connection call
58+
assert connection is True
59+
# connect with enable_iam_auth True
60+
connection = await connector.connect_async(
61+
connect_string,
62+
"asyncpg",
63+
user="my-user",
64+
password="my-pass",
65+
db="my-db",
66+
enable_iam_auth=True,
67+
)
68+
# verify connector made connection call
69+
assert connection is True
70+
# verify both cache entries for same instance exist
71+
assert len(connector._cache) == 2
72+
assert (connect_string, True) in connector._cache
73+
assert (connect_string, False) in connector._cache
5474

5575

5676
async def test_connect_incompatible_driver_error(
@@ -305,15 +325,15 @@ async def test_Connector_remove_cached_bad_instance(
305325
conn_name = "bad-project:bad-region:bad-inst"
306326
# populate cache
307327
cache = RefreshAheadCache(conn_name, fake_client, connector._keys)
308-
connector._cache[conn_name] = cache
328+
connector._cache[(conn_name, False)] = cache
309329
# aiohttp client should throw a 404 ClientResponseError
310330
with pytest.raises(ClientResponseError):
311331
await connector.connect_async(
312332
conn_name,
313333
"pg8000",
314334
)
315335
# check that cache has been removed from dict
316-
assert conn_name not in connector._cache
336+
assert (conn_name, False) not in connector._cache
317337

318338

319339
async def test_Connector_remove_cached_no_ip_type(
@@ -331,7 +351,7 @@ async def test_Connector_remove_cached_no_ip_type(
331351
conn_name = "test-project:test-region:test-instance"
332352
# populate cache
333353
cache = RefreshAheadCache(conn_name, fake_client, connector._keys)
334-
connector._cache[conn_name] = cache
354+
connector._cache[(conn_name, False)] = cache
335355
# test instance does not have Private IP, thus should invalidate cache
336356
with pytest.raises(CloudSQLIPTypeError):
337357
await connector.connect_async(
@@ -342,7 +362,7 @@ async def test_Connector_remove_cached_no_ip_type(
342362
ip_type="private",
343363
)
344364
# check that cache has been removed from dict
345-
assert conn_name not in connector._cache
365+
assert (conn_name, False) not in connector._cache
346366

347367

348368
def test_default_universe_domain(fake_credentials: Credentials) -> None:

0 commit comments

Comments
 (0)