Skip to content

Commit 746331e

Browse files
committed
std: drop all messages in bounded channel when destroying the last receiver
1 parent a64ef7d commit 746331e

File tree

2 files changed

+109
-27
lines changed

2 files changed

+109
-27
lines changed

library/std/src/sync/mpmc/array.rs

+107-25
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use super::utils::{Backoff, CachePadded};
1515
use super::waker::SyncWaker;
1616

1717
use crate::cell::UnsafeCell;
18-
use crate::mem::MaybeUninit;
18+
use crate::mem::{self, MaybeUninit};
1919
use crate::ptr;
2020
use crate::sync::atomic::{self, AtomicUsize, Ordering};
2121
use crate::time::Instant;
@@ -25,7 +25,8 @@ struct Slot<T> {
2525
/// The current stamp.
2626
stamp: AtomicUsize,
2727

28-
/// The message in this slot.
28+
/// The message in this slot. Either read out in `read` or dropped through
29+
/// `discard_all_messages`.
2930
msg: UnsafeCell<MaybeUninit<T>>,
3031
}
3132

@@ -439,21 +440,122 @@ impl<T> Channel<T> {
439440
Some(self.cap)
440441
}
441442

442-
/// Disconnects the channel and wakes up all blocked senders and receivers.
443+
/// Disconnects senders and wakes up all blocked receivers.
443444
///
444445
/// Returns `true` if this call disconnected the channel.
445-
pub(crate) fn disconnect(&self) -> bool {
446+
pub(crate) fn disconnect_senders(&self) -> bool {
446447
let tail = self.tail.fetch_or(self.mark_bit, Ordering::SeqCst);
447448

448449
if tail & self.mark_bit == 0 {
449-
self.senders.disconnect();
450450
self.receivers.disconnect();
451451
true
452452
} else {
453453
false
454454
}
455455
}
456456

457+
/// Disconnects receivers and wakes up all blocked senders.
458+
///
459+
/// Returns `true` if this call disconnected the channel.
460+
///
461+
/// # Safety
462+
/// May only be called once upon dropping the last receiver. The
463+
/// destruction of all other receivers must have been observed with acquire
464+
/// ordering or stronger.
465+
pub(crate) unsafe fn disconnect_receivers(&self) -> bool {
466+
let tail = self.tail.fetch_or(self.mark_bit, Ordering::SeqCst);
467+
self.discard_all_messages(tail);
468+
469+
if tail & self.mark_bit == 0 {
470+
self.senders.disconnect();
471+
true
472+
} else {
473+
false
474+
}
475+
}
476+
477+
/// Discards all messages.
478+
///
479+
/// `tail` should be the current (and therefore last) value of `tail`.
480+
///
481+
/// # Safety
482+
/// This method must only be called when dropping the last receiver. The
483+
/// destruction of all other receivers must have been observed with acquire
484+
/// ordering or stronger.
485+
unsafe fn discard_all_messages(&self, tail: usize) {
486+
debug_assert!(self.is_disconnected());
487+
488+
/// Use a helper struct with a custom `Drop` to ensure all messages are
489+
/// dropped, even if a destructor panicks.
490+
struct DiscardState<'a, T> {
491+
channel: &'a Channel<T>,
492+
head: usize,
493+
tail: usize,
494+
backoff: Backoff,
495+
}
496+
497+
impl<'a, T> DiscardState<'a, T> {
498+
fn discard(&mut self) {
499+
loop {
500+
// Deconstruct the head.
501+
let index = self.head & (self.channel.mark_bit - 1);
502+
let lap = self.head & !(self.channel.one_lap - 1);
503+
504+
// Inspect the corresponding slot.
505+
debug_assert!(index < self.channel.buffer.len());
506+
let slot = unsafe { self.channel.buffer.get_unchecked(index) };
507+
let stamp = slot.stamp.load(Ordering::Acquire);
508+
509+
// If the stamp is ahead of the head by 1, we may drop the message.
510+
if self.head + 1 == stamp {
511+
self.head = if index + 1 < self.channel.cap {
512+
// Same lap, incremented index.
513+
// Set to `{ lap: lap, mark: 0, index: index + 1 }`.
514+
self.head + 1
515+
} else {
516+
// One lap forward, index wraps around to zero.
517+
// Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`.
518+
lap.wrapping_add(self.channel.one_lap)
519+
};
520+
521+
// We updated the head, so even if this descrutor panics,
522+
// we will not attempt to destroy the slot again.
523+
unsafe {
524+
(*slot.msg.get()).assume_init_drop();
525+
}
526+
// If the tail equals the head, that means the channel is empty.
527+
} else if self.tail == self.head {
528+
return;
529+
// Otherwise, a sender is about to write into the slot, so we need
530+
// to wait for it to update the stamp.
531+
} else {
532+
self.backoff.spin_heavy();
533+
}
534+
}
535+
}
536+
}
537+
538+
impl<'a, T> Drop for DiscardState<'a, T> {
539+
fn drop(&mut self) {
540+
self.discard();
541+
}
542+
}
543+
544+
let mut state = DiscardState {
545+
channel: self,
546+
// Only receivers modify `head`, so since we are the last one,
547+
// this value will not change and will not be observed (since
548+
// no new messages can be sent after disconnection).
549+
head: self.head.load(Ordering::Relaxed),
550+
tail: tail & !self.mark_bit,
551+
backoff: Backoff::new(),
552+
};
553+
state.discard();
554+
// This point is only reached if no destructor panics, so all messages
555+
// have already been dropped.
556+
mem::forget(state);
557+
}
558+
457559
/// Returns `true` if the channel is disconnected.
458560
pub(crate) fn is_disconnected(&self) -> bool {
459561
self.tail.load(Ordering::SeqCst) & self.mark_bit != 0
@@ -483,23 +585,3 @@ impl<T> Channel<T> {
483585
head.wrapping_add(self.one_lap) == tail & !self.mark_bit
484586
}
485587
}
486-
487-
impl<T> Drop for Channel<T> {
488-
fn drop(&mut self) {
489-
// Get the index of the head.
490-
let hix = self.head.load(Ordering::Relaxed) & (self.mark_bit - 1);
491-
492-
// Loop over all slots that hold a message and drop them.
493-
for i in 0..self.len() {
494-
// Compute the index of the next slot holding a message.
495-
let index = if hix + i < self.cap { hix + i } else { hix + i - self.cap };
496-
497-
unsafe {
498-
debug_assert!(index < self.buffer.len());
499-
let slot = self.buffer.get_unchecked_mut(index);
500-
let msg = &mut *slot.msg.get();
501-
msg.as_mut_ptr().drop_in_place();
502-
}
503-
}
504-
}
505-
}

library/std/src/sync/mpmc/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ impl<T> Drop for Sender<T> {
227227
fn drop(&mut self) {
228228
unsafe {
229229
match &self.flavor {
230-
SenderFlavor::Array(chan) => chan.release(|c| c.disconnect()),
230+
SenderFlavor::Array(chan) => chan.release(|c| c.disconnect_senders()),
231231
SenderFlavor::List(chan) => chan.release(|c| c.disconnect_senders()),
232232
SenderFlavor::Zero(chan) => chan.release(|c| c.disconnect()),
233233
}
@@ -403,7 +403,7 @@ impl<T> Drop for Receiver<T> {
403403
fn drop(&mut self) {
404404
unsafe {
405405
match &self.flavor {
406-
ReceiverFlavor::Array(chan) => chan.release(|c| c.disconnect()),
406+
ReceiverFlavor::Array(chan) => chan.release(|c| c.disconnect_receivers()),
407407
ReceiverFlavor::List(chan) => chan.release(|c| c.disconnect_receivers()),
408408
ReceiverFlavor::Zero(chan) => chan.release(|c| c.disconnect()),
409409
}

0 commit comments

Comments
 (0)