Skip to content

Commit 0d616f2

Browse files
committed
fixes
1 parent cd766cb commit 0d616f2

File tree

3 files changed

+73
-21
lines changed

3 files changed

+73
-21
lines changed

Diff for: asyncpg/_testbase/__init__.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -262,22 +262,28 @@ def create_pool(dsn=None, *,
262262
max_size=10,
263263
max_queries=50000,
264264
max_inactive_connection_lifetime=60.0,
265+
connect=None,
265266
setup=None,
266267
init=None,
267268
loop=None,
268269
pool_class=pg_pool.Pool,
269270
connection_class=pg_connection.Connection,
270271
record_class=asyncpg.Record,
271-
connect_fn=pg_connection.connect,
272272
**connect_kwargs):
273273
return pool_class(
274274
dsn,
275-
min_size=min_size, max_size=max_size,
276-
max_queries=max_queries, loop=loop, setup=setup, init=init,
275+
min_size=min_size,
276+
max_size=max_size,
277+
max_queries=max_queries,
278+
loop=loop,
279+
connect=connect,
280+
setup=setup,
281+
init=init,
277282
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
278283
connection_class=connection_class,
279-
record_class=record_class, connect_fn=connect_fn,
280-
**connect_kwargs)
284+
record_class=record_class,
285+
**connect_kwargs,
286+
)
281287

282288

283289
class ClusterTestCase(TestCase):

Diff for: asyncpg/pool.py

+42-15
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ class Pool:
313313

