Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 0 additions & 47 deletions compiler/rustc_data_structures/src/marker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,53 +188,6 @@ pub fn assert_dyn_send<T: ?Sized + PointeeSized + DynSend>() {}
pub fn assert_dyn_send_val<T: ?Sized + PointeeSized + DynSend>(_t: &T) {}
pub fn assert_dyn_send_sync_val<T: ?Sized + PointeeSized + DynSync + DynSend>(_t: &T) {}

#[derive(Copy, Clone)]
pub struct FromDyn<T>(T);

impl<T> FromDyn<T> {
#[inline(always)]
pub fn from(val: T) -> Self {
// Check that `sync::is_dyn_thread_safe()` is true on creation so we can
// implement `Send` and `Sync` for this structure when `T`
// implements `DynSend` and `DynSync` respectively.
assert!(crate::sync::is_dyn_thread_safe());
FromDyn(val)
}

#[inline(always)]
pub fn derive<O>(&self, val: O) -> FromDyn<O> {
// We already did the check for `sync::is_dyn_thread_safe()` when creating `Self`
FromDyn(val)
}

#[inline(always)]
pub fn into_inner(self) -> T {
self.0
}
}

// `FromDyn` is `Send` if `T` is `DynSend`, since it ensures that sync::is_dyn_thread_safe() is true.
unsafe impl<T: DynSend> Send for FromDyn<T> {}

// `FromDyn` is `Sync` if `T` is `DynSync`, since it ensures that sync::is_dyn_thread_safe() is true.
unsafe impl<T: DynSync> Sync for FromDyn<T> {}

