Skip to content

Commit 8cee824

Browse files
bors[bot]cuviper
andauthored
Merge #1011
1011: Use pointers instead of `&self` in `Latch::set` r=cuviper a=cuviper `Latch::set` can invalidate its own `&self`, because it releases the owning thread to continue execution, which may then invalidate the latch by deallocation, reuse, etc. We've known about this problem when it comes to accessing latch fields too late, but the possibly dangling reference was still a problem, like rust-lang/rust#55005. The result of that was rust-lang/rust#98017, omitting the LLVM attribute `dereferenceable` on references to `!Freeze` types -- those containing `UnsafeCell`. However, miri's Stacked Borrows implementation is finer- grained than that, only relaxing for the cell itself in the `!Freeze` type. For rayon, that solves the dangling reference in atomic calls, but remains a problem for other fields of a `Latch`. This easiest fix for rayon is to use a raw pointer instead of `&self`. We still end up with some temporary references for stuff like atomics, but those should be fine with the rules above. Co-authored-by: Josh Stone <[email protected]>
2 parents ed98853 + f880d02 commit 8cee824

File tree

5 files changed

+95
-58
lines changed

5 files changed

+95
-58
lines changed

rayon-core/src/broadcast/mod.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::job::{ArcJob, StackJob};
2+
use crate::latch::LatchRef;
23
use crate::registry::{Registry, WorkerThread};
34
use crate::scope::ScopeLatch;
45
use std::fmt;
@@ -107,7 +108,9 @@ where
107108
let n_threads = registry.num_threads();
108109
let current_thread = WorkerThread::current().as_ref();
109110
let latch = ScopeLatch::with_count(n_threads, current_thread);
110-
let jobs: Vec<_> = (0..n_threads).map(|_| StackJob::new(&f, &latch)).collect();
111+
let jobs: Vec<_> = (0..n_threads)
112+
.map(|_| StackJob::new(&f, LatchRef::new(&latch)))
113+
.collect();
111114
let job_refs = jobs.iter().map(|job| job.as_job_ref());
112115

113116
registry.inject_broadcast(job_refs);

rayon-core/src/job.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ where
112112
let abort = unwind::AbortIfPanic;
113113
let func = (*this.func.get()).take().unwrap();
114114
(*this.result.get()) = JobResult::call(func);
115-
this.latch.set();
115+
Latch::set(&this.latch);
116116
mem::forget(abort);
117117
}
118118
}

rayon-core/src/latch.rs

+62-28
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::marker::PhantomData;
2+
use std::ops::Deref;
13
use std::sync::atomic::{AtomicUsize, Ordering};
24
use std::sync::{Arc, Condvar, Mutex};
35
use std::usize;
@@ -37,10 +39,15 @@ pub(super) trait Latch {
3739
///
3840
/// Setting a latch triggers other threads to wake up and (in some
3941
/// cases) complete. This may, in turn, cause memory to be
40-
/// allocated and so forth. One must be very careful about this,
42+
/// deallocated and so forth. One must be very careful about this,
4143
/// and it's typically better to read all the fields you will need
4244
/// to access *before* a latch is set!
43-
fn set(&self);
45+
///
46+
/// This function operates on `*const Self` instead of `&self` to allow it
47+
/// to become dangling during this call. The caller must ensure that the
48+
/// pointer is valid upon entry, and not invalidated during the call by any
49+
/// actions other than `set` itself.
50+
unsafe fn set(this: *const Self);
4451
}
4552

