@@ -866,7 +866,7 @@ async def copy_to_table(self, table_name, *, source,
866
866
delimiter = None , null = None , header = None ,
867
867
quote = None , escape = None , force_quote = None ,
868
868
force_not_null = None , force_null = None ,
869
- encoding = None ):
869
+ encoding = None , where = None ):
870
870
"""Copy data to the specified table.
871
871
872
872
:param str table_name:
@@ -885,6 +885,15 @@ async def copy_to_table(self, table_name, *, source,
885
885
:param str schema_name:
886
886
An optional schema name to qualify the table.
887
887
888
+ :param str where:
889
+ An optional SQL expression used to filter rows when copying.
890
+
891
+ .. note::
892
+
893
+ Usage of this parameter requires support for the
894
+ ``COPY FROM ... WHERE`` syntax, introduced in
895
+ PostgreSQL version 12.
896
+
888
897
:param float timeout:
889
898
Optional timeout value in seconds.
890
899
@@ -912,6 +921,9 @@ async def copy_to_table(self, table_name, *, source,
912
921
https://www.postgresql.org/docs/current/static/sql-copy.html
913
922
914
923
.. versionadded:: 0.11.0
924
+
925
+ .. versionadded:: 0.29.0
926
+ Added the *where* parameter.
915
927
"""
916
928
tabname = utils ._quote_ident (table_name )
917
929
if schema_name :
@@ -923,21 +935,22 @@ async def copy_to_table(self, table_name, *, source,
923
935
else :
924
936
cols = ''
925
937
938
+ cond = self ._format_copy_where (where )
926
939
opts = self ._format_copy_opts (
927
940
format = format , oids = oids , freeze = freeze , delimiter = delimiter ,
928
941
null = null , header = header , quote = quote , escape = escape ,
929
942
force_not_null = force_not_null , force_null = force_null ,
930
943
encoding = encoding
931
944
)
932
945
933
- copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}' .format (
934
- tab = tabname , cols = cols , opts = opts )
946
+ copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond} ' .format (
947
+ tab = tabname , cols = cols , opts = opts , cond = cond )
935
948
936
949
return await self ._copy_in (copy_stmt , source , timeout )
937
950
938
951
async def copy_records_to_table (self , table_name , * , records ,
939
952
columns = None , schema_name = None ,
940
- timeout = None ):
953
+ timeout = None , where = None ):
941
954
"""Copy a list of records to the specified table using binary COPY.
942
955
943
956
:param str table_name:
@@ -954,6 +967,16 @@ async def copy_records_to_table(self, table_name, *, records,
954
967
:param str schema_name:
955
968
An optional schema name to qualify the table.
956
969
970
+ :param str where:
971
+ An optional SQL expression used to filter rows when copying.
972
+
973
+ .. note::
974
+
975
+ Usage of this parameter requires support for the
976
+ ``COPY FROM ... WHERE`` syntax, introduced in
977
+ PostgreSQL version 12.
978
+
979
+
957
980
:param float timeout:
958
981
Optional timeout value in seconds.
959
982
@@ -998,6 +1021,9 @@ async def copy_records_to_table(self, table_name, *, records,
998
1021
999
1022
.. versionchanged:: 0.24.0
1000
1023
The ``records`` argument may be an asynchronous iterable.
1024
+
1025
+ .. versionadded:: 0.29.0
1026
+ Added the *where* parameter.
1001
1027
"""
1002
1028
tabname = utils ._quote_ident (table_name )
1003
1029
if schema_name :
@@ -1015,14 +1041,27 @@ async def copy_records_to_table(self, table_name, *, records,
1015
1041
1016
1042
intro_ps = await self ._prepare (intro_query , use_cache = True )
1017
1043
1044
+ cond = self ._format_copy_where (where )
1018
1045
opts = '(FORMAT binary)'
1019
1046
1020
- copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}' .format (
1021
- tab = tabname , cols = cols , opts = opts )
1047
+ copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond} ' .format (
1048
+ tab = tabname , cols = cols , opts = opts , cond = cond )
1022
1049
1023
1050
return await self ._protocol .copy_in (
1024
1051
copy_stmt , None , None , records , intro_ps ._state , timeout )
1025
1052
1053
+ def _format_copy_where (self , where ):
1054
+ if where and not self ._server_caps .sql_copy_from_where :
1055
+ raise exceptions .UnsupportedServerFeatureError (
1056
+ 'the `where` parameter requires PostgreSQL 12 or later' )
1057
+
1058
+ if where :
1059
+ where_clause = 'WHERE ' + where
1060
+ else :
1061
+ where_clause = ''
1062
+
1063
+ return where_clause
1064
+
1026
1065
def _format_copy_opts (self , * , format = None , oids = None , freeze = None ,
1027
1066
delimiter = None , null = None , header = None , quote = None ,
1028
1067
escape = None , force_quote = None , force_not_null = None ,
@@ -2404,7 +2443,7 @@ class _ConnectionProxy:
2404
2443
ServerCapabilities = collections .namedtuple (
2405
2444
'ServerCapabilities' ,
2406
2445
['advisory_locks' , 'notifications' , 'plpgsql' , 'sql_reset' ,
2407
- 'sql_close_all' , 'jit' ])
2446
+ 'sql_close_all' , 'sql_copy_from_where' , ' jit' ])
2408
2447
ServerCapabilities .__doc__ = 'PostgreSQL server capabilities.'
2409
2448
2410
2449
@@ -2417,6 +2456,7 @@ def _detect_server_capabilities(server_version, connection_settings):
2417
2456
sql_reset = True
2418
2457
sql_close_all = False
2419
2458
jit = False
2459
+ sql_copy_from_where = False
2420
2460
elif hasattr (connection_settings , 'crdb_version' ):
2421
2461
# CockroachDB detected.
2422
2462
advisory_locks = False
@@ -2425,6 +2465,7 @@ def _detect_server_capabilities(server_version, connection_settings):
2425
2465
sql_reset = False
2426
2466
sql_close_all = False
2427
2467
jit = False
2468
+ sql_copy_from_where = False
2428
2469
elif hasattr (connection_settings , 'crate_version' ):
2429
2470
# CrateDB detected.
2430
2471
advisory_locks = False
@@ -2433,6 +2474,7 @@ def _detect_server_capabilities(server_version, connection_settings):
2433
2474
sql_reset = False
2434
2475
sql_close_all = False
2435
2476
jit = False
2477
+ sql_copy_from_where = False
2436
2478
else :
2437
2479
# Standard PostgreSQL server assumed.
2438
2480
advisory_locks = True
@@ -2441,13 +2483,15 @@ def _detect_server_capabilities(server_version, connection_settings):
2441
2483
sql_reset = True
2442
2484
sql_close_all = True
2443
2485
jit = server_version >= (11 , 0 )
2486
+ sql_copy_from_where = server_version .major >= 12
2444
2487
2445
2488
return ServerCapabilities (
2446
2489
advisory_locks = advisory_locks ,
2447
2490
notifications = notifications ,
2448
2491
plpgsql = plpgsql ,
2449
2492
sql_reset = sql_reset ,
2450
2493
sql_close_all = sql_close_all ,
2494
+ sql_copy_from_where = sql_copy_from_where ,
2451
2495
jit = jit ,
2452
2496
)
2453
2497
0 commit comments