Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 48 additions & 30 deletions cranelift/jit/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ const READONLY_DATA_ALIGNMENT: u64 = 0x1;
/// A builder for `JITModule`.
pub struct JITBuilder {
isa: OwnedTargetIsa,
symbols: HashMap<String, *const u8>,
lookup_symbols: Vec<Box<dyn Fn(&str) -> Option<*const u8>>>,
symbols: HashMap<String, SendWrapper<*const u8>>,
lookup_symbols: Vec<Box<dyn Fn(&str) -> Option<*const u8> + Send>>,
libcall_names: Box<dyn Fn(ir::LibCall) -> String + Send + Sync>,
hotswap_enabled: bool,
}
Expand Down Expand Up @@ -116,7 +116,7 @@ impl JITBuilder {
where
K: Into<String>,
{
self.symbols.insert(name.into(), ptr);
self.symbols.insert(name.into(), SendWrapper(ptr));
self
}

Expand All @@ -129,7 +129,7 @@ impl JITBuilder {
K: Into<String>,
{
for (name, ptr) in symbols {
self.symbols.insert(name.into(), ptr);
self.symbols.insert(name.into(), SendWrapper(ptr));
}
self
}
Expand All @@ -140,7 +140,7 @@ impl JITBuilder {
/// symbol table. Symbol lookup fn's are called in reverse of the order in which they were added.
pub fn symbol_lookup_fn(
&mut self,
symbol_lookup_fn: Box<dyn Fn(&str) -> Option<*const u8>>,
symbol_lookup_fn: Box<dyn Fn(&str) -> Option<*const u8> + Send>,
Comment thread
MolotovCherry marked this conversation as resolved.
) -> &mut Self {
self.lookup_symbols.push(symbol_lookup_fn);
self
Expand All @@ -165,23 +165,32 @@ struct GotUpdate {
ptr: *const u8,
}

unsafe impl Send for GotUpdate {}

/// A wrapper that impls Send for the contents.
///
/// SAFETY: This must not be used for any types where it would be UB for them to be Send
#[derive(Copy, Clone)]
struct SendWrapper<T>(T);
Comment thread
MolotovCherry marked this conversation as resolved.
unsafe impl<T> Send for SendWrapper<T> {}

/// A `JITModule` implements `Module` and emits code and data into memory where it can be
/// directly called and accessed.
///
/// See the `JITBuilder` for a convenient way to construct `JITModule` instances.
pub struct JITModule {
isa: OwnedTargetIsa,
hotswap_enabled: bool,
symbols: RefCell<HashMap<String, *const u8>>,
lookup_symbols: Vec<Box<dyn Fn(&str) -> Option<*const u8>>>,
libcall_names: Box<dyn Fn(ir::LibCall) -> String>,
symbols: RefCell<HashMap<String, SendWrapper<*const u8>>>,
lookup_symbols: Vec<Box<dyn Fn(&str) -> Option<*const u8> + Send>>,
libcall_names: Box<dyn Fn(ir::LibCall) -> String + Send + Sync>,
memory: MemoryHandle,
declarations: ModuleDeclarations,
function_got_entries: SecondaryMap<FuncId, Option<NonNull<AtomicPtr<u8>>>>,
function_plt_entries: SecondaryMap<FuncId, Option<NonNull<[u8; 16]>>>,
data_object_got_entries: SecondaryMap<DataId, Option<NonNull<AtomicPtr<u8>>>>,
libcall_got_entries: HashMap<ir::LibCall, NonNull<AtomicPtr<u8>>>,
libcall_plt_entries: HashMap<ir::LibCall, NonNull<[u8; 16]>>,
function_got_entries: SecondaryMap<FuncId, Option<SendWrapper<NonNull<AtomicPtr<u8>>>>>,
function_plt_entries: SecondaryMap<FuncId, Option<SendWrapper<NonNull<[u8; 16]>>>>,
data_object_got_entries: SecondaryMap<DataId, Option<SendWrapper<NonNull<AtomicPtr<u8>>>>>,
libcall_got_entries: HashMap<ir::LibCall, SendWrapper<NonNull<AtomicPtr<u8>>>>,
libcall_plt_entries: HashMap<ir::LibCall, SendWrapper<NonNull<[u8; 16]>>>,
compiled_functions: SecondaryMap<FuncId, Option<CompiledBlob>>,
compiled_data_objects: SecondaryMap<DataId, Option<CompiledBlob>>,
functions_to_finalize: Vec<FuncId>,
Expand Down Expand Up @@ -215,15 +224,15 @@ impl JITModule {

fn lookup_symbol(&self, name: &str) -> Option<*const u8> {
match self.symbols.borrow_mut().entry(name.to_owned()) {
std::collections::hash_map::Entry::Occupied(occ) => Some(*occ.get()),
std::collections::hash_map::Entry::Occupied(occ) => Some(occ.get().0),
std::collections::hash_map::Entry::Vacant(vac) => {
let ptr = self
.lookup_symbols
.iter()
.rev() // Try last lookup function first
.find_map(|lookup| lookup(name));
if let Some(ptr) = ptr {
vac.insert(ptr);
vac.insert(SendWrapper(ptr));
}
ptr
}
Expand Down Expand Up @@ -266,7 +275,7 @@ impl JITModule {

fn new_func_plt_entry(&mut self, id: FuncId, val: *const u8) {
let got_entry = self.new_got_entry(val);
self.function_got_entries[id] = Some(got_entry);
self.function_got_entries[id] = Some(SendWrapper(got_entry));
let plt_entry = self.new_plt_entry(got_entry);
self.record_function_for_perf(
plt_entry.as_ptr().cast(),
Expand All @@ -276,12 +285,12 @@ impl JITModule {
self.declarations.get_function_decl(id).linkage_name(id)
),
);
self.function_plt_entries[id] = Some(plt_entry);
self.function_plt_entries[id] = Some(SendWrapper(plt_entry));
}

fn new_data_got_entry(&mut self, id: DataId, val: *const u8) {
let got_entry = self.new_got_entry(val);
self.data_object_got_entries[id] = Some(got_entry);
self.data_object_got_entries[id] = Some(SendWrapper(got_entry));
}

unsafe fn write_plt_entry_bytes(plt_ptr: *mut [u8; 16], got_ptr: NonNull<AtomicPtr<u8>>) {
Expand Down Expand Up @@ -350,24 +359,26 @@ impl JITModule {
/// Panics if there's no entry in the table for the given function.
pub fn read_got_entry(&self, func_id: FuncId) -> *const u8 {
let got_entry = self.function_got_entries[func_id].unwrap();
unsafe { got_entry.as_ref() }.load(Ordering::SeqCst)
unsafe { got_entry.0.as_ref() }.load(Ordering::SeqCst)
}

fn get_got_address(&self, name: &ModuleRelocTarget) -> NonNull<AtomicPtr<u8>> {
match *name {
ModuleRelocTarget::User { .. } => {
if ModuleDeclarations::is_function(name) {
let func_id = FuncId::from_name(name);
self.function_got_entries[func_id].unwrap()
self.function_got_entries[func_id].unwrap().0
} else {
let data_id = DataId::from_name(name);
self.data_object_got_entries[data_id].unwrap()
self.data_object_got_entries[data_id].unwrap().0
}
}
ModuleRelocTarget::LibCall(ref libcall) => *self
.libcall_got_entries
.get(libcall)
.unwrap_or_else(|| panic!("can't resolve libcall {}", libcall)),
ModuleRelocTarget::LibCall(ref libcall) => {
self.libcall_got_entries
.get(libcall)
.unwrap_or_else(|| panic!("can't resolve libcall {}", libcall))
.0
}
_ => panic!("invalid name"),
}
}
Expand All @@ -379,6 +390,7 @@ impl JITModule {
let func_id = FuncId::from_name(name);
self.function_plt_entries[func_id]
.unwrap()
.0
.as_ptr()
.cast::<u8>()
} else {
Expand All @@ -389,6 +401,7 @@ impl JITModule {
.libcall_plt_entries
.get(libcall)
.unwrap_or_else(|| panic!("can't resolve libcall {}", libcall))
.0
.as_ptr()
.cast::<u8>(),
_ => panic!("invalid name"),
Expand Down Expand Up @@ -543,9 +556,13 @@ impl JITModule {
continue;
};
let got_entry = module.new_got_entry(addr);
module.libcall_got_entries.insert(libcall, got_entry);
module
.libcall_got_entries
.insert(libcall, SendWrapper(got_entry));
let plt_entry = module.new_plt_entry(got_entry);
module.libcall_plt_entries.insert(libcall, plt_entry);
module
.libcall_plt_entries
.insert(libcall, SendWrapper(plt_entry));
}

module
Expand Down Expand Up @@ -719,7 +736,7 @@ impl Module for JITModule {

if self.isa.flags().is_pic() {
self.pending_got_updates.push(GotUpdate {
entry: self.function_got_entries[id].unwrap(),
entry: self.function_got_entries[id].unwrap().0,
ptr,
})
}
Expand All @@ -737,6 +754,7 @@ impl Module for JITModule {
.libcall_plt_entries
.get(libcall)
.unwrap_or_else(|| panic!("can't resolve libcall {}", libcall))
.0
.as_ptr()
.cast::<u8>(),
_ => panic!("invalid name"),
Expand Down Expand Up @@ -802,7 +820,7 @@ impl Module for JITModule {

if self.isa.flags().is_pic() {
self.pending_got_updates.push(GotUpdate {
entry: self.function_got_entries[id].unwrap(),
entry: self.function_got_entries[id].unwrap().0,
ptr,
})
}
Expand Down Expand Up @@ -907,7 +925,7 @@ impl Module for JITModule {
self.data_objects_to_finalize.push(id);
if self.isa.flags().is_pic() {
self.pending_got_updates.push(GotUpdate {
entry: self.data_object_got_entries[id].unwrap(),
entry: self.data_object_got_entries[id].unwrap().0,
ptr,
})
}
Expand Down
2 changes: 2 additions & 0 deletions cranelift/jit/src/compiled_blob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub(crate) struct CompiledBlob {
pub(crate) relocs: Vec<ModuleReloc>,
}

unsafe impl Send for CompiledBlob {}

impl CompiledBlob {
pub(crate) fn perform_relocations(
&self,
Expand Down
2 changes: 2 additions & 0 deletions cranelift/jit/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ pub(crate) struct Memory {
branch_protection: BranchProtection,
}

unsafe impl Send for Memory {}

impl Memory {
pub(crate) fn new(branch_protection: BranchProtection) -> Self {
Self {
Expand Down