Skip to content

Commit ec510b3

Browse files
Finish support for Postgres COPY (#1345)
* feat(postgres): WIP implement `COPY FROM/TO STDIN` Signed-off-by: Austin Bonander <[email protected]> * feat(postgres): WIP implement `COPY FROM/TO STDIN` Signed-off-by: Austin Bonander <[email protected]> * test and complete support for postgres copy Co-authored-by: Austin Bonander <[email protected]>
1 parent 687fbf9 commit ec510b3

File tree

9 files changed

+527
-6
lines changed

9 files changed

+527
-6
lines changed

sqlx-core/src/ext/async_stream.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ macro_rules! try_stream {
6262
($($block:tt)*) => {
6363
crate::ext::async_stream::TryAsyncStream::new(move |mut sender| async move {
6464
macro_rules! r#yield {
65-
($v:expr) => {
65+
($v:expr) => {{
6666
let _ = futures_util::sink::SinkExt::send(&mut sender, Ok($v)).await;
67-
}
67+
}}
6868
}
6969

7070
$($block)*

sqlx-core/src/io/buf_stream.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ pub struct BufStream<S>
1515
where
1616
S: AsyncRead + AsyncWrite + Unpin,
1717
{
18-
stream: S,
18+
pub(crate) stream: S,
1919

2020
// writes with `write` to the underlying stream are buffered
2121
// this can be flushed with `flush`

sqlx-core/src/postgres/connection/mod.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@ use crate::error::Error;
1111
use crate::executor::Executor;
1212
use crate::ext::ustr::UStr;
1313
use crate::io::Decode;
14-
use crate::postgres::connection::stream::PgStream;
1514
use crate::postgres::message::{
1615
Close, Message, MessageFormat, ReadyForQuery, Terminate, TransactionStatus,
1716
};
1817
use crate::postgres::statement::PgStatementMetadata;
1918
use crate::postgres::{PgConnectOptions, PgTypeInfo, Postgres};
2019
use crate::transaction::Transaction;
2120

21+
pub use self::stream::PgStream;
22+
2223
pub(crate) mod describe;
2324
mod establish;
2425
mod executor;
@@ -66,7 +67,7 @@ pub struct PgConnection {
6667

6768
impl PgConnection {
6869
// will return when the connection is ready for another query
69-
async fn wait_until_ready(&mut self) -> Result<(), Error> {
70+
pub(in crate::postgres) async fn wait_until_ready(&mut self) -> Result<(), Error> {
7071
if !self.stream.wbuf.is_empty() {
7172
self.stream.flush().await?;
7273
}

sqlx-core/src/postgres/copy.rs

+317
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
use crate::error::{Error, Result};
2+
use crate::ext::async_stream::TryAsyncStream;
3+
use crate::pool::{Pool, PoolConnection};
4+
use crate::postgres::connection::PgConnection;
5+
use crate::postgres::message::{
6+
CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query,
7+
};
8+
use crate::postgres::Postgres;
9+
use bytes::{BufMut, Bytes};
10+
use futures_core::stream::BoxStream;
11+
use smallvec::alloc::borrow::Cow;
12+
use sqlx_rt::{AsyncRead, AsyncReadExt, AsyncWriteExt};
13+
use std::convert::TryFrom;
14+
use std::ops::{Deref, DerefMut};
15+
16+
impl PgConnection {
17+
/// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data
18+
/// to Postgres. This is a more efficient way to import data into Postgres as compared to
19+
/// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
20+
///
21+
/// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
22+
/// returned.
23+
///
24+
/// Command examples and accepted formats for `COPY` data are shown here:
25+
/// https://www.postgresql.org/docs/current/sql-copy.html
26+
///
27+
/// ### Note
28+
/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
29+
/// will return an error the next time it is used.
30+
pub async fn copy_in_raw(&mut self, statement: &str) -> Result<PgCopyIn<&mut Self>> {
31+
PgCopyIn::begin(self, statement).await
32+
}
33+
34+
/// Issue a `COPY TO STDOUT` statement and transition the connection to streaming data
35+
/// from Postgres. This is a more efficient way to export data from Postgres but
36+
/// arrives in chunks of one of a few data formats (text/CSV/binary).
37+
///
38+
/// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
39+
/// an error is returned.
40+
///
41+
/// Note that once this process has begun, unless you read the stream to completion,
42+
/// it can only be canceled in two ways:
43+
///
44+
/// 1. by closing the connection, or:
45+
/// 2. by using another connection to kill the server process that is sending the data as shown
46+
/// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
47+
///
48+
/// If you don't read the stream to completion, the next time the connection is used it will
49+
/// need to read and discard all the remaining queued data, which could take some time.
50+
///
51+
/// Command examples and accepted formats for `COPY` data are shown here:
52+
/// https://www.postgresql.org/docs/current/sql-copy.html
53+
#[allow(clippy::needless_lifetimes)]
54+
pub async fn copy_out_raw<'c>(
55+
&'c mut self,
56+
statement: &str,
57+
) -> Result<BoxStream<'c, Result<Bytes>>> {
58+
pg_begin_copy_out(self, statement).await
59+
}
60+
}
61+
62+
impl Pool<Postgres> {
63+
/// Issue a `COPY FROM STDIN` statement and begin streaming data to Postgres.
64+
/// This is a more efficient way to import data into Postgres as compared to
65+
/// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
66+
///
67+
/// A single connection will be checked out for the duration.
68+
///
69+
/// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
70+
/// returned.
71+
///
72+
/// Command examples and accepted formats for `COPY` data are shown here:
73+
/// https://www.postgresql.org/docs/current/sql-copy.html
74+
///
75+
/// ### Note
76+
/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
77+
/// will return an error the next time it is used.
78+
pub async fn copy_in_raw(
79+
&mut self,
80+
statement: &str,
81+
) -> Result<PgCopyIn<PoolConnection<Postgres>>> {
82+
PgCopyIn::begin(self.acquire().await?, statement).await
83+
}
84+
85+
/// Issue a `COPY TO STDOUT` statement and begin streaming data
86+
/// from Postgres. This is a more efficient way to export data from Postgres but
87+
/// arrives in chunks of one of a few data formats (text/CSV/binary).
88+
///
89+
/// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
90+
/// an error is returned.
91+
///
92+
/// Note that once this process has begun, unless you read the stream to completion,
93+
/// it can only be canceled in two ways:
94+
///
95+
/// 1. by closing the connection, or:
96+
/// 2. by using another connection to kill the server process that is sending the data as shown
97+
/// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
98+
///
99+
/// If you don't read the stream to completion, the next time the connection is used it will
100+
/// need to read and discard all the remaining queued data, which could take some time.
101+
///
102+
/// Command examples and accepted formats for `COPY` data are shown here:
103+
/// https://www.postgresql.org/docs/current/sql-copy.html
104+
pub async fn copy_out_raw(
105+
&mut self,
106+
statement: &str,
107+
) -> Result<BoxStream<'static, Result<Bytes>>> {
108+
pg_begin_copy_out(self.acquire().await?, statement).await
109+
}
110+
}
111+
112+
/// A connection in streaming `COPY FROM STDIN` mode.
113+
///
114+
/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw].
115+
///
116+
/// ### Note
117+
/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
118+
/// will return an error the next time it is used.
119+
#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"]
120+
pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
121+
conn: Option<C>,
122+
response: CopyResponse,
123+
}
124+
125+
impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
126+
async fn begin(mut conn: C, statement: &str) -> Result<Self> {
127+
conn.wait_until_ready().await?;
128+
conn.stream.send(Query(statement)).await?;
129+
130+
let response: CopyResponse = conn
131+
.stream
132+
.recv_expect(MessageFormat::CopyInResponse)
133+
.await?;
134+
135+
Ok(PgCopyIn {
136+
conn: Some(conn),
137+
response,
138+
})
139+
}
140+
141+
/// Send a chunk of `COPY` data.
142+
///
143+
/// If you're copying data from an `AsyncRead`, maybe consider [Self::copy_from] instead.
144+
pub async fn send(&mut self, data: impl Deref<Target = [u8]>) -> Result<&mut Self> {
145+
self.conn
146+
.as_deref_mut()
147+
.expect("send_data: conn taken")
148+
.stream
149+
.send(CopyData(data))
150+
.await?;
151+
152+
Ok(self)
153+
}
154+
155+
/// Copy data directly from `source` to the database without requiring an intermediate buffer.
156+
///
157+
/// `source` will be read to the end.
158+
///
159+
/// ### Note
160+
/// You must still call either [Self::finish] or [Self::abort] to complete the process.
161+
pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> {
162+
// this is a separate guard from WriteAndFlush so we can reuse the buffer without zeroing
163+
struct BufGuard<'s>(&'s mut Vec<u8>);
164+
165+
impl Drop for BufGuard<'_> {
166+
fn drop(&mut self) {
167+
self.0.clear()
168+
}
169+
}
170+
171+
let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken");
172+
173+
// flush any existing messages in the buffer and clear it
174+
conn.stream.flush().await?;
175+
176+
{
177+
let buf_stream = &mut *conn.stream;
178+
let stream = &mut buf_stream.stream;
179+
180+
// ensures the buffer isn't left in an inconsistent state
181+
let mut guard = BufGuard(&mut buf_stream.wbuf);
182+
183+
let buf: &mut Vec<u8> = &mut guard.0;
184+
buf.push(b'd'); // CopyData format code
185+
buf.resize(5, 0); // reserve space for the length
186+
187+
loop {
188+
let read = match () {
189+
// Tokio lets us read into the buffer without zeroing first
190+
#[cfg(any(feature = "runtime-tokio", feature = "runtime-actix"))]
191+
_ if buf.len() != buf.capacity() => {
192+
// in case we have some data in the buffer, which can occur
193+
// if the previous write did not fill the buffer
194+
buf.truncate(5);
195+
source.read_buf(buf).await?
196+
}
197+
_ => {
198+
// should be a no-op unless len != capacity
199+
buf.resize(buf.capacity(), 0);
200+
source.read(&mut buf[5..]).await?
201+
}
202+
};
203+
204+
if read == 0 {
205+
break;
206+
}
207+
208+
let read32 = u32::try_from(read)
209+
.map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?;
210+
211+
(&mut buf[1..]).put_u32(read32 + 4);
212+
213+
stream.write_all(&buf[..read + 5]).await?;
214+
stream.flush().await?;
215+
}
216+
}
217+
218+
Ok(self)
219+
}
220+
221+
/// Signal that the `COPY` process should be aborted and any data received should be discarded.
222+
///
223+
/// The given message can be used for indicating the reason for the abort in the database logs.
224+
///
225+
/// The server is expected to respond with an error, so only _unexpected_ errors are returned.
226+
pub async fn abort(mut self, msg: impl Into<String>) -> Result<()> {
227+
let mut conn = self
228+
.conn
229+
.take()
230+
.expect("PgCopyIn::fail_with: conn taken illegally");
231+
232+
conn.stream.send(CopyFail::new(msg)).await?;
233+
234+
match conn.stream.recv().await {
235+
Ok(msg) => Err(err_protocol!(
236+
"fail_with: expected ErrorResponse, got: {:?}",
237+
msg.format
238+
)),
239+
Err(Error::Database(e)) => {
240+
match e.code() {
241+
Some(Cow::Borrowed("57014")) => {
242+
// postgres abort received error code
243+
conn.stream
244+
.recv_expect(MessageFormat::ReadyForQuery)
245+
.await?;
246+
Ok(())
247+
}
248+
_ => Err(Error::Database(e)),
249+
}
250+
}
251+
Err(e) => Err(e),
252+
}
253+
}
254+
255+
/// Signal that the `COPY` process is complete.
256+
///
257+
/// The number of rows affected is returned.
258+
pub async fn finish(mut self) -> Result<u64> {
259+
let mut conn = self
260+
.conn
261+
.take()
262+
.expect("CopyWriter::finish: conn taken illegally");
263+
264+
conn.stream.send(CopyDone).await?;
265+
let cc: CommandComplete = conn
266+
.stream
267+
.recv_expect(MessageFormat::CommandComplete)
268+
.await?;
269+
270+
conn.stream
271+
.recv_expect(MessageFormat::ReadyForQuery)
272+
.await?;
273+
274+
Ok(cc.rows_affected())
275+
}
276+
}
277+
278+
impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> {
279+
fn drop(&mut self) {
280+
if let Some(mut conn) = self.conn.take() {
281+
conn.stream.write(CopyFail::new(
282+
"PgCopyIn dropped without calling finish() or fail()",
283+
));
284+
}
285+
}
286+
}
287+
288+
async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
289+
mut conn: C,
290+
statement: &str,
291+
) -> Result<BoxStream<'c, Result<Bytes>>> {
292+
conn.wait_until_ready().await?;
293+
conn.stream.send(Query(statement)).await?;
294+
295+
let _: CopyResponse = conn
296+
.stream
297+
.recv_expect(MessageFormat::CopyOutResponse)
298+
.await?;
299+
300+
let stream: TryAsyncStream<'c, Bytes> = try_stream! {
301+
loop {
302+
let msg = conn.stream.recv().await?;
303+
match msg.format {
304+
MessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
305+
MessageFormat::CopyDone => {
306+
let _ = msg.decode::<CopyDone>()?;
307+
conn.stream.recv_expect(MessageFormat::CommandComplete).await?;
308+
conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?;
309+
return Ok(())
310+
},
311+
_ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format))
312+
}
313+
}
314+
};
315+
316+
Ok(Box::pin(stream))
317+
}

0 commit comments

Comments
 (0)