diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index c65f68a6..f2d6a9fc 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -7,6 +7,7 @@ from __future__ import annotations import asyncio +import configparser import collections from collections.abc import Callable import enum @@ -87,6 +88,9 @@ class SSLNegotiation(compat.StrEnum): PGPASSFILE = '.pgpass' +PG_SERVICEFILE = '.pg_service.conf' + + def _read_password_file(passfile: pathlib.Path) \ -> typing.List[typing.Tuple[str, ...]]: @@ -269,6 +273,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: def _parse_connect_dsn_and_args(*, dsn, host, port, user, password, passfile, database, ssl, + service, servicefile, direct_tls, server_settings, target_session_attrs, krbsrvname, gsslib): # `auth_hosts` is the version of host information for the purposes @@ -281,6 +286,32 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if dsn: parsed = urllib.parse.urlparse(dsn) + query = None + if parsed.query: + query = urllib.parse.parse_qs(parsed.query, strict_parsing=True) + for key, val in query.items(): + if isinstance(val, list): + query[key] = val[-1] + + if 'service' in query: + val = query.pop('service') + if not service and val: + service = val + + connection_service_file = servicefile + + if connection_service_file is None: + connection_service_file = os.getenv('PGSERVICEFILE') + + if connection_service_file is None: + homedir = compat.get_pg_home_directory() + if homedir: + connection_service_file = homedir / PG_SERVICEFILE + else: + connection_service_file = None + else: + connection_service_file = pathlib.Path(connection_service_file) + if parsed.scheme not in {'postgresql', 'postgres'}: raise exceptions.ClientConfigurationError( 'invalid DSN: scheme is expected to be either ' @@ -315,11 +346,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if password is None and dsn_password: password = urllib.parse.unquote(dsn_password) - if parsed.query: - query = urllib.parse.parse_qs(parsed.query, strict_parsing=True) - for key, val in query.items(): - if isinstance(val, list): - query[key] = val[-1] + if query: if 'port' in query: val = query.pop('port') @@ -406,12 +433,124 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if gsslib is None: gsslib = val + if 'service' in query: + val = query.pop('service') + if service is None: + service = val + if query: if server_settings is None: server_settings = query else: server_settings = {**query, **server_settings} + if connection_service_file is not None and service is not None: + pg_service = configparser.ConfigParser() + pg_service.read(connection_service_file) + if service in pg_service.sections(): + service_params = pg_service[service] + if 'port' in service_params: + val = service_params.pop('port') + if not port and val: + port = [int(p) for p in val.split(',')] + + if 'host' in service_params: + val = service_params.pop('host') + if not host and val: + host, port = _parse_hostlist(val, port) + + if 'dbname' in service_params: + val = service_params.pop('dbname') + if database is None: + database = val + + if 'database' in service_params: + val = service_params.pop('database') + if database is None: + database = val + + if 'user' in service_params: + val = service_params.pop('user') + if user is None: + user = val + + if 'password' in service_params: + val = service_params.pop('password') + if password is None: + password = val + + if 'passfile' in service_params: + val = service_params.pop('passfile') + if passfile is None: + passfile = val + + if 'sslmode' in service_params: + val = service_params.pop('sslmode') + if ssl is None: + ssl = val + + if 'sslcert' in service_params: + val = service_params.pop('sslcert') + if sslcert is None: + sslcert = val + + if 'sslkey' in service_params: + val = service_params.pop('sslkey') + if sslkey is None: + sslkey = val + + if 'sslrootcert' in service_params: + val = service_params.pop('sslrootcert') + if sslrootcert is None: + sslrootcert = val + + if 'sslnegotiation' in service_params: + val = service_params.pop('sslnegotiation') + if sslnegotiation is None: + sslnegotiation = val + + if 'sslcrl' in service_params: + val = service_params.pop('sslcrl') + if sslcrl is None: + sslcrl = val + + if 'sslpassword' in service_params: + val = service_params.pop('sslpassword') + if sslpassword is None: + sslpassword = val + + if 'ssl_min_protocol_version' in service_params: + val = service_params.pop( + 'ssl_min_protocol_version' + ) + if ssl_min_protocol_version is None: + ssl_min_protocol_version = val + + if 'ssl_max_protocol_version' in service_params: + val = service_params.pop( + 'ssl_max_protocol_version' + ) + if ssl_max_protocol_version is None: + ssl_max_protocol_version = val + + if 'target_session_attrs' in service_params: + dsn_target_session_attrs = service_params.pop( + 'target_session_attrs' + ) + if target_session_attrs is None: + target_session_attrs = dsn_target_session_attrs + + if 'krbsrvname' in service_params: + val = service_params.pop('krbsrvname') + if krbsrvname is None: + krbsrvname = val + + if 'gsslib' in service_params: + val = service_params.pop('gsslib') + if gsslib is None: + gsslib = val + if not service: + service = os.environ.get('PGSERVICE') if not host: hostspec = os.environ.get('PGHOST') if hostspec: @@ -724,7 +863,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, max_cached_statement_lifetime, max_cacheable_statement_size, ssl, direct_tls, server_settings, - target_session_attrs, krbsrvname, gsslib): + target_session_attrs, krbsrvname, gsslib, + service, servicefile): local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', @@ -754,7 +894,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, direct_tls=direct_tls, database=database, server_settings=server_settings, target_session_attrs=target_session_attrs, - krbsrvname=krbsrvname, gsslib=gsslib) + krbsrvname=krbsrvname, gsslib=gsslib, + service=service, servicefile=servicefile) config = _ClientConfiguration( command_timeout=command_timeout, diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 3a86466c..4e3e5cf1 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -2074,6 +2074,8 @@ async def _do_execute( async def connect(dsn=None, *, host=None, port=None, user=None, password=None, passfile=None, + service=None, + servicefile=None, database=None, loop=None, timeout=60, @@ -2183,6 +2185,14 @@ async def connect(dsn=None, *, (defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf`` on Windows). + :param service: + The name of the postgres connection service stored in the postgres + connection service file. + + :param servicefile: + The location of the connnection service file used to store + connection parameters. + :param loop: An asyncio event loop instance. If ``None``, the default event loop will be used. @@ -2395,6 +2405,9 @@ async def connect(dsn=None, *, .. versionchanged:: 0.30.0 Added the *krbsrvname* and *gsslib* parameters. + .. versionchanged:: 0.31.0 + Added the *servicefile* and *service* parameters. + .. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext .. _create_default_context: https://docs.python.org/3/library/ssl.html#ssl.create_default_context @@ -2428,6 +2441,8 @@ async def connect(dsn=None, *, user=user, password=password, passfile=passfile, + service=service, + servicefile=servicefile, ssl=ssl, direct_tls=direct_tls, database=database, diff --git a/tests/test_connect.py b/tests/test_connect.py index 0037ee5e..ac95e314 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1116,7 +1116,8 @@ def run_testcase(self, testcase): env = testcase.get('env', {}) test_env = {'PGHOST': None, 'PGPORT': None, 'PGUSER': None, 'PGPASSWORD': None, - 'PGDATABASE': None, 'PGSSLMODE': None} + 'PGDATABASE': None, 'PGSSLMODE': None, + 'PGSERVICE': None, } test_env.update(env) dsn = testcase.get('dsn') @@ -1132,6 +1133,8 @@ def run_testcase(self, testcase): target_session_attrs = testcase.get('target_session_attrs') krbsrvname = testcase.get('krbsrvname') gsslib = testcase.get('gsslib') + service = testcase.get('service') + servicefile = testcase.get('servicefile') expected = testcase.get('result') expected_error = testcase.get('error') @@ -1157,7 +1160,8 @@ def run_testcase(self, testcase): direct_tls=direct_tls, server_settings=server_settings, target_session_attrs=target_session_attrs, - krbsrvname=krbsrvname, gsslib=gsslib) + krbsrvname=krbsrvname, gsslib=gsslib, + service=service, servicefile=servicefile) params = { k: v for k, v in params._asdict().items() @@ -1236,6 +1240,111 @@ def test_connect_params(self): for testcase in self.TESTS: self.run_testcase(testcase) + def test_connect_connection_service_file(self): + connection_service_file = tempfile.NamedTemporaryFile( + 'w+t', delete=False) + connection_service_file.write(textwrap.dedent(''' +[test_service_dbname] +port=5433 +host=somehost +dbname=test_dbname +user=admin +password=test_password +target_session_attrs=primary +krbsrvname=fakekrbsrvname +gsslib=sspi + +[test_service_database] +port=5433 +host=somehost +database=test_dbname +user=admin +password=test_password +target_session_attrs=primary +krbsrvname=fakekrbsrvname +gsslib=sspi + ''')) + connection_service_file.close() + os.chmod(connection_service_file.name, stat.S_IWUSR | stat.S_IRUSR) + try: + # Test connection service file with dbname + self.run_testcase({ + 'dsn': 'postgresql://?service=test_service_dbname', + 'env': { + 'PGSERVICEFILE': connection_service_file.name + }, + 'result': ( + [('somehost', 5433)], + { + 'user': 'admin', + 'password': 'test_password', + 'database': 'test_dbname', + 'target_session_attrs': 'primary', + 'krbsrvname': 'fakekrbsrvname', + 'gsslib': 'sspi', + } + ) + }) + # Test connection service file with database + self.run_testcase({ + 'dsn': 'postgresql://?service=test_service_database', + 'env': { + 'PGSERVICEFILE': connection_service_file.name + }, + 'result': ( + [('somehost', 5433)], + { + 'user': 'admin', + 'password': 'test_password', + 'database': 'test_dbname', + 'target_session_attrs': 'primary', + 'krbsrvname': 'fakekrbsrvname', + 'gsslib': 'sspi', + } + ) + }) + # Test that envvars are overridden by service file + self.run_testcase({ + 'dsn': 'postgresql://?service=test_service_dbname', + 'env': { + 'PGUSER': 'user', + 'PGSERVICEFILE': connection_service_file.name + }, + 'result': ( + [('somehost', 5433)], + { + 'user': 'admin', + 'password': 'test_password', + 'database': 'test_dbname', + 'target_session_attrs': 'primary', + 'krbsrvname': 'fakekrbsrvname', + 'gsslib': 'sspi', + } + ) + }) + # Test that dsn params overwrite service file + self.run_testcase({ + 'dsn': 'postgresql://?service={}&dbname={}'.format( + "test_service_dbname", "test_dbname_dsn" + ), + 'env': { + 'PGSERVICEFILE': connection_service_file.name + }, + 'result': ( + [('somehost', 5433)], + { + 'user': 'admin', + 'password': 'test_password', + 'database': 'test_dbname_dsn', + 'target_session_attrs': 'primary', + 'krbsrvname': 'fakekrbsrvname', + 'gsslib': 'sspi', + } + ) + }) + finally: + os.unlink(connection_service_file.name) + def test_connect_pgpass_regular(self): passfile = tempfile.NamedTemporaryFile('w+t', delete=False) passfile.write(textwrap.dedent(R'''