impl<T> std::ops::Deref for FromDyn<T> {
type Target = T;

#[inline(always)]
fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<T> std::ops::DerefMut for FromDyn<T> {
#[inline(always)]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

// A wrapper to convert a struct that is already a `Send` or `Sync` into
// an instance of `DynSend` and `DynSync`, since the compiler cannot infer
// it automatically in some cases. (e.g. Box<dyn Send / Sync>)
Expand Down
50 changes: 49 additions & 1 deletion compiler/rustc_data_structures/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ pub use self::atomic::AtomicU64;
pub use self::freeze::{FreezeLock, FreezeReadGuard, FreezeWriteGuard};
#[doc(no_inline)]
pub use self::lock::{Lock, LockGuard, Mode};
pub use self::mode::{is_dyn_thread_safe, set_dyn_thread_safe_mode};
pub use self::mode::{
FromDyn, check_dyn_thread_safe, is_dyn_thread_safe, set_dyn_thread_safe_mode,
};
pub use self::parallel::{
broadcast, par_fns, par_for_each_in, par_join, par_map, parallel_guard, spawn,
try_par_for_each_in,
Expand Down Expand Up @@ -64,12 +66,20 @@ mod atomic {
mod mode {
use std::sync::atomic::{AtomicU8, Ordering};

use crate::sync::{DynSend, DynSync};

const UNINITIALIZED: u8 = 0;
const DYN_NOT_THREAD_SAFE: u8 = 1;
const DYN_THREAD_SAFE: u8 = 2;

static DYN_THREAD_SAFE_MODE: AtomicU8 = AtomicU8::new(UNINITIALIZED);

// Whether thread safety is enabled (due to running under multiple threads).
#[inline]
pub fn check_dyn_thread_safe() -> Option<FromDyn<()>> {
is_dyn_thread_safe().then_some(FromDyn(()))
}

// Whether thread safety is enabled (due to running under multiple threads).
#[inline]
pub fn is_dyn_thread_safe() -> bool {
Expand Down Expand Up @@ -99,6 +109,44 @@ mod mode {
// Check that the mode was either uninitialized or was already set to the requested mode.
assert!(previous.is_ok() || previous == Err(set));
}

#[derive(Copy, Clone)]
pub struct FromDyn<T>(T);

impl<T> FromDyn<T> {
#[inline(always)]
pub fn derive<O>(&self, val: O) -> FromDyn<O> {
// We already did the check for `sync::is_dyn_thread_safe()` when creating `Self`
FromDyn(val)
}

#[inline(always)]
pub fn into_inner(self) -> T {
self.0
}
}

// `FromDyn` is `Send` if `T` is `DynSend`, since it ensures that sync::is_dyn_thread_safe() is true.
unsafe impl<T: DynSend> Send for FromDyn<T> {}

// `FromDyn` is `Sync` if `T` is `DynSync`, since it ensures that sync::is_dyn_thread_safe() is true.
unsafe impl<T: DynSync> Sync for FromDyn<T> {}

impl<T> std::ops::Deref for FromDyn<T> {
type Target = T;

#[inline(always)]
fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<T> std::ops::DerefMut for FromDyn<T> {
#[inline(always)]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
}

/// This makes locks panic if they are already held.
Expand Down
63 changes: 37 additions & 26 deletions compiler/rustc_data_structures/src/sync/parallel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ where
}

pub fn spawn(func: impl FnOnce() + DynSend + 'static) {
if mode::is_dyn_thread_safe() {
let func = FromDyn::from(func);
if let Some(proof) = mode::check_dyn_thread_safe() {
let func = proof.derive(func);
rustc_thread_pool::spawn(|| {
(func.into_inner())();
});
Expand All @@ -73,8 +73,8 @@ pub fn spawn(func: impl FnOnce() + DynSend + 'static) {
/// Use that for the longest running function for better scheduling.
pub fn par_fns(funcs: &mut [&mut (dyn FnMut() + DynSend)]) {
parallel_guard(|guard: &ParallelGuard| {
if mode::is_dyn_thread_safe() {
let funcs = FromDyn::from(funcs);
if let Some(proof) = mode::check_dyn_thread_safe() {
let funcs = proof.derive(funcs);
rustc_thread_pool::scope(|s| {
let Some((first, rest)) = funcs.into_inner().split_at_mut_checked(1) else {
return;
Expand All @@ -84,7 +84,7 @@ pub fn par_fns(funcs: &mut [&mut (dyn FnMut() + DynSend)]) {
// order when using a single thread. This ensures the execution order matches
// that of a single threaded rustc.
for f in rest.iter_mut().rev() {
let f = FromDyn::from(f);
let f = proof.derive(f);
s.spawn(|_| {
guard.run(|| (f.into_inner())());
});
Expand All @@ -108,13 +108,13 @@ where
A: FnOnce() -> RA + DynSend,
B: FnOnce() -> RB + DynSend,
{
if mode::is_dyn_thread_safe() {
let oper_a = FromDyn::from(oper_a);
let oper_b = FromDyn::from(oper_b);
if let Some(proof) = mode::check_dyn_thread_safe() {
let oper_a = proof.derive(oper_a);
let oper_b = proof.derive(oper_b);
let (a, b) = parallel_guard(|guard| {
rustc_thread_pool::join(
move || guard.run(move || FromDyn::from(oper_a.into_inner()())),
move || guard.run(move || FromDyn::from(oper_b.into_inner()())),
move || guard.run(move || proof.derive(oper_a.into_inner()())),
move || guard.run(move || proof.derive(oper_b.into_inner()())),
)
});
(a.unwrap().into_inner(), b.unwrap().into_inner())
Expand All @@ -127,8 +127,9 @@ fn par_slice<I: DynSend>(
items: &mut [I],
guard: &ParallelGuard,
for_each: impl Fn(&mut I) + DynSync + DynSend,
proof: FromDyn<()>,
) {
let for_each = FromDyn::from(for_each);
let for_each = proof.derive(for_each);
let mut items = for_each.derive(items);
rustc_thread_pool::scope(|s| {
let proof = items.derive(());
Expand All @@ -150,9 +151,9 @@ pub fn par_for_each_in<I: DynSend, T: IntoIterator<Item = I>>(
for_each: impl Fn(&I) + DynSync + DynSend,
) {
parallel_guard(|guard| {
if mode::is_dyn_thread_safe() {
if let Some(proof) = mode::check_dyn_thread_safe() {
let mut items: Vec<_> = t.into_iter().collect();
par_slice(&mut items, guard, |i| for_each(&*i))
par_slice(&mut items, guard, |i| for_each(&*i), proof)
} else {
t.into_iter().for_each(|i| {
guard.run(|| for_each(&i));
Expand All @@ -173,16 +174,21 @@ where
<T as IntoIterator>::Item: DynSend,
{
parallel_guard(|guard| {
if mode::is_dyn_thread_safe() {
if let Some(proof) = mode::check_dyn_thread_safe() {
let mut items: Vec<_> = t.into_iter().collect();

let error = Mutex::new(None);

par_slice(&mut items, guard, |i| {
if let Err(err) = for_each(&*i) {
*error.lock() = Some(err);
}
});
par_slice(
&mut items,
guard,
|i| {
if let Err(err) = for_each(&*i) {
*error.lock() = Some(err);
}
},
proof,
);

if let Some(err) = error.into_inner() { Err(err) } else { Ok(()) }
} else {
Expand All @@ -196,15 +202,20 @@ pub fn par_map<I: DynSend, T: IntoIterator<Item = I>, R: DynSend, C: FromIterato
map: impl Fn(I) -> R + DynSync + DynSend,
) -> C {
parallel_guard(|guard| {
if mode::is_dyn_thread_safe() {
let map = FromDyn::from(map);
if let Some(proof) = mode::check_dyn_thread_safe() {
let map = proof.derive(map);

let mut items: Vec<(Option<I>, Option<R>)> =
t.into_iter().map(|i| (Some(i), None)).collect();

par_slice(&mut items, guard, |i| {
i.1 = Some(map(i.0.take().unwrap()));
});
par_slice(
&mut items,
guard,
|i| {
i.1 = Some(map(i.0.take().unwrap()));
},
proof,
);

items.into_iter().filter_map(|i| i.1).collect()
} else {
Expand All @@ -214,8 +225,8 @@ pub fn par_map<I: DynSend, T: IntoIterator<Item = I>, R: DynSend, C: FromIterato
}

pub fn broadcast<R: DynSend>(op: impl Fn(usize) -> R + DynSync) -> Vec<R> {
if mode::is_dyn_thread_safe() {
let op = FromDyn::from(op);
if let Some(proof) = mode::check_dyn_thread_safe() {
let op = proof.derive(op);
let results = rustc_thread_pool::broadcast(|context| op.derive(op(context.index())));
results.into_iter().map(|r| r.into_inner()).collect()
} else {
Expand Down
9 changes: 4 additions & 5 deletions compiler/rustc_interface/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,14 @@ pub(crate) fn run_in_thread_pool_with_globals<
use std::process;

use rustc_data_structures::defer;
use rustc_data_structures::sync::FromDyn;
use rustc_middle::ty::tls;
use rustc_query_impl::break_query_cycles;

let thread_stack_size = init_stack_size(thread_builder_diag);

let registry = sync::Registry::new(std::num::NonZero::new(threads).unwrap());

if !sync::is_dyn_thread_safe() {
let Some(proof) = sync::check_dyn_thread_safe() else {
return run_in_thread_with_globals(
thread_stack_size,
edition,
Expand All @@ -204,9 +203,9 @@ pub(crate) fn run_in_thread_pool_with_globals<
f(current_gcx, jobserver_proxy)
},
);
}
};

let current_gcx = FromDyn::from(CurrentGcx::new());
let current_gcx = proof.derive(CurrentGcx::new());
let current_gcx2 = current_gcx.clone();

let proxy = Proxy::new();
Expand Down Expand Up @@ -278,7 +277,7 @@ internal compiler error: query cycle handler thread panicked, aborting process";
// `Send` in the parallel compiler.
rustc_span::create_session_globals_then(edition, extra_symbols, Some(sm_inputs), || {
rustc_span::with_session_globals(|session_globals| {
let session_globals = FromDyn::from(session_globals);
let session_globals = proof.derive(session_globals);
builder
.build_scoped(
// Initialize each new worker thread when created.
Expand Down
Loading