|
| 1 | +#![allow(missing_docs)] |
| 2 | + |
| 3 | +use std::collections::HashMap; |
| 4 | + |
| 5 | +use crate::{ |
| 6 | + bson::{Bson, Document}, |
| 7 | + error::{ClientBulkWriteError, Error, ErrorKind, Result}, |
| 8 | + operation::bulk_write::BulkWrite as BulkWriteOperation, |
| 9 | + options::{BulkWriteOptions, WriteConcern, WriteModel}, |
| 10 | + results::BulkWriteResult, |
| 11 | + Client, |
| 12 | + ClientSession, |
| 13 | +}; |
| 14 | + |
| 15 | +use super::{action_impl, option_setters}; |
| 16 | + |
| 17 | +impl Client { |
| 18 | + pub fn bulk_write(&self, models: impl IntoIterator<Item = WriteModel>) -> BulkWrite { |
| 19 | + BulkWrite::new(self, models.into_iter().collect()) |
| 20 | + } |
| 21 | +} |
| 22 | + |
| 23 | +#[must_use] |
| 24 | +pub struct BulkWrite<'a> { |
| 25 | + client: &'a Client, |
| 26 | + models: Vec<WriteModel>, |
| 27 | + options: Option<BulkWriteOptions>, |
| 28 | + session: Option<&'a mut ClientSession>, |
| 29 | +} |
| 30 | + |
| 31 | +impl<'a> BulkWrite<'a> { |
| 32 | + option_setters!(options: BulkWriteOptions; |
| 33 | + ordered: bool, |
| 34 | + bypass_document_validation: bool, |
| 35 | + comment: Bson, |
| 36 | + let_vars: Document, |
| 37 | + verbose_results: bool, |
| 38 | + write_concern: WriteConcern, |
| 39 | + ); |
| 40 | + |
| 41 | + pub fn session(mut self, session: &'a mut ClientSession) -> BulkWrite<'a> { |
| 42 | + self.session = Some(session); |
| 43 | + self |
| 44 | + } |
| 45 | + |
| 46 | + fn new(client: &'a Client, models: Vec<WriteModel>) -> Self { |
| 47 | + Self { |
| 48 | + client, |
| 49 | + models, |
| 50 | + options: None, |
| 51 | + session: None, |
| 52 | + } |
| 53 | + } |
| 54 | + |
| 55 | + fn is_ordered(&self) -> bool { |
| 56 | + self.options |
| 57 | + .as_ref() |
| 58 | + .and_then(|options| options.ordered) |
| 59 | + .unwrap_or(true) |
| 60 | + } |
| 61 | +} |
| 62 | + |
| 63 | +#[action_impl] |
| 64 | +impl<'a> Action for BulkWrite<'a> { |
| 65 | + type Future = BulkWriteFuture; |
| 66 | + |
| 67 | + async fn execute(mut self) -> Result<BulkWriteResult> { |
| 68 | + #[cfg(feature = "in-use-encryption-unstable")] |
| 69 | + if self.client.should_auto_encrypt().await { |
| 70 | + use mongocrypt::error::{Error as EncryptionError, ErrorKind as EncryptionErrorKind}; |
| 71 | + |
| 72 | + let error = EncryptionError { |
| 73 | + kind: EncryptionErrorKind::Client, |
| 74 | + code: None, |
| 75 | + message: Some( |
| 76 | + "bulkWrite does not currently support automatic encryption".to_string(), |
| 77 | + ), |
| 78 | + }; |
| 79 | + return Err(ErrorKind::Encryption(error).into()); |
| 80 | + } |
| 81 | + |
| 82 | + resolve_write_concern_with_session!( |
| 83 | + self.client, |
| 84 | + self.options, |
| 85 | + self.session.as_deref_mut() |
| 86 | + )?; |
| 87 | + |
| 88 | + let mut total_attempted = 0; |
| 89 | + let mut execution_status = ExecutionStatus::None; |
| 90 | + |
| 91 | + while total_attempted < self.models.len() |
| 92 | + && execution_status.should_continue(self.is_ordered()) |
| 93 | + { |
| 94 | + let mut operation = BulkWriteOperation::new( |
| 95 | + self.client.clone(), |
| 96 | + &self.models[total_attempted..], |
| 97 | + total_attempted, |
| 98 | + self.options.as_ref(), |
| 99 | + ) |
| 100 | + .await; |
| 101 | + let result = self |
| 102 | + .client |
| 103 | + .execute_operation::<BulkWriteOperation>( |
| 104 | + &mut operation, |
| 105 | + self.session.as_deref_mut(), |
| 106 | + ) |
| 107 | + .await; |
| 108 | + total_attempted += operation.n_attempted; |
| 109 | + |
| 110 | + match result { |
| 111 | + Ok(result) => { |
| 112 | + execution_status = execution_status.with_success(result); |
| 113 | + } |
| 114 | + Err(error) => { |
| 115 | + execution_status = execution_status.with_failure(error); |
| 116 | + } |
| 117 | + } |
| 118 | + } |
| 119 | + |
| 120 | + match execution_status { |
| 121 | + ExecutionStatus::Success(bulk_write_result) => Ok(bulk_write_result), |
| 122 | + ExecutionStatus::Error(error) => Err(error), |
| 123 | + ExecutionStatus::None => Err(ErrorKind::InvalidArgument { |
| 124 | + message: "bulk_write must be provided at least one write operation".into(), |
| 125 | + } |
| 126 | + .into()), |
| 127 | + } |
| 128 | + } |
| 129 | +} |
| 130 | + |
| 131 | +/// Represents the execution status of a bulk write. The status starts at `None`, indicating that no |
| 132 | +/// writes have been attempted yet, and transitions to either `Success` or `Error` as batches are |
| 133 | +/// executed. The contents of `Error` can be inspected to determine whether a bulk write can |
| 134 | +/// continue with further batches or should be terminated. |
| 135 | +enum ExecutionStatus { |
| 136 | + Success(BulkWriteResult), |
| 137 | + Error(Error), |
| 138 | + None, |
| 139 | +} |
| 140 | + |
| 141 | +impl ExecutionStatus { |
| 142 | + fn with_success(mut self, result: BulkWriteResult) -> Self { |
| 143 | + match self { |
| 144 | + // Merge two successful sets of results together. |
| 145 | + Self::Success(ref mut current_result) => { |
| 146 | + current_result.merge(result); |
| 147 | + self |
| 148 | + } |
| 149 | + // Merge the results of the new batch into the existing bulk write error. |
| 150 | + Self::Error(ref mut current_error) => { |
| 151 | + let bulk_write_error = Self::get_current_bulk_write_error(current_error); |
| 152 | + bulk_write_error.merge_partial_results(result); |
| 153 | + self |
| 154 | + } |
| 155 | + Self::None => Self::Success(result), |
| 156 | + } |
| 157 | + } |
| 158 | + |
| 159 | + fn with_failure(self, mut error: Error) -> Self { |
| 160 | + match self { |
| 161 | + // If the new error is a BulkWriteError, merge the successful results into the error's |
| 162 | + // partial result. Otherwise, create a new BulkWriteError with the existing results and |
| 163 | + // set its source as the error that just occurred. |
| 164 | + Self::Success(current_result) => match *error.kind { |
| 165 | + ErrorKind::ClientBulkWrite(ref mut bulk_write_error) => { |
| 166 | + bulk_write_error.merge_partial_results(current_result); |
| 167 | + Self::Error(error) |
| 168 | + } |
| 169 | + _ => { |
| 170 | + let bulk_write_error: Error = |
| 171 | + ErrorKind::ClientBulkWrite(ClientBulkWriteError { |
| 172 | + write_errors: HashMap::new(), |
| 173 | + write_concern_errors: Vec::new(), |
| 174 | + partial_result: Some(current_result), |
| 175 | + }) |
| 176 | + .into(); |
| 177 | + Self::Error(bulk_write_error.with_source(error)) |
| 178 | + } |
| 179 | + }, |
| 180 | + // If the new error is a BulkWriteError, merge its contents with the existing error. |
| 181 | + // Otherwise, set the new error as the existing error's source. |
| 182 | + Self::Error(mut current_error) => match *error.kind { |
| 183 | + ErrorKind::ClientBulkWrite(bulk_write_error) => { |
| 184 | + let current_bulk_write_error = |
| 185 | + Self::get_current_bulk_write_error(&mut current_error); |
| 186 | + current_bulk_write_error.merge(bulk_write_error); |
| 187 | + Self::Error(current_error) |
| 188 | + } |
| 189 | + _ => Self::Error(current_error.with_source(error)), |
| 190 | + }, |
| 191 | + Self::None => Self::Error(error), |
| 192 | + } |
| 193 | + } |
| 194 | + |
| 195 | + /// Gets a BulkWriteError from a given Error. This method should only be called when adding a |
| 196 | + /// new result or error to the existing state, as it requires that the given Error's kind is |
| 197 | + /// ClientBulkWrite. |
| 198 | + fn get_current_bulk_write_error(error: &mut Error) -> &mut ClientBulkWriteError { |
| 199 | + match *error.kind { |
| 200 | + ErrorKind::ClientBulkWrite(ref mut bulk_write_error) => bulk_write_error, |
| 201 | + _ => unreachable!(), |
| 202 | + } |
| 203 | + } |
| 204 | + |
| 205 | + /// Whether further bulk write batches should be executed based on the current status of |
| 206 | + /// execution. |
| 207 | + fn should_continue(&self, ordered: bool) -> bool { |
| 208 | + match self { |
| 209 | + Self::Error(ref error) => { |
| 210 | + match *error.kind { |
| 211 | + ErrorKind::ClientBulkWrite(ref bulk_write_error) => { |
| 212 | + // A top-level error is always fatal. |
| 213 | + let top_level_error_occurred = error.source.is_some(); |
| 214 | + // A write error occurring during an ordered bulk write is fatal. |
| 215 | + let terminal_write_error_occurred = |
| 216 | + ordered && !bulk_write_error.write_errors.is_empty(); |
| 217 | + |
| 218 | + !top_level_error_occurred && !terminal_write_error_occurred |
| 219 | + } |
| 220 | + // A top-level error is always fatal. |
| 221 | + _ => false, |
| 222 | + } |
| 223 | + } |
| 224 | + _ => true, |
| 225 | + } |
| 226 | + } |
| 227 | +} |
0 commit comments