Skip to content

Commit d153cc1

Browse files
redgoldlaceelprans
authored andcommitted
Add support for the WHERE clause in copy_to methods
1 parent 313b2b2 commit d153cc1

File tree

4 files changed

+98
-14
lines changed

4 files changed

+98
-14
lines changed

Diff for: asyncpg/connection.py

+51-7
Original file line numberDiff line numberDiff line change
@@ -866,7 +866,7 @@ async def copy_to_table(self, table_name, *, source,
866866
delimiter=None, null=None, header=None,
867867
quote=None, escape=None, force_quote=None,
868868
force_not_null=None, force_null=None,
869-
encoding=None):
869+
encoding=None, where=None):
870870
"""Copy data to the specified table.
871871
872872
:param str table_name:
@@ -885,6 +885,15 @@ async def copy_to_table(self, table_name, *, source,
885885
:param str schema_name:
886886
An optional schema name to qualify the table.
887887
888+
:param str where:
889+
An optional SQL expression used to filter rows when copying.
890+
891+
.. note::
892+
893+
Usage of this parameter requires support for the
894+
``COPY FROM ... WHERE`` syntax, introduced in
895+
PostgreSQL version 12.
896+
888897
:param float timeout:
889898
Optional timeout value in seconds.
890899
@@ -912,6 +921,9 @@ async def copy_to_table(self, table_name, *, source,
912921
https://www.postgresql.org/docs/current/static/sql-copy.html
913922
914923
.. versionadded:: 0.11.0
924+
925+
.. versionadded:: 0.29.0
926+
Added the *where* parameter.
915927
"""
916928
tabname = utils._quote_ident(table_name)
917929
if schema_name:
@@ -923,21 +935,22 @@ async def copy_to_table(self, table_name, *, source,
923935
else:
924936
cols = ''
925937

938+
cond = self._format_copy_where(where)
926939
opts = self._format_copy_opts(
927940
format=format, oids=oids, freeze=freeze, delimiter=delimiter,
928941
null=null, header=header, quote=quote, escape=escape,
929942
force_not_null=force_not_null, force_null=force_null,
930943
encoding=encoding
931944
)
932945

933-
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
934-
tab=tabname, cols=cols, opts=opts)
946+
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format(
947+
tab=tabname, cols=cols, opts=opts, cond=cond)
935948

936949
return await self._copy_in(copy_stmt, source, timeout)
937950

938951
async def copy_records_to_table(self, table_name, *, records,
939952
columns=None, schema_name=None,
940-
timeout=None):
953+
timeout=None, where=None):
941954
"""Copy a list of records to the specified table using binary COPY.
942955
943956
:param str table_name:
@@ -954,6 +967,16 @@ async def copy_records_to_table(self, table_name, *, records,
954967
:param str schema_name:
955968
An optional schema name to qualify the table.
956969
970+
:param str where:
971+
An optional SQL expression used to filter rows when copying.
972+
973+
.. note::
974+
975+
Usage of this parameter requires support for the
976+
``COPY FROM ... WHERE`` syntax, introduced in
977+
PostgreSQL version 12.
978+
979+
957980
:param float timeout:
958981
Optional timeout value in seconds.
959982
@@ -998,6 +1021,9 @@ async def copy_records_to_table(self, table_name, *, records,
9981021
9991022
.. versionchanged:: 0.24.0
10001023
The ``records`` argument may be an asynchronous iterable.
1024+
1025+
.. versionadded:: 0.29.0
1026+
Added the *where* parameter.
10011027
"""
10021028
tabname = utils._quote_ident(table_name)
10031029
if schema_name:
@@ -1015,14 +1041,27 @@ async def copy_records_to_table(self, table_name, *, records,
10151041

10161042
intro_ps = await self._prepare(intro_query, use_cache=True)
10171043

1044+
cond = self._format_copy_where(where)
10181045
opts = '(FORMAT binary)'
10191046

1020-
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
1021-
tab=tabname, cols=cols, opts=opts)
1047+
copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format(
1048+
tab=tabname, cols=cols, opts=opts, cond=cond)
10221049

10231050
return await self._protocol.copy_in(
10241051
copy_stmt, None, None, records, intro_ps._state, timeout)
10251052

1053+
def _format_copy_where(self, where):
1054+
if where and not self._server_caps.sql_copy_from_where:
1055+
raise exceptions.UnsupportedServerFeatureError(
1056+
'the `where` parameter requires PostgreSQL 12 or later')
1057+
1058+
if where:
1059+
where_clause = 'WHERE ' + where
1060+
else:
1061+
where_clause = ''
1062+
1063+
return where_clause
1064+
10261065
def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
10271066
delimiter=None, null=None, header=None, quote=None,
10281067
escape=None, force_quote=None, force_not_null=None,
@@ -2404,7 +2443,7 @@ class _ConnectionProxy:
24042443
ServerCapabilities = collections.namedtuple(
24052444
'ServerCapabilities',
24062445
['advisory_locks', 'notifications', 'plpgsql', 'sql_reset',
2407-
'sql_close_all', 'jit'])
2446+
'sql_close_all', 'sql_copy_from_where', 'jit'])
24082447
ServerCapabilities.__doc__ = 'PostgreSQL server capabilities.'
24092448

24102449

