diff --git a/crates/bevy_ecs/Cargo.toml b/crates/bevy_ecs/Cargo.toml index ff86bb109c1b4..58bf7474c64e5 100644 --- a/crates/bevy_ecs/Cargo.toml +++ b/crates/bevy_ecs/Cargo.toml @@ -22,6 +22,7 @@ bevy_reflect = { path = "../bevy_reflect", version = "0.5.0", optional = true } bevy_tasks = { path = "../bevy_tasks", version = "0.5.0" } bevy_utils = { path = "../bevy_utils", version = "0.5.0" } bevy_ecs_macros = { path = "macros", version = "0.5.0" } +futures-lite = "1.4.0" async-channel = "1.4" bitflags = "1.2" diff --git a/crates/bevy_ecs/src/system/accessor.rs b/crates/bevy_ecs/src/system/accessor.rs new file mode 100644 index 0000000000000..f192928d9eb56 --- /dev/null +++ b/crates/bevy_ecs/src/system/accessor.rs @@ -0,0 +1,356 @@ +use std::{ + borrow::Cow, + future::Future, + marker::PhantomData, + sync::Arc, + task::{Poll, Waker}, +}; + +use async_channel::{Receiver, Sender}; +use bevy_ecs_macros::all_tuples; +use futures_lite::pin; +use parking_lot::Mutex; + +use crate::{ + archetype::{Archetype, ArchetypeComponentId}, + component::ComponentId, + prelude::World, + query::Access, +}; + +use super::{ + check_system_change_tick, System, SystemId, SystemParam, SystemParamFetch, SystemParamState, + SystemState, +}; + +pub struct Accessor { + channel: Sender>>, + _marker: PhantomData P>, +} + +impl Accessor

{ + pub fn new() -> (Self, AccessorRunnerSystem

) { + let (tx, rx) = async_channel::unbounded(); + ( + Accessor { + channel: tx, + _marker: Default::default(), + }, + AccessorRunnerSystem { + system_state: SystemState::with_name( + format!("Accessor system {}", std::any::type_name::

()).into(), + ), + param_state: None, + channel: rx, + _marker: Default::default(), + }, + ) + } +} + +impl Clone for Accessor

{ + fn clone(&self) -> Self { + Self { + channel: self.channel.clone(), + _marker: Default::default(), + } + } +} + +#[doc(hidden)] +pub trait AccessFn<'a, Out, Param: SystemParam, M>: Send + Sync + 'static { + /// # Safety + /// this is an internal trait that exists to bypass some limitations of rustc, please ignore it. + unsafe fn run( + self: Box, + state: &'a mut Param::Fetch, + system_state: &'a SystemState, + world: &'a World, + change_tick: u32, + ) -> Out; +} +pub struct SingleMarker; +impl<'a, Out, Func, P> AccessFn<'a, Out, P, SingleMarker> for Func +where + Func: FnOnce(P) -> Out + + FnOnce(<

::Fetch as SystemParamFetch<'a>>::Item) -> Out + + Send + + Sync + + 'static, + Out: 'static, + P: SystemParam, +{ + #[inline] + unsafe fn run( + self: Box, + state: &'a mut

