10
10
import io
11
11
import os
12
12
import tempfile
13
+ import unittest
13
14
14
15
import asyncpg
15
16
from asyncpg import _testbase as tb
@@ -415,7 +416,6 @@ async def test_copy_to_table_basics(self):
415
416
'*a5*|b5' ,
416
417
'*!**|*n-u-l-l*' ,
417
418
'n-u-l-l|bb' ,
418
- '_-_filtered_-_value_-_|never-here'
419
419
]).encode ('utf-8' )
420
420
)
421
421
f .seek (0 )
@@ -432,7 +432,7 @@ async def test_copy_to_table_basics(self):
432
432
schema_name = 'public' , format = 'csv' ,
433
433
delimiter = '|' , null = 'n-u-l-l' , header = True ,
434
434
quote = '*' , escape = '!' , force_not_null = ('a' ,),
435
- force_null = force_null , where = 'a <> \' _-_filtered_-_value_-_ \' ' )
435
+ force_null = force_null )
436
436
437
437
self .assertEqual (res , 'COPY 7' )
438
438
@@ -636,16 +636,44 @@ async def test_copy_records_to_table_1(self):
636
636
]
637
637
638
638
records .append (('a-100' , None , None ))
639
- records .append (('b-999' , None , None ))
640
639
641
640
res = await self .con .copy_records_to_table (
642
- 'copytab' , records = records , where = 'a <> \' b-999 \' ' )
641
+ 'copytab' , records = records )
643
642
644
643
self .assertEqual (res , 'COPY 101' )
645
644
646
645
finally :
647
646
await self .con .execute ('DROP TABLE copytab' )
648
647
648
+ async def test_copy_records_to_table_where (self ):
649
+ if not self .con ._server_caps .sql_copy_from_where :
650
+ raise unittest .SkipTest (
651
+ 'COPY WHERE not supported on server' )
652
+
653
+ await self .con .execute ('''
654
+ CREATE TABLE copytab_where(a text, b int, c timestamptz);
655
+ ''' )
656
+
657
+ try :
658
+ date = datetime .datetime .now (tz = datetime .timezone .utc )
659
+ delta = datetime .timedelta (days = 1 )
660
+
661
+ records = [
662
+ ('a-{}' .format (i ), i , date + delta )
663
+ for i in range (100 )
664
+ ]
665
+
666
+ records .append (('a-100' , None , None ))
667
+ records .append (('b-999' , None , None ))
668
+
669
+ res = await self .con .copy_records_to_table (
670
+ 'copytab_where' , records = records , where = 'a <> \' b-999\' ' )
671
+
672
+ self .assertEqual (res , 'COPY 101' )
673
+
674
+ finally :
675
+ await self .con .execute ('DROP TABLE copytab_where' )
676
+
649
677
async def test_copy_records_to_table_async (self ):
650
678
await self .con .execute ('''
651
679
CREATE TABLE copytab_async(a text, b int, c timestamptz);
@@ -660,11 +688,9 @@ async def record_generator():
660
688
yield ('a-{}' .format (i ), i , date + delta )
661
689
662
690
yield ('a-100' , None , None )
663
- yield ('b-999' , None , None )
664
691
665
692
res = await self .con .copy_records_to_table (
666
693
'copytab_async' , records = record_generator (),
667
- where = 'a <> \' b-999\' '
668
694
)
669
695
670
696
self .assertEqual (res , 'COPY 101' )
0 commit comments