Skip to content

Commit f5d7c4c

Browse files
RUST-1713 Bulk Write (#1034)
1 parent 31a0750 commit f5d7c4c

File tree

120 files changed

+14245
-2127
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

120 files changed

+14245
-2127
lines changed

src/action.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Action builder types.
22
33
mod aggregate;
4+
mod bulk_write;
45
mod client_options;
56
mod count;
67
mod create_collection;
@@ -31,8 +32,10 @@ mod watch;
3132

3233
use std::{future::IntoFuture, marker::PhantomData, ops::Deref};
3334

35+
use crate::bson::Document;
36+
3437
pub use aggregate::Aggregate;
35-
use bson::Document;
38+
pub use bulk_write::BulkWrite;
3639
pub use client_options::ParseConnectionString;
3740
pub use count::{CountDocuments, EstimatedDocumentCount};
3841
pub use create_collection::CreateCollection;

src/action/bulk_write.rs

+227
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
#![allow(missing_docs)]
2+
3+
use std::collections::HashMap;
4+
5+
use crate::{
6+
bson::{Bson, Document},
7+
error::{ClientBulkWriteError, Error, ErrorKind, Result},
8+
operation::bulk_write::BulkWrite as BulkWriteOperation,
9+
options::{BulkWriteOptions, WriteConcern, WriteModel},
10+
results::BulkWriteResult,
11+
Client,
12+
ClientSession,
13+
};
14+
15+
use super::{action_impl, option_setters};
16+
17+
impl Client {
18+
pub fn bulk_write(&self, models: impl IntoIterator<Item = WriteModel>) -> BulkWrite {
19+
BulkWrite::new(self, models.into_iter().collect())
20+
}
21+
}
22+
23+
#[must_use]
24+
pub struct BulkWrite<'a> {
25+
client: &'a Client,
26+
models: Vec<WriteModel>,
27+
options: Option<BulkWriteOptions>,
28+
session: Option<&'a mut ClientSession>,
29+
}
30+
31+
impl<'a> BulkWrite<'a> {
32+
option_setters!(options: BulkWriteOptions;
33+
ordered: bool,
34+
bypass_document_validation: bool,
35+
comment: Bson,
36+
let_vars: Document,
37+
verbose_results: bool,
38+
write_concern: WriteConcern,
39+
);
40+
41+
pub fn session(mut self, session: &'a mut ClientSession) -> BulkWrite<'a> {
42+
self.session = Some(session);
43+
self
44+
}
45+
46+
fn new(client: &'a Client, models: Vec<WriteModel>) -> Self {
47+
Self {
48+
client,
49+
models,
50+
options: None,
51+
session: None,
52+
}
53+
}
54+
55+
fn is_ordered(&self) -> bool {
56+
self.options
57+
.as_ref()
58+
.and_then(|options| options.ordered)
59+
.unwrap_or(true)
60+
}
61+
}
62+
63+
#[action_impl]
64+
impl<'a> Action for BulkWrite<'a> {
65+
type Future = BulkWriteFuture;
66+
67+
async fn execute(mut self) -> Result<BulkWriteResult> {
68+
#[cfg(feature = "in-use-encryption-unstable")]
69+
if self.client.should_auto_encrypt().await {
70+
use mongocrypt::error::{Error as EncryptionError, ErrorKind as EncryptionErrorKind};
71+
72+
let error = EncryptionError {
73+
kind: EncryptionErrorKind::Client,
74+
code: None,
75+
message: Some(
76+
"bulkWrite does not currently support automatic encryption".to_string(),
77+
),
78+
};
79+
return Err(ErrorKind::Encryption(error).into());
80+
}
81+
82+
resolve_write_concern_with_session!(
83+
self.client,
84+
self.options,
85+
self.session.as_deref_mut()
86+
)?;
87+
88+
let mut total_attempted = 0;
89+
let mut execution_status = ExecutionStatus::None;
90+
91+
while total_attempted < self.models.len()
92+
&& execution_status.should_continue(self.is_ordered())
93+
{
94+
let mut operation = BulkWriteOperation::new(
95+
self.client.clone(),
96+
&self.models[total_attempted..],
97+
total_attempted,
98+
self.options.as_ref(),
99+
)
100+
.await;
101+
let result = self
102+
.client
103+
.execute_operation::<BulkWriteOperation>(
104+
&mut operation,
105+
self.session.as_deref_mut(),
106+
)
107+
.await;
108+
total_attempted += operation.n_attempted;
109+
110+
match result {
111+
Ok(result) => {
112+
execution_status = execution_status.with_success(result);
113+
}
114+
Err(error) => {
115+
execution_status = execution_status.with_failure(error);
116+
}
117+
}
118+
}
119+
120+
match execution_status {
121+
ExecutionStatus::Success(bulk_write_result) => Ok(bulk_write_result),
122+
ExecutionStatus::Error(error) => Err(error),
123+
ExecutionStatus::None => Err(ErrorKind::InvalidArgument {
124+
message: "bulk_write must be provided at least one write operation".into(),
125+
}
126+
.into()),
127+
}
128+
}
129+
}
130+
131+
/// Represents the execution status of a bulk write. The status starts at `None`, indicating that no
132+
/// writes have been attempted yet, and transitions to either `Success` or `Error` as batches are
133+
/// executed. The contents of `Error` can be inspected to determine whether a bulk write can
134+
/// continue with further batches or should be terminated.
135+
enum ExecutionStatus {
136+
Success(BulkWriteResult),
137+
Error(Error),
138+
None,
139+
}
140+
141+
impl ExecutionStatus {
142+
fn with_success(mut self, result: BulkWriteResult) -> Self {
143+
match self {
144+
// Merge two successful sets of results together.
145+
Self::Success(ref mut current_result) => {
146+
current_result.merge(result);
147+
self
148+
}
149+
// Merge the results of the new batch into the existing bulk write error.
150+
Self::Error(ref mut current_error) => {
151+
let bulk_write_error = Self::get_current_bulk_write_error(current_error);
152+
bulk_write_error.merge_partial_results(result);
153+
self
154+
}
155+
Self::None => Self::Success(result),
156+
}
157+
}
158+
159+
fn with_failure(self, mut error: Error) -> Self {
160+
match self {
161+
// If the new error is a BulkWriteError, merge the successful results into the error's
162+
// partial result. Otherwise, create a new BulkWriteError with the existing results and
163+
// set its source as the error that just occurred.
164+
Self::Success(current_result) => match *error.kind {
165+
ErrorKind::ClientBulkWrite(ref mut bulk_write_error) => {
166+
bulk_write_error.merge_partial_results(current_result);
167+
Self::Error(error)
168+
}
169+
_ => {
170+
let bulk_write_error: Error =
171+
ErrorKind::ClientBulkWrite(ClientBulkWriteError {
172+
write_errors: HashMap::new(),
173+
write_concern_errors: Vec::new(),
174+
partial_result: Some(current_result),
175+
})
176+
.into();
177+
Self::Error(bulk_write_error.with_source(error))
178+
}
179+
},
180+
// If the new error is a BulkWriteError, merge its contents with the existing error.
181+
// Otherwise, set the new error as the existing error's source.
182+
Self::Error(mut current_error) => match *error.kind {
183+
ErrorKind::ClientBulkWrite(bulk_write_error) => {
184+
let current_bulk_write_error =
185+
Self::get_current_bulk_write_error(&mut current_error);
186+
current_bulk_write_error.merge(bulk_write_error);
187+
Self::Error(current_error)
188+
}
189+
_ => Self::Error(current_error.with_source(error)),
190+
},
191+
Self::None => Self::Error(error),
192+
}
193+
}
194+
195+
/// Gets a BulkWriteError from a given Error. This method should only be called when adding a
196+
/// new result or error to the existing state, as it requires that the given Error's kind is
197+
/// ClientBulkWrite.
198+
fn get_current_bulk_write_error(error: &mut Error) -> &mut ClientBulkWriteError {
199+
match *error.kind {
200+
ErrorKind::ClientBulkWrite(ref mut bulk_write_error) => bulk_write_error,
201+
_ => unreachable!(),
202+
}
203+
}
204+
205+
/// Whether further bulk write batches should be executed based on the current status of
206+
/// execution.
207+
fn should_continue(&self, ordered: bool) -> bool {
208+
match self {
209+
Self::Error(ref error) => {
210+
match *error.kind {
211+
ErrorKind::ClientBulkWrite(ref bulk_write_error) => {
212+
// A top-level error is always fatal.
213+
let top_level_error_occurred = error.source.is_some();
214+
// A write error occurring during an ordered bulk write is fatal.
215+
let terminal_write_error_occurred =
216+
ordered && !bulk_write_error.write_errors.is_empty();
217+
218+
!top_level_error_occurred && !terminal_write_error_occurred
219+
}
220+
// A top-level error is always fatal.
221+
_ => false,
222+
}
223+
}
224+
_ => true,
225+
}
226+
}
227+
}

src/action/insert_many.rs

+1-4
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,7 @@ impl<'a> Action for InsertMany<'a> {
103103
.as_ref()
104104
.and_then(|o| o.ordered)
105105
.unwrap_or(true);
106-
#[cfg(feature = "in-use-encryption-unstable")]
107-
let encrypted = self.coll.client().auto_encryption_opts().await.is_some();
108-
#[cfg(not(feature = "in-use-encryption-unstable"))]
109-
let encrypted = false;
106+
let encrypted = self.coll.client().should_auto_encrypt().await;
110107

111108
let mut cumulative_failure: Option<BulkWriteFailure> = None;
112109
let mut error_labels: HashSet<String> = Default::default();

src/action/insert_one.rs

+1-6
Original file line numberDiff line numberDiff line change
@@ -87,18 +87,13 @@ impl<'a> Action for InsertOne<'a> {
8787
async fn execute(mut self) -> Result<InsertOneResult> {
8888
resolve_write_concern_with_session!(self.coll, self.options, self.session.as_ref())?;
8989

90-
#[cfg(feature = "in-use-encryption-unstable")]
91-
let encrypted = self.coll.client().auto_encryption_opts().await.is_some();
92-
#[cfg(not(feature = "in-use-encryption-unstable"))]
93-
let encrypted = false;
94-
9590
let doc = self.doc?;
9691

9792
let insert = Op::new(
9893
self.coll.namespace(),
9994
vec![doc.deref()],
10095
self.options.map(InsertManyOptions::from_insert_one_options),
101-
encrypted,
96+
self.coll.client().should_auto_encrypt().await,
10297
);
10398
self.coll
10499
.client()

0 commit comments

Comments
 (0)