@@ -2417,6 +2456,7 @@ def _detect_server_capabilities(server_version, connection_settings):
24172456
sql_reset = True
24182457
sql_close_all = False
24192458
jit = False
2459+
sql_copy_from_where = False
24202460
elif hasattr(connection_settings, 'crdb_version'):
24212461
# CockroachDB detected.
24222462
advisory_locks = False
@@ -2425,6 +2465,7 @@ def _detect_server_capabilities(server_version, connection_settings):
24252465
sql_reset = False
24262466
sql_close_all = False
24272467
jit = False
2468+
sql_copy_from_where = False
24282469
elif hasattr(connection_settings, 'crate_version'):
24292470
# CrateDB detected.
24302471
advisory_locks = False
@@ -2433,6 +2474,7 @@ def _detect_server_capabilities(server_version, connection_settings):
24332474
sql_reset = False
24342475
sql_close_all = False
24352476
jit = False
2477+
sql_copy_from_where = False
24362478
else:
24372479
# Standard PostgreSQL server assumed.
24382480
advisory_locks = True
@@ -2441,13 +2483,15 @@ def _detect_server_capabilities(server_version, connection_settings):
24412483
sql_reset = True
24422484
sql_close_all = True
24432485
jit = server_version >= (11, 0)
2486+
sql_copy_from_where = server_version.major >= 12
24442487

24452488
return ServerCapabilities(
24462489
advisory_locks=advisory_locks,
24472490
notifications=notifications,
24482491
plpgsql=plpgsql,
24492492
sql_reset=sql_reset,
24502493
sql_close_all=sql_close_all,
2494+
sql_copy_from_where=sql_copy_from_where,
24512495
jit=jit,
24522496
)
24532497

Diff for: asyncpg/exceptions/_base.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212

1313
__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
1414
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
15+
'ClientConfigurationError',
1516
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
1617
'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched',
17-
'ClientConfigurationError')
18+
'UnsupportedServerFeatureError')
1819

1920

2021
def _is_asyncpg_class(cls):
@@ -233,6 +234,10 @@ class UnsupportedClientFeatureError(InterfaceError):
233234
"""Requested feature is unsupported by asyncpg."""
234235

235236

237+
class UnsupportedServerFeatureError(InterfaceError):
238+
"""Requested feature is unsupported by PostgreSQL server."""
239+
240+
236241
class InterfaceWarning(InterfaceMessage, UserWarning):
237242
"""A warning caused by an improper use of asyncpg API."""
238243

Diff for: asyncpg/pool.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,8 @@ async def copy_to_table(
711711
force_quote=None,
712712
force_not_null=None,
713713
force_null=None,
714-
encoding=None
714+
encoding=None,
715+
where=None
715716
):
716717
"""Copy data to the specified table.
717718
@@ -740,7 +741,8 @@ async def copy_to_table(
740741
force_quote=force_quote,
741742
force_not_null=force_not_null,
742743
force_null=force_null,
743-
encoding=encoding
744+
encoding=encoding,
745+
where=where
744746
)
745747

746748
async def copy_records_to_table(
@@ -750,7 +752,8 @@ async def copy_records_to_table(
750752
records,
751753
columns=None,
752754
schema_name=None,
753-
timeout=None
755+
timeout=None,
756+
where=None
754757
):
755758
"""Copy a list of records to the specified table using binary COPY.
756759
@@ -767,7 +770,8 @@ async def copy_records_to_table(
767770
records=records,
768771
columns=columns,
769772
schema_name=schema_name,
770-
timeout=timeout
773+
timeout=timeout,
774+
where=where
771775
)
772776

773777
def acquire(self, *, timeout=None):

Diff for: tests/test_copy.py

+33-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import io
1111
import os
1212
import tempfile
13+
import unittest
1314

1415
import asyncpg
1516
from asyncpg import _testbase as tb
@@ -414,7 +415,7 @@ async def test_copy_to_table_basics(self):
414415
'*a4*|b4',
415416
'*a5*|b5',
416417
'*!**|*n-u-l-l*',
417-
'n-u-l-l|bb'
418+
'n-u-l-l|bb',
418419
]).encode('utf-8')
419420
)
420421
f.seek(0)
@@ -644,6 +645,35 @@ async def test_copy_records_to_table_1(self):
644645
finally:
645646
await self.con.execute('DROP TABLE copytab')
646647

648+
async def test_copy_records_to_table_where(self):
649+
if not self.con._server_caps.sql_copy_from_where:
650+
raise unittest.SkipTest(
651+
'COPY WHERE not supported on server')
652+
653+
await self.con.execute('''
654+
CREATE TABLE copytab_where(a text, b int, c timestamptz);
655+
''')
656+
657+
try:
658+
date = datetime.datetime.now(tz=datetime.timezone.utc)
659+
delta = datetime.timedelta(days=1)
660+
661+
records = [
662+
('a-{}'.format(i), i, date + delta)
663+
for i in range(100)
664+
]
665+
666+
records.append(('a-100', None, None))
667+
records.append(('b-999', None, None))
668+
669+
res = await self.con.copy_records_to_table(
670+
'copytab_where', records=records, where='a <> \'b-999\'')
671+
672+
self.assertEqual(res, 'COPY 101')
673+
674+
finally:
675+
await self.con.execute('DROP TABLE copytab_where')
676+
647677
async def test_copy_records_to_table_async(self):
648678
await self.con.execute('''
649679
CREATE TABLE copytab_async(a text, b int, c timestamptz);
@@ -660,7 +690,8 @@ async def record_generator():
660690
yield ('a-100', None, None)
661691

662692
res = await self.con.copy_records_to_table(
663-
'copytab_async', records=record_generator())
693+
'copytab_async', records=record_generator(),
694+
)
664695

665696
self.assertEqual(res, 'COPY 101')
666697

0 commit comments

Comments
 (0)