4653
pub(super) trait AsCoreLatch {
@@ -123,8 +130,8 @@ impl CoreLatch {
123130
/// doing some wakeups; those are encapsulated in the surrounding
124131
/// latch code.
125132
#[inline]
126-
fn set(&self) -> bool {
127-
let old_state = self.state.swap(SET, Ordering::AcqRel);
133+
unsafe fn set(this: *const Self) -> bool {
134+
let old_state = (*this).state.swap(SET, Ordering::AcqRel);
128135
old_state == SLEEPING
129136
}
130137

@@ -186,29 +193,29 @@ impl<'r> AsCoreLatch for SpinLatch<'r> {
186193

187194
impl<'r> Latch for SpinLatch<'r> {
188195
#[inline]
189-
fn set(&self) {
196+
unsafe fn set(this: *const Self) {
190197
let cross_registry;
191198

192-
let registry: &Registry = if self.cross {
199+
let registry: &Registry = if (*this).cross {
193200
// Ensure the registry stays alive while we notify it.
194201
// Otherwise, it would be possible that we set the spin
195202
// latch and the other thread sees it and exits, causing
196203
// the registry to be deallocated, all before we get a
197204
// chance to invoke `registry.notify_worker_latch_is_set`.
198-
cross_registry = Arc::clone(self.registry);
205+
cross_registry = Arc::clone((*this).registry);
199206
&cross_registry
200207
} else {
201208
// If this is not a "cross-registry" spin-latch, then the
202209
// thread which is performing `set` is itself ensuring
203210
// that the registry stays alive. However, that doesn't
204211
// include this *particular* `Arc` handle if the waiting
205212
// thread then exits, so we must completely dereference it.
206-
self.registry
213+
(*this).registry
207214
};
208-
let target_worker_index = self.target_worker_index;
215+
let target_worker_index = (*this).target_worker_index;
209216

210-
// NOTE: Once we `set`, the target may proceed and invalidate `&self`!
211-
if self.core_latch.set() {
217+
// NOTE: Once we `set`, the target may proceed and invalidate `this`!
218+
if CoreLatch::set(&(*this).core_latch) {
212219
// Subtle: at this point, we can no longer read from
213220
// `self`, because the thread owning this spin latch may
214221
// have awoken and deallocated the latch. Therefore, we
@@ -255,10 +262,10 @@ impl LockLatch {
255262

256263
impl Latch for LockLatch {
257264
#[inline]
258-
fn set(&self) {
259-
let mut guard = self.m.lock().unwrap();
265+
unsafe fn set(this: *const Self) {
266+
let mut guard = (*this).m.lock().unwrap();
260267
*guard = true;
261-
self.v.notify_all();
268+
(*this).v.notify_all();
262269
}
263270
}
264271

@@ -307,9 +314,9 @@ impl CountLatch {
307314
/// count, then the latch is **set**, and calls to `probe()` will
308315
/// return true. Returns whether the latch was set.
309316
#[inline]
310-
pub(super) fn set(&self) -> bool {
311-
if self.counter.fetch_sub(1, Ordering::SeqCst) == 1 {
312-
self.core_latch.set();
317+
pub(super) unsafe fn set(this: *const Self) -> bool {
318+
if (*this).counter.fetch_sub(1, Ordering::SeqCst) == 1 {
319+
CoreLatch::set(&(*this).core_latch);
313320
true
314321
} else {
315322
false
@@ -320,8 +327,12 @@ impl CountLatch {
320327
/// the latch is set, then the specific worker thread is tickled,
321328
/// which should be the one that owns this latch.
322329
#[inline]
323-
pub(super) fn set_and_tickle_one(&self, registry: &Registry, target_worker_index: usize) {
324-
if self.set() {
330+
pub(super) unsafe fn set_and_tickle_one(
331+
this: *const Self,
332+
registry: &Registry,
333+
target_worker_index: usize,
334+
) {
335+
if Self::set(this) {
325336
registry.notify_worker_latch_is_set(target_worker_index);
326337
}
327338
}
@@ -362,19 +373,42 @@ impl CountLockLatch {
362373

363374
impl Latch for CountLockLatch {
364375
#[inline]
365-
fn set(&self) {
366-
if self.counter.fetch_sub(1, Ordering::SeqCst) == 1 {
367-
self.lock_latch.set();
376+
unsafe fn set(this: *const Self) {
377+
if (*this).counter.fetch_sub(1, Ordering::SeqCst) == 1 {
378+
LockLatch::set(&(*this).lock_latch);
368379
}
369380
}
370381
}
371382

372-
impl<'a, L> Latch for &'a L
373-
where
374-
L: Latch,
375-
{
383+
/// `&L` without any implication of `dereferenceable` for `Latch::set`
384+
pub(super) struct LatchRef<'a, L> {
385+
inner: *const L,
386+
marker: PhantomData<&'a L>,
387+
}
388+
389+
impl<L> LatchRef<'_, L> {
390+
pub(super) fn new(inner: &L) -> LatchRef<'_, L> {
391+
LatchRef {
392+
inner,
393+
marker: PhantomData,
394+
}
395+
}
396+
}
397+
398+
unsafe impl<L: Sync> Sync for LatchRef<'_, L> {}
399+
400+
impl<L> Deref for LatchRef<'_, L> {
401+
type Target = L;
402+
403+
fn deref(&self) -> &L {
404+
// SAFETY: if we have &self, the inner latch is still alive
405+
unsafe { &*self.inner }
406+
}
407+
}
408+
409+
impl<L: Latch> Latch for LatchRef<'_, L> {
376410
#[inline]
377-
fn set(&self) {
378-
L::set(self);
411+
unsafe fn set(this: *const Self) {
412+
L::set((*this).inner);
379413
}
380414
}

rayon-core/src/registry.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::job::{JobFifo, JobRef, StackJob};
2-
use crate::latch::{AsCoreLatch, CoreLatch, CountLatch, Latch, LockLatch, SpinLatch};
2+
use crate::latch::{AsCoreLatch, CoreLatch, CountLatch, Latch, LatchRef, LockLatch, SpinLatch};
33
use crate::log::Event::*;
44
use crate::log::Logger;
55
use crate::sleep::Sleep;
@@ -505,7 +505,7 @@ impl Registry {
505505
assert!(injected && !worker_thread.is_null());
506506
op(&*worker_thread, true)
507507
},
508-
l,
508+
LatchRef::new(l),
509509
);
510510
self.inject(&[job.as_job_ref()]);
511511
job.latch.wait_and_reset(); // Make sure we can use the same latch again next time.
@@ -575,7 +575,7 @@ impl Registry {
575575
pub(super) fn terminate(&self) {
576576
if self.terminate_count.fetch_sub(1, Ordering::AcqRel) == 1 {
577577
for (i, thread_info) in self.thread_infos.iter().enumerate() {
578-
thread_info.terminate.set_and_tickle_one(self, i);
578+
unsafe { CountLatch::set_and_tickle_one(&thread_info.terminate, self, i) };
579579
}
580580
}
581581
}
@@ -869,7 +869,7 @@ unsafe fn main_loop(
869869
let registry = &*worker_thread.registry;
870870

871871
// let registry know we are ready to do work
872-
registry.thread_infos[index].primed.set();
872+
Latch::set(&registry.thread_infos[index].primed);
873873

874874
// Worker threads should not panic. If they do, just abort, as the
875875
// internal state of the threadpool is corrupted. Note that if
@@ -892,7 +892,7 @@ unsafe fn main_loop(
892892
debug_assert!(worker_thread.take_local_job().is_none());
893893

894894
// let registry know we are done
895-
registry.thread_infos[index].stopped.set();
895+
Latch::set(&registry.thread_infos[index].stopped);
896896

897897
// Normal termination, do not abort.
898898
mem::forget(abort_guard);

rayon-core/src/scope/mod.rs

+23-23
Original file line numberDiff line numberDiff line change
@@ -540,10 +540,10 @@ impl<'scope> Scope<'scope> {
540540
BODY: FnOnce(&Scope<'scope>) + Send + 'scope,
541541
{
542542
let scope_ptr = ScopePtr(self);
543-
let job = HeapJob::new(move || {
543+
let job = HeapJob::new(move || unsafe {
544544
// SAFETY: this job will execute before the scope ends.
545-
let scope = unsafe { scope_ptr.as_ref() };
546-
scope.base.execute_job(move || body(scope))
545+
let scope = scope_ptr.as_ref();
546+
ScopeBase::execute_job(&scope.base, move || body(scope))
547547
});
548548
let job_ref = self.base.heap_job_ref(job);
549549

@@ -562,12 +562,12 @@ impl<'scope> Scope<'scope> {
562562
BODY: Fn(&Scope<'scope>, BroadcastContext<'_>) + Send + Sync + 'scope,
563563
{
564564
let scope_ptr = ScopePtr(self);
565-
let job = ArcJob::new(move || {
565+
let job = ArcJob::new(move || unsafe {
566566
// SAFETY: this job will execute before the scope ends.
567-
let scope = unsafe { scope_ptr.as_ref() };
567+
let scope = scope_ptr.as_ref();
568568
let body = &body;
569569
let func = move || BroadcastContext::with(move |ctx| body(scope, ctx));
570-
scope.base.execute_job(func);
570+
ScopeBase::execute_job(&scope.base, func)
571571
});
572572
self.base.inject_broadcast(job)
573573
}
@@ -600,10 +600,10 @@ impl<'scope> ScopeFifo<'scope> {
600600
BODY: FnOnce(&ScopeFifo<'scope>) + Send + 'scope,
601601
{
602602
let scope_ptr = ScopePtr(self);
603-
let job = HeapJob::new(move || {
603+
let job = HeapJob::new(move || unsafe {
604604
// SAFETY: this job will execute before the scope ends.
605-
let scope = unsafe { scope_ptr.as_ref() };
606-
scope.base.execute_job(move || body(scope))
605+
let scope = scope_ptr.as_ref();
606+
ScopeBase::execute_job(&scope.base, move || body(scope))
607607
});
608608
let job_ref = self.base.heap_job_ref(job);
609609

@@ -628,12 +628,12 @@ impl<'scope> ScopeFifo<'scope> {
628628
BODY: Fn(&ScopeFifo<'scope>, BroadcastContext<'_>) + Send + Sync + 'scope,
629629
{
630630
let scope_ptr = ScopePtr(self);
631-
let job = ArcJob::new(move || {
631+
let job = ArcJob::new(move || unsafe {
632632
// SAFETY: this job will execute before the scope ends.
633-
let scope = unsafe { scope_ptr.as_ref() };
633+
let scope = scope_ptr.as_ref();
634634
let body = &body;
635635
let func = move || BroadcastContext::with(move |ctx| body(scope, ctx));
636-
scope.base.execute_job(func);
636+
ScopeBase::execute_job(&scope.base, func)
637637
});
638638
self.base.inject_broadcast(job)
639639
}
@@ -688,36 +688,36 @@ impl<'scope> ScopeBase<'scope> {
688688
where
689689
FUNC: FnOnce() -> R,
690690
{
691-
let result = self.execute_job_closure(func);
691+
let result = unsafe { Self::execute_job_closure(self, func) };
692692
self.job_completed_latch.wait(owner);
693693
self.maybe_propagate_panic();
694694
result.unwrap() // only None if `op` panicked, and that would have been propagated
695695
}
696696

697697
/// Executes `func` as a job, either aborting or executing as
698698
/// appropriate.
699-
fn execute_job<FUNC>(&self, func: FUNC)
699+
unsafe fn execute_job<FUNC>(this: *const Self, func: FUNC)
700700
where
701701
FUNC: FnOnce(),
702702
{
703-
let _: Option<()> = self.execute_job_closure(func);
703+
let _: Option<()> = Self::execute_job_closure(this, func);
704704
}
705705

706706
/// Executes `func` as a job in scope. Adjusts the "job completed"
707707
/// counters and also catches any panic and stores it into
708708
/// `scope`.
709-
fn execute_job_closure<FUNC, R>(&self, func: FUNC) -> Option<R>
709+
unsafe fn execute_job_closure<FUNC, R>(this: *const Self, func: FUNC) -> Option<R>
710710
where
711711
FUNC: FnOnce() -> R,
712712
{
713713
match unwind::halt_unwinding(func) {
714714
Ok(r) => {
715-
self.job_completed_latch.set();
715+
Latch::set(&(*this).job_completed_latch);
716716
Some(r)
717717
}
718718
Err(err) => {
719-
self.job_panicked(err);
720-
self.job_completed_latch.set();
719+
(*this).job_panicked(err);
720+
Latch::set(&(*this).job_completed_latch);
721721
None
722722
}
723723
}
@@ -797,14 +797,14 @@ impl ScopeLatch {
797797
}
798798

799799
impl Latch for ScopeLatch {
800-
fn set(&self) {
801-
match self {
800+
unsafe fn set(this: *const Self) {
801+
match &*this {
802802
ScopeLatch::Stealing {
803803
latch,
804804
registry,
805805
worker_index,
806-
} => latch.set_and_tickle_one(registry, *worker_index),
807-
ScopeLatch::Blocking { latch } => latch.set(),
806+
} => CountLatch::set_and_tickle_one(latch, registry, *worker_index),
807+
ScopeLatch::Blocking { latch } => Latch::set(latch),
808808
}
809809
}
810810
}

0 commit comments

Comments
 (0)