diff --git a/drivers/android/process.rs b/drivers/android/process.rs index fa0ce400ed35e8..58a3357bff6590 100644 --- a/drivers/android/process.rs +++ b/drivers/android/process.rs @@ -16,6 +16,7 @@ use kernel::{ prelude::*, rbtree::RBTree, sync::{Guard, Mutex, Ref}, + task::Task, user_ptr::{UserSlicePtr, UserSlicePtrReader}, Error, }; @@ -33,10 +34,6 @@ use crate::{ // TODO: Review this: // Lock order: Process::node_refs -> Process::inner -> Thread::inner -extern "C" { - fn rust_helper_current_pid() -> c_types::c_int; -} - pub(crate) struct AllocationInfo { /// Range within the allocation where we can find the offsets to the object descriptors. pub(crate) offsets: Range, @@ -799,7 +796,7 @@ impl IoctlHandler for Process { cmd: u32, reader: &mut UserSlicePtrReader, ) -> Result { - let thread = this.get_thread(unsafe { rust_helper_current_pid() })?; + let thread = this.get_thread(Task::current().pid())?; match cmd { bindings::BINDER_SET_MAX_THREADS => this.set_max_threads(reader.read()?), bindings::BINDER_SET_CONTEXT_MGR => this.set_as_manager(None, &thread)?, @@ -813,7 +810,7 @@ impl IoctlHandler for Process { } fn read_write(this: &Ref, file: &File, cmd: u32, data: UserSlicePtr) -> Result { - let thread = this.get_thread(unsafe { rust_helper_current_pid() })?; + let thread = this.get_thread(Task::current().pid())?; match cmd { bindings::BINDER_WRITE_READ => thread.write_read(data, file.is_blocking())?, bindings::BINDER_GET_NODE_DEBUG_INFO => this.get_node_debug_info(data)?, @@ -939,7 +936,7 @@ impl FileOperations for Process { } fn poll(this: &Ref, file: &File, table: &PollTable) -> Result { - let thread = this.get_thread(unsafe { rust_helper_current_pid() })?; + let thread = this.get_thread(Task::current().pid())?; let (from_proc, mut mask) = thread.poll(file, table); if mask == 0 && from_proc && !this.inner.lock().work.is_empty() { mask |= bindings::POLLIN; diff --git a/rust/helpers.c b/rust/helpers.c index f2049e041bdd42..f57f9340f49f3b 100644 --- a/rust/helpers.c +++ b/rust/helpers.c @@ -60,15 +60,9 @@ void rust_helper_init_wait(struct wait_queue_entry *wq_entry) } EXPORT_SYMBOL_GPL(rust_helper_init_wait); -int rust_helper_current_pid(void) +int rust_helper_signal_pending(struct task_struct *t) { - return current->pid; -} -EXPORT_SYMBOL_GPL(rust_helper_current_pid); - -int rust_helper_signal_pending(void) -{ - return signal_pending(current); + return signal_pending(t); } EXPORT_SYMBOL_GPL(rust_helper_signal_pending); @@ -171,6 +165,24 @@ void rust_helper_rb_link_node(struct rb_node *node, struct rb_node *parent, } EXPORT_SYMBOL_GPL(rust_helper_rb_link_node); +struct task_struct *rust_helper_get_current(void) +{ + return current; +} +EXPORT_SYMBOL_GPL(rust_helper_get_current); + +void rust_helper_get_task_struct(struct task_struct * t) +{ + get_task_struct(t); +} +EXPORT_SYMBOL_GPL(rust_helper_get_task_struct); + +void rust_helper_put_task_struct(struct task_struct * t) +{ + put_task_struct(t); +} +EXPORT_SYMBOL_GPL(rust_helper_put_task_struct); + /* We use bindgen's --size_t-is-usize option to bind the C size_t type * as the Rust usize type, so we can use it in contexts where Rust * expects a usize like slice (array) indices. usize is defined to be diff --git a/rust/kernel/file.rs b/rust/kernel/file.rs index 262a856fc4910f..015c5284e1c616 100644 --- a/rust/kernel/file.rs +++ b/rust/kernel/file.rs @@ -53,7 +53,7 @@ impl Drop for File { } } -/// A wrapper for [`File`] that doesn't automatically decrement that refcount when dropped. +/// A wrapper for [`File`] that doesn't automatically decrement the refcount when dropped. /// /// We need the wrapper because [`ManuallyDrop`] alone would allow callers to call /// [`ManuallyDrop::into_inner`]. This would allow an unsafe sequence to be triggered without diff --git a/rust/kernel/lib.rs b/rust/kernel/lib.rs index 06e24011067eb0..7201cebd7981df 100644 --- a/rust/kernel/lib.rs +++ b/rust/kernel/lib.rs @@ -51,6 +51,7 @@ pub mod file_operations; pub mod miscdev; pub mod pages; pub mod str; +pub mod task; pub mod traits; pub mod linked_list; diff --git a/rust/kernel/sync/condvar.rs b/rust/kernel/sync/condvar.rs index 29649c0a3a6db8..993087e6c23397 100644 --- a/rust/kernel/sync/condvar.rs +++ b/rust/kernel/sync/condvar.rs @@ -6,8 +6,7 @@ //! variable. use super::{Guard, Lock, NeedsLockClass}; -use crate::bindings; -use crate::str::CStr; +use crate::{bindings, str::CStr, task::Task}; use core::{cell::UnsafeCell, marker::PhantomPinned, mem::MaybeUninit, pin::Pin}; extern "C" { @@ -92,7 +91,7 @@ impl CondVar { // SAFETY: Both `wait` and `wait_list` point to valid memory. unsafe { bindings::finish_wait(self.wait_list.get(), wait.as_mut_ptr()) }; - super::signal_pending() + Task::current().signal_pending() } /// Calls the kernel function to notify the appropriate number of threads with the given flags. diff --git a/rust/kernel/sync/mod.rs b/rust/kernel/sync/mod.rs index eac76d3a0b33ae..ce863109c06eb8 100644 --- a/rust/kernel/sync/mod.rs +++ b/rust/kernel/sync/mod.rs @@ -39,7 +39,6 @@ pub use mutex::Mutex; pub use spinlock::SpinLock; extern "C" { - fn rust_helper_signal_pending() -> c_types::c_int; fn rust_helper_cond_resched() -> c_types::c_int; } @@ -78,12 +77,6 @@ pub trait NeedsLockClass { unsafe fn init(self: Pin<&mut Self>, name: &'static CStr, key: *mut bindings::lock_class_key); } -/// Determines if a signal is pending on the current process. -pub fn signal_pending() -> bool { - // SAFETY: No arguments, just checks `current` for pending signals. - unsafe { rust_helper_signal_pending() != 0 } -} - /// Reschedules the caller's task if needed. pub fn cond_resched() -> bool { // SAFETY: No arguments, reschedules `current` if needed. diff --git a/rust/kernel/task.rs b/rust/kernel/task.rs new file mode 100644 index 00000000000000..d526d0056e7fdb --- /dev/null +++ b/rust/kernel/task.rs @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Tasks (threads and processes). +//! +//! C header: [`include/linux/sched.h`](../../../../include/linux/sched.h). + +use crate::{bindings, c_types}; +use core::{marker::PhantomData, mem::ManuallyDrop, ops::Deref}; + +extern "C" { + #[allow(improper_ctypes)] + fn rust_helper_signal_pending(t: *const bindings::task_struct) -> c_types::c_int; + #[allow(improper_ctypes)] + fn rust_helper_get_current() -> *mut bindings::task_struct; + #[allow(improper_ctypes)] + fn rust_helper_get_task_struct(t: *mut bindings::task_struct); + #[allow(improper_ctypes)] + fn rust_helper_put_task_struct(t: *mut bindings::task_struct); +} + +/// Wraps the kernel's `struct task_struct`. +/// +/// # Invariants +/// +/// The pointer [`Task::ptr`] is non-null and valid. Its reference count is also non-zero. +/// +/// # Examples +/// +/// The following is an example of getting the PID of the current thread with zero additional cost +/// when compared to the C version: +/// +/// ``` +/// # use kernel::prelude::*; +/// use kernel::task::Task; +/// +/// # fn test() { +/// Task::current().pid(); +/// # } +/// ``` +/// +/// Getting the PID of the current process, also zero additional cost: +/// +/// ``` +/// # use kernel::prelude::*; +/// use kernel::task::Task; +/// +/// # fn test() { +/// Task::current().group_leader().pid(); +/// # } +/// ``` +/// +/// Getting the current task and storing it in some struct. The reference count is automatically +/// incremented when creating `State` and decremented when it is dropped: +/// +/// ``` +/// # use kernel::prelude::*; +/// use kernel::task::Task; +/// +/// struct State { +/// creator: Task, +/// index: u32, +/// } +/// +/// impl State { +/// fn new() -> Self { +/// Self { +/// creator: Task::current().clone(), +/// index: 0, +/// } +/// } +/// } +/// ``` +pub struct Task { + pub(crate) ptr: *mut bindings::task_struct, +} + +// SAFETY: Given that the task is referenced, it is OK to send it to another thread. +unsafe impl Send for Task {} + +// SAFETY: It's OK to access `Task` through references from other threads because we're either +// accessing properties that don't change (e.g., `pid`, `group_leader`) or that are properly +// synchronised by C code (e.g., `signal_pending`). +unsafe impl Sync for Task {} + +/// The type of process identifiers (PIDs). +type Pid = bindings::pid_t; + +impl Task { + /// Returns a task reference for the currently executing task/thread. + pub fn current<'a>() -> TaskRef<'a> { + // SAFETY: Just an FFI call. + let ptr = unsafe { rust_helper_get_current() }; + + // SAFETY: If the current thread is still running, the current task is valid. Given + // that `TaskRef` is not `Send`, we know it cannot be transferred to another thread (where + // it could potentially outlive the caller). + unsafe { TaskRef::from_ptr(ptr) } + } + + /// Returns the group leader of the given task. + pub fn group_leader(&self) -> TaskRef<'_> { + // SAFETY: By the type invariant, we know that `self.ptr` is non-null and valid. + let ptr = unsafe { (*self.ptr).group_leader }; + + // SAFETY: The lifetime of the returned task reference is tied to the lifetime of `self`, + // and given that a task has a reference to its group leader, we know it must be valid for + // the lifetime of the returned task reference. + unsafe { TaskRef::from_ptr(ptr) } + } + + /// Returns the PID of the given task. + pub fn pid(&self) -> Pid { + // SAFETY: By the type invariant, we know that `self.ptr` is non-null and valid. + unsafe { (*self.ptr).pid } + } + + /// Determines whether the given task has pending signals. + pub fn signal_pending(&self) -> bool { + // SAFETY: By the type invariant, we know that `self.ptr` is non-null and valid. + unsafe { rust_helper_signal_pending(self.ptr) != 0 } + } +} + +impl PartialEq for Task { + fn eq(&self, other: &Self) -> bool { + self.ptr == other.ptr + } +} + +impl Eq for Task {} + +impl Clone for Task { + fn clone(&self) -> Self { + // SAFETY: The type invariants guarantee that `self.ptr` has a non-zero reference count. + unsafe { rust_helper_get_task_struct(self.ptr) }; + + // INVARIANT: We incremented the reference count to account for the new `Task` being + // created. + Self { ptr: self.ptr } + } +} + +impl Drop for Task { + fn drop(&mut self) { + // INVARIANT: We may decrement the refcount to zero, but the `Task` is being dropped, so + // this is not observable. + // SAFETY: The type invariants guarantee that `Task::ptr` has a non-zero reference count. + unsafe { rust_helper_put_task_struct(self.ptr) }; + } +} + +/// A wrapper for [`Task`] that doesn't automatically decrement the refcount when dropped. +/// +/// We need the wrapper because [`ManuallyDrop`] alone would allow callers to call +/// [`ManuallyDrop::into_inner`]. This would allow an unsafe sequence to be triggered without +/// `unsafe` blocks because it would trigger an unbalanced call to `put_task_struct`. +/// +/// We make this explicitly not [`Send`] so that we can use it to represent the current thread +/// without having to increment/decrement its reference count. +/// +/// # Invariants +/// +/// The wrapped [`Task`] remains valid for the lifetime of the object. +pub struct TaskRef<'a> { + task: ManuallyDrop, + _not_send: PhantomData<(&'a (), *mut ())>, +} + +impl TaskRef<'_> { + /// Constructs a new `struct task_struct` wrapper that doesn't change its reference count. + /// + /// # Safety + /// + /// The pointer `ptr` must be non-null and valid for the lifetime of the object. + pub(crate) unsafe fn from_ptr(ptr: *mut bindings::task_struct) -> Self { + Self { + task: ManuallyDrop::new(Task { ptr }), + _not_send: PhantomData, + } + } +} + +// SAFETY: It is OK to share a reference to the current thread with another thread because we know +// the owner cannot go away while the shared reference exists (and `Task` itself is `Sync`). +unsafe impl Sync for TaskRef<'_> {} + +impl Deref for TaskRef<'_> { + type Target = Task; + + fn deref(&self) -> &Self::Target { + self.task.deref() + } +}