|
| 1 | +use crate::error::{Error, Result}; |
| 2 | +use crate::ext::async_stream::TryAsyncStream; |
| 3 | +use crate::pool::{Pool, PoolConnection}; |
| 4 | +use crate::postgres::connection::PgConnection; |
| 5 | +use crate::postgres::message::{ |
| 6 | + CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query, |
| 7 | +}; |
| 8 | +use crate::postgres::Postgres; |
| 9 | +use bytes::{BufMut, Bytes}; |
| 10 | +use futures_core::stream::BoxStream; |
| 11 | +use smallvec::alloc::borrow::Cow; |
| 12 | +use sqlx_rt::{AsyncRead, AsyncReadExt, AsyncWriteExt}; |
| 13 | +use std::convert::TryFrom; |
| 14 | +use std::ops::{Deref, DerefMut}; |
| 15 | + |
| 16 | +impl PgConnection { |
| 17 | + /// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data |
| 18 | + /// to Postgres. This is a more efficient way to import data into Postgres as compared to |
| 19 | + /// `INSERT` but requires one of a few specific data formats (text/CSV/binary). |
| 20 | + /// |
| 21 | + /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is |
| 22 | + /// returned. |
| 23 | + /// |
| 24 | + /// Command examples and accepted formats for `COPY` data are shown here: |
| 25 | + /// https://www.postgresql.org/docs/current/sql-copy.html |
| 26 | + /// |
| 27 | + /// ### Note |
| 28 | + /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection |
| 29 | + /// will return an error the next time it is used. |
| 30 | + pub async fn copy_in_raw(&mut self, statement: &str) -> Result<PgCopyIn<&mut Self>> { |
| 31 | + PgCopyIn::begin(self, statement).await |
| 32 | + } |
| 33 | + |
| 34 | + /// Issue a `COPY TO STDOUT` statement and transition the connection to streaming data |
| 35 | + /// from Postgres. This is a more efficient way to export data from Postgres but |
| 36 | + /// arrives in chunks of one of a few data formats (text/CSV/binary). |
| 37 | + /// |
| 38 | + /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command, |
| 39 | + /// an error is returned. |
| 40 | + /// |
| 41 | + /// Note that once this process has begun, unless you read the stream to completion, |
| 42 | + /// it can only be canceled in two ways: |
| 43 | + /// |
| 44 | + /// 1. by closing the connection, or: |
| 45 | + /// 2. by using another connection to kill the server process that is sending the data as shown |
| 46 | + /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598). |
| 47 | + /// |
| 48 | + /// If you don't read the stream to completion, the next time the connection is used it will |
| 49 | + /// need to read and discard all the remaining queued data, which could take some time. |
| 50 | + /// |
| 51 | + /// Command examples and accepted formats for `COPY` data are shown here: |
| 52 | + /// https://www.postgresql.org/docs/current/sql-copy.html |
| 53 | + #[allow(clippy::needless_lifetimes)] |
| 54 | + pub async fn copy_out_raw<'c>( |
| 55 | + &'c mut self, |
| 56 | + statement: &str, |
| 57 | + ) -> Result<BoxStream<'c, Result<Bytes>>> { |
| 58 | + pg_begin_copy_out(self, statement).await |
| 59 | + } |
| 60 | +} |
| 61 | + |
| 62 | +impl Pool<Postgres> { |
| 63 | + /// Issue a `COPY FROM STDIN` statement and begin streaming data to Postgres. |
| 64 | + /// This is a more efficient way to import data into Postgres as compared to |
| 65 | + /// `INSERT` but requires one of a few specific data formats (text/CSV/binary). |
| 66 | + /// |
| 67 | + /// A single connection will be checked out for the duration. |
| 68 | + /// |
| 69 | + /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is |
| 70 | + /// returned. |
| 71 | + /// |
| 72 | + /// Command examples and accepted formats for `COPY` data are shown here: |
| 73 | + /// https://www.postgresql.org/docs/current/sql-copy.html |
| 74 | + /// |
| 75 | + /// ### Note |
| 76 | + /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection |
| 77 | + /// will return an error the next time it is used. |
| 78 | + pub async fn copy_in_raw( |
| 79 | + &mut self, |
| 80 | + statement: &str, |
| 81 | + ) -> Result<PgCopyIn<PoolConnection<Postgres>>> { |
| 82 | + PgCopyIn::begin(self.acquire().await?, statement).await |
| 83 | + } |
| 84 | + |
| 85 | + /// Issue a `COPY TO STDOUT` statement and begin streaming data |
| 86 | + /// from Postgres. This is a more efficient way to export data from Postgres but |
| 87 | + /// arrives in chunks of one of a few data formats (text/CSV/binary). |
| 88 | + /// |
| 89 | + /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command, |
| 90 | + /// an error is returned. |
| 91 | + /// |
| 92 | + /// Note that once this process has begun, unless you read the stream to completion, |
| 93 | + /// it can only be canceled in two ways: |
| 94 | + /// |
| 95 | + /// 1. by closing the connection, or: |
| 96 | + /// 2. by using another connection to kill the server process that is sending the data as shown |
| 97 | + /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598). |
| 98 | + /// |
| 99 | + /// If you don't read the stream to completion, the next time the connection is used it will |
| 100 | + /// need to read and discard all the remaining queued data, which could take some time. |
| 101 | + /// |
| 102 | + /// Command examples and accepted formats for `COPY` data are shown here: |
| 103 | + /// https://www.postgresql.org/docs/current/sql-copy.html |
| 104 | + pub async fn copy_out_raw( |
| 105 | + &mut self, |
| 106 | + statement: &str, |
| 107 | + ) -> Result<BoxStream<'static, Result<Bytes>>> { |
| 108 | + pg_begin_copy_out(self.acquire().await?, statement).await |
| 109 | + } |
| 110 | +} |
| 111 | + |
| 112 | +/// A connection in streaming `COPY FROM STDIN` mode. |
| 113 | +/// |
| 114 | +/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw]. |
| 115 | +/// |
| 116 | +/// ### Note |
| 117 | +/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection |
| 118 | +/// will return an error the next time it is used. |
| 119 | +#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"] |
| 120 | +pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> { |
| 121 | + conn: Option<C>, |
| 122 | + response: CopyResponse, |
| 123 | +} |
| 124 | + |
| 125 | +impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> { |
| 126 | + async fn begin(mut conn: C, statement: &str) -> Result<Self> { |
| 127 | + conn.wait_until_ready().await?; |
| 128 | + conn.stream.send(Query(statement)).await?; |
| 129 | + |
| 130 | + let response: CopyResponse = conn |
| 131 | + .stream |
| 132 | + .recv_expect(MessageFormat::CopyInResponse) |
| 133 | + .await?; |
| 134 | + |
| 135 | + Ok(PgCopyIn { |
| 136 | + conn: Some(conn), |
| 137 | + response, |
| 138 | + }) |
| 139 | + } |
| 140 | + |
| 141 | + /// Send a chunk of `COPY` data. |
| 142 | + /// |
| 143 | + /// If you're copying data from an `AsyncRead`, maybe consider [Self::copy_from] instead. |
| 144 | + pub async fn send(&mut self, data: impl Deref<Target = [u8]>) -> Result<&mut Self> { |
| 145 | + self.conn |
| 146 | + .as_deref_mut() |
| 147 | + .expect("send_data: conn taken") |
| 148 | + .stream |
| 149 | + .send(CopyData(data)) |
| 150 | + .await?; |
| 151 | + |
| 152 | + Ok(self) |
| 153 | + } |
| 154 | + |
| 155 | + /// Copy data directly from `source` to the database without requiring an intermediate buffer. |
| 156 | + /// |
| 157 | + /// `source` will be read to the end. |
| 158 | + /// |
| 159 | + /// ### Note |
| 160 | + /// You must still call either [Self::finish] or [Self::abort] to complete the process. |
| 161 | + pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> { |
| 162 | + // this is a separate guard from WriteAndFlush so we can reuse the buffer without zeroing |
| 163 | + struct BufGuard<'s>(&'s mut Vec<u8>); |
| 164 | + |
| 165 | + impl Drop for BufGuard<'_> { |
| 166 | + fn drop(&mut self) { |
| 167 | + self.0.clear() |
| 168 | + } |
| 169 | + } |
| 170 | + |
| 171 | + let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken"); |
| 172 | + |
| 173 | + // flush any existing messages in the buffer and clear it |
| 174 | + conn.stream.flush().await?; |
| 175 | + |
| 176 | + { |
| 177 | + let buf_stream = &mut *conn.stream; |
| 178 | + let stream = &mut buf_stream.stream; |
| 179 | + |
| 180 | + // ensures the buffer isn't left in an inconsistent state |
| 181 | + let mut guard = BufGuard(&mut buf_stream.wbuf); |
| 182 | + |
| 183 | + let buf: &mut Vec<u8> = &mut guard.0; |
| 184 | + buf.push(b'd'); // CopyData format code |
| 185 | + buf.resize(5, 0); // reserve space for the length |
| 186 | + |
| 187 | + loop { |
| 188 | + let read = match () { |
| 189 | + // Tokio lets us read into the buffer without zeroing first |
| 190 | + #[cfg(any(feature = "runtime-tokio", feature = "runtime-actix"))] |
| 191 | + _ if buf.len() != buf.capacity() => { |
| 192 | + // in case we have some data in the buffer, which can occur |
| 193 | + // if the previous write did not fill the buffer |
| 194 | + buf.truncate(5); |
| 195 | + source.read_buf(buf).await? |
| 196 | + } |
| 197 | + _ => { |
| 198 | + // should be a no-op unless len != capacity |
| 199 | + buf.resize(buf.capacity(), 0); |
| 200 | + source.read(&mut buf[5..]).await? |
| 201 | + } |
| 202 | + }; |
| 203 | + |
| 204 | + if read == 0 { |
| 205 | + break; |
| 206 | + } |
| 207 | + |
| 208 | + let read32 = u32::try_from(read) |
| 209 | + .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?; |
| 210 | + |
| 211 | + (&mut buf[1..]).put_u32(read32 + 4); |
| 212 | + |
| 213 | + stream.write_all(&buf[..read + 5]).await?; |
| 214 | + stream.flush().await?; |
| 215 | + } |
| 216 | + } |
| 217 | + |
| 218 | + Ok(self) |
| 219 | + } |
| 220 | + |
| 221 | + /// Signal that the `COPY` process should be aborted and any data received should be discarded. |
| 222 | + /// |
| 223 | + /// The given message can be used for indicating the reason for the abort in the database logs. |
| 224 | + /// |
| 225 | + /// The server is expected to respond with an error, so only _unexpected_ errors are returned. |
| 226 | + pub async fn abort(mut self, msg: impl Into<String>) -> Result<()> { |
| 227 | + let mut conn = self |
| 228 | + .conn |
| 229 | + .take() |
| 230 | + .expect("PgCopyIn::fail_with: conn taken illegally"); |
| 231 | + |
| 232 | + conn.stream.send(CopyFail::new(msg)).await?; |
| 233 | + |
| 234 | + match conn.stream.recv().await { |
| 235 | + Ok(msg) => Err(err_protocol!( |
| 236 | + "fail_with: expected ErrorResponse, got: {:?}", |
| 237 | + msg.format |
| 238 | + )), |
| 239 | + Err(Error::Database(e)) => { |
| 240 | + match e.code() { |
| 241 | + Some(Cow::Borrowed("57014")) => { |
| 242 | + // postgres abort received error code |
| 243 | + conn.stream |
| 244 | + .recv_expect(MessageFormat::ReadyForQuery) |
| 245 | + .await?; |
| 246 | + Ok(()) |
| 247 | + } |
| 248 | + _ => Err(Error::Database(e)), |
| 249 | + } |
| 250 | + } |
| 251 | + Err(e) => Err(e), |
| 252 | + } |
| 253 | + } |
| 254 | + |
| 255 | + /// Signal that the `COPY` process is complete. |
| 256 | + /// |
| 257 | + /// The number of rows affected is returned. |
| 258 | + pub async fn finish(mut self) -> Result<u64> { |
| 259 | + let mut conn = self |
| 260 | + .conn |
| 261 | + .take() |
| 262 | + .expect("CopyWriter::finish: conn taken illegally"); |
| 263 | + |
| 264 | + conn.stream.send(CopyDone).await?; |
| 265 | + let cc: CommandComplete = conn |
| 266 | + .stream |
| 267 | + .recv_expect(MessageFormat::CommandComplete) |
| 268 | + .await?; |
| 269 | + |
| 270 | + conn.stream |
| 271 | + .recv_expect(MessageFormat::ReadyForQuery) |
| 272 | + .await?; |
| 273 | + |
| 274 | + Ok(cc.rows_affected()) |
| 275 | + } |
| 276 | +} |
| 277 | + |
| 278 | +impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> { |
| 279 | + fn drop(&mut self) { |
| 280 | + if let Some(mut conn) = self.conn.take() { |
| 281 | + conn.stream.write(CopyFail::new( |
| 282 | + "PgCopyIn dropped without calling finish() or fail()", |
| 283 | + )); |
| 284 | + } |
| 285 | + } |
| 286 | +} |
| 287 | + |
| 288 | +async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>( |
| 289 | + mut conn: C, |
| 290 | + statement: &str, |
| 291 | +) -> Result<BoxStream<'c, Result<Bytes>>> { |
| 292 | + conn.wait_until_ready().await?; |
| 293 | + conn.stream.send(Query(statement)).await?; |
| 294 | + |
| 295 | + let _: CopyResponse = conn |
| 296 | + .stream |
| 297 | + .recv_expect(MessageFormat::CopyOutResponse) |
| 298 | + .await?; |
| 299 | + |
| 300 | + let stream: TryAsyncStream<'c, Bytes> = try_stream! { |
| 301 | + loop { |
| 302 | + let msg = conn.stream.recv().await?; |
| 303 | + match msg.format { |
| 304 | + MessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0), |
| 305 | + MessageFormat::CopyDone => { |
| 306 | + let _ = msg.decode::<CopyDone>()?; |
| 307 | + conn.stream.recv_expect(MessageFormat::CommandComplete).await?; |
| 308 | + conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?; |
| 309 | + return Ok(()) |
| 310 | + }, |
| 311 | + _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format)) |
| 312 | + } |
| 313 | + } |
| 314 | + }; |
| 315 | + |
| 316 | + Ok(Box::pin(stream)) |
| 317 | +} |
0 commit comments