diff --git a/src/file_operations.rs b/src/file_operations.rs index 1f7d85cc..93df083a 100644 --- a/src/file_operations.rs +++ b/src/file_operations.rs @@ -138,23 +138,46 @@ unsafe extern "C" fn llseek_callback( } } +unsafe extern "C" fn fsync_callback( + file: *mut bindings::file, + start: bindings::loff_t, + end: bindings::loff_t, + datasync: c_types::c_int, +) -> c_types::c_int { + let start = match start.try_into() { + Ok(v) => v, + Err(_) => return Error::EINVAL.to_kernel_errno(), + }; + let end = match end.try_into() { + Ok(v) => v, + Err(_) => return Error::EINVAL.to_kernel_errno(), + }; + let datasync = datasync != 0; + let fsync = T::FSYNC.unwrap(); + let f = &*((*file).private_data as *const T); + match fsync(f, &File::from_ptr(file), start, end, datasync) { + Ok(result) => result as c_types::c_int, + Err(e) => e.to_kernel_errno(), + } +} + pub(crate) struct FileOperationsVtable(marker::PhantomData); impl FileOperationsVtable { pub(crate) const VTABLE: bindings::file_operations = bindings::file_operations { open: Some(open_callback::), release: Some(release_callback::), - read: if let Some(_) = T::READ { + read: if T::READ.is_some() { Some(read_callback::) } else { None }, - write: if let Some(_) = T::WRITE { + write: if T::WRITE.is_some() { Some(write_callback::) } else { None }, - llseek: if let Some(_) = T::SEEK { + llseek: if T::SEEK.is_some() { Some(llseek_callback::) } else { None @@ -176,7 +199,11 @@ impl FileOperationsVtable { fasync: None, flock: None, flush: None, - fsync: None, + fsync: if T::FSYNC.is_some() { + Some(fsync_callback::) + } else { + None + }, get_unmapped_area: None, iterate: None, #[cfg(kernel_4_7_0_or_greater)] @@ -207,6 +234,7 @@ impl FileOperationsVtable { pub type ReadFn = Option KernelResult<()>>; pub type WriteFn = Option KernelResult<()>>; pub type SeekFn = Option KernelResult>; +pub type FSync = Option KernelResult>; /// `FileOperations` corresponds to the kernel's `struct file_operations`. You /// implement this trait whenever you'd create a `struct file_operations`. @@ -228,4 +256,6 @@ pub trait FileOperations: Sync + Sized { /// Changes the position of the file. Corresponds to the `llseek` function /// pointer in `struct file_operations`. const SEEK: SeekFn = None; + + const FSYNC: FSync = None; } diff --git a/tests/chrdev/src/lib.rs b/tests/chrdev/src/lib.rs index 71aa3333..e96e1a88 100644 --- a/tests/chrdev/src/lib.rs +++ b/tests/chrdev/src/lib.rs @@ -3,7 +3,7 @@ extern crate alloc; use alloc::string::ToString; -use core::sync::atomic::{AtomicUsize, Ordering}; +use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use linux_kernel_module::{self, cstr}; @@ -83,6 +83,59 @@ impl linux_kernel_module::file_operations::FileOperations for WriteFile { ); } +struct FSyncFile { + data_synced: AtomicBool, + meta_synced: AtomicBool, +} + +impl linux_kernel_module::file_operations::FileOperations for FSyncFile { + fn open() -> linux_kernel_module::KernelResult { + Ok(FSyncFile { + data_synced: AtomicBool::new(true), + meta_synced: AtomicBool::new(true), + }) + } + + const READ: linux_kernel_module::file_operations::ReadFn = Some( + |this: &Self, + _file: &linux_kernel_module::file_operations::File, + buf: &mut linux_kernel_module::user_ptr::UserSlicePtrWriter, + _offset: u64| + -> linux_kernel_module::KernelResult<()> { + let data = (this.data_synced.load(Ordering::SeqCst) as i32).to_string(); + let meta = (this.meta_synced.load(Ordering::SeqCst) as i32).to_string(); + buf.write((data + &meta).as_bytes())?; + Ok(()) + }, + ); + + const WRITE: linux_kernel_module::file_operations::WriteFn = Some( + |this: &Self, + _buf: &mut linux_kernel_module::user_ptr::UserSlicePtrReader, + _offset: u64| + -> linux_kernel_module::KernelResult<()> { + this.data_synced.store(false, Ordering::SeqCst); + this.meta_synced.store(false, Ordering::SeqCst); + Ok(()) + }, + ); + + const FSYNC: linux_kernel_module::file_operations::FSync = Some( + |this: &Self, + _file: &linux_kernel_module::file_operations::File, + _start: u64, + _end: u64, + datasync: bool| + -> linux_kernel_module::KernelResult { + this.data_synced.store(true, Ordering::SeqCst); + if !datasync { + this.meta_synced.store(true, Ordering::SeqCst); + } + Ok(0) + }, + ); +} + struct ChrdevTestModule { _chrdev_registration: linux_kernel_module::chrdev::Registration, } @@ -90,10 +143,11 @@ struct ChrdevTestModule { impl linux_kernel_module::KernelModule for ChrdevTestModule { fn init() -> linux_kernel_module::KernelResult { let chrdev_registration = - linux_kernel_module::chrdev::builder(cstr!("chrdev-tests"), 0..3)? + linux_kernel_module::chrdev::builder(cstr!("chrdev-tests"), 0..4)? .register_device::() .register_device::() .register_device::() + .register_device::() .build()?; Ok(ChrdevTestModule { _chrdev_registration: chrdev_registration, diff --git a/tests/chrdev/tests/tests.rs b/tests/chrdev/tests/tests.rs index 1ae12f76..8dcf0542 100644 --- a/tests/chrdev/tests/tests.rs +++ b/tests/chrdev/tests/tests.rs @@ -8,6 +8,7 @@ const DEVICE_NAME: &'static str = "chrdev-tests"; const READ_FILE_MINOR: libc::dev_t = 0; const SEEK_FILE_MINOR: libc::dev_t = 1; const WRITE_FILE_MINOR: libc::dev_t = 2; +const SYNC_FILE_MINOR: libc::dev_t = 3; #[test] fn test_mknod() { @@ -171,3 +172,55 @@ fn test_write() { assert_eq!(&buf, b"8"); }) } + +#[test] +fn test_fsync_unimplemented() { + with_kernel_module(|| { + let device_number = get_device_major_number(DEVICE_NAME); + let p = temporary_file_path(); + let _u = mknod(&p, device_number, READ_FILE_MINOR); + + let f = fs::OpenOptions::new().write(true).open(&p).unwrap(); + assert_eq!( + f.sync_all().unwrap_err().raw_os_error().unwrap(), + libc::EINVAL + ); + }) +} + +#[test] +fn test_fsync() { + with_kernel_module(|| { + let device_number = get_device_major_number(DEVICE_NAME); + let p = temporary_file_path(); + let _u = mknod(&p, device_number, SYNC_FILE_MINOR); + + let mut f = fs::OpenOptions::new() + .read(true) + .write(true) + .open(&p) + .unwrap(); + + let mut buf = [0; 2]; + f.read_exact(&mut buf).unwrap(); + assert_eq!(&buf, b"11"); + + f.write(&[1, 2]).unwrap(); + + let mut buf = [0; 2]; + f.read_exact(&mut buf).unwrap(); + assert_eq!(&buf, b"00"); + + f.sync_data().unwrap(); + + let mut buf = [0; 2]; + f.read_exact(&mut buf).unwrap(); + assert_eq!(&buf, b"10"); + + f.sync_all().unwrap(); + + let mut buf = [0; 2]; + f.read_exact(&mut buf).unwrap(); + assert_eq!(&buf, b"11"); + }) +}