Skip to content

Commit e3df089

Browse files
RUST-1222 Cancel in-progress operations when SDAM heartbeats time out (#1249)
1 parent 450c8a3 commit e3df089

File tree

57 files changed

+1034
-503
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1034
-503
lines changed

src/client/auth/aws.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ async fn authenticate_stream_inner(
8383
);
8484
let client_first = sasl_start.into_command();
8585

86-
let server_first_response = conn.send_command(client_first, None).await?;
86+
let server_first_response = conn.send_message(client_first).await?;
8787

8888
let server_first = ServerFirst::parse(server_first_response.auth_response_body(MECH_NAME)?)?;
8989
server_first.validate(&nonce)?;
@@ -135,7 +135,7 @@ async fn authenticate_stream_inner(
135135

136136
let client_second = sasl_continue.into_command();
137137

138-
let server_second_response = conn.send_command(client_second, None).await?;
138+
let server_second_response = conn.send_message(client_second).await?;
139139
let server_second = SaslResponse::parse(
140140
MECH_NAME,
141141
server_second_response.auth_response_body(MECH_NAME)?,

src/client/auth/oidc.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ async fn send_sasl_command(
880880
conn: &mut Connection,
881881
command: crate::cmap::Command,
882882
) -> Result<SaslResponse> {
883-
let response = conn.send_command(command, None).await?;
883+
let response = conn.send_message(command).await?;
884884
SaslResponse::parse(
885885
MONGODB_OIDC_STR,
886886
response.auth_response_body(MONGODB_OIDC_STR)?,

src/client/auth/plain.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ pub(crate) async fn authenticate_stream(
3535
)
3636
.into_command();
3737

38-
let response = conn.send_command(sasl_start, None).await?;
38+
let response = conn.send_message(sasl_start).await?;
3939
let sasl_response = SaslResponse::parse("PLAIN", response.auth_response_body("PLAIN")?)?;
4040

4141
if !sasl_response.done {

src/client/auth/scram.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ impl ScramVersion {
151151

152152
let command = client_first.to_command(self);
153153

154-
let server_first = conn.send_command(command, None).await?;
154+
let server_first = conn.send_message(command).await?;
155155

156156
Ok(FirstRound {
157157
client_first,
@@ -215,7 +215,7 @@ impl ScramVersion {
215215

216216
let command = client_final.to_command();
217217

218-
let server_final_response = conn.send_command(command, None).await?;
218+
let server_final_response = conn.send_message(command).await?;
219219
let server_final = ServerFinal::parse(server_final_response.auth_response_body("SCRAM")?)?;
220220
server_final.validate(salted_password.as_slice(), &client_final, self)?;
221221

@@ -231,7 +231,7 @@ impl ScramVersion {
231231
);
232232
let command = noop.into_command();
233233

234-
let server_noop_response = conn.send_command(command, None).await?;
234+
let server_noop_response = conn.send_message(command).await?;
235235
let server_noop_response_document: Document =
236236
server_noop_response.auth_response_body("SCRAM")?;
237237

src/client/auth/x509.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ pub(crate) async fn send_client_first(
4343
) -> Result<RawCommandResponse> {
4444
let command = build_client_first(credential, server_api);
4545

46-
conn.send_command(command, None).await
46+
conn.send_message(command).await
4747
}
4848

4949
/// Performs X.509 authentication for a given stream.

src/client/executor.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -614,13 +614,12 @@ impl Client {
614614
}
615615

616616
let should_redact = cmd.should_redact();
617-
let should_compress = cmd.should_compress();
618617

619618
let cmd_name = cmd.name.clone();
620619
let target_db = cmd.target_db.clone();
621620

622-
#[allow(unused_mut)]
623-
let mut message = Message::from_command(cmd, Some(request_id))?;
621+
let mut message = Message::try_from(cmd)?;
622+
message.request_id = Some(request_id);
624623
#[cfg(feature = "in-use-encryption")]
625624
{
626625
let guard = self.inner.csfle.read().await;
@@ -652,7 +651,7 @@ impl Client {
652651
.await;
653652

654653
let start_time = Instant::now();
655-
let command_result = match connection.send_message(message, should_compress).await {
654+
let command_result = match connection.send_message(message).await {
656655
Ok(response) => {
657656
async fn handle_response<T: Operation>(
658657
client: &Client,

src/cmap/conn.rs

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@ use derive_where::derive_where;
99
use serde::Serialize;
1010
use tokio::{
1111
io::BufStream,
12-
sync::{mpsc, Mutex},
12+
sync::{
13+
broadcast::{self, error::RecvError},
14+
mpsc,
15+
Mutex,
16+
},
1317
};
1418

1519
use self::wire::{Message, MessageFlags};
@@ -171,12 +175,44 @@ impl Connection {
171175
self.error.is_some()
172176
}
173177

178+
pub(crate) async fn send_message_with_cancellation(
179+
&mut self,
180+
message: impl TryInto<Message, Error = impl Into<Error>>,
181+
cancellation_receiver: &mut broadcast::Receiver<()>,
182+
) -> Result<RawCommandResponse> {
183+
tokio::select! {
184+
biased;
185+
186+
// A lagged error indicates that more heartbeats failed than the channel's capacity
187+
// between checking out this connection and executing the operation. If this occurs,
188+
// then proceed with cancelling the operation. RecvError::Closed can be ignored, as
189+
// the sender (and by extension the connection pool) dropping does not indicate that
190+
// the operation should be cancelled.
191+
Ok(_) | Err(RecvError::Lagged(_)) = cancellation_receiver.recv() => {
192+
let error: Error = ErrorKind::ConnectionPoolCleared {
193+
message: format!(
194+
"Connection to {} interrupted due to server monitor timeout",
195+
self.address,
196+
)
197+
}.into();
198+
self.error = Some(error.clone());
199+
Err(error)
200+
}
201+
// This future is not cancellation safe because it contains calls to methods that are
202+
// not cancellation safe (e.g. AsyncReadExt::read_exact). However, in the case that
203+
// this future is cancelled because a cancellation message was received, this
204+
// connection will be closed upon being returned to the pool, so any data loss on its
205+
// underlying stream is not an issue.
206+
result = self.send_message(message) => result,
207+
}
208+
}
209+
174210
pub(crate) async fn send_message(
175211
&mut self,
176-
message: Message,
177-
// This value is only read if a compression feature flag is enabled.
178-
#[allow(unused_variables)] can_compress: bool,
212+
message: impl TryInto<Message, Error = impl Into<Error>>,
179213
) -> Result<RawCommandResponse> {
214+
let message = message.try_into().map_err(Into::into)?;
215+
180216
if self.more_to_come {
181217
return Err(Error::internal(format!(
182218
"attempted to send a new message to {} but moreToCome bit was set",
@@ -192,7 +228,7 @@ impl Connection {
192228
feature = "snappy-compression"
193229
))]
194230
let write_result = match self.compressor {
195-
Some(ref compressor) if can_compress => {
231+
Some(ref compressor) if message.should_compress => {
196232
message
197233
.write_op_compressed_to(&mut self.stream, compressor)
198234
.await
@@ -232,21 +268,6 @@ impl Connection {
232268
))
233269
}
234270

235-
/// Executes a `Command` and returns a `CommandResponse` containing the result from the server.
236-
///
237-
/// An `Ok(...)` result simply means the server received the command and that the driver
238-
/// driver received the response; it does not imply anything about the success of the command
239-
/// itself.
240-
pub(crate) async fn send_command(
241-
&mut self,
242-
command: Command,
243-
request_id: impl Into<Option<i32>>,
244-
) -> Result<RawCommandResponse> {
245-
let to_compress = command.should_compress();
246-
let message = Message::from_command(command, request_id.into())?;
247-
self.send_message(message, to_compress).await
248-
}
249-
250271
/// Receive the next message from the connection.
251272
/// This will return an error if the previous response on this connection did not include the
252273
/// moreToCome flag.
@@ -378,6 +399,7 @@ pub(crate) struct PendingConnection {
378399
pub(crate) generation: PoolGeneration,
379400
pub(crate) event_emitter: CmapEventEmitter,
380401
pub(crate) time_created: Instant,
402+
pub(crate) cancellation_receiver: Option<broadcast::Receiver<()>>,
381403
}
382404

383405
impl PendingConnection {

src/cmap/conn/pooled.rs

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,18 @@ use std::{
55
};
66

77
use derive_where::derive_where;
8-
use tokio::sync::{mpsc, Mutex};
8+
use tokio::sync::{broadcast, mpsc, Mutex};
99

1010
use super::{
1111
CmapEventEmitter,
1212
Connection,
1313
ConnectionGeneration,
1414
ConnectionInfo,
15+
Message,
1516
PendingConnection,
1617
PinnedConnectionHandle,
1718
PoolManager,
19+
RawCommandResponse,
1820
};
1921
use crate::{
2022
bson::oid::ObjectId,
@@ -50,7 +52,7 @@ pub(crate) struct PooledConnection {
5052
}
5153

5254
/// The state of a pooled connection.
53-
#[derive(Clone, Debug)]
55+
#[derive(Debug)]
5456
enum PooledConnectionState {
5557
/// The state associated with a connection checked into the connection pool.
5658
CheckedIn { available_time: Instant },
@@ -59,6 +61,10 @@ enum PooledConnectionState {
5961
CheckedOut {
6062
/// The manager used to check this connection back into the pool.
6163
pool_manager: PoolManager,
64+
65+
/// The receiver to receive a cancellation notice. Only present on non-load-balanced
66+
/// connections.
67+
cancellation_receiver: Option<broadcast::Receiver<()>>,
6268
},
6369

6470
/// The state associated with a pinned connection.
@@ -140,6 +146,24 @@ impl PooledConnection {
140146
.and_then(|sd| sd.service_id)
141147
}
142148

149+
/// Sends a message on this connection.
150+
pub(crate) async fn send_message(
151+
&mut self,
152+
message: impl TryInto<Message, Error = impl Into<Error>>,
153+
) -> Result<RawCommandResponse> {
154+
match self.state {
155+
PooledConnectionState::CheckedOut {
156+
cancellation_receiver: Some(ref mut cancellation_receiver),
157+
..
158+
} => {
159+
self.connection
160+
.send_message_with_cancellation(message, cancellation_receiver)
161+
.await
162+
}
163+
_ => self.connection.send_message(message).await,
164+
}
165+
}
166+
143167
/// Updates the state of the connection to indicate that it is checked into the pool.
144168
pub(crate) fn mark_checked_in(&mut self) {
145169
if !matches!(self.state, PooledConnectionState::CheckedIn { .. }) {
@@ -155,8 +179,15 @@ impl PooledConnection {
155179
}
156180

157181
/// Updates the state of the connection to indicate that it is checked out of the pool.
158-
pub(crate) fn mark_checked_out(&mut self, pool_manager: PoolManager) {
159-
self.state = PooledConnectionState::CheckedOut { pool_manager };
182+
pub(crate) fn mark_checked_out(
183+
&mut self,
184+
pool_manager: PoolManager,
185+
cancellation_receiver: Option<broadcast::Receiver<()>>,
186+
) {
187+
self.state = PooledConnectionState::CheckedOut {
188+
pool_manager,
189+
cancellation_receiver,
190+
};
160191
}
161192

162193
/// Whether this connection is idle.
@@ -175,15 +206,14 @@ impl PooledConnection {
175206
Instant::now().duration_since(available_time) >= max_idle_time
176207
}
177208

178-
/// Nullifies the internal state of this connection and returns it in a new [PooledConnection].
179-
/// If a state is provided, then the new connection will contain that state; otherwise, this
180-
/// connection's state will be cloned.
181-
fn take(&mut self, state: impl Into<Option<PooledConnectionState>>) -> Self {
209+
/// Nullifies the internal state of this connection and returns it in a new [PooledConnection]
210+
/// with the given state.
211+
fn take(&mut self, new_state: PooledConnectionState) -> Self {
182212
Self {
183213
connection: self.connection.take(),
184214
generation: self.generation,
185215
event_emitter: self.event_emitter.clone(),
186-
state: state.into().unwrap_or_else(|| self.state.clone()),
216+
state: new_state,
187217
}
188218
}
189219

@@ -196,7 +226,9 @@ impl PooledConnection {
196226
self.id
197227
)))
198228
}
199-
PooledConnectionState::CheckedOut { ref pool_manager } => {
229+
PooledConnectionState::CheckedOut {
230+
ref pool_manager, ..
231+
} => {
200232
let (tx, rx) = mpsc::channel(1);
201233
self.state = PooledConnectionState::Pinned {
202234
// Mark the connection as in-use while the operation currently using the
@@ -286,10 +318,11 @@ impl Drop for PooledConnection {
286318
// Nothing needs to be done when a checked-in connection is dropped.
287319
PooledConnectionState::CheckedIn { .. } => Ok(()),
288320
// A checked-out connection should be sent back to the connection pool.
289-
PooledConnectionState::CheckedOut { pool_manager } => {
321+
PooledConnectionState::CheckedOut { pool_manager, .. } => {
290322
let pool_manager = pool_manager.clone();
291-
let mut dropped_connection = self.take(None);
292-
dropped_connection.mark_checked_in();
323+
let dropped_connection = self.take(PooledConnectionState::CheckedIn {
324+
available_time: Instant::now(),
325+
});
293326
pool_manager.check_in(dropped_connection)
294327
}
295328
// A pinned connection should be returned to its pinner or to the connection pool.
@@ -339,7 +372,11 @@ impl Drop for PooledConnection {
339372
}
340373
// The pinner of this connection has been dropped while the connection was
341374
// sitting in its channel, so the connection should be returned to the pool.
342-
PinnedState::Returned { .. } => pool_manager.check_in(self.take(None)),
375+
PinnedState::Returned { .. } => {
376+
pool_manager.check_in(self.take(PooledConnectionState::CheckedIn {
377+
available_time: Instant::now(),
378+
}))
379+
}
343380
}
344381
}
345382
};

0 commit comments

Comments
 (0)