314314
__slots__ = (
315315
'_queue', '_loop', '_minsize', '_maxsize',
316-
'_init', '_connect_fn', '_connect_args', '_connect_kwargs',
316+
'_init', '_connect', '_connect_args', '_connect_kwargs',
317317
'_holders', '_initialized', '_initializing', '_closing',
318318
'_closed', '_connection_class', '_record_class', '_generation',
319319
'_setup', '_max_queries', '_max_inactive_connection_lifetime'
@@ -324,12 +324,12 @@ def __init__(self, *connect_args,
324324
max_size,
325325
max_queries,
326326
max_inactive_connection_lifetime,
327-
setup,
328-
init,
327+
connect = None,
328+
setup = None,
329+
init = None,
329330
loop,
330331
connection_class,
331332
record_class,
332-
connect_fn,
333333
**connect_kwargs):
334334

335335
if len(connect_args) > 1:
@@ -386,12 +386,14 @@ def __init__(self, *connect_args,
386386
self._closing = False
387387
self._closed = False
388388
self._generation = 0
389-
self._init = init
389+
390+
self._connect = connect if connect is not None else connection.connect
390391
self._connect_args = connect_args
391392
self._connect_kwargs = connect_kwargs
392-
self._connect_fn = connect_fn
393393

394394
self._setup = setup
395+
self._init = init
396+
395397
self._max_queries = max_queries
396398
self._max_inactive_connection_lifetime = \
397399
max_inactive_connection_lifetime
@@ -505,13 +507,25 @@ def set_connect_args(self, dsn=None, **connect_kwargs):
505507
self._connect_kwargs = connect_kwargs
506508

507509
async def _get_new_connection(self):
508-
con = await self._connect_fn(
510+
con = await self._connect(
509511
*self._connect_args,
510512
loop=self._loop,
511513
connection_class=self._connection_class,
512514
record_class=self._record_class,
513515
**self._connect_kwargs,
514516
)
517+
if not isinstance(con, self._connection_class):
518+
good = self._connection_class
519+
good_n = f'{good.__module__}.{good.__name__}'
520+
bad = type(con)
521+
if bad.__module__ == "builtins":
522+
bad_n = bad.__name__
523+
else:
524+
bad_n = f'{bad.__module__}.{bad.__name__}'
525+
raise exceptions.InterfaceError(
526+
"expected pool connect callback to return an instance of "
527+
f"'{good_n}', got " f"'{bad_n}'"
528+
)
515529

516530
if self._init is not None:
517531
try:
@@ -1003,6 +1017,7 @@ def create_pool(dsn=None, *,
10031017
max_size=10,
10041018
max_queries=50000,
10051019
max_inactive_connection_lifetime=300.0,
1020+
connect=None,
10061021
setup=None,
10071022
init=None,
10081023
loop=None,
@@ -1085,6 +1100,13 @@ def create_pool(dsn=None, *,
10851100
Number of seconds after which inactive connections in the
10861101
pool will be closed. Pass ``0`` to disable this mechanism.
10871102
1103+
:param coroutine connect:
1104+
A coroutine that is called instead of
1105+
:func:`~asyncpg.connection.connect` whenever the pool needs to make a
1106+
new connection. Must return an instance of type specified by
1107+
*connection_class* or :class:`~asyncpg.connection.Connection` if
1108+
*connection_class* was not specified.
1109+
10881110
:param coroutine setup:
10891111
A coroutine to prepare a connection right before it is returned
10901112
from :meth:`Pool.acquire() <pool.Pool.acquire>`. An example use
@@ -1099,10 +1121,6 @@ def create_pool(dsn=None, *,
10991121
or :meth:`Connection.set_type_codec() <\
11001122
asyncpg.connection.Connection.set_type_codec>`.
11011123
1102-
:param coroutine connect_fn:
1103-
A coroutine with signature identical to :func:`~asyncpg.connection.connect`. This can be used to add custom
1104-
authentication or ssl logic when creating a connection, as is required by GCP's cloud-sql-python-connector.
1105-
11061124
:param loop:
11071125
An asyncio event loop instance. If ``None``, the default
11081126
event loop will be used.
@@ -1129,12 +1147,21 @@ def create_pool(dsn=None, *,
11291147
11301148
.. versionchanged:: 0.22.0
11311149
Added the *record_class* parameter.
1150+
1151+
.. versionchanged:: 0.30.0
1152+
Added the *connect* parameter.
11321153
"""
11331154
return Pool(
11341155
dsn,
11351156
connection_class=connection_class,
1136-
record_class=record_class, connect_fn=connection.connect,
1137-
min_size=min_size, max_size=max_size,
1138-
max_queries=max_queries, loop=loop, setup=setup, init=init,
1157+
record_class=record_class,
1158+
min_size=min_size,
1159+
max_size=max_size,
1160+
max_queries=max_queries,
1161+
loop=loop,
1162+
connect=connect,
1163+
setup=setup,
1164+
init=init,
11391165
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
1140-
**connect_kwargs)
1166+
**connect_kwargs,
1167+
)

Diff for: tests/test_pool.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,12 @@ async def setup(con):
136136

137137
async def test_pool_07(self):
138138
cons = set()
139+
connect_called = 0
140+
141+
async def connect(*args, **kwargs):
142+
nonlocal connect_called
143+
connect_called += 1
144+
return await pg_connection.connect(*args, **kwargs)
139145

140146
async def setup(con):
141147
if con._con not in cons: # `con` is `PoolConnectionProxy`.
@@ -152,13 +158,26 @@ async def user(pool):
152158
raise RuntimeError('init was not called')
153159

154160
async with self.create_pool(database='postgres',
155-
min_size=2, max_size=5,
161+
min_size=2,
162+
max_size=5,
163+
connect=connect,
156164
init=init,
157165
setup=setup) as pool:
158166
users = asyncio.gather(*[user(pool) for _ in range(10)])
159167
await users
160168

161169
self.assertEqual(len(cons), 5)
170+
self.assertEqual(connect_called, 5)
171+
172+
async def bad_connect(*args, **kwargs):
173+
return 1
174+
175+
with self.assertRaisesRegex(
176+
asyncpg.InterfaceError,
177+
"expected pool connect callback to return an instance of "
178+
"'asyncpg\\.connection\\.Connection', got 'int'"
179+
):
180+
await self.create_pool(database='postgres', connect=bad_connect)
162181

163182
async def test_pool_08(self):
164183
pool = await self.create_pool(database='postgres',

0 commit comments

Comments
 (0)