::Fetch, + system_state: &'a SystemState, + world: &'a World, + change_tick: u32, + ) -> Out { + let param = <

::Fetch as SystemParamFetch<'a>>::get_param( + state, + system_state, + world, + change_tick, + ); + self(param) + } +} + +macro_rules! impl_system_function { + ($($param: ident),*) => { + #[allow(non_snake_case)] + impl<'a, Out, Func, $($param: SystemParam),*> AccessFn<'a, Out, ($($param,)*), ()> for Func + where + Func: + FnOnce($($param),*) -> Out + + FnOnce($(<<$param as SystemParam>::Fetch as SystemParamFetch<'a>>::Item),*) -> Out + Send + Sync + 'static, + Out: 'static + { + #[inline] + unsafe fn run( + self: Box, + state: &'a mut <($($param,)*) as SystemParam>::Fetch, + system_state: &'a SystemState, world: &'a World, + change_tick: u32, + ) -> Out { + let ($($param,)*) = <<($($param,)*) as SystemParam>::Fetch as SystemParamFetch<'a>>::get_param(state, system_state, world, change_tick); + self($($param),*) + } + } + }; +} + +all_tuples!(impl_system_function, 0, 12, F); + +impl Accessor

{ + pub fn access( + &mut self, + sync: impl for<'a> AccessFn<'a, R, P, M> + Send + Sync, + ) -> impl Future + Send + Sync + 'static { + AccessFuture { + state: AccessFutureState::FirstPoll { + boxed: Box::new(sync), + tx: self.channel.clone(), + }, + } + } +} + +struct ParallelAccess { + inner: Arc AccessFn<'a, Out, P, M> + Send + Sync>>>>, + tx: Sender, + waker: Waker, +} + +trait GenericAccess: Send + Sync + 'static { + unsafe fn run( + self: Box, + param_state: &mut P::Fetch, + system_state: &SystemState, + world: &World, + change_tick: u32, + ); +} + +impl GenericAccess

for ParallelAccess +where + P: SystemParam + 'static, + Out: Send + Sync + 'static, +{ + unsafe fn run( + self: Box, + param_state: &mut P::Fetch, + state: &SystemState, + world: &World, + change_tick: u32, + ) { + if let Some(sync) = self.inner.lock().take() { + self.tx + .try_send(sync.run(param_state, state, world, change_tick)) + .unwrap(); + } + self.waker.wake(); + } +} + +enum AccessFutureState { + FirstPoll { + boxed: Box AccessFn<'a, R, P, M> + Send + Sync + 'static>, + tx: Sender>>, + }, + WaitingForCompletion( + Receiver, + Arc AccessFn<'a, R, P, M> + Send + Sync + 'static>>>>, + ), +} + +pub struct AccessFuture { + state: AccessFutureState, +} + +impl Future for AccessFuture +where + P: SystemParam + 'static, + R: Send + Sync + 'static, + M: 'static, +{ + type Output = R; + + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + match &mut self.state { + AccessFutureState::FirstPoll { .. } => { + let (tx, rx) = async_channel::bounded(1); + let arc = Arc::new(Mutex::new(None)); + if let AccessFutureState::FirstPoll { boxed, tx: mtx } = std::mem::replace( + &mut self.state, + AccessFutureState::WaitingForCompletion(rx, arc.clone()), + ) { + *arc.lock() = Some(boxed); + let msg = ParallelAccess { + inner: arc, + tx, + waker: cx.waker().clone(), + }; + let boxed: Box> = Box::new(msg); + mtx.try_send(boxed).unwrap(); + Poll::Pending + } else { + unreachable!() + } + } + AccessFutureState::WaitingForCompletion(rx, _) => { + let future = rx.recv(); + pin!(future); + future.poll(cx).map(|v| v.unwrap()) + } + } + } +} + +pub struct AccessorRunnerSystem { + system_state: SystemState, + param_state: Option, + channel: Receiver>>, + _marker: PhantomData P>, +} + +impl System for AccessorRunnerSystem

{ + type In = (); + type Out = (); + + fn name(&self) -> Cow<'static, str> { + self.system_state.name.clone() + } + + fn id(&self) -> SystemId { + self.system_state.id + } + + fn archetype_component_access(&self) -> &Access { + &self.system_state.archetype_component_access + } + + unsafe fn run_unsafe(&mut self, _: Self::In, world: &World) -> Self::Out { + loop { + match self.channel.try_recv() { + Ok(sync) => { + let change_tick = world.increment_change_tick(); + sync.run( + &mut self.param_state.as_mut().unwrap(), + &self.system_state, + world, + change_tick, + ); + self.system_state.last_change_tick = change_tick; + } + Err(async_channel::TryRecvError::Closed) => panic!( + "`AccessorRunnerSystem` called but all relevant accessors have been dropped!" + ), + Err(async_channel::TryRecvError::Empty) => break, + } + } + } + + fn initialize(&mut self, world: &mut World) { + self.param_state = Some(::init( + world, + &mut self.system_state, + ::default_config(), + )) + } + + fn apply_buffers(&mut self, world: &mut World) { + let param_state = self.param_state.as_mut().unwrap(); + param_state.apply(world); + } + + fn component_access(&self) -> &Access { + &self.system_state.component_access_set.combined_access() + } + + fn new_archetype(&mut self, archetype: &Archetype) { + let param_state = self.param_state.as_mut().unwrap(); + param_state.new_archetype(archetype, &mut self.system_state); + } + + fn is_send(&self) -> bool { + self.system_state.is_send() + } + + #[inline] + fn check_change_tick(&mut self, change_tick: u32) { + check_system_change_tick( + &mut self.system_state.last_change_tick, + change_tick, + self.system_state.name.as_ref(), + ); + } +} + +#[cfg(test)] +mod test { + use bevy_tasks::TaskPool; + + use crate::{ + prelude::{Res, ResMut, World}, + schedule::{Stage, SystemStage}, + }; + + use super::Accessor; + + #[test] + fn simple_test() { + let mut world = World::new(); + let ctp = TaskPool::new(); + world.insert_resource("hi".to_string()); + world.insert_resource(3u32); + let (mut accessor, system) = Accessor::<(Res, ResMut)>::new(); + let mut stage = SystemStage::parallel(); + stage.add_system(system); + let _a = accessor.clone(); + ctp.spawn(async move { + accessor + .access(|(r, mut s): (Res, ResMut)| { + assert_eq!(*r, 3); + *s = "hello".into(); + }) + .await; + }) + .detach(); + + let start = std::time::Instant::now(); + loop { + stage.run(&mut world); + if world.get_resource::().unwrap() == "hello" { + break; + } else if std::time::Instant::now() - start > std::time::Duration::from_secs_f32(0.1) { + panic!("timeout!"); + } + } + } +} diff --git a/crates/bevy_ecs/src/system/into_system.rs b/crates/bevy_ecs/src/system/into_system.rs index 4a48a3878eec8..1e64aa7bf92db 100644 --- a/crates/bevy_ecs/src/system/into_system.rs +++ b/crates/bevy_ecs/src/system/into_system.rs @@ -33,6 +33,17 @@ impl SystemState { } } + pub fn with_name(name: Cow<'static, str>) -> Self { + Self { + name, + archetype_component_access: Access::default(), + component_access_set: FilteredAccessSet::default(), + is_send: true, + id: SystemId::new(), + last_change_tick: 0, + } + } + #[inline] pub fn is_send(&self) -> bool { self.is_send diff --git a/crates/bevy_ecs/src/system/mod.rs b/crates/bevy_ecs/src/system/mod.rs index b10ec03c6506e..ecdd6a56ab125 100644 --- a/crates/bevy_ecs/src/system/mod.rs +++ b/crates/bevy_ecs/src/system/mod.rs @@ -1,3 +1,4 @@ +mod accessor; mod commands; mod exclusive_system; mod into_system; @@ -7,6 +8,7 @@ mod system; mod system_chaining; mod system_param; +pub use accessor::*; pub use commands::*; pub use exclusive_system::*; pub use into_system::*;