Skip to content

Commit 35a4f74

Browse files
author
Montana Low
committed
test and complete support for postgres copy
1 parent 0346f16 commit 35a4f74

File tree

6 files changed

+121
-7
lines changed

6 files changed

+121
-7
lines changed

sqlx-core/src/postgres/connection/mod.rs

-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ use crate::transaction::Transaction;
2020

2121
pub use self::stream::PgStream;
2222

23-
pub use self::stream::PgStream;
24-
2523
pub(crate) mod describe;
2624
mod establish;
2725
mod executor;

sqlx-core/src/postgres/copy.rs

+20-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::postgres::message::{
88
use crate::postgres::Postgres;
99
use bytes::{BufMut, Bytes};
1010
use futures_core::stream::BoxStream;
11+
use smallvec::alloc::borrow::Cow;
1112
use sqlx_rt::{AsyncRead, AsyncReadExt, AsyncWriteExt};
1213
use std::convert::TryFrom;
1314
use std::ops::{Deref, DerefMut};
@@ -235,8 +236,18 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
235236
"fail_with: expected ErrorResponse, got: {:?}",
236237
msg.format
237238
)),
238-
// FIXME: inspect the `DatabaseError` to make sure we're not discarding another error
239-
Err(Error::Database(_db)) => Ok(()),
239+
Err(Error::Database(e)) => {
240+
match e.code() {
241+
Some(Cow::Borrowed("57014")) => {
242+
// postgres abort received error code
243+
conn.stream
244+
.recv_expect(MessageFormat::ReadyForQuery)
245+
.await?;
246+
Ok(())
247+
}
248+
_ => Err(Error::Database(e)),
249+
}
250+
}
240251
Err(e) => Err(e),
241252
}
242253
}
@@ -256,6 +267,10 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
256267
.recv_expect(MessageFormat::CommandComplete)
257268
.await?;
258269

270+
conn.stream
271+
.recv_expect(MessageFormat::ReadyForQuery)
272+
.await?;
273+
259274
Ok(cc.rows_affected())
260275
}
261276
}
@@ -279,7 +294,7 @@ async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
279294

280295
let _: CopyResponse = conn
281296
.stream
282-
.recv_expect(MessageFormat::CopyInResponse)
297+
.recv_expect(MessageFormat::CopyOutResponse)
283298
.await?;
284299

285300
let stream: TryAsyncStream<'c, Bytes> = try_stream! {
@@ -289,6 +304,8 @@ async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
289304
MessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
290305
MessageFormat::CopyDone => {
291306
let _ = msg.decode::<CopyDone>()?;
307+
conn.stream.recv_expect(MessageFormat::CommandComplete).await?;
308+
conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?;
292309
return Ok(())
293310
},
294311
_ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format))

sqlx-core/src/postgres/message/copy.rs

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use crate::io::{BufExt, BufMutExt, Decode, Encode};
33
use bytes::{Buf, BufMut, Bytes};
44
use std::ops::Deref;
55

6-
76
/// The same structure is sent for both `CopyInResponse` and `CopyOutResponse`
87
pub struct CopyResponse {
98
pub format: i8,

sqlx-core/src/postgres/message/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ impl MessageFormat {
106106
b'C' => MessageFormat::CommandComplete,
107107
b'd' => MessageFormat::CopyData,
108108
b'c' => MessageFormat::CopyDone,
109+
b'G' => MessageFormat::CopyInResponse,
110+
b'H' => MessageFormat::CopyOutResponse,
109111
b'D' => MessageFormat::DataRow,
110112
b'E' => MessageFormat::ErrorResponse,
111113
b'I' => MessageFormat::EmptyQueryResponse,

tests/postgres/postgres.rs

+97-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use futures::TryStreamExt;
1+
use futures::{StreamExt, TryStreamExt};
22
use sqlx::postgres::{
33
PgConnectOptions, PgConnection, PgDatabaseError, PgErrorPosition, PgSeverity,
44
};
@@ -1104,3 +1104,99 @@ async fn test_pg_server_num() -> anyhow::Result<()> {
11041104

11051105
Ok(())
11061106
}
1107+
1108+
#[sqlx_macros::test]
1109+
async fn it_can_copy_in() -> anyhow::Result<()> {
1110+
let mut conn = new::<Postgres>().await?;
1111+
conn.execute(
1112+
r#"
1113+
CREATE TEMPORARY TABLE users (id INTEGER NOT NULL);
1114+
"#,
1115+
)
1116+
.await?;
1117+
1118+
let mut copy = conn
1119+
.copy_in_raw(
1120+
r#"
1121+
COPY users (id) FROM STDIN WITH (FORMAT CSV, HEADER);
1122+
"#,
1123+
)
1124+
.await?;
1125+
1126+
copy.send("id\n1\n2\n".as_bytes()).await?;
1127+
let rows = copy.finish().await?;
1128+
assert_eq!(rows, 2);
1129+
1130+
// conn is safe for reuse
1131+
let value = sqlx::query("select 1 + 1")
1132+
.try_map(|row: PgRow| row.try_get::<i32, _>(0))
1133+
.fetch_one(&mut conn)
1134+
.await?;
1135+
1136+
assert_eq!(2i32, value);
1137+
1138+
Ok(())
1139+
}
1140+
1141+
#[sqlx_macros::test]
1142+
async fn it_can_abort_copy_in() -> anyhow::Result<()> {
1143+
let mut conn = new::<Postgres>().await?;
1144+
conn.execute(
1145+
r#"
1146+
CREATE TEMPORARY TABLE users (id INTEGER NOT NULL);
1147+
"#,
1148+
)
1149+
.await?;
1150+
1151+
let mut copy = conn
1152+
.copy_in_raw(
1153+
r#"
1154+
COPY users (id) FROM STDIN WITH (FORMAT CSV, HEADER);
1155+
"#,
1156+
)
1157+
.await?;
1158+
1159+
copy.abort("this is only a test").await?;
1160+
1161+
// conn is safe for reuse
1162+
let value = sqlx::query("select 1 + 1")
1163+
.try_map(|row: PgRow| row.try_get::<i32, _>(0))
1164+
.fetch_one(&mut conn)
1165+
.await?;
1166+
1167+
assert_eq!(2i32, value);
1168+
1169+
Ok(())
1170+
}
1171+
1172+
#[sqlx_macros::test]
1173+
async fn it_can_copy_out() -> anyhow::Result<()> {
1174+
let mut conn = new::<Postgres>().await?;
1175+
1176+
{
1177+
let mut copy = conn
1178+
.copy_out_raw(
1179+
"
1180+
COPY (SELECT generate_series(1, 2) AS id) TO STDOUT WITH (FORMAT CSV, HEADER);
1181+
",
1182+
)
1183+
.await?;
1184+
1185+
assert_eq!(copy.next().await.unwrap().unwrap(), "id\n");
1186+
assert_eq!(copy.next().await.unwrap().unwrap(), "1\n");
1187+
assert_eq!(copy.next().await.unwrap().unwrap(), "2\n");
1188+
if copy.next().await.is_some() {
1189+
anyhow::bail!("Unexpected data from COPY");
1190+
}
1191+
}
1192+
1193+
// conn is safe for reuse
1194+
let value = sqlx::query("select 1 + 1")
1195+
.try_map(|row: PgRow| row.try_get::<i32, _>(0))
1196+
.fetch_one(&mut conn)
1197+
.await?;
1198+
1199+
assert_eq!(2i32, value);
1200+
1201+
Ok(())
1202+
}

tests/sqlite/.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
sqlite.db
2+

0 commit comments

Comments
 (0)