From 431f4c56981b71ef53be6adfedc6ef7bc49e54ce Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Tue, 28 Sep 2021 13:34:44 -0700 Subject: [PATCH 1/5] Implement async witx functions for a Wasmtime host This commit is the initial implementation of `async function()` in witx for the Wasmtime host runtime. This integration is unfortunately not as nice as the previous JS implementation, but the general goal is to hook up async witx functions to `async` functions in Rust in an at least somewhat of an idiomatic fashion. For exported functions not much has changed relative to the preexisting `async` support for Wasmtime. Functions are still exposed as `async` functions that take a store and the arguments as input. The major consequence of this is that only one async method call on a wasm module can be active at any point in time. The alternative to this design is to embed the `Store` within the bindings and have some sort of async mutex around it, but that doesn't feel appropriate at this time since it's not really how Wasmtime's idioms are designed (where you're typically given a store rather than embedding it). For imported functions, however, the implementation is significantly different from the previous `async` support for Wasmtime. The support no longer uses `async_trait` for the reason of being able to call multiple imports simultaneously. This means that imports receive their arguments and return `'static` futures. This does not map to any supported mode of `async_trait`. This means that host bindings for these sorts of functions are likely going to be somewhat unidiomatic since they'll require cloning state into host callbacks as necessary and/or using things like `Rc>` or similar. Overall I'm not super happy with how this ended up. I feel that it's generally pretty un-ergonomic and confusing to use. Unfortunately though I don't really know how to make it better at this time. On the plus side it at least all works for the given use cases. There's almost surely a bug or two with the generated code but this should be at least a somewhat solid base to build on as we tweak things in the future. --- Cargo.lock | 23 ++ crates/gen-js/src/lib.rs | 15 +- crates/gen-rust-wasm/src/lib.rs | 1 + crates/gen-spidermonkey/src/lib.rs | 1 + crates/gen-wasmtime/Cargo.toml | 3 + crates/gen-wasmtime/src/lib.rs | 398 +++++++++++++++++++++----- crates/gen-wasmtime/tests/codegen.rs | 6 - crates/wasmtime/src/futures.rs | 295 +++++++++++++++++++ crates/wasmtime/src/lib.rs | 26 ++ crates/witx2/src/abi.rs | 35 ++- tests/runtime/async_functions/host.rs | 148 ++++++++++ 11 files changed, 860 insertions(+), 91 deletions(-) create mode 100644 crates/wasmtime/src/futures.rs create mode 100644 tests/runtime/async_functions/host.rs diff --git a/Cargo.lock b/Cargo.lock index d2ecdfad1..a2a6b0a66 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -570,12 +570,32 @@ dependencies = [ "winapi", ] +[[package]] +name = "futures-channel" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5da6ba8c3bb3c165d3c7319fc1cc8304facf1fb8db99c5de877183c08a273888" +dependencies = [ + "futures-core", +] + [[package]] name = "futures-core" version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "88d1c26957f23603395cd326b0ffe64124b818f4449552f960d815cfba83a53d" +[[package]] +name = "futures-executor" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45025be030969d763025784f7f355043dc6bc74093e4ecc5000ca4dc50d8745c" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-macro" version = "0.3.17" @@ -2116,6 +2136,9 @@ name = "witx-bindgen-gen-wasmtime" version = "0.1.0" dependencies = [ "anyhow", + "futures-channel", + "futures-executor", + "futures-util", "heck", "structopt", "test-helpers", diff --git a/crates/gen-js/src/lib.rs b/crates/gen-js/src/lib.rs index 3cee7e2a9..28bc365db 100644 --- a/crates/gen-js/src/lib.rs +++ b/crates/gen-js/src/lib.rs @@ -2079,6 +2079,16 @@ impl Bindgen for FunctionBindgen<'_> { } }, + Instruction::CompletionCallback { .. } => { + // TODO: shouldn't hardcode the function table name, should + // verify the table is present, and should verify the type of + // the function returned. + results.push(format!( + "get_export(\"__indirect_function_table\").get({})", + operands[0], + )); + } + Instruction::ReturnAsyncImport { .. } => { // When we reenter webassembly successfully that means that the // host's promise resolved without exception. Take the current @@ -2089,15 +2099,12 @@ impl Bindgen for FunctionBindgen<'_> { // Note that the name `cur_promise` used here is introduced in // the `CallInterface` codegen above in the closure for // `with_current_promise` which we're using here. - // - // TODO: hardcoding `__indirect_function_table` and no help if - // it's not actually defined. self.gen.needs_get_export = true; let with = self.gen.intrinsic(Intrinsic::WithCurrentPromise); self.src.js(&format!( "\ {with}(cur_promise, _prev => {{ - get_export(\"__indirect_function_table\").get({})({}); + {}({}); }}); ", operands[0], diff --git a/crates/gen-rust-wasm/src/lib.rs b/crates/gen-rust-wasm/src/lib.rs index 826d20b95..20e857ea2 100644 --- a/crates/gen-rust-wasm/src/lib.rs +++ b/crates/gen-rust-wasm/src/lib.rs @@ -1461,6 +1461,7 @@ impl Bindgen for FunctionBindgen<'_> { operands[0], operands[1] )); } + Instruction::CompletionCallback { .. } => unreachable!(), Instruction::ReturnAsyncImport { .. } => unreachable!(), Instruction::I32Load { offset } => { diff --git a/crates/gen-spidermonkey/src/lib.rs b/crates/gen-spidermonkey/src/lib.rs index 32935abe2..6aed957bc 100644 --- a/crates/gen-spidermonkey/src/lib.rs +++ b/crates/gen-spidermonkey/src/lib.rs @@ -2176,6 +2176,7 @@ impl witx2::abi::Bindgen for Bindgen<'_, '_> { } witx2::abi::Instruction::ReturnAsyncExport { .. } => todo!(), + witx2::abi::Instruction::CompletionCallback { .. } => todo!(), witx2::abi::Instruction::ReturnAsyncImport { .. } => todo!(), witx2::abi::Instruction::Witx { instr: _ } => { diff --git a/crates/gen-wasmtime/Cargo.toml b/crates/gen-wasmtime/Cargo.toml index 8a4660e46..0e9d106be 100644 --- a/crates/gen-wasmtime/Cargo.toml +++ b/crates/gen-wasmtime/Cargo.toml @@ -20,6 +20,9 @@ test-helpers = { path = '../test-helpers', features = ['witx-bindgen-gen-wasmtim wasmtime = "0.30.0" wasmtime-wasi = "0.30.0" witx-bindgen-wasmtime = { path = '../wasmtime', features = ['tracing', 'async'] } +futures-executor = "0.3" +futures-channel = "0.3" +futures-util = "0.3" [features] old-witx-compat = ['witx-bindgen-gen-core/old-witx-compat'] diff --git a/crates/gen-wasmtime/src/lib.rs b/crates/gen-wasmtime/src/lib.rs index 2863b7641..2438813ce 100644 --- a/crates/gen-wasmtime/src/lib.rs +++ b/crates/gen-wasmtime/src/lib.rs @@ -46,7 +46,7 @@ enum NeededFunction { } struct Import { - is_async: bool, + wrap_async: bool, name: String, trait_signature: String, num_wasm_params: usize, @@ -84,6 +84,13 @@ pub struct Opts { pub custom_error: bool, } +// TODO: with the introduction of `async function()` to witx this is no longer a +// good name. The purpose of this configuration is "should the wasmtime function +// be invoked with `call_async`" which is sort of a different form of async. +// Async with wasmtime involves stack switching and with fuel enables +// preemption, but async witx functions don't actually use async host functions +// as-defined-by-wasmtime. All that to say that this probably needs a better +// name. #[derive(Debug, Clone)] pub enum Async { None, @@ -366,10 +373,11 @@ impl Generator for Wasmtime { self.types.analyze(iface); self.in_import = dir == Direction::Import; self.trait_name = iface.name.to_camel_case(); + self.src.push_str("#[allow(unused_imports)]\n"); self.src .push_str(&format!("pub mod {} {{\n", iface.name.to_snake_case())); self.src - .push_str("#[allow(unused_imports)]\nuse witx_bindgen_wasmtime::{wasmtime, anyhow};\n"); + .push_str("use witx_bindgen_wasmtime::{wasmtime, anyhow};\n"); self.sizes.fill(dir, iface); } @@ -573,7 +581,6 @@ impl Generator for Wasmtime { // } fn export(&mut self, iface: &Interface, func: &Function) { - assert!(!func.is_async, "async not supported yet"); let prev = mem::take(&mut self.src); let is_dtor = self.types.is_preview1_dtor_func(func); @@ -621,7 +628,7 @@ impl Generator for Wasmtime { let mut fnsig = FnSig::default(); fnsig.private = true; - fnsig.async_ = self.opts.async_.includes(&func.name); + fnsig.async_ = self.opts.async_.includes(&func.name) && !func.is_async; fnsig.self_arg = Some(self_arg); self.print_docs_and_params( iface, @@ -635,20 +642,21 @@ impl Generator for Wasmtime { ); // The Rust return type may differ from the wasm return type based on // the `custom_error` configuration of this code generator. + self.push_str(" -> "); + if func.is_async { + self.push_str("std::pin::Pin "); - self.print_results(iface, func); - } + self.print_results(iface, func); } FunctionRet::CustomToTrap => { - self.push_str(" -> Result<"); + self.push_str("Result<"); self.print_results(iface, func); self.push_str(", Self::Error>"); } FunctionRet::CustomToError { ok, .. } => { - self.push_str(" -> Result<"); + self.push_str("Result<"); match ok { Some(ty) => self.print_ty(iface, &ty, TypeMode::Owned), None => self.push_str("()"), @@ -656,6 +664,9 @@ impl Generator for Wasmtime { self.push_str(", Self::Error>"); } } + if func.is_async { + self.push_str("> + Send>>"); + } self.in_trait = false; let trait_signature = mem::take(&mut self.src).into(); @@ -679,7 +690,9 @@ impl Generator for Wasmtime { // // If none of that happens, then this is fine to be sync because // everything is sync. - let is_async = if async_intrinsic_called || self.opts.async_.includes(&func.name) { + let finish_async_block = if !func.is_async + && (async_intrinsic_called || self.opts.async_.includes(&func.name)) + { self.src.push_str("Box::new(async move {\n"); true } else { @@ -716,7 +729,7 @@ impl Generator for Wasmtime { if needs_memory || needs_borrow_checker { self.src - .push_str("let memory = &get_memory(&mut caller, \"memory\")?;\n"); + .push_str("let memory = get_memory(&mut caller, \"memory\")?;\n"); self.needs_get_memory = true; } @@ -730,13 +743,26 @@ impl Generator for Wasmtime { self.src.push_str("let host = get(caller.data_mut());\n"); } + let mut rebind = String::new(); if self.all_needed_handles.len() > 0 { - self.src.push_str("let (host, _tables) = host;\n"); + rebind.push_str("_tables, "); + } + if iface.functions.iter().any(|f| f.is_async) { + rebind.push_str("_async_cx, "); + } + if rebind != "" { + self.src + .push_str(&format!("let (host, {}) = host;\n", rebind)); } self.src.push_str(&String::from(src)); - if is_async { + if func.is_async { + self.src.push_str("});\n"); // finish `register_async_import` + self.src.push_str("Ok(())\n") + } + + if finish_async_block { self.src.push_str("})\n"); } self.src.push_str("}"); @@ -746,7 +772,7 @@ impl Generator for Wasmtime { .entry(iface.name.to_string()) .or_insert(Vec::new()) .push(Import { - is_async, + wrap_async: finish_async_block, num_wasm_params: sig.params.len(), name: func.name.to_string(), closure, @@ -755,13 +781,12 @@ impl Generator for Wasmtime { } fn import(&mut self, iface: &Interface, func: &Function) { - assert!(!func.is_async, "async not supported yet"); let prev = mem::take(&mut self.src); // If anything is asynchronous on exports then everything must be // asynchronous, Wasmtime can't intermix async and sync calls because // it's unknown whether the wasm module will make an async host call. - let is_async = !self.opts.async_.is_none(); + let is_async = !self.opts.async_.is_none() || func.is_async; let mut sig = FnSig::default(); sig.async_ = is_async; sig.self_arg = Some("&self, mut caller: impl wasmtime::AsContextMut".to_string()); @@ -769,6 +794,7 @@ impl Generator for Wasmtime { self.push_str("-> Result<"); self.print_results(iface, func); self.push_str(", wasmtime::Trap> {\n"); + self.push_str("let mut caller = caller.as_context_mut();\n"); let is_dtor = self.types.is_preview1_dtor_func(func); if is_dtor { @@ -872,6 +898,7 @@ impl Generator for Wasmtime { } fn finish_one(&mut self, iface: &Interface, files: &mut Files) { + let any_async_func = iface.functions.iter().any(|f| f.is_async); for (module, funcs) in sorted_iter(&self.imports) { let module_camel = module.to_camel_case(); let is_async = !self.opts.async_.is_none(); @@ -897,7 +924,11 @@ impl Generator for Wasmtime { } } if self.opts.custom_error { - self.src.push_str("type Error;\n"); + self.src.push_str("type Error"); + if any_async_func { + self.src.push_str(": Send + 'static"); + } + self.src.push_str(";\n"); if self.needs_custom_error_to_trap { self.src.push_str( "fn error_to_trap(&mut self, err: Self::Error) -> wasmtime::Trap;\n", @@ -946,29 +977,38 @@ impl Generator for Wasmtime { self.src.push_str("> Default for "); self.src.push_str(&module_camel); self.src.push_str("Tables {\n"); - self.src.push_str("fn default() -> Self { Self {"); + self.src.push_str("fn default() -> Self {\nSelf {\n"); for handle in self.all_needed_handles.iter() { self.src.push_str(&handle.to_snake_case()); - self.src.push_str("_table: Default::default(),"); + self.src.push_str("_table: Default::default(),\n"); } - self.src.push_str("}}}"); + self.src.push_str("}\n}\n}\n"); } } for (module, funcs) in mem::take(&mut self.imports) { let module_camel = module.to_camel_case(); let is_async = !self.opts.async_.is_none(); + self.push_str("#[allow(path_statements)]\n"); self.push_str("\npub fn add_to_linker(linker: &mut wasmtime::Linker"); self.push_str(", get: impl Fn(&mut T) -> "); - if self.all_needed_handles.is_empty() { - self.push_str("&mut U"); + + let mut get_rets = vec!["&mut U".to_string()]; + if self.all_needed_handles.len() > 0 { + get_rets.push(format!("&mut {}Tables", module_camel)); + } + if any_async_func { + get_rets.push(format!("&mut witx_bindgen_wasmtime::rt::Async")); + } + if get_rets.len() > 1 { + self.push_str(&format!("({})", get_rets.join(", "))); } else { - self.push_str(&format!("(&mut U, &mut {}Tables)", module_camel)); + self.push_str(&get_rets[0]); } - self.push_str("+ Send + Sync + Copy + 'static) -> anyhow::Result<()> \n"); + self.push_str(" + Send + Sync + Copy + 'static) -> anyhow::Result<()> \n"); self.push_str("where U: "); self.push_str(&module_camel); - if is_async { + if is_async || any_async_func { self.push_str(", T: Send,"); } self.push_str("\n{\n"); @@ -979,7 +1019,7 @@ impl Generator for Wasmtime { self.push_str("use witx_bindgen_wasmtime::rt::get_func;\n"); } for f in funcs { - let method = if f.is_async { + let method = if f.wrap_async { format!("func_wrap{}_async", f.num_wasm_params) } else { String::from("func_wrap") @@ -996,14 +1036,14 @@ impl Generator for Wasmtime { \"canonical_abi\", \"resource_drop_{name}\", move |mut caller: wasmtime::Caller<'_, T>, handle: u32| {{ - let (host, tables) = get(caller.data_mut()); - let handle = tables + let data = get(caller.data_mut()); + let handle = data.1 .{snake}_table .remove(handle) .map_err(|e| {{ wasmtime::Trap::new(format!(\"failed to remove handle: {{}}\", e)) }})?; - host.drop_{snake}(handle); + data.0.drop_{snake}(handle); Ok(()) }} )?;\n", @@ -1031,9 +1071,7 @@ impl Generator for Wasmtime { ", ); self.push_str("#[derive(Default)]\n"); - self.push_str("pub struct "); - self.push_str(&name); - self.push_str("Data {\n"); + self.push_str(&format!("pub struct {}Data {{\n", name)); for r in self.exported_resources.iter() { self.src.push_str(&format!( " @@ -1046,12 +1084,18 @@ impl Generator for Wasmtime { } self.push_str("}\n"); - self.push_str("pub struct "); - self.push_str(&name); - self.push_str(" {\n"); + let mut get_state_ret = format!("&mut {}Data", name); + let mut bind_state = "state"; + if any_async_func { + get_state_ret = + format!("({}, &mut witx_bindgen_wasmtime::Async)", get_state_ret); + bind_state = "(state, _)"; + } + + self.push_str(&format!("pub struct {} {{\n", name)); self.push_str(&format!( - "get_state: Box &mut {}Data + Send + Sync>,\n", - name + "get_state: Box {} + Send + Sync>,\n", + get_state_ret, )); for (name, (ty, _)) in exports.fields.iter() { self.push_str(name); @@ -1063,7 +1107,7 @@ impl Generator for Wasmtime { // self.push_str("buffer_glue: witx_bindgen_wasmtime::imports::BufferGlue,"); // } self.push_str("}\n"); - let bound = if self.opts.async_.is_none() { + let bound = if self.opts.async_.is_none() && !any_async_func { "" } else { ": Send" @@ -1083,10 +1127,10 @@ impl Generator for Wasmtime { /// the general store's state. pub fn add_to_linker( linker: &mut wasmtime::Linker, - get_state: impl Fn(&mut T) -> &mut {}Data + Send + Sync + Copy + 'static, + get_state: impl Fn(&mut T) -> {} + Send + Sync + Copy + 'static, ) -> anyhow::Result<()> {{ ", - name, + get_state_ret, )); for r in self.exported_resources.iter() { let (func_wrap, call, wait, prefix, suffix) = if self.opts.async_.is_none() { @@ -1106,7 +1150,7 @@ impl Generator for Wasmtime { \"canonical_abi\", \"resource_drop_{name}\", move |mut caller: wasmtime::Caller<'_, T>, idx: u32| {prefix}{{ - let state = get_state(caller.data_mut()); + let {bind_state} = get_state(caller.data_mut()); let resource_idx = state.index_slab{idx}.remove(idx)?; let wasm = match state.resource_slab{idx}.drop(resource_idx) {{ Some(wasm) => wasm, @@ -1121,7 +1165,7 @@ impl Generator for Wasmtime { \"canonical_abi\", \"resource_clone_{name}\", move |mut caller: wasmtime::Caller<'_, T>, idx: u32| {{ - let state = get_state(caller.data_mut()); + let {bind_state} = get_state(caller.data_mut()); let resource_idx = state.index_slab{idx}.get(idx)?; state.resource_slab{idx}.clone(resource_idx)?; Ok(state.index_slab{idx}.insert(resource_idx)) @@ -1131,7 +1175,7 @@ impl Generator for Wasmtime { \"canonical_abi\", \"resource_get_{name}\", move |mut caller: wasmtime::Caller<'_, T>, idx: u32| {{ - let state = get_state(caller.data_mut()); + let {bind_state} = get_state(caller.data_mut()); let resource_idx = state.index_slab{idx}.get(idx)?; Ok(state.resource_slab{idx}.get(resource_idx)) }}, @@ -1140,7 +1184,7 @@ impl Generator for Wasmtime { \"canonical_abi\", \"resource_new_{name}\", move |mut caller: wasmtime::Caller<'_, T>, val: i32| {{ - let state = get_state(caller.data_mut()); + let {bind_state} = get_state(caller.data_mut()); let resource_idx = state.resource_slab{idx}.insert(val); Ok(state.index_slab{idx}.insert(resource_idx)) }}, @@ -1153,6 +1197,23 @@ impl Generator for Wasmtime { wait = wait, prefix = prefix, suffix = suffix, + bind_state = bind_state, + )); + } + if iface.functions.iter().any(|f| f.is_async) { + self.src.push_str(&format!( + " + linker.func_wrap( + \"canonical_abi\", + \"async_export_done\", + move |mut caller: wasmtime::Caller<'_, T>, cx: i32, ptr: i32| {{ + let memory = witx_bindgen_wasmtime::rt::get_memory(&mut caller, \"memory\")?; + let (memory, state) = memory.data_and_store_mut(&mut caller); + let (_, async_) = get_state(state); + async_.async_export_done(cx, ptr, memory) + }}, + )?; + ", )); } // if self.needs_buffer_glue { @@ -1222,14 +1283,14 @@ impl Generator for Wasmtime { mut store: impl wasmtime::AsContextMut, module: &wasmtime::Module, linker: &mut wasmtime::Linker, - get_state: impl Fn(&mut T) -> &mut {}Data + Send + Sync + Copy + 'static, + get_state: impl Fn(&mut T) -> {} + Send + Sync + Copy + 'static, ) -> anyhow::Result<(Self, wasmtime::Instance)> {{ Self::add_to_linker(linker, get_state)?; let instance = linker.instantiate{}(&mut store, module){}?; Ok((Self::new(store, &instance,get_state)?, instance)) }} ", - async_fn, name, instantiate, wait, + async_fn, get_state_ret, instantiate, wait, )); self.push_str(&format!( @@ -1245,10 +1306,10 @@ impl Generator for Wasmtime { pub fn new( mut store: impl wasmtime::AsContextMut, instance: &wasmtime::Instance, - get_state: impl Fn(&mut T) -> &mut {}Data + Send + Sync + Copy + 'static, + get_state: impl Fn(&mut T) -> {} + Send + Sync + Copy + 'static, ) -> anyhow::Result {{ ", - name, + get_state_ret, )); self.push_str("let mut store = store.as_context_mut();\n"); assert!(!self.needs_get_func); @@ -1262,16 +1323,27 @@ impl Generator for Wasmtime { for r in self.exported_resources.iter() { self.src.push_str(&format!( " - get_state(store.data_mut()).dtor{} = \ - Some(instance.get_typed_func::(\ - &mut store, \ - \"canonical_abi_drop_{}\", \ - )?);\n + let dtor = instance.get_typed_func::(\ + &mut store, \ + \"canonical_abi_drop_{name}\", \ + )?; + let {bind_state} = get_state(store.data_mut()); + state.dtor{idx} = Some(dtor); ", - r.index(), - iface.resources[*r].name, + idx = r.index(), + name = iface.resources[*r].name, + bind_state = bind_state, )); } + if any_async_func { + self.push_str( + " + let table = instance.get_table(&mut store, \"__indirect_function_table\") + .ok_or_else(|| wasmtime::Trap::new(\"no exported function table\"))?; + get_state(store.data_mut()).1.set_table(table); + ", + ); + } self.push_str("Ok("); self.push_str(&name); self.push_str("{\n"); @@ -1308,12 +1380,12 @@ impl Generator for Wasmtime { val: {name_camel}, ) -> Result<(), wasmtime::Trap> {{ let mut store = store.as_context_mut(); - let data = (self.get_state)(store.data_mut()); - let wasm = match data.resource_slab{idx}.drop(val.0) {{ + let {bind_state} = (self.get_state)(store.data_mut()); + let wasm = match state.resource_slab{idx}.drop(val.0) {{ Some(val) => val, None => return Ok(()), }}; - data.dtor{idx}.unwrap().{call}(&mut store, wasm){wait}?; + state.dtor{idx}.unwrap().{call}(&mut store, wasm){wait}?; Ok(()) }} ", @@ -1323,6 +1395,7 @@ impl Generator for Wasmtime { async = async_fn, call = call, wait = wait, + bind_state = bind_state, )); } @@ -1452,7 +1525,7 @@ impl FunctionBindgen<'_> { self.push_str( "let (caller_memory, data) = memory.data_and_store_mut(&mut caller);\n", ); - self.push_str("let (_, _tables) = get(data);\n"); + self.push_str("let _tables = &mut get(data).1;\n"); } else { self.push_str("let caller_memory = memory.data_mut(&mut caller);\n"); } @@ -1505,6 +1578,22 @@ impl FunctionBindgen<'_> { mem, operands[1], offset, method, operands[0], extra )); } + + fn bind_results(&mut self, amt: usize, results: &mut Vec) { + if amt == 0 { + return; + } + + let tmp = self.tmp(); + self.push_str("let ("); + for i in 0..amt { + let arg = format!("result{}_{}", tmp, i); + self.push_str(&arg); + self.push_str(","); + results.push(arg); + } + self.push_str(") = "); + } } impl RustFunctionGenerator for FunctionBindgen<'_> { @@ -1685,27 +1774,35 @@ impl Bindgen for FunctionBindgen<'_> { } } Instruction::I32FromBorrowedHandle { ty } => { + let any_async = iface.functions.iter().any(|f| f.is_async); let tmp = self.tmp(); self.push_str(&format!( " let obj{tmp} = {op}; - (self.get_state)(caller.as_context_mut().data_mut()).resource_slab{idx}.clone(obj{tmp}.0)?; - let handle{tmp} = (self.get_state)(caller.as_context_mut().data_mut()).index_slab{idx}.insert(obj{tmp}.0); + let {bind_state} = (self.get_state)(caller.data_mut()); + state.resource_slab{idx}.clone(obj{tmp}.0)?; + let handle{tmp} = state.index_slab{idx}.insert(obj{tmp}.0); ", tmp = tmp, idx = ty.index(), op = operands[0], + bind_state = if any_async { "(state, _)" } else { "state" }, )); results.push(format!("handle{} as i32", tmp,)); } Instruction::HandleOwnedFromI32 { ty } => { + let any_async = iface.functions.iter().any(|f| f.is_async); let tmp = self.tmp(); self.push_str(&format!( - "let handle{} = (self.get_state)(caller.as_context_mut().data_mut()).index_slab{}.remove({} as u32)?;\n", + " + let {bind_state} = (self.get_state)(caller.data_mut()); + let handle{} = state.index_slab{}.remove({} as u32)?; + ", tmp, ty.index(), operands[0], + bind_state = if any_async { "(state, _)" } else { "state" }, )); let name = iface.resources[*ty].name.to_camel_case(); @@ -2067,17 +2164,7 @@ impl Bindgen for FunctionBindgen<'_> { name, sig, } => { - if sig.results.len() > 0 { - let tmp = self.tmp(); - self.push_str("let ("); - for i in 0..sig.results.len() { - let arg = format!("result{}_{}", tmp, i); - self.push_str(&arg); - self.push_str(","); - results.push(arg); - } - self.push_str(") = "); - } + self.bind_results(sig.results.len(), results); self.push_str("self."); self.push_str(&to_rust_ident(name)); if self.gen.opts.async_.includes(name) { @@ -2100,7 +2187,87 @@ impl Bindgen for FunctionBindgen<'_> { } Instruction::CallWasmAsyncImport { .. } => unimplemented!(), - Instruction::CallWasmAsyncExport { .. } => unimplemented!(), + + Instruction::CallWasmAsyncExport { + module: _, + name, + params: _, + results: wasm_results, + } => { + self.push_str(&format!( + "let mut raw_results = [0; {}];\n", + wasm_results.len() + )); + + // Move the func out of `self` since it's captured in the + // `'static` future and `self` isn't `'static`. + self.push_str(&format!("let wasm_func = self.{};\n", to_rust_ident(name))); + + // Move the arguments into their own temporaries. These are all + // scalar expressions but sometimes their results reference + // local variables. These all get captured in the `'static` + // future as well so we need to make sure no local variables are + // captured there. + let tmp = self.tmp(); + let mut args = String::new(); + for (i, arg) in operands.iter().enumerate() { + self.push_str(&format!("let arg{}_{} = {};\n", tmp, i, arg)); + args.push_str(&format!("arg{}_{}, ", tmp, i)); + } + args.push_str("async_cx,"); + + // Start the async call with various parameters passed in... + self.push_str( + "witx_bindgen_wasmtime::rt::Async::call_async_export(\ + &mut caller, \ + &mut raw_results, \ + &|t| (self.get_state)(t).1, \ + |caller, async_cx| {\n\ + ", + ); + + // Delegate to `call_async` or `call` as appropriate. + if self.gen.opts.async_.includes(name) { + self.push_str(&format!( + "Box::pin(async move {{ wasm_func.call_async(caller, ({})).await }})", + args, + )); + } else { + self.push_str(&format!( + "Box::pin(async move {{ wasm_func.call(caller, ({})) }})", + args, + )); + } + // Close the async call + self.push_str("\n}).await?;\n"); + + // Read all the results from the `raw_results` array, + // interpreting the 64-bit values as appropriate. + let tmp = self.tmp(); + for (i, ty) in wasm_results.iter().enumerate() { + let name = format!("result{}_{}", tmp, i); + self.push_str(&format!("let {} = ", name)); + results.push(name); + match ty { + WasmType::I32 => { + self.push_str(&format!("raw_results[{}] as i32", i)); + } + WasmType::I64 => { + self.push_str(&format!("raw_results[{}]", i)); + } + WasmType::F32 => { + self.push_str(&format!("f32::from_bits(raw_results[{}] as u32)", i)); + } + WasmType::F64 => { + self.push_str(&format!("f64::from_bits(raw_results[{}] as u64)", i)); + } + } + self.push_str(";\n"); + } + + self.after_call = true; + self.caller_memory_available = false; // invalidated by call + } Instruction::CallInterface { module: _, func } => { for (i, operand) in operands.iter().enumerate() { @@ -2134,7 +2301,30 @@ impl Bindgen for FunctionBindgen<'_> { call.push_str(&format!("param{}, ", i)); } call.push_str(")"); - if self.gen.opts.async_.includes(&func.name) { + + // If this is itself an async function then the future is first + // created. The actual await-ing happens inside of a separate + // future we create here and pass to the `_async_cx` which will + // manage execution of the future in connection with the + // original invocation of an async export. + // + // The `future` is `await`'d initially and its results are then + // moved into a completion callback which is processed once the + // store is available again. + if func.is_async { + self.push_str("let future = "); + self.push_str(&call); + self.push_str(";\n"); + self.push_str("_async_cx.register_async_import(async move {\n"); + self.push_str("let result = future.await;\n"); + call = format!("result"); + self.push_str( + "witx_bindgen_wasmtime::rt::box_future_callback(move |mut caller| {\n", + ); + self.push_str("witx_bindgen_wasmtime::rt::pin_result_future(async move {\n"); + self.push_str("let host = &mut get(caller.data_mut()).0;\n"); + self.push_str("drop(&mut *host);\n"); // ignore unused variable + } else if self.gen.opts.async_.includes(&func.name) { call.push_str(".await"); } @@ -2197,7 +2387,61 @@ impl Bindgen for FunctionBindgen<'_> { } Instruction::ReturnAsyncExport { .. } => unimplemented!(), - Instruction::ReturnAsyncImport { .. } => unimplemented!(), + + Instruction::CompletionCallback { params, .. } => { + let mut tys = String::new(); + tys.push_str("i32,"); + for param in params.iter() { + tys.push_str(wasm_type(*param)); + tys.push_str(", "); + } + self.closures.push_str(&format!( + "\ + let completion_callback = get(caller.data_mut()).{async_idx} + .table() + .get(&mut caller, {idx} as u32) + .ok_or_else(|| wasmtime::Trap::new(\"invalid function index\"))? + .funcref() + .ok_or_else(|| wasmtime::Trap::new(\"not a funcref table\"))? + .ok_or_else(|| wasmtime::Trap::new(\"callback was a null function\"))? + .typed::<({tys}), (), _>(&caller)?; + ", + idx = operands[0], + tys = tys, + async_idx = if self.gen.all_needed_handles.len() > 0 { + 2 + } else { + 1 + }, + )); + results.push(format!("completion_callback")); + } + + Instruction::ReturnAsyncImport { .. } => { + let mut result = operands[1..].join(", "); + result.push_str(","); + self.push_str(&format!("let ret = ({});\n", result)); + if let Some(cleanup) = self.cleanup.take() { + self.push_str(&cleanup); + } + + self.push_str(&operands[0]); + if self.gen.opts.async_.is_none() { + self.push_str(".call"); + } else { + self.push_str(".call_async"); + } + self.push_str("(&mut caller, ret)"); + if !self.gen.opts.async_.is_none() { + self.push_str(".await"); + } + self.push_str("\n"); + + // finish the `pin_result_future(...)` + self.push_str("})\n"); + // finish the `box_future_callback(...)` + self.push_str("})\n"); + } Instruction::I32Load { offset } => results.push(self.load(*offset, "i32", operands)), Instruction::I32Load8U { offset } => { diff --git a/crates/gen-wasmtime/tests/codegen.rs b/crates/gen-wasmtime/tests/codegen.rs index 1941d7ea6..cce1256cb 100644 --- a/crates/gen-wasmtime/tests/codegen.rs +++ b/crates/gen-wasmtime/tests/codegen.rs @@ -9,9 +9,6 @@ mod exports { test_helpers::codegen_wasmtime_export!( "*.witx" - // TODO: implement async support - "!async_functions.witx" - // If you want to exclude a specific test you can include it here with // gitignore glob syntax: // @@ -28,9 +25,6 @@ mod imports { test_helpers::codegen_wasmtime_import!( "*.witx" - // TODO: implement async support - "!async_functions.witx" - // TODO: these use push/pull buffer which isn't implemented in the test // generator just yet "!wasi_next.witx" diff --git a/crates/wasmtime/src/futures.rs b/crates/wasmtime/src/futures.rs new file mode 100644 index 000000000..5ca912428 --- /dev/null +++ b/crates/wasmtime/src/futures.rs @@ -0,0 +1,295 @@ +use crate::rt::RawMem; +use crate::slab::Slab; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use wasmtime::Table; +use wasmtime::{StoreContextMut, Trap}; + +pub struct Async { + exports: Slab, + + table: Option, + + /// List of imports that we're waiting on. + /// + /// This is a list of async imports that have been called as part of calling + /// wasm and are registered here. When these imports complete they produce a + /// result which then itself produces another future. The result is given a + /// `StoreContextMut` and is expected to further execute WebAssembly, + /// translating the results of the async host import to wasm and then + /// invoking the wasm completion callback. When the wasm completion callback + /// is finished then the future is complete. + // + // TODO: should this be in `FutureState` because imports-called are a + // per-export thing? + imports: Vec> + Send>>>, +} + +impl Default for Async { + fn default() -> Async { + Async { + exports: Slab::default(), + imports: Vec::new(), + table: None, + } + } +} + +struct FutureState { + results: Vec, // TODO: shouldn't need to heap-allocate this + done: bool, +} + +pub type HostFuture = Pin + Send>>; + +/// The result of a host import. This is mostly synthesized by bindings and +/// represents that a host import produces a closure. The closure is given +/// context to execute WebAssembly and then the execution itself results in a +/// future. This returned future represents the completion of the WebAssembly +/// itself. +pub type ImportResult = Box< + dyn for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + ) -> Pin> + Send + 'a>> + + Send, +>; + +impl Async { + /// Implementation of the `async_export_done` canonical ABI function. + /// + /// The first two parameters are provided by wasm itself, and the `mem` is + /// the wasm's linear memory. The first parameter `cx` is the original value + /// returned by `start_async_export` and indicates which call to which + /// export is being completed. The `ptr` is a pointer into `mem` where the + /// encoded results are located. + pub fn async_export_done(&mut self, cx: i32, ptr: i32, mem: &[u8]) -> Result<(), Trap> { + let cx = cx as u32; + let dst = self + .exports + .get_mut(cx) + .ok_or_else(|| Trap::new("async context not valid"))?; + if dst.done { + return Err(Trap::new("async context not valid")); + } + dst.done = true; + for slot in dst.results.iter_mut() { + let ptr = (ptr as u32) + .checked_add(8) + .ok_or_else(|| Trap::new("pointer to async completion not valid"))?; + *slot = mem.load(ptr as i32)?; + } + Ok(()) + } + + /// Registers a new future returned from an async import. + /// + /// This function is used when an async import is invoked by wasm. The + /// asynchronous import is represented as a future and when the future + /// completes it needs to call the completion callback in WebAssembly. The + /// invocation of the completion callback is represented by the output of + /// the future here, the `ImportResult` which is a closure that takes a + /// store context and invokes WebAssembly (further in an async fashion). + /// + /// Note that this doesn't actually do anything, it simply enqueues the + /// future internally. The future will actually be driven from the + /// `wait_for_async_export` function below. + pub fn register_async_import( + &mut self, + future: impl Future> + Send + 'static, + ) { + self.imports.push(Box::pin(future)); + } + + /// Blocks on the completion of an asynchronous export. + /// + /// This function is used to await the result of an async export. In other + /// words this is used to wait for wasm to invoke the completion callback + /// with the `async_cx` specified. + /// + /// This will "block" for one of two reasons: + /// + /// * First is that an async import was called and the wasm's completion + /// callback wasn't called yet. In this scenario this function will block + /// on the completion of the async import. + /// + /// * Second is the execution of the wasm's own import completion callback. + /// This execution of WebAssembly may be asynchronous due to things like + /// fuel context switching or similar. + /// + /// This function invokes WebAssembly within `cx` and will not return until + /// the completion callback for `async_cx` is invoked. When the completion + /// callback is invoked the results of the callback are written into + /// `results`. The `get_state` method is used to extract an `Async` from + /// the store state within `cx`. + /// + /// This returns `Ok(())` when the completion callback was successfully + /// invoked, but it may also return `Err(trap)` if a trap happens while + /// executing a wasm completion callback for an import. + pub async fn call_async_export( + cx: &mut StoreContextMut<'_, T>, + results: &mut [i64], + get_state: &(dyn Fn(&mut T) -> &mut Async + Send + Sync), + invoke_wasm: impl for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + i32, + ) + -> Pin> + Send + 'a>>, + ) -> Result<(), Trap> { + // First register a new export happening in our slab of running + // `exports` futures. + // + // NB: at this time due to take `&mut StoreContextMut` as an argument to + // this function it means that the size of `exports` is at most one. In + // the future this will probably take some sort of async mutex and only + // hold the mutex when wasm is running to allow concurrent execution of + // wasm. + let async_cx = get_state(cx.data_mut()).exports.insert(FutureState { + results: vec![0; results.len()], + done: false, + }); + + // Once the registration is made we immediately construct the + // `WaitForAsyncExport` helper struct. The destructor of this struct + // will forcibly remove the registration we just made above to prevent + // leaking anything if the wasm future is dropped and forgotten about. + let waiter = WaitForAsyncExport { + cx, + async_cx, + get_state, + }; + + // Now that things are set up this is the original invocation of + // WebAssembly. This invocation is itself asynchronous hence we await + // the result here. + invoke_wasm(waiter.cx, async_cx as i32).await?; + + // Once we've invoked the export then it's our job to wait for the + // `async_export_done` function to get invoked. That happens here as we + // observer the state of `async_cx` within `state.exports`. If it's not + // finished yet then that means that we need to wait for the next of a + // set of futures to complete (those in `state.imports`), which is + // deferred to the `WaitForNextFuture` helper struct. + loop { + let state = (waiter.get_state)(waiter.cx.data_mut()); + if state.exports.get(async_cx).unwrap().done { + break; + } + + let result = WaitForNextFuture { state }.await?; + + // TODO: while this is executing we're not polling the futures + // inside of `Async`, is that ok? Will this need to poll the + // future inside the state in parallel with the async wasm here to + // ensure that things work out as expected. + result(waiter.cx).await?; + } + + // If we're here then that means that the `async_export_done` function + // was called, which means taht we can copy the results into the final + // `results` slice. + let future = (waiter.get_state)(waiter.cx.data_mut()) + .exports + .get(async_cx) + .unwrap(); + assert_eq!(results.len(), future.results.len()); + results.copy_from_slice(&future.results); + return Ok(()); + + /// This is a helper struct used to remove `async_cx` from `cx` on drop. + /// + /// This ensures that if any wasm returns a trap or if the future itself + /// is entirely dropped that we properly clean things up and don't leak + /// the export's async status and allow it to accidentally be + /// "completed" by someone else. + struct WaitForAsyncExport<'a, 'b, 'c, 'd, T> { + cx: &'a mut StoreContextMut<'b, T>, + async_cx: u32, + get_state: &'c (dyn Fn(&mut T) -> &mut Async + Send + Sync + 'd), + } + + impl Drop for WaitForAsyncExport<'_, '_, '_, '_, T> { + fn drop(&mut self) { + (self.get_state)(self.cx.data_mut()) + .exports + .remove(self.async_cx) + .unwrap(); + } + } + } + + /// Returns the previously configured function table via `set_table`. + // + // TODO: this probably isn't the right interface, need to figure out a + // better way to pass this table (and other intrinsics to the wasm instance) + // around. + pub fn table(&self) -> Table { + self.table.expect("table wasn't set yet") + } + + /// Stores a table to later get returned by `table()`. + // + // TODO: like `table`, this probably isn't the right interface + pub fn set_table(&mut self, table: Table) { + assert!(self.table.is_none(), "table already set"); + self.table = Some(table); + } +} + +struct WaitForNextFuture<'a, T> { + state: &'a mut Async, +} + +impl Future for WaitForNextFuture<'_, T> { + type Output = Result, Trap>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // If there aren't any active imports then that means that WebAssembly + // hasn't invoked the completion callback for an export but it also + // didn't call any imports to block on anything. That means that the + // completion callback will never be called which is an error, so + // simulate a trap happening and report it to the embedder. + if self.state.imports.len() == 0 { + return Poll::Ready(Err(Trap::new( + "wasm isn't waiting on any imports but export completion callback wasn't called", + ))); + } + + // If we have imports then we'll "block" this current future on one of + // the sub-futures within `self.state.imports`. By polling at least one + // future that means we'll get re-awakened whenever the sub-future is + // ready and we'll check here again. + // + // If anything is ready we return the first item that we get. This means + // that if any import has its result ready then we propagate the result + // outwards which will invoke the completion callback for that import's + // execution. If, after running the import completion callback, the + // export completion callback still hasn't been invoked then we'll come + // back here and look for other finished imports. + // + // TODO: this can theoretically exhibit quadratic behavior if the wasm + // calls tons and tons of imports. This should use a more intelligent + // future-polling mechanism to avoid re-polling everything we're not + // interested in every time. + for (i, import) in self.state.imports.iter_mut().enumerate() { + match import.as_mut().poll(cx) { + Poll::Ready(value) => { + drop(self.state.imports.swap_remove(i)); + return Poll::Ready(Ok(value)); + } + Poll::Pending => {} + } + } + + Poll::Pending + } +} + +fn _assert() { + fn _assert_send(_: &T) {} + + fn _test(x: &mut StoreContextMut<'_, ()>) { + let f = Async::<()>::call_async_export(x, &mut [], &|_| panic!(), |_, _| panic!()); + _assert_send(&f); + } +} diff --git a/crates/wasmtime/src/lib.rs b/crates/wasmtime/src/lib.rs index 043ab2297..3610de584 100644 --- a/crates/wasmtime/src/lib.rs +++ b/crates/wasmtime/src/lib.rs @@ -7,8 +7,11 @@ pub use tracing_lib as tracing; #[doc(hidden)] pub use {anyhow, bitflags, wasmtime}; +pub use futures::{Async, HostFuture}; + mod error; pub mod exports; +mod futures; pub mod imports; mod le; mod region; @@ -32,6 +35,7 @@ unsafe impl Sync for RawMemory {} #[doc(hidden)] pub mod rt { + pub use crate::futures::Async; use crate::slab::Slab; use crate::{Endian, Le}; use std::mem; @@ -257,4 +261,26 @@ pub mod rt { Some(resource.wasm) } } + + use std::future::Future; + use std::pin::Pin; + + /// Helper function to assist with type inference for async bits + pub fn pin_result_future<'a>( + future: impl Future> + Send + 'a, + ) -> Pin> + Send + 'a>> { + Box::pin(future) + } + + /// Helper function to assist with type inference for async bits + pub fn box_future_callback( + callback: impl for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + ) + -> Pin> + Send + 'a>> + + Send + + 'static, + ) -> crate::futures::ImportResult { + Box::new(callback) + } } diff --git a/crates/witx2/src/abi.rs b/crates/witx2/src/abi.rs index e4a01eaba..fd009e74e 100644 --- a/crates/witx2/src/abi.rs +++ b/crates/witx2/src/abi.rs @@ -662,6 +662,20 @@ def_instruction! { /// the `async_export_done` intrinsic in the `canonical_abi` module. ReturnAsyncExport { func: &'a Function } : [2] => [0], + /// Validates a completion callback index as provided by wasm. + /// + /// This takes an `i32` argument which was provided by WebAssembly as an + /// index into the function table. This index should be a valid index + /// pointing to a valid function. The function should take the `params` + /// specified plus a leading `i32` parameter. The function should return + /// no values. + /// + /// This instruction should push an expression representing the + /// function, and the expression is later used as the first argument to + /// `ReturnAsyncImport` to actually get invoked in a later async + /// context. + CompletionCallback { func: &'a Function, params: &'a [WasmType] } : [1] => [1], + /// "Returns" from an asynchronous import. /// /// This is only used for host modules at this time, and @@ -1373,6 +1387,21 @@ impl<'a, B: Bindgen> Generator<'a, B> { self.emit(&Instruction::GetArg { nth }); } + // If we're invoking a completion callback then allow codegen to + // front-load validation of the function pointer argument to + // ensure we can continue successfully once we've committed to + // translating all the arguments and calling the host function. + let callback = if func.is_async && self.dir == Direction::Import { + self.emit(&Instruction::GetArg { + nth: sig.params.len() - 2, + }); + let params = sig.retptr.as_ref().unwrap(); + self.emit(&Instruction::CompletionCallback { func, params }); + Some(self.stack.pop().unwrap()) + } else { + None + }; + // Once everything is on the stack we can lift all arguments // one-by-one into their interface-types equivalent. self.lift_all(&func.params); @@ -1393,10 +1422,8 @@ impl<'a, B: Bindgen> Generator<'a, B> { Direction::Import => { assert_eq!(self.stack.len(), tys.len()); let operands = mem::take(&mut self.stack); - // function index to call - self.emit(&Instruction::GetArg { - nth: sig.params.len() - 2, - }); + // wasm function to call + self.stack.extend(callback); // environment for the function self.emit(&Instruction::GetArg { nth: sig.params.len() - 1, diff --git a/tests/runtime/async_functions/host.rs b/tests/runtime/async_functions/host.rs new file mode 100644 index 000000000..0f5fef5a5 --- /dev/null +++ b/tests/runtime/async_functions/host.rs @@ -0,0 +1,148 @@ +witx_bindgen_wasmtime::import!("./tests/runtime/async_functions/imports.witx"); + +use anyhow::Result; +use futures_channel::oneshot::{channel, Receiver, Sender}; +use futures_util::FutureExt; +use imports::*; +use wasmtime::{Engine, Linker, Module, Store}; +use witx_bindgen_wasmtime::{Async, HostFuture}; + +#[derive(Default)] +pub struct MyImports { + thunk_hit: bool, + unblock1: Option>, + unblock2: Option>, + unblock3: Option>, + wait1: Option>, + wait2: Option>, + wait3: Option>, +} + +impl Imports for MyImports { + fn thunk(&mut self) -> HostFuture<()> { + self.thunk_hit = true; + Box::pin(async { + async {}.await; + }) + } + + fn concurrent1(&mut self, a: u32) -> HostFuture { + assert_eq!(a, 1); + self.unblock1.take().unwrap().send(()).unwrap(); + let wait = self.wait1.take().unwrap(); + Box::pin(async move { + wait.await.unwrap(); + a + 10 + }) + } + + fn concurrent2(&mut self, a: u32) -> HostFuture { + assert_eq!(a, 2); + self.unblock2.take().unwrap().send(()).unwrap(); + let wait = self.wait2.take().unwrap(); + Box::pin(async move { + wait.await.unwrap(); + a + 10 + }) + } + + fn concurrent3(&mut self, a: u32) -> HostFuture { + assert_eq!(a, 3); + self.unblock3.take().unwrap().send(()).unwrap(); + let wait = self.wait3.take().unwrap(); + Box::pin(async move { + wait.await.unwrap(); + a + 10 + }) + } +} + +witx_bindgen_wasmtime::export!("./tests/runtime/async_functions/exports.witx"); + +fn run(wasm: &str) -> Result<()> { + struct Context { + wasi: wasmtime_wasi::WasiCtx, + imports: MyImports, + async_: Async, + exports: exports::ExportsData, + } + + let engine = Engine::default(); + let module = Module::from_file(&engine, wasm)?; + let mut linker = Linker::::new(&engine); + imports::add_imports_to_linker(&mut linker, |cx| (&mut cx.imports, &mut cx.async_))?; + wasmtime_wasi::add_to_linker(&mut linker, |cx| &mut cx.wasi)?; + + let mut store = Store::new( + &engine, + Context { + wasi: crate::default_wasi(), + imports: MyImports::default(), + async_: Default::default(), + exports: Default::default(), + }, + ); + let (exports, _instance) = + exports::Exports::instantiate(&mut store, &module, &mut linker, |cx| { + (&mut cx.exports, &mut cx.async_) + })?; + + let import = &mut store.data_mut().imports; + + // Initialize various channels which we use as synchronization points to + // test the concurrent aspect of async wasm. The first channels here are + // used to wait for host functions to get entered by the wasm. + let (a, mut wait1) = channel(); + import.unblock1 = Some(a); + let (a, mut wait2) = channel(); + import.unblock2 = Some(a); + let (a, mut wait3) = channel(); + import.unblock3 = Some(a); + + // This second set of channels are used to unblock host futures that wasm + // calls, simulating work that returns back to the host and takes some time + // to complete. + let (unblock1, b) = channel(); + import.wait1 = Some(b); + let (unblock2, b) = channel(); + import.wait2 = Some(b); + let (unblock3, b) = channel(); + import.wait3 = Some(b); + + futures_executor::block_on(async { + exports.thunk(&mut store).await?; + assert!(store.data_mut().imports.thunk_hit); + + let mut future = Box::pin(exports.test_concurrent(&mut store)).fuse(); + + // wait for all three concurrent methods to get entered. Note that we + // poll the `future` while we're here as well to run any callbacks + // inside as necessary, but it shouldn't ever finish. + let mut done = 0; + while done < 3 { + futures_util::select! { + _ = future => unreachable!(), + r = wait1 => { r.unwrap(); done += 1; } + r = wait2 => { r.unwrap(); done += 1; } + r = wait3 => { r.unwrap(); done += 1; } + } + } + + // Now we can "complete" the async task that each function was waiting + // on. Our original future shouldn't be done until they're all complete. + unblock3.send(()).unwrap(); + futures_util::select! { + _ = future => unreachable!(), + default => {} + } + unblock2.send(()).unwrap(); + futures_util::select! { + _ = future => unreachable!(), + default => {} + } + unblock1.send(()).unwrap(); + future.await?; + + Ok(()) + }) +} From d30e12429c3e24f1fae66290823e8cfdd6de3d11 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Wed, 6 Oct 2021 14:36:33 -0700 Subject: [PATCH 2/5] Add more Send/Sync as necessary for `async` func --- crates/gen-wasmtime/src/lib.rs | 2 +- tests/codegen/async_functions.witx | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/crates/gen-wasmtime/src/lib.rs b/crates/gen-wasmtime/src/lib.rs index 2438813ce..144d9f2b9 100644 --- a/crates/gen-wasmtime/src/lib.rs +++ b/crates/gen-wasmtime/src/lib.rs @@ -917,7 +917,7 @@ impl Generator for Wasmtime { self.src.push_str("type "); self.src.push_str(&handle.to_camel_case()); self.src.push_str(": std::fmt::Debug"); - if is_async { + if is_async || any_async_func { self.src.push_str(" + Send + Sync"); } self.src.push_str(";\n"); diff --git a/tests/codegen/async_functions.witx b/tests/codegen/async_functions.witx index e713eded4..53f0e531c 100644 --- a/tests/codegen/async_functions.witx +++ b/tests/codegen/async_functions.witx @@ -5,3 +5,12 @@ async_results: async function() -> (u32, string, list) resource async_resource { frob: async function() } + +fetch: async function(url: string) -> Response + +resource Response { + body: async function() -> list + status: function() -> u32 + status_text: function() -> string + +} From 93d3dd1dfb853dc723a0ac36e4197b1949eb001e Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Wed, 6 Oct 2021 14:42:36 -0700 Subject: [PATCH 3/5] Fix some lifetimes issues in generated async code * Associated types need `'static` to live across a captured future * Caller-source data is rebound when we reenter after the caller is available again. --- crates/gen-wasmtime/src/lib.rs | 36 ++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/crates/gen-wasmtime/src/lib.rs b/crates/gen-wasmtime/src/lib.rs index 144d9f2b9..56e2d11b3 100644 --- a/crates/gen-wasmtime/src/lib.rs +++ b/crates/gen-wasmtime/src/lib.rs @@ -215,6 +215,21 @@ impl Wasmtime { self.needs_custom_error_to_trap = true; FunctionRet::CustomToTrap } + + fn rebind_host(&self, iface: &Interface) -> Option { + let mut rebind = String::new(); + if self.all_needed_handles.len() > 0 { + rebind.push_str("_tables, "); + } + if iface.functions.iter().any(|f| f.is_async) { + rebind.push_str("_async_cx, "); + } + if rebind != "" { + Some(format!("let (host, {}) = host;\n", rebind)) + } else { + None + } + } } impl RustGenerator for Wasmtime { @@ -742,17 +757,8 @@ impl Generator for Wasmtime { } else { self.src.push_str("let host = get(caller.data_mut());\n"); } - - let mut rebind = String::new(); - if self.all_needed_handles.len() > 0 { - rebind.push_str("_tables, "); - } - if iface.functions.iter().any(|f| f.is_async) { - rebind.push_str("_async_cx, "); - } - if rebind != "" { - self.src - .push_str(&format!("let (host, {}) = host;\n", rebind)); + if let Some(rebind) = self.rebind_host(iface) { + self.src.push_str(&rebind); } self.src.push_str(&String::from(src)); @@ -920,6 +926,9 @@ impl Generator for Wasmtime { if is_async || any_async_func { self.src.push_str(" + Send + Sync"); } + if any_async_func { + self.src.push_str(" + 'static"); + } self.src.push_str(";\n"); } } @@ -2322,7 +2331,10 @@ impl Bindgen for FunctionBindgen<'_> { "witx_bindgen_wasmtime::rt::box_future_callback(move |mut caller| {\n", ); self.push_str("witx_bindgen_wasmtime::rt::pin_result_future(async move {\n"); - self.push_str("let host = &mut get(caller.data_mut()).0;\n"); + self.push_str("let host = get(caller.data_mut());\n"); + if let Some(rebind) = self.gen.rebind_host(iface) { + self.push_str(&rebind); + } self.push_str("drop(&mut *host);\n"); // ignore unused variable } else if self.gen.opts.async_.includes(&func.name) { call.push_str(".await"); From 388f91c9510bfb79d5343642e010215f5bf482f3 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Thu, 21 Oct 2021 13:52:14 -0700 Subject: [PATCH 4/5] Lots of updates. This is a very large number of updates which are borne out of recent discussions about the async model and how best to map it to Wasmtime. This resulted in a complete rearchitecture of Wasmtime's support for `async function` where the store is now owned by a "reactor" task and bindings are specifically for Tokio. More tests have been added, particularly around various failure conditions, and the JS async function bindings were also updated to handle more errors. --- Cargo.lock | 48 +- crates/gen-js/src/lib.rs | 215 +++-- crates/gen-wasmtime/Cargo.toml | 4 +- crates/gen-wasmtime/src/lib.rs | 512 ++++++---- crates/test-rust-wasm/Cargo.toml | 4 + crates/test-rust-wasm/src/bin/async_raw.rs | 3 + crates/wasmtime/Cargo.toml | 3 +- crates/wasmtime/src/futures.rs | 1001 +++++++++++++++----- crates/wasmtime/src/lib.rs | 55 +- crates/wasmtime/src/slab.rs | 11 + tests/codegen/async_functions.witx | 5 +- tests/runtime/async_functions/exports.witx | 8 + tests/runtime/async_functions/host.rs | 338 +++++-- tests/runtime/async_functions/host.ts | 26 + tests/runtime/async_functions/imports.witx | 6 + tests/runtime/async_functions/wasm.rs | 23 + tests/runtime/async_raw/exports.witx | 12 + tests/runtime/async_raw/host.rs | 161 ++++ tests/runtime/async_raw/host.ts | 54 ++ tests/runtime/async_raw/imports.witx | 1 + tests/runtime/async_raw/wasm.rs | 82 ++ 21 files changed, 1946 insertions(+), 626 deletions(-) create mode 100644 crates/test-rust-wasm/src/bin/async_raw.rs create mode 100644 tests/runtime/async_raw/exports.witx create mode 100644 tests/runtime/async_raw/host.rs create mode 100644 tests/runtime/async_raw/host.ts create mode 100644 tests/runtime/async_raw/imports.witx create mode 100644 tests/runtime/async_raw/wasm.rs diff --git a/Cargo.lock b/Cargo.lock index a2a6b0a66..3668b99da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -570,32 +570,12 @@ dependencies = [ "winapi", ] -[[package]] -name = "futures-channel" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5da6ba8c3bb3c165d3c7319fc1cc8304facf1fb8db99c5de877183c08a273888" -dependencies = [ - "futures-core", -] - [[package]] name = "futures-core" version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "88d1c26957f23603395cd326b0ffe64124b818f4449552f960d815cfba83a53d" -[[package]] -name = "futures-executor" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45025be030969d763025784f7f355043dc6bc74093e4ecc5000ca4dc50d8745c" -dependencies = [ - "futures-core", - "futures-task", - "futures-util", -] - [[package]] name = "futures-macro" version = "0.3.17" @@ -1477,6 +1457,29 @@ dependencies = [ "once_cell", ] +[[package]] +name = "tokio" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2c2416fdedca8443ae44b4527de1ea633af61d8f7169ffa6e72c5b53d24efcc" +dependencies = [ + "autocfg", + "num_cpus", + "pin-project-lite", + "tokio-macros", +] + +[[package]] +name = "tokio-macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2dd85aeaba7b68df939bd357c6afb36c87951be9e80bf9c859f2fc3e9fca0fd" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "toml" version = "0.5.8" @@ -2136,12 +2139,10 @@ name = "witx-bindgen-gen-wasmtime" version = "0.1.0" dependencies = [ "anyhow", - "futures-channel", - "futures-executor", - "futures-util", "heck", "structopt", "test-helpers", + "tokio", "wasmtime", "wasmtime-wasi", "witx-bindgen-gen-core", @@ -2185,6 +2186,7 @@ dependencies = [ "async-trait", "bitflags", "thiserror", + "tokio", "tracing", "wasmtime", "witx-bindgen-wasmtime-impl", diff --git a/crates/gen-js/src/lib.rs b/crates/gen-js/src/lib.rs index 28bc365db..44970822f 100644 --- a/crates/gen-js/src/lib.rs +++ b/crates/gen-js/src/lib.rs @@ -75,7 +75,6 @@ enum Intrinsic { Utf8EncodedLen, Slab, Promises, - WithCurrentPromise, } impl Intrinsic { @@ -102,7 +101,6 @@ impl Intrinsic { Intrinsic::Utf8EncodedLen => "UTF8_ENCODED_LEN", Intrinsic::Slab => "Slab", Intrinsic::Promises => "PROMISES", - Intrinsic::WithCurrentPromise => "with_current_promise", } } } @@ -603,10 +601,10 @@ impl Generator for Js { self.src.js(&src.js); if func.is_async { - // Note that `catch_closure` here is defined by the `CallInterface` - // instruction. - self.src.js("}, catch_closure);\n"); // `.then` block - self.src.js("});\n"); // `with_current_promise` block. + self.src.js("};\n"); // `const complete = ...` block + let promises = self.intrinsic(Intrinsic::Promises); + self.src.js(&promises); + self.src.js(".spawn_import(promise, complete);\n"); } self.src.js("}"); @@ -894,7 +892,7 @@ impl Generator for Js { self.src.js(&format!( " imports.canonical_abi['async_export_done'] = (ctx, ptr) => {{ - {}.remove(ctx)(ptr >>> 0) + {}.async_export_done(ctx, ptr >>> 0) }}; ", promises @@ -1936,20 +1934,10 @@ impl Bindgen for FunctionBindgen<'_> { results: wasm_results, } => { self.bind_results(wasm_results.len(), results); - let promises = self.gen.intrinsic(Intrinsic::Promises); - self.src.js(&format!( - "\ - await new Promise((resolve, reject) => {{ - const promise_ctx = {promises}.insert(val => {{ - if (typeof val !== 'number') - return reject(val); - resolve(\ - ", - promises = promises - )); + self.src.js("await new Promise((resolve, reject) => {\n"); if wasm_results.len() > 0 { - self.src.js("["); + self.src.js("resolve = val => {\nresolve(["); let operands = &["val".to_string()]; let mut results = Vec::new(); for (i, result) in wasm_results.iter().enumerate() { @@ -1965,16 +1953,12 @@ impl Bindgen for FunctionBindgen<'_> { self.load(method, (i * 8) as i32, operands, &mut results); self.src.js(&results.pop().unwrap()); } - self.src.js("]"); + self.src.js("])\n};\n"); // `resolve(...)` } - // Finish the blocks from above - self.src.js(");\n"); // `resolve(...)` - self.src.js("});\n"); // `promises.insert(...)` - - let with = self.gen.intrinsic(Intrinsic::WithCurrentPromise); - self.src.js(&with); - self.src.js("(promise_ctx, _prev => {\n"); + let promises = self.gen.intrinsic(Intrinsic::Promises); + self.src.js(&promises); + self.src.js(".spawn(resolve, reject, promise_ctx => {\n"); self.src.js(&self.src_object); self.src.js("._exports['"); self.src.js(&name); @@ -1984,7 +1968,7 @@ impl Bindgen for FunctionBindgen<'_> { self.src.js(", "); } self.src.js("promise_ctx);\n"); - self.src.js("});\n"); // call to `with` + self.src.js("});\n"); // call to `spawn` self.src.js("});\n"); // `await new Promise(...)` } @@ -2037,16 +2021,10 @@ impl Bindgen for FunctionBindgen<'_> { }; if func.is_async { - let with = self.gen.intrinsic(Intrinsic::WithCurrentPromise); - let promises = self.gen.intrinsic(Intrinsic::Promises); - self.src.js(&with); - self.src.js("(null, cur_promise => {\n"); - self.src.js(&format!( - "const catch_closure = e => {}.remove(cur_promise)(e);\n", - promises - )); + self.src.js("const promise = "); call(self); - self.src.js(".then(e => {\n"); + self.src.js(";\n"); + self.src.js("const complete = e => {\n"); if func.results.len() > 0 { bind_results(self); self.src.js("e;\n"); @@ -2083,34 +2061,22 @@ impl Bindgen for FunctionBindgen<'_> { // TODO: shouldn't hardcode the function table name, should // verify the table is present, and should verify the type of // the function returned. - results.push(format!( - "get_export(\"__indirect_function_table\").get({})", - operands[0], + self.src.js(&format!( + " + const callback = get_export(\"__indirect_function_table\").get({}); + if (callback === null) + throw new Error('table index is a null function'); + ", + operands[0] )); + results.push("callback".to_string()); } Instruction::ReturnAsyncImport { .. } => { - // When we reenter webassembly successfully that means that the - // host's promise resolved without exception. Take the current - // promise index saved as part of `CallInterface` and update the - // `CUR_PROMISE` global with what's currently being executed. - // This'll get reset once the wasm returns again. - // - // Note that the name `cur_promise` used here is introduced in - // the `CallInterface` codegen above in the closure for - // `with_current_promise` which we're using here. + // TODO self.gen.needs_get_export = true; - let with = self.gen.intrinsic(Intrinsic::WithCurrentPromise); - self.src.js(&format!( - "\ - {with}(cur_promise, _prev => {{ - {}({}); - }}); - ", - operands[0], - operands[1..].join(", "), - with = with, - )); + self.src + .js(&format!("{}({});\n", operands[0], operands[1..].join(", "),)); } Instruction::I32Load { offset } => self.load("getInt32", *offset, operands, results), @@ -2397,12 +2363,19 @@ impl Js { } get(idx) { - if (idx >= this.list.length) + const ret = this.maybeGet(idx); + if (ret === null) throw new RangeError('handle index not valid'); + return ret; + } + + maybeGet(idx) { + if (idx >= this.list.length) + return null; const slot = this.list[idx]; if (slot.next === -1) return slot.val; - throw new RangeError('handle index not valid'); + return null; } remove(idx) { @@ -2416,18 +2389,116 @@ impl Js { } "), - Intrinsic::Promises => self.src.js("export const PROMISES = new Slab();\n"), - Intrinsic::WithCurrentPromise => self.src.js(" - let CUR_PROMISE = null; - export function with_current_promise(val, closure) { - const prev = CUR_PROMISE; - CUR_PROMISE = val; - try { - closure(prev); - } finally { - CUR_PROMISE = prev; + Intrinsic::Promises => self.src.js(" + class Coroutine { + constructor(resolve, reject) { + this.resolve = resolve; + this.reject = reject; + this.completion = null; + this.pending = 0; + } + + complete(val) { + if (this.completion === null) { + this.completion = val; + } else { + throw new Error('cannot complete coroutine twice'); + } + } + + check_completion() { + if (this.completion === null) { + if (this.pending === 0) { + throw new Error('blocked coroutine with 0 pending callbacks'); + } + return false; + } else { + this.resolve(this.completion); + return true; + } + } + } + + class Promises { + constructor() { + this.slab = new Slab(); + this.current = null; + } + + spawn(resolve, reject, callback) { + const idx = this.slab.insert(new Coroutine(resolve, reject)); + this.set_current(idx, () => callback(idx)); + } + + // called when `promise` is the result of an async host + // call, where `complete` should be called on `.then(..)` + // after the promise is finished. + spawn_import(promise, complete) { + const current = this.current; + if (current === null) { + throw new Error('cannot call async import in sync export'); + } + const coroutine = this.slab.get(current); + coroutine.pending += 1; + promise.then( + e => { + coroutine.pending -= 1; + this.set_current(current, () => complete(e)); + }, + // on error this error is carried over to the + // original coroutine, if it's still present. The + // coroutine may already have failed due to some + // other reason in which case we continue to + // propagate this error in the hopes that someone + // else will log uncaught exceptions or similar. + e => { + const coroutine = this.slab.maybeGet(current); + if (coroutine !== null) { + this.slab.remove(current).reject(e); + } else { + throw e; + } + }, + ) + } + + // implementation of the corresponding wasm intrinsic, + // validates `id` and `val` internally + async_export_done(id, val) { + const coroutine = this.slab.maybeGet(id); + if (coroutine !== null) { + coroutine.complete(val); + } else { + throw new RangeError('invalid coroutine index'); + } + } + + set_current(val, closure) { + // assert that the coroutine is still present + this.slab.get(val); + const prev = this.current; + this.current = val; + try { + // execute some wasm code + closure(); + // if it finished successfully then determine if + // the coroutine has completed, or if there are 0 + // pending callbacks and it's not completed an + // exception is thrown here. + if (this.slab.get(val).check_completion()) { + this.slab.remove(val); + } + } catch (e) { + // any errors immediately result in aborting the + // coroutine and rejection of the coroutine's + // promise with the same error. + this.slab.remove(val).reject(e); + } finally { + this.current = prev; + } } } + export const PROMISES = new Promises(); "), } } diff --git a/crates/gen-wasmtime/Cargo.toml b/crates/gen-wasmtime/Cargo.toml index 0e9d106be..d20519957 100644 --- a/crates/gen-wasmtime/Cargo.toml +++ b/crates/gen-wasmtime/Cargo.toml @@ -20,9 +20,7 @@ test-helpers = { path = '../test-helpers', features = ['witx-bindgen-gen-wasmtim wasmtime = "0.30.0" wasmtime-wasi = "0.30.0" witx-bindgen-wasmtime = { path = '../wasmtime', features = ['tracing', 'async'] } -futures-executor = "0.3" -futures-channel = "0.3" -futures-util = "0.3" +tokio = { version = "1.0", features = ['rt', 'sync', 'macros', 'rt-multi-thread', 'time'] } [features] old-witx-compat = ['witx-bindgen-gen-core/old-witx-compat'] diff --git a/crates/gen-wasmtime/src/lib.rs b/crates/gen-wasmtime/src/lib.rs index 56e2d11b3..60b85e9db 100644 --- a/crates/gen-wasmtime/src/lib.rs +++ b/crates/gen-wasmtime/src/lib.rs @@ -38,6 +38,7 @@ pub struct Wasmtime { trait_name: String, has_preview1_dtor: bool, sizes: SizeAlign, + any_async_func: bool, } enum NeededFunction { @@ -216,14 +217,11 @@ impl Wasmtime { FunctionRet::CustomToTrap } - fn rebind_host(&self, iface: &Interface) -> Option { + fn rebind_host(&self, _iface: &Interface) -> Option { let mut rebind = String::new(); if self.all_needed_handles.len() > 0 { rebind.push_str("_tables, "); } - if iface.functions.iter().any(|f| f.is_async) { - rebind.push_str("_async_cx, "); - } if rebind != "" { Some(format!("let (host, {}) = host;\n", rebind)) } else { @@ -238,6 +236,10 @@ impl RustGenerator for Wasmtime { // The default here is that only leaf values can be borrowed because // otherwise lists and such need to be copied into our own memory. TypeMode::LeafBorrowed("'a") + } else if self.any_async_func { + // Once `async` functions are in play then there's a task spawned + // that owns the reactor, and this means + TypeMode::Owned } else { // When we're calling wasm exports, however, there's no need to take // any ownership of anything from the host so everything is borrowed @@ -389,11 +391,14 @@ impl Generator for Wasmtime { self.in_import = dir == Direction::Import; self.trait_name = iface.name.to_camel_case(); self.src.push_str("#[allow(unused_imports)]\n"); + self.src.push_str("#[allow(unused_variables)]\n"); + self.src.push_str("#[allow(unused_mut)]\n"); self.src .push_str(&format!("pub mod {} {{\n", iface.name.to_snake_case())); self.src .push_str("use witx_bindgen_wasmtime::{wasmtime, anyhow};\n"); self.sizes.fill(dir, iface); + self.any_async_func = iface.functions.iter().any(|f| f.is_async); } fn type_record( @@ -522,6 +527,15 @@ impl Generator for Wasmtime { let tyname = name.to_camel_case(); self.rustdoc(&iface.resources[ty].docs); self.src.push_str("#[derive(Debug)]\n"); + // TODO: for now in an async environment all of these handles are taken + // by-value in functions whereas in non-async environments everything is + // taken by reference except for destructors. This means that the + // take-by-ownership `drop` function is less meaningful in an async + // environment. This seems like a reasonable-ish way to manage this for + // now but this probably wants a better solution long-term. + if self.any_async_func { + self.src.push_str("#[derive(Clone, Copy)]\n"); + } self.src.push_str(&format!( "pub struct {}(witx_bindgen_wasmtime::rt::ResourceIndex);\n", tyname @@ -764,12 +778,12 @@ impl Generator for Wasmtime { self.src.push_str(&String::from(src)); if func.is_async { - self.src.push_str("});\n"); // finish `register_async_import` + self.src.push_str("})?; // finish `spawn_import`\n"); self.src.push_str("Ok(())\n") } if finish_async_block { - self.src.push_str("})\n"); + self.src.push_str("}) // end `Box::new(async move { ...`\n"); } self.src.push_str("}"); let closure = mem::replace(&mut self.src, prev).into(); @@ -794,13 +808,22 @@ impl Generator for Wasmtime { // it's unknown whether the wasm module will make an async host call. let is_async = !self.opts.async_.is_none() || func.is_async; let mut sig = FnSig::default(); - sig.async_ = is_async; - sig.self_arg = Some("&self, mut caller: impl wasmtime::AsContextMut".to_string()); - self.print_docs_and_params(iface, func, TypeMode::AllBorrowed("'_"), &sig); + sig.async_ = is_async || self.any_async_func; + if self.any_async_func { + sig.self_arg = Some("&self".to_string()); + } else { + sig.self_arg = + Some("&self, mut caller: impl wasmtime::AsContextMut".to_string()); + } + let mode = if self.any_async_func { + TypeMode::Owned + } else { + TypeMode::AllBorrowed("'_") + }; + self.print_docs_and_params(iface, func, mode, &sig); self.push_str("-> Result<"); self.print_results(iface, func); self.push_str(", wasmtime::Trap> {\n"); - self.push_str("let mut caller = caller.as_context_mut();\n"); let is_dtor = self.types.is_preview1_dtor_func(func); if is_dtor { @@ -812,6 +835,9 @@ impl Generator for Wasmtime { .map(|(name, _)| to_rust_ident(name).to_string()) .collect(); let mut f = FunctionBindgen::new(self, is_dtor, params); + if f.gen.any_async_func { + f.src.indent(2); + } iface.call( Direction::Export, LiftLower::LowerArgsLiftResults, @@ -825,6 +851,7 @@ impl Generator for Wasmtime { needs_buffer_transaction, closures, needs_functions, + needs_get_state, .. } = f; @@ -834,7 +861,7 @@ impl Generator for Wasmtime { .or_insert_with(Exports::default); for (name, func) in needs_functions { self.src - .push_str(&format!("let func_{0} = &self.{0};\n", name)); + .push_str(&format!("let func_{0} = self.{0};\n", name)); let get = format!( "instance.get_typed_func::<{}, _>(&mut store, \"{}\")?", func.cvt(), @@ -847,7 +874,7 @@ impl Generator for Wasmtime { assert!(!needs_borrow_checker); if needs_memory { - self.src.push_str("let memory = &self.memory;\n"); + self.src.push_str("let memory = self.memory;\n"); exports.fields.insert( "memory".to_string(), ( @@ -856,7 +883,7 @@ impl Generator for Wasmtime { .get_memory(&mut store, \"memory\") .ok_or_else(|| { anyhow::anyhow!(\"`memory` export not a memory\") - })? + })?\ " .to_string(), ), @@ -869,7 +896,54 @@ impl Generator for Wasmtime { .push_str("let mut buffer_transaction = self.buffer_glue.transaction();\n"); } + if needs_get_state { + self.src + .push_str("let get_state = self.get_state.clone();\n"); + } + + if func.is_async { + // If this function itself is async then we start off with an + // initial callback that gets an `async_cx` argument which is the + // integer descriptor for the generated future. + self.src.push_str(&format!( + " + let wasm_func = self.{}; + let start = witx_bindgen_wasmtime::rt::infer_start(move |mut caller, async_cx| {{ + let async_cx = async_cx as i32; + Box::pin(async move {{ + ", + to_rust_ident(&func.name), + )); + } else if self.any_async_func { + // Otherwise if any other function in this interface is async then + // it means all functions are invoked through a reactor task which + // means we need to start a standalone callback to get executed + // on the reactor task. + self.src.push_str(&format!( + " + let wasm_func = self.{}; + let start = witx_bindgen_wasmtime::rt::infer_standalone(move |mut caller| {{ + Box::pin(async move {{ + ", + to_rust_ident(&func.name), + )); + } else { + // And finally with no async functions involved everything is + // simply generated inline. + self.src + .push_str("let mut caller = caller.as_context_mut();\n"); + } + self.src.push_str(&String::from(src)); + + if func.is_async { + self.src + .push_str("self.handle.execute(start, complete).await\n"); + } else if self.any_async_func { + self.src + .push_str("self.handle.run_no_coroutine(start).await\n"); + } + self.src.push_str("}\n"); let func_body = mem::replace(&mut self.src, prev); if !is_dtor { @@ -904,7 +978,6 @@ impl Generator for Wasmtime { } fn finish_one(&mut self, iface: &Interface, files: &mut Files) { - let any_async_func = iface.functions.iter().any(|f| f.is_async); for (module, funcs) in sorted_iter(&self.imports) { let module_camel = module.to_camel_case(); let is_async = !self.opts.async_.is_none(); @@ -923,10 +996,10 @@ impl Generator for Wasmtime { self.src.push_str("type "); self.src.push_str(&handle.to_camel_case()); self.src.push_str(": std::fmt::Debug"); - if is_async || any_async_func { + if is_async || self.any_async_func { self.src.push_str(" + Send + Sync"); } - if any_async_func { + if self.any_async_func { self.src.push_str(" + 'static"); } self.src.push_str(";\n"); @@ -934,7 +1007,7 @@ impl Generator for Wasmtime { } if self.opts.custom_error { self.src.push_str("type Error"); - if any_async_func { + if self.any_async_func { self.src.push_str(": Send + 'static"); } self.src.push_str(";\n"); @@ -998,17 +1071,14 @@ impl Generator for Wasmtime { for (module, funcs) in mem::take(&mut self.imports) { let module_camel = module.to_camel_case(); let is_async = !self.opts.async_.is_none(); - self.push_str("#[allow(path_statements)]\n"); - self.push_str("\npub fn add_to_linker(linker: &mut wasmtime::Linker"); + self.push_str("\n#[allow(path_statements)]\n"); + self.push_str("pub fn add_to_linker(linker: &mut wasmtime::Linker"); self.push_str(", get: impl Fn(&mut T) -> "); let mut get_rets = vec!["&mut U".to_string()]; if self.all_needed_handles.len() > 0 { get_rets.push(format!("&mut {}Tables", module_camel)); } - if any_async_func { - get_rets.push(format!("&mut witx_bindgen_wasmtime::rt::Async")); - } if get_rets.len() > 1 { self.push_str(&format!("({})", get_rets.join(", "))); } else { @@ -1017,8 +1087,8 @@ impl Generator for Wasmtime { self.push_str(" + Send + Sync + Copy + 'static) -> anyhow::Result<()> \n"); self.push_str("where U: "); self.push_str(&module_camel); - if is_async || any_async_func { - self.push_str(", T: Send,"); + if is_async || self.any_async_func { + self.push_str(", T: Send + 'static,"); } self.push_str("\n{\n"); if self.needs_get_memory { @@ -1093,19 +1163,17 @@ impl Generator for Wasmtime { } self.push_str("}\n"); - let mut get_state_ret = format!("&mut {}Data", name); - let mut bind_state = "state"; - if any_async_func { - get_state_ret = - format!("({}, &mut witx_bindgen_wasmtime::Async)", get_state_ret); - bind_state = "(state, _)"; - } - + let get_state_ret = format!("&mut {}Data", name); self.push_str(&format!("pub struct {} {{\n", name)); self.push_str(&format!( - "get_state: Box {} + Send + Sync>,\n", + "get_state: std::sync::Arc {} + Send + Sync>,\n", get_state_ret, )); + if self.any_async_func { + self.push_str(&format!( + "handle: witx_bindgen_wasmtime::rt::AsyncHandle,\n", + )); + } for (name, (ty, _)) in exports.fields.iter() { self.push_str(name); self.push_str(": "); @@ -1116,16 +1184,13 @@ impl Generator for Wasmtime { // self.push_str("buffer_glue: witx_bindgen_wasmtime::imports::BufferGlue,"); // } self.push_str("}\n"); - let bound = if self.opts.async_.is_none() && !any_async_func { + let bound = if self.opts.async_.is_none() && !self.any_async_func { "" } else { - ": Send" + ": Send + 'static" }; self.push_str(&format!("impl {} {{\n", bound, name)); - if self.exported_resources.len() == 0 { - self.push_str("#[allow(unused_variables)]\n"); - } self.push_str(&format!( " /// Adds any intrinsics, if necessary for this exported wasm @@ -1159,7 +1224,7 @@ impl Generator for Wasmtime { \"canonical_abi\", \"resource_drop_{name}\", move |mut caller: wasmtime::Caller<'_, T>, idx: u32| {prefix}{{ - let {bind_state} = get_state(caller.data_mut()); + let state = get_state(caller.data_mut()); let resource_idx = state.index_slab{idx}.remove(idx)?; let wasm = match state.resource_slab{idx}.drop(resource_idx) {{ Some(wasm) => wasm, @@ -1174,7 +1239,7 @@ impl Generator for Wasmtime { \"canonical_abi\", \"resource_clone_{name}\", move |mut caller: wasmtime::Caller<'_, T>, idx: u32| {{ - let {bind_state} = get_state(caller.data_mut()); + let state = get_state(caller.data_mut()); let resource_idx = state.index_slab{idx}.get(idx)?; state.resource_slab{idx}.clone(resource_idx)?; Ok(state.index_slab{idx}.insert(resource_idx)) @@ -1184,7 +1249,7 @@ impl Generator for Wasmtime { \"canonical_abi\", \"resource_get_{name}\", move |mut caller: wasmtime::Caller<'_, T>, idx: u32| {{ - let {bind_state} = get_state(caller.data_mut()); + let state = get_state(caller.data_mut()); let resource_idx = state.index_slab{idx}.get(idx)?; Ok(state.resource_slab{idx}.get(resource_idx)) }}, @@ -1193,7 +1258,7 @@ impl Generator for Wasmtime { \"canonical_abi\", \"resource_new_{name}\", move |mut caller: wasmtime::Caller<'_, T>, val: i32| {{ - let {bind_state} = get_state(caller.data_mut()); + let state = get_state(caller.data_mut()); let resource_idx = state.resource_slab{idx}.insert(val); Ok(state.index_slab{idx}.insert(resource_idx)) }}, @@ -1206,20 +1271,24 @@ impl Generator for Wasmtime { wait = wait, prefix = prefix, suffix = suffix, - bind_state = bind_state, )); } - if iface.functions.iter().any(|f| f.is_async) { + if self.any_async_func { self.src.push_str(&format!( " - linker.func_wrap( + linker.func_wrap2_async( \"canonical_abi\", \"async_export_done\", move |mut caller: wasmtime::Caller<'_, T>, cx: i32, ptr: i32| {{ - let memory = witx_bindgen_wasmtime::rt::get_memory(&mut caller, \"memory\")?; - let (memory, state) = memory.data_and_store_mut(&mut caller); - let (_, async_) = get_state(state); - async_.async_export_done(cx, ptr, memory) + Box::new(async move {{ + let memory = witx_bindgen_wasmtime::rt::get_memory(&mut caller, \"memory\")?; + witx_bindgen_wasmtime::rt::Async::async_export_done( + caller, + cx, + ptr, + memory, + ).await + }}) }}, )?; ", @@ -1272,36 +1341,44 @@ impl Generator for Wasmtime { } else { ("async ", "_async", ".await") }; - self.push_str(&format!( - " - /// Instantaites the provided `module` using the specified - /// parameters, wrapping up the result in a structure that - /// translates between wasm and the host. - /// - /// The `linker` provided will have intrinsics added to it - /// automatically, so it's not necessary to call - /// `add_to_linker` beforehand. This function will - /// instantiate the `module` otherwise using `linker`, and - /// both an instance of this structure and the underlying - /// `wasmtime::Instance` will be returned. - /// - /// The `get_state` parameter is used to access the - /// auxiliary state necessary for these wasm exports from - /// the general store state `T`. - pub {}fn instantiate( - mut store: impl wasmtime::AsContextMut, - module: &wasmtime::Module, - linker: &mut wasmtime::Linker, - get_state: impl Fn(&mut T) -> {} + Send + Sync + Copy + 'static, - ) -> anyhow::Result<(Self, wasmtime::Instance)> {{ - Self::add_to_linker(linker, get_state)?; - let instance = linker.instantiate{}(&mut store, module){}?; - Ok((Self::new(store, &instance,get_state)?, instance)) - }} - ", - async_fn, get_state_ret, instantiate, wait, - )); + if !self.any_async_func { + self.push_str(&format!( + " + /// Instantiates the provided `module` using the + /// specified parameters, wrapping up the result in a + /// structure that translates between wasm and the + /// host. + /// + /// The `linker` provided will have intrinsics added to + /// it automatically, so it's not necessary to call + /// `add_to_linker` beforehand. This function will + /// instantiate the `module` otherwise using `linker`, + /// and both an instance of this structure and the + /// underlying `wasmtime::Instance` will be returned. + /// + /// The `get_state` parameter is used to access the + /// auxiliary state necessary for these wasm exports + /// from the general store state `T`. + pub {}fn instantiate( + mut store: impl wasmtime::AsContextMut, + module: &wasmtime::Module, + linker: &mut wasmtime::Linker, + get_state: impl Fn(&mut T) -> {} + Send + Sync + Copy + 'static, + ) -> anyhow::Result<(Self, wasmtime::Instance)> {{ + Self::add_to_linker(linker, get_state)?; + let instance = linker.instantiate{}(&mut store, module){}?; + Ok((Self::new(store, &instance,get_state)?, instance)) + }} + ", + async_fn, get_state_ret, instantiate, wait, + )); + } + let store_ty = if self.any_async_func { + "wasmtime::Store" + } else { + "impl wasmtime::AsContextMut" + }; self.push_str(&format!( " /// Low-level creation wrapper for wrapping up the exports @@ -1313,14 +1390,17 @@ impl Generator for Wasmtime { /// returned structure which can be used to interact with /// the wasm module. pub fn new( - mut store: impl wasmtime::AsContextMut, + mut store: {store_ty}, instance: &wasmtime::Instance, get_state: impl Fn(&mut T) -> {} + Send + Sync + Copy + 'static, ) -> anyhow::Result {{ ", get_state_ret, + store_ty = store_ty, )); - self.push_str("let mut store = store.as_context_mut();\n"); + if !self.any_async_func { + self.push_str("let mut store = store.as_context_mut();\n"); + } assert!(!self.needs_get_func); for (name, (_, get)) in exports.fields.iter() { self.push_str("let "); @@ -1336,20 +1416,19 @@ impl Generator for Wasmtime { &mut store, \ \"canonical_abi_drop_{name}\", \ )?; - let {bind_state} = get_state(store.data_mut()); + let state = get_state(store.data_mut()); state.dtor{idx} = Some(dtor); ", idx = r.index(), name = iface.resources[*r].name, - bind_state = bind_state, )); } - if any_async_func { + if self.any_async_func { self.push_str( " let table = instance.get_table(&mut store, \"__indirect_function_table\") .ok_or_else(|| wasmtime::Trap::new(\"no exported function table\"))?; - get_state(store.data_mut()).1.set_table(table); + let handle = witx_bindgen_wasmtime::rt::Async::spawn(store, table); ", ); } @@ -1360,7 +1439,10 @@ impl Generator for Wasmtime { self.push_str(name); self.push_str(",\n"); } - self.push_str("get_state: Box::new(get_state),\n"); + self.push_str("get_state: std::sync::Arc::new(get_state),\n"); + if self.any_async_func { + self.push_str("handle,\n"); + } self.push_str("\n})\n"); self.push_str("}\n"); @@ -1369,12 +1451,13 @@ impl Generator for Wasmtime { } for r in self.exported_resources.iter() { - let (async_fn, call, wait) = if self.opts.async_.is_none() { + let (async_fn, call, wait) = if self.opts.async_.is_none() && !self.any_async_func { ("", "call", "") } else { ("async ", "call_async", ".await") }; - self.src.push_str(&format!( + + self.src.push_str( " /// Drops the host-owned handle to the resource /// specified. @@ -1383,29 +1466,59 @@ impl Generator for Wasmtime { /// destructor for this type. This also may not run /// the destructor if there are still other references /// to this type. - pub {async}fn drop_{name_snake}( - &self, - mut store: impl wasmtime::AsContextMut, - val: {name_camel}, - ) -> Result<(), wasmtime::Trap> {{ - let mut store = store.as_context_mut(); - let {bind_state} = (self.get_state)(store.data_mut()); - let wasm = match state.resource_slab{idx}.drop(val.0) {{ - Some(val) => val, - None => return Ok(()), - }}; - state.dtor{idx}.unwrap().{call}(&mut store, wasm){wait}?; - Ok(()) - }} ", - name_snake = iface.resources[*r].name.to_snake_case(), - name_camel = iface.resources[*r].name.to_camel_case(), + ); + let body = format!( + " + let state = get_state(store.data_mut()); + let wasm = match state.resource_slab{idx}.drop(val.0) {{ + Some(val) => val, + None => return Ok(()), + }}; + state.dtor{idx}.unwrap().{call}(&mut store, wasm){wait}?; + Ok(()) + ", idx = r.index(), - async = async_fn, call = call, wait = wait, - bind_state = bind_state, - )); + ); + + if self.any_async_func { + self.src.push_str(&format!( + " + pub async fn drop_{name_snake}( + &self, + val: {name_camel}, + ) -> Result<(), wasmtime::Trap> {{ + let get_state = self.get_state.clone(); + self.handle.run_no_coroutine(move |mut store| Box::pin(async move {{ + {body} + }})).await + }} + ", + name_snake = iface.resources[*r].name.to_snake_case(), + name_camel = iface.resources[*r].name.to_camel_case(), + body = body, + )); + } else { + self.src.push_str(&format!( + " + pub {async}fn drop_{name_snake}( + &self, + mut store: impl wasmtime::AsContextMut, + val: {name_camel}, + ) -> Result<(), wasmtime::Trap> {{ + let mut store = store.as_context_mut(); + let get_state = &self.get_state; + {body} + }} + ", + name_snake = iface.resources[*r].name.to_snake_case(), + name_camel = iface.resources[*r].name.to_camel_case(), + body = body, + async = async_fn, + )); + } } self.push_str("}\n"); @@ -1490,6 +1603,7 @@ struct FunctionBindgen<'a> { needs_borrow_checker: bool, needs_memory: bool, needs_functions: HashMap, + needs_get_state: bool, } impl FunctionBindgen<'_> { @@ -1512,6 +1626,7 @@ impl FunctionBindgen<'_> { needs_functions: HashMap::new(), is_dtor, params, + needs_get_state: false, } } @@ -1783,35 +1898,33 @@ impl Bindgen for FunctionBindgen<'_> { } } Instruction::I32FromBorrowedHandle { ty } => { - let any_async = iface.functions.iter().any(|f| f.is_async); let tmp = self.tmp(); + self.needs_get_state = true; self.push_str(&format!( " let obj{tmp} = {op}; - let {bind_state} = (self.get_state)(caller.data_mut()); + let state = get_state(caller.data_mut()); state.resource_slab{idx}.clone(obj{tmp}.0)?; let handle{tmp} = state.index_slab{idx}.insert(obj{tmp}.0); ", tmp = tmp, idx = ty.index(), op = operands[0], - bind_state = if any_async { "(state, _)" } else { "state" }, )); results.push(format!("handle{} as i32", tmp,)); } Instruction::HandleOwnedFromI32 { ty } => { - let any_async = iface.functions.iter().any(|f| f.is_async); let tmp = self.tmp(); + self.needs_get_state = true; self.push_str(&format!( " - let {bind_state} = (self.get_state)(caller.data_mut()); + let state = get_state(caller.data_mut()); let handle{} = state.index_slab{}.remove({} as u32)?; ", tmp, ty.index(), operands[0], - bind_state = if any_async { "(state, _)" } else { "state" }, )); let name = iface.resources[*ty].name.to_camel_case(); @@ -1974,8 +2087,8 @@ impl Bindgen for FunctionBindgen<'_> { " copy_slice( &mut caller, - memory, - func_{}, + &memory, + &func_{}, ptr{tmp}, len{tmp}, {} )? ", @@ -2036,7 +2149,7 @@ impl Bindgen for FunctionBindgen<'_> { )); self.push_str(&format!("let base = {} + (i as i32) * {};\n", result, size)); self.push_str(&body); - self.push_str("}"); + self.push_str("\n}\n"); results.push(result); results.push(len); @@ -2174,8 +2287,12 @@ impl Bindgen for FunctionBindgen<'_> { sig, } => { self.bind_results(sig.results.len(), results); - self.push_str("self."); - self.push_str(&to_rust_ident(name)); + if self.gen.any_async_func { + self.push_str("wasm_func"); + } else { + self.push_str("self."); + self.push_str(&to_rust_ident(name)); + } if self.gen.opts.async_.includes(name) { self.push_str(".call_async("); } else { @@ -2203,79 +2320,53 @@ impl Bindgen for FunctionBindgen<'_> { params: _, results: wasm_results, } => { - self.push_str(&format!( - "let mut raw_results = [0; {}];\n", - wasm_results.len() - )); + self.push_str("wasm_func"); + if self.gen.opts.async_.includes(name) { + self.push_str(".call_async("); + } else { + self.push_str(".call("); + } + self.push_str("&mut caller, ("); + for operand in operands { + self.push_str(operand); + self.push_str(", "); + } + self.push_str("async_cx,"); + self.push_str("))"); + if self.gen.opts.async_.includes(name) { + self.push_str(".await"); + } + self.push_str("?;\n"); + self.push_str("Ok(())\n"); + self.after_call = true; + self.caller_memory_available = false; // invalidated by call - // Move the func out of `self` since it's captured in the - // `'static` future and `self` isn't `'static`. - self.push_str(&format!("let wasm_func = self.{};\n", to_rust_ident(name))); + self.push_str("}) // finish Box::pin\n"); + self.push_str("}); // finish `let start = ...`\n"); - // Move the arguments into their own temporaries. These are all - // scalar expressions but sometimes their results reference - // local variables. These all get captured in the `'static` - // future as well so we need to make sure no local variables are - // captured there. - let tmp = self.tmp(); - let mut args = String::new(); - for (i, arg) in operands.iter().enumerate() { - self.push_str(&format!("let arg{}_{} = {};\n", tmp, i, arg)); - args.push_str(&format!("arg{}_{}, ", tmp, i)); - } - args.push_str("async_cx,"); + // TODO: this is somewhat inefficient since it's an `Arc` clone + // that could be unnecessary. It's not clear whether this will + // get closed over in the completion callback below. Generated + // code may need this `get_state` in both the initial and + // completion callback though, and that's why it's cloned here + // too to ensure that there's two values to close over. Should + // figure out a better way to emit this so it's only done if + // necessary. + self.push_str("let get_state = self.get_state.clone();\n"); - // Start the async call with various parameters passed in... self.push_str( - "witx_bindgen_wasmtime::rt::Async::call_async_export(\ - &mut caller, \ - &mut raw_results, \ - &|t| (self.get_state)(t).1, \ - |caller, async_cx| {\n\ + " + let complete = witx_bindgen_wasmtime::rt::infer_complete(move |mut caller, ptr, memory| { + Box::pin(async move { ", ); - // Delegate to `call_async` or `call` as appropriate. - if self.gen.opts.async_.includes(name) { - self.push_str(&format!( - "Box::pin(async move {{ wasm_func.call_async(caller, ({})).await }})", - args, - )); - } else { - self.push_str(&format!( - "Box::pin(async move {{ wasm_func.call(caller, ({})) }})", - args, - )); - } - // Close the async call - self.push_str("\n}).await?;\n"); - - // Read all the results from the `raw_results` array, - // interpreting the 64-bit values as appropriate. - let tmp = self.tmp(); + let operands = ["ptr".to_string()]; for (i, ty) in wasm_results.iter().enumerate() { - let name = format!("result{}_{}", tmp, i); - self.push_str(&format!("let {} = ", name)); - results.push(name); - match ty { - WasmType::I32 => { - self.push_str(&format!("raw_results[{}] as i32", i)); - } - WasmType::I64 => { - self.push_str(&format!("raw_results[{}]", i)); - } - WasmType::F32 => { - self.push_str(&format!("f32::from_bits(raw_results[{}] as u32)", i)); - } - WasmType::F64 => { - self.push_str(&format!("f64::from_bits(raw_results[{}] as u64)", i)); - } - } - self.push_str(";\n"); + let ty = wasm_type(*ty); + let load = self.load((i as i32) * 8, ty, &operands); + results.push(load); } - - self.after_call = true; - self.caller_memory_available = false; // invalidated by call } Instruction::CallInterface { module: _, func } => { @@ -2324,13 +2415,11 @@ impl Bindgen for FunctionBindgen<'_> { self.push_str("let future = "); self.push_str(&call); self.push_str(";\n"); - self.push_str("_async_cx.register_async_import(async move {\n"); + self.push_str("witx_bindgen_wasmtime::rt::Async::spawn_import(async move {\n"); self.push_str("let result = future.await;\n"); call = format!("result"); - self.push_str( - "witx_bindgen_wasmtime::rt::box_future_callback(move |mut caller| {\n", - ); - self.push_str("witx_bindgen_wasmtime::rt::pin_result_future(async move {\n"); + self.push_str("witx_bindgen_wasmtime::rt::box_callback(move |mut caller| {\n"); + self.push_str("Box::pin(async move {\n"); self.push_str("let host = get(caller.data_mut());\n"); if let Some(rebind) = self.gen.rebind_host(iface) { self.push_str(&rebind); @@ -2382,19 +2471,35 @@ impl Bindgen for FunctionBindgen<'_> { Instruction::Return { amt, .. } => { let result = match amt { - 0 => format!("Ok(())\n"), - 1 => format!("Ok({})\n", operands[0]), - _ => format!("Ok(({}))\n", operands.join(", ")), + 0 => format!("()"), + 1 => format!("{}", operands[0]), + _ => format!("({})", operands.join(", ")), }; - match self.cleanup.take() { - Some(cleanup) => { - self.push_str("let ret = "); - self.push_str(&result); - self.push_str(";\n"); + if self.gen.any_async_func && !self.gen.in_import { + self.push_str("let ret = "); + self.push_str(&result); + self.push_str(";\n"); + if let Some(cleanup) = self.cleanup.take() { self.push_str(&cleanup); - self.push_str("ret"); } - None => self.push_str(&result), + self.push_str("Ok(ret)\n"); + self.push_str("}) // finish Box::pin\n"); + self.push_str("}); // finish `let complete`\n"); + } else { + let result = format!("Ok({})", result); + match self.cleanup.take() { + Some(cleanup) => { + self.push_str("let ret = "); + self.push_str(&result); + self.push_str(";\n"); + self.push_str(&cleanup); + self.push_str("ret"); + } + None => { + self.push_str(&result); + self.push_str("\n"); + } + } } } @@ -2409,8 +2514,8 @@ impl Bindgen for FunctionBindgen<'_> { } self.closures.push_str(&format!( "\ - let completion_callback = get(caller.data_mut()).{async_idx} - .table() + let completion_callback = + witx_bindgen_wasmtime::rt::Async::::function_table() .get(&mut caller, {idx} as u32) .ok_or_else(|| wasmtime::Trap::new(\"invalid function index\"))? .funcref() @@ -2420,11 +2525,6 @@ impl Bindgen for FunctionBindgen<'_> { ", idx = operands[0], tys = tys, - async_idx = if self.gen.all_needed_handles.len() > 0 { - 2 - } else { - 1 - }, )); results.push(format!("completion_callback")); } @@ -2449,10 +2549,8 @@ impl Bindgen for FunctionBindgen<'_> { } self.push_str("\n"); - // finish the `pin_result_future(...)` - self.push_str("})\n"); - // finish the `box_future_callback(...)` - self.push_str("})\n"); + self.push_str("}) // end Box:pin\n"); + self.push_str("}) // end box_callback\n"); } Instruction::I32Load { offset } => results.push(self.load(*offset, "i32", operands)), diff --git a/crates/test-rust-wasm/Cargo.toml b/crates/test-rust-wasm/Cargo.toml index 32b6402fb..dda16896d 100644 --- a/crates/test-rust-wasm/Cargo.toml +++ b/crates/test-rust-wasm/Cargo.toml @@ -51,3 +51,7 @@ test = false [[bin]] name = "async_functions" test = false + +[[bin]] +name = "async_raw" +test = false diff --git a/crates/test-rust-wasm/src/bin/async_raw.rs b/crates/test-rust-wasm/src/bin/async_raw.rs new file mode 100644 index 000000000..7d47247f5 --- /dev/null +++ b/crates/test-rust-wasm/src/bin/async_raw.rs @@ -0,0 +1,3 @@ +include!("../../../../tests/runtime/async_raw/wasm.rs"); + +fn main() {} diff --git a/crates/wasmtime/Cargo.toml b/crates/wasmtime/Cargo.toml index 81a769efd..6510a2e89 100644 --- a/crates/wasmtime/Cargo.toml +++ b/crates/wasmtime/Cargo.toml @@ -12,6 +12,7 @@ wasmtime = "0.30.0" witx-bindgen-wasmtime-impl = { path = "../wasmtime-impl", version = "0.1" } tracing-lib = { version = "0.1.26", optional = true, package = 'tracing' } async-trait = { version = "0.1.50", optional = true } +tokio = { version = "1.12", features = ['rt', 'sync'], optional = true } [features] # Enables generated code to emit events via the `tracing` crate whenever wasm is @@ -21,7 +22,7 @@ tracing = ['tracing-lib', 'witx-bindgen-wasmtime-impl/tracing'] # Enables async support for generated code, although when enabled this still # needs to be configured through the macro invocation. -async = ['async-trait', 'witx-bindgen-wasmtime-impl/async'] +async = ['async-trait', 'witx-bindgen-wasmtime-impl/async', 'tokio', 'wasmtime/async'] # Enables the ability to parse the old s-expression-based `*.witx` format. old-witx-compat = ['witx-bindgen-wasmtime-impl/old-witx-compat'] diff --git a/crates/wasmtime/src/futures.rs b/crates/wasmtime/src/futures.rs index 5ca912428..936f53098 100644 --- a/crates/wasmtime/src/futures.rs +++ b/crates/wasmtime/src/futures.rs @@ -1,295 +1,822 @@ -use crate::rt::RawMem; use crate::slab::Slab; +use std::any::Any; +use std::cell::{Cell, RefCell}; use std::future::Future; +use std::mem; use std::pin::Pin; -use std::task::{Context, Poll}; -use wasmtime::Table; -use wasmtime::{StoreContextMut, Trap}; +use std::sync::{Arc, Weak}; +use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedSender}; +use tokio::task::JoinHandle; +use wasmtime::{AsContextMut, Caller, Memory, Store, StoreContextMut, Table, Trap}; pub struct Async { - exports: Slab, + function_table: Table, - table: Option
, + /// Channel used to send messages to the main event loop where this + /// `Async` is managed. + /// + /// Note that this is stored in a `Weak` pointer to not hold a strong + /// reference to it because this `Async` is owned by the event loop which + /// is otherwise terminated when the `Sender` is gone, hence if this were + /// a strong reference it would loop forever. + /// + /// The main strong reference to this channel is held by a generated struct + /// that will live in user code on some other task. Other references to this + /// sender, also weak though, will live in each imported async host function + /// when invoked. + sender: Weak>>, + + /// The list of active WebAssembly coroutines that are executing in this + /// event loop. + /// + /// Note that for now the term "coroutine" here is specifically used for the + /// interface-types notion of a coroutine and does not correspond to a + /// literal coroutine/fiber on the host. Interface types coroutines are only + /// implemented right now with the callback ABI, meaning there's no + /// coroutine in the sense of "there's a suspended host stack" somewhere. + /// Instead wasm retains all state necessary for resumption and such. + /// + /// This list of active coroutines will have one-per-export called and when + /// suspended the coroutines here are all guaranteed to have pending imports + /// they're waiting on. + /// + /// Note that internally `Coroutines` is simply a `Slab>` + /// and is only structured this way to have lookups via `&CoroutineId` + /// instead of `u32` as slabs do. + coroutines: RefCell>, + + /// The "currently active" coroutine. + /// + /// This is used to persist state in the host about what coroutine is + /// currently active so that when an import is called we can automatically + /// assign that import's "thread" of execution to the currently active + /// coroutine, adding it to the right import list. This enables keeping + /// track on the host for what imports are used where and what to cancel + /// whenever one coroutine aborts (if at all). + cur_wasm_coroutine: CoroutineId, - /// List of imports that we're waiting on. + /// The next unique ID to hand out to a coroutine. /// - /// This is a list of async imports that have been called as part of calling - /// wasm and are registered here. When these imports complete they produce a - /// result which then itself produces another future. The result is given a - /// `StoreContextMut` and is expected to further execute WebAssembly, - /// translating the results of the async host import to wasm and then - /// invoking the wasm completion callback. When the wasm completion callback - /// is finished then the future is complete. - // - // TODO: should this be in `FutureState` because imports-called are a - // per-export thing? - imports: Vec> + Send>>>, + /// This is a monotonically increasing counter which is intended to be + /// unique for all coroutines for the lifetime of a program. This is a + /// generational index of sorts which prevents accidentally resuing slab + /// indices in the `coroutines` array. + cur_unique_id: Cell, } -impl Default for Async { - fn default() -> Async { - Async { - exports: Slab::default(), - imports: Vec::new(), - table: None, - } - } +/// An "integer" identifier for a coroutine. +/// +/// This is used to uniquely identify a logical coroutine of WebAssembly +/// execution, and internally contains the slab index it's stored at as well as +/// a unique generational ID. +#[derive(Copy, Clone)] +pub struct CoroutineId { + slab_index: u32, + unique_id: u64, } -struct FutureState { - results: Vec, // TODO: shouldn't need to heap-allocate this - done: bool, +struct Coroutines { + slab: Slab>, } -pub type HostFuture = Pin + Send>>; +enum Message { + Execute(Start, Complete, UnboundedSender), + RunNoCoroutine(RunStandalone, UnboundedSender), + FinishImport(Callback, CoroutineId, u32), + Cancel(CoroutineId), +} + +struct Coroutine { + /// A unique ID for this coroutine which is used to ensure that even if this + /// coroutine's slab index is reused a `CoroutineId` uniquely points to one + /// logical coroutine. This mostly comes up where when a coroutine exits + /// early due to a trap we need to make sure that even if the slab slot is + /// reused we don't accidentally use some future coroutine for lingering + /// completion callbacks. + unique_id: u64, + + /// A list of spawned tasks corresponding to imported host functions that + /// this coroutine is waiting on. This list is appended to whenever an async + /// host function is invoked and it's removed from when the host function + /// completes (and the message gets back to the main loop). + /// + /// The primary purpose of this list is so that when a coroutine fails (via + /// a trap) that all of the spawned host work for the coroutine can exit + /// ASAP via an `abort()` signal on the `JoinHandle`. + pending_imports: Slab>, + + /// The number of imports that we're waiting on, corresponding to the number + /// of present entries in `pending_imports`. + num_pending_imports: usize, -/// The result of a host import. This is mostly synthesized by bindings and -/// represents that a host import produces a closure. The closure is given -/// context to execute WebAssembly and then the execution itself results in a -/// future. This returned future represents the completion of the WebAssembly -/// itself. -pub type ImportResult = Box< + /// A callback to invoke whenever a coroutine's `async_export_done` + /// completion callback is invoked. This is used by the host to deserialize + /// the results from WebAssembly (possibly doing things like wasm + /// malloc/free) and then sending the results on a channel. + /// + /// Typically this contains a `Sender` internally within this closure + /// which gets a message once all the wasm arguments have been successfully + /// deserialized. + complete: Option>, + + sender: UnboundedSender, + cancel_task: Option>, +} + +pub type HostFuture = Pin + Send>>; +pub type CoroutineResult = Result, Trap>; +pub type Start = Box< + dyn for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + u32, + ) -> Pin> + Send + 'a>> + + Send, +>; +pub type Callback = Box< dyn for<'a> FnOnce( &'a mut StoreContextMut<'_, T>, ) -> Pin> + Send + 'a>> + Send, >; +pub type Complete = Box< + dyn for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + i32, + wasmtime::Memory, + ) -> Pin + Send + 'a>> + + Send, +>; +pub type RunStandalone = Box< + dyn for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + ) -> Pin + Send + 'a>> + + Send, +>; -impl Async { - /// Implementation of the `async_export_done` canonical ABI function. - /// - /// The first two parameters are provided by wasm itself, and the `mem` is - /// the wasm's linear memory. The first parameter `cx` is the original value - /// returned by `start_async_export` and indicates which call to which - /// export is being completed. The `ptr` is a pointer into `mem` where the - /// encoded results are located. - pub fn async_export_done(&mut self, cx: i32, ptr: i32, mem: &[u8]) -> Result<(), Trap> { - let cx = cx as u32; - let dst = self - .exports - .get_mut(cx) - .ok_or_else(|| Trap::new("async context not valid"))?; - if dst.done { - return Err(Trap::new("async context not valid")); - } - dst.done = true; - for slot in dst.results.iter_mut() { - let ptr = (ptr as u32) - .checked_add(8) - .ok_or_else(|| Trap::new("pointer to async completion not valid"))?; - *slot = mem.load(ptr as i32)?; - } - Ok(()) +impl Async { + /// Spawns a new task which will manage async execution of wasm within the + /// `store` provided. + pub fn spawn(mut store: Store, function_table: Table) -> AsyncHandle + where + T: Send, + { + // This channel is the primary point of communication into the task that + // we're going to spawn. This'll be bounded to ensure it doesn't get + // overrun, and additionally the sender will be stored in an `Arc` to + // ensure that the returned handle is the only owning handle and the + // intenral weak handle held by `Async` doesn't keep it alive to let + // the task terminate gracefully. + let (sender, receiver) = mpsc::channel(5 /* TODO: should this be configurable? */); + let sender = Arc::new(sender); + let mut cx = Async { + function_table, + sender: Arc::downgrade(&sender), + coroutines: RefCell::new(Coroutines { + slab: Slab::default(), + }), + cur_wasm_coroutine: CoroutineId { + slab_index: u32::MAX, + unique_id: u64::MAX, + }, + cur_unique_id: Cell::new(0), + }; + + tokio::spawn(async move { cx.run(&mut store.as_context_mut(), receiver).await }); + AsyncHandle { sender } } - /// Registers a new future returned from an async import. - /// - /// This function is used when an async import is invoked by wasm. The - /// asynchronous import is represented as a future and when the future - /// completes it needs to call the completion callback in WebAssembly. The - /// invocation of the completion callback is represented by the output of - /// the future here, the `ImportResult` which is a closure that takes a - /// store context and invokes WebAssembly (further in an async fashion). - /// - /// Note that this doesn't actually do anything, it simply enqueues the - /// future internally. The future will actually be driven from the - /// `wait_for_async_export` function below. - pub fn register_async_import( + pub fn spawn_import( + future: impl Future> + Send + 'static, + ) -> Result<(), Trap> { + Self::with(|cx| { + let sender = cx.sender.clone(); + // Register a new pending import for the currently executing wasm + // coroutine. This will ensure that full completion of this + // coroutine is delayed until this import is resolved. + let coroutine_id = cx.cur_wasm_coroutine; + let mut coroutines = cx.coroutines.borrow_mut(); + let coroutine = coroutines + .get_mut(&coroutine_id) + .ok_or_else(|| Trap::new("cannot call async import from non-async export"))?; + let pending_import_id = coroutine.pending_imports.next_id(); + coroutine.num_pending_imports += 1; + + // Note that `tokio::spawn` is used here to allow the `future` for + // this import to execute in parallel, not just concurrently. The + // result is re-acquired via sending a message on our internal + // channel. + let task = tokio::spawn(async move { + let import_result = future.await; + + // If the main task has exited for some reason then it'll never + // receive our result, but that's ok since we're trying to + // complete a wasm import and if the main task isn't there to + // receive it there's nothing else to do with this result. This + // being an error is theoretically possible but should be rare. + if let Some(sender) = sender.upgrade() { + let send_result = sender + .send(Message::FinishImport( + import_result, + coroutine_id, + pending_import_id, + )) + .await; + drop(send_result); + } + }); + + let id = coroutine.pending_imports.insert(task); + assert_eq!(id, pending_import_id); + Ok(()) + }) + } + + async fn run( &mut self, - future: impl Future> + Send + 'static, + store: &mut StoreContextMut<'_, T>, + mut receiver: Receiver>, ) { - self.imports.push(Box::pin(future)); - } + // Infinitely process messages on `receiver` which represent events such as + // requests to invoke an export or completion of an import which results in + // execution of a completion callback. + while let Some(msg) = receiver.recv().await { + let coroutines = self.coroutines.get_mut(); + let (to_execute, coroutine_id) = match msg { + // This message is the start of a new task ("coroutine" in + // interface-types-vernacular) so we allocate a new task in our + // slab and set up its state. + // + // Note that we spawn a "helper" task here to send a message to + // our channel when the `sender` specified here is disconnected. + // That scenario means that this coroutine is cancelled and by + // sending a message into our channel we can start processing + // that. + Message::Execute(run, complete, sender) => { + let unique_id = self.cur_unique_id.get(); + self.cur_unique_id.set(unique_id + 1); + let coroutine_id = coroutines.next_id(unique_id); + let my_sender = self.sender.clone(); + let await_close_sender = sender.clone(); + let cancel_task = tokio::spawn(async move { + await_close_sender.closed().await; + // if the main task is gone one way or another we ignore + // the error here since no one's going to receive it + // anyway and all relevant work should be cancelled. + if let Some(sender) = my_sender.upgrade() { + drop(sender.send(Message::Cancel(coroutine_id)).await); + } + }); + coroutines.insert(Coroutine { + unique_id, + complete: Some(complete), + sender, + pending_imports: Slab::default(), + num_pending_imports: 0, + cancel_task: Some(cancel_task), + }); + (ToExecute::Start(run, coroutine_id.slab_index), coroutine_id) + } - /// Blocks on the completion of an asynchronous export. - /// - /// This function is used to await the result of an async export. In other - /// words this is used to wait for wasm to invoke the completion callback - /// with the `async_cx` specified. - /// - /// This will "block" for one of two reasons: - /// - /// * First is that an async import was called and the wasm's completion - /// callback wasn't called yet. In this scenario this function will block - /// on the completion of the async import. - /// - /// * Second is the execution of the wasm's own import completion callback. - /// This execution of WebAssembly may be asynchronous due to things like - /// fuel context switching or similar. - /// - /// This function invokes WebAssembly within `cx` and will not return until - /// the completion callback for `async_cx` is invoked. When the completion - /// callback is invoked the results of the callback are written into - /// `results`. The `get_state` method is used to extract an `Async` from - /// the store state within `cx`. - /// - /// This returns `Ok(())` when the completion callback was successfully - /// invoked, but it may also return `Err(trap)` if a trap happens while - /// executing a wasm completion callback for an import. - pub async fn call_async_export( - cx: &mut StoreContextMut<'_, T>, - results: &mut [i64], - get_state: &(dyn Fn(&mut T) -> &mut Async + Send + Sync), - invoke_wasm: impl for<'a> FnOnce( - &'a mut StoreContextMut<'_, T>, - i32, - ) - -> Pin> + Send + 'a>>, - ) -> Result<(), Trap> { - // First register a new export happening in our slab of running - // `exports` futures. - // - // NB: at this time due to take `&mut StoreContextMut` as an argument to - // this function it means that the size of `exports` is at most one. In - // the future this will probably take some sort of async mutex and only - // hold the mutex when wasm is running to allow concurrent execution of - // wasm. - let async_cx = get_state(cx.data_mut()).exports.insert(FutureState { - results: vec![0; results.len()], - done: false, - }); + // This message means that we need to execute `run` specified + // which is a "non blocking"-in-the-coroutine-sense wasm + // function. This is basically "go run that single callback" and + // is currently only used for things like resource destructors. + // These aren't allowed to call blocking functions and a trap is + // generated if they try to call a blocking function (since + // there isn't a coroutine set up). + // + // Note that here we avoid allocating a coroutine entirely since + // this isn't actually a coroutine, which means that any attempt + // to call a blocking function will be met with failure (a + // trap). Additionally note that the actual execution of the + // wasm here is select'd against the closure of the `sender` + // here as well, since if the runtime becomes disinterested in + // the result of this async call we can interrupt and abort the + // wasm. + // + // Finally note that if the wasm completes but we fail to send + // the result of the wasm to the receiver then we ignore the + // error since that was basically a race between wasm exiting + // and the sender being closed. + // + // TODO: should this dropped result/error get logged/processed + // somewhere? + Message::RunNoCoroutine(run, sender) => { + tokio::select! { + r = tls::scope(self, run(store)) => { + let is_trap = r.is_err(); + let _ = sender.send(r); - // Once the registration is made we immediately construct the - // `WaitForAsyncExport` helper struct. The destructor of this struct - // will forcibly remove the registration we just made above to prevent - // leaking anything if the wasm future is dropped and forgotten about. - let waiter = WaitForAsyncExport { - cx, - async_cx, - get_state, - }; + // Shut down this reactor if a trap happened because + // the instance is now in an indeterminate state. + if is_trap { + break; + } + } + _ = sender.closed() => break, + } + continue; + } - // Now that things are set up this is the original invocation of - // WebAssembly. This invocation is itself asynchronous hence we await - // the result here. - invoke_wasm(waiter.cx, async_cx as i32).await?; - - // Once we've invoked the export then it's our job to wait for the - // `async_export_done` function to get invoked. That happens here as we - // observer the state of `async_cx` within `state.exports`. If it's not - // finished yet then that means that we need to wait for the next of a - // set of futures to complete (those in `state.imports`), which is - // deferred to the `WaitForNextFuture` helper struct. - loop { - let state = (waiter.get_state)(waiter.cx.data_mut()); - if state.exports.get(async_cx).unwrap().done { + // This message indicates that an import has completed and + // the completion callback for the wasm must be executed. + // This, plus the serialization of the arguments into wasm + // according to the canonical ABI, is represented by + // `run`. + // + // Note, though, that in some cases we don't actually run + // the completion callback. For example if a previous + // completion callback for this wasm task has failed with a + // trap we don't continue to run completion callbacks for + // the wasm task. This situation is indicated when the + // coroutine is not actually present in our `coroutines` + // list, so we do a lookup here before allowing execution. When + // the coroutine isn't present we simply skip this message which + // will run destructors for any relevant host values. + Message::FinishImport(run, coroutine_id, import_id) => { + let coroutine = match coroutines.get_mut(&coroutine_id) { + Some(c) => c, + None => continue, + }; + coroutine.pending_imports.remove(import_id).unwrap(); + coroutine.num_pending_imports -= 1; + (ToExecute::Callback(run), coroutine_id) + } + + // This message indicates that the specified coroutine has been + // cancelled, meaning that the sender which would send back the + // result of the coroutine is now a closed channel that we can + // no longer send a message along. Our response to this is to + // remove the coroutine, and its destructor will trigger further + // cancellation if necessary. + // + // Note that this message may race with the actual completion of + // the coroutine so we don't assert that the ID specified here + // is actually in our list. If a coroutine is removed though we + // assume that the wasm is now in an indeterminate state which + // results in aborting this reactor task. If nothing is removed + // then we assume the race was properly resolved and we skip + // this message. + Message::Cancel(coroutine_id) => { + if coroutines.remove(&coroutine_id).is_some() { + break; + } + continue; + } + }; + + // Actually execute the WebAssembly callback. The call to + // `to_execute.run` here is what will actually execute WebAssembly + // asynchronously, and note that it's also executed within a + // `tls::scope` to ensure that the `tls::with` function will work + // for the duration of the future. + // + // Also note, though, that we want to be able to cancel the + // execution of this WebAssembly if the caller becomes disinterested + // in the result. This happens by using the `closed()` method on the + // channel back to the sender, and if that happens we abort wasm + // entirely and abort the whole coroutine by removing it later. + // + // If this wasm operations is aborted then we exit this loop + // entirely and tear down this reactor task. That triggers + // cancellation of all spawned sub-tasks and sibling coroutines, and + // the rationale for this is that we zapped wasm while it was + // executing so it's now in an indeterminate state and not one that + // we can resume. + // + // TODO: this is a `clone()`-per-callback which is probably cheap, + // but this is also a sort of wonky setup so this may wish to change + // in the future. + let cancel_signal = coroutines.get_mut(&coroutine_id).unwrap().sender.clone(); + let prev_coroutine_id = mem::replace(&mut self.cur_wasm_coroutine, coroutine_id); + let result = tokio::select! { + r = tls::scope(self, to_execute.run(store)) => r, + _ = cancel_signal.closed() => break, + }; + self.cur_wasm_coroutine = prev_coroutine_id; + + let coroutines = self.coroutines.get_mut(); + let coroutine = coroutines.get_mut(&coroutine_id).unwrap(); + if let Err(trap) = result { + // Our WebAssembly callback trapped. That means that this + // entire coroutine is now in a failure state. No further + // wasm callbacks will be invoked and the coroutine is + // removed from out internal list to invoke the failure + // callback, informing what trap caused the failure. + // + // Note that this reopens `coroutine_id.slab_index` to get + // possibly reused, intentionally so, which is why + // `CoroutineId` is a form of generational ID which is + // resilient to this form of reuse. In other words when we + // remove the result here if in the future a pending import + // for this coroutine completes we'll simply discard the + // message. + // + // Any error in sending the trap along the coroutine's channel + // is ignored since we can race with the coroutine getting + // dropped. + // + // TODO: should the trap still be sent somewhere? Is this ok to + // simply ignore? + // + // Finally we exit the reactor in this case because traps + // typically represent fatal conditions for wasm where we can't + // really resume since it may be in an indeterminate state (wasm + // can't handle traps itself), so after we inform the original + // coroutine of the original trap we break out and cancel all + // further execution. + let coroutine = coroutines.remove(&coroutine_id).unwrap(); + let _ = coroutine.sender.send(Err(trap)); break; + } else if coroutine.num_pending_imports == 0 { + // Our wasm callback succeeded, and there are no pending + // imports for this coroutine. + // + // In this state it means that the coroutine has completed + // since no further work can possibly happen for the + // coroutine. This means that we can safely remove it from + // our internal list. + // + // If the coroutine's completion wasn't ever signaled, + // however, then that indicates a bug in the wasm code + // itself. This bug is translated into a trap which will get + // reported to the caller to inform the original invocation + // of the export that the result of the coroutine never + // actually came about. + // + // Note that like above a failure to send a trap along the + // channel is ignored since we raced with the caller becoming + // disinterested in the result which is fine to happen at any + // time. + // + // TODO: should the trap still be sent somewhere? Is this ok to + // simply ignore? + // + // TODO: should this tear down the reactor as well, despite it + // being a synthetically created trap? + let coroutine = coroutines.remove(&coroutine_id).unwrap(); + if coroutine.complete.is_some() { + let _ = coroutine + .sender + .send(Err(Trap::new("completion callback never called"))); + } + } else { + // Our wasm callback succeeded, and there are pending + // imports for this coroutine. + // + // This means that the coroutine isn't finished yet so we + // simply turn the loop and wait for something else to + // happen. We'll next be executing WebAssembly when one of + // the coroutine's imports finish. } + } + } - let result = WaitForNextFuture { state }.await?; + pub async fn async_export_done( + mut caller: Caller<'_, T>, + task_id: i32, + ptr: i32, + mem: Memory, + ) -> Result<(), Trap> { + // Extract the completion callback registered in Rust for the `task_id`. + // This will deserialize all of the canonical ABI results specified by + // `ptr`, and presumably send the result on some sort of channel back to the + // task that originally invoked the wasm. + let task_id = task_id as u32; + let complete = Self::with(|cx| { + let mut coroutines = cx.coroutines.borrow_mut(); + let coroutine = coroutines + .slab + .get_mut(task_id) + .ok_or_else(|| Trap::new("async context not valid"))?; + coroutine + .complete + .take() + .ok_or_else(|| Trap::new("async context not valid")) + })?; + + // Note that this is an async-enabled call to allow `call_async` for things + // like fuel in case the completion callback needs to invoke wasm + // asychronously for things like deallocation. + let result = complete(&mut caller.as_context_mut(), ptr, mem).await?; + + // With the final result of the coroutine we send this along the channel + // back to the original task which was waiting for the result. Note that + // this send may fail if we're racing with cancellation of this task, + // and if cancellation happens we translate that to a trap to ensure + // that wasm is cleaned up quickly (as oppose to waiting for the next + // yield point where it should get cleaned up anyway). + Self::with(|cx| { + let mut coroutines = cx.coroutines.borrow_mut(); + let coroutine = coroutines.slab.get_mut(task_id).unwrap(); + coroutine + .sender + .send(Ok(result)) + .map_err(|_| Trap::new("task has been cancelled")) + }) + } + + // TODO: this is a pretty bad interface to manage the table with... + pub fn function_table() -> Table { + Self::with(|cx| cx.function_table) + } + + fn with(f: impl FnOnce(&Async) -> R) -> R { + tls::with(|cx| f(cx.downcast_ref().unwrap())) + } +} - // TODO: while this is executing we're not polling the futures - // inside of `Async`, is that ok? Will this need to poll the - // future inside the state in parallel with the async wasm here to - // ensure that things work out as expected. - result(waiter.cx).await?; +impl Coroutines { + fn next_id(&self, unique_id: u64) -> CoroutineId { + CoroutineId { + unique_id, + slab_index: self.slab.next_id(), } + } - // If we're here then that means that the `async_export_done` function - // was called, which means taht we can copy the results into the final - // `results` slice. - let future = (waiter.get_state)(waiter.cx.data_mut()) - .exports - .get(async_cx) - .unwrap(); - assert_eq!(results.len(), future.results.len()); - results.copy_from_slice(&future.results); - return Ok(()); - - /// This is a helper struct used to remove `async_cx` from `cx` on drop. - /// - /// This ensures that if any wasm returns a trap or if the future itself - /// is entirely dropped that we properly clean things up and don't leak - /// the export's async status and allow it to accidentally be - /// "completed" by someone else. - struct WaitForAsyncExport<'a, 'b, 'c, 'd, T> { - cx: &'a mut StoreContextMut<'b, T>, - async_cx: u32, - get_state: &'c (dyn Fn(&mut T) -> &mut Async + Send + Sync + 'd), + fn insert(&mut self, coroutine: Coroutine) -> CoroutineId { + let unique_id = coroutine.unique_id; + let slab_index = self.slab.insert(coroutine); + CoroutineId { + unique_id, + slab_index, } + } - impl Drop for WaitForAsyncExport<'_, '_, '_, '_, T> { - fn drop(&mut self) { - (self.get_state)(self.cx.data_mut()) - .exports - .remove(self.async_cx) - .unwrap(); - } + fn get_mut(&mut self, id: &CoroutineId) -> Option<&mut Coroutine> { + let entry = self.slab.get_mut(id.slab_index)?; + if entry.unique_id == id.unique_id { + Some(entry) + } else { + None } } - /// Returns the previously configured function table via `set_table`. - // - // TODO: this probably isn't the right interface, need to figure out a - // better way to pass this table (and other intrinsics to the wasm instance) - // around. - pub fn table(&self) -> Table { - self.table.expect("table wasn't set yet") + fn remove(&mut self, id: &CoroutineId) -> Option> { + let entry = self.slab.get_mut(id.slab_index)?; + if entry.unique_id == id.unique_id { + self.slab.remove(id.slab_index) + } else { + None + } } +} - /// Stores a table to later get returned by `table()`. - // - // TODO: like `table`, this probably isn't the right interface - pub fn set_table(&mut self, table: Table) { - assert!(self.table.is_none(), "table already set"); - self.table = Some(table); +impl Drop for Coroutine { + fn drop(&mut self) { + // When a coroutine is removed and dropped from the internal list of + // coroutines then we're no longer interested in any of the results for + // any of the spawned tasks. This means we can proactively cancel + // anything that this coroutine might be waiting on (imported functions) + // plus the task that's used to send a message to the "main loop" on + // cancellation. + if let Some(task) = &self.cancel_task { + task.abort(); + } + for task in self.pending_imports.iter() { + task.abort(); + } } } -struct WaitForNextFuture<'a, T> { - state: &'a mut Async, +enum ToExecute { + Start(Start, u32), + Callback(Callback), } -impl Future for WaitForNextFuture<'_, T> { - type Output = Result, Trap>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - // If there aren't any active imports then that means that WebAssembly - // hasn't invoked the completion callback for an export but it also - // didn't call any imports to block on anything. That means that the - // completion callback will never be called which is an error, so - // simulate a trap happening and report it to the embedder. - if self.state.imports.len() == 0 { - return Poll::Ready(Err(Trap::new( - "wasm isn't waiting on any imports but export completion callback wasn't called", - ))); +impl ToExecute { + async fn run(self, store: &mut StoreContextMut<'_, T>) -> Result<(), Trap> { + match self { + ToExecute::Start(cb, val) => cb(store, val).await, + ToExecute::Callback(cb) => cb(store).await, } + } +} + +pub struct AsyncHandle { + sender: Arc>>, +} - // If we have imports then we'll "block" this current future on one of - // the sub-futures within `self.state.imports`. By polling at least one - // future that means we'll get re-awakened whenever the sub-future is - // ready and we'll check here again. +impl AsyncHandle { + /// Executes a new WebAssembly in the "reactor" that this handle is + /// connected to. + /// + /// This function will execute `start` as the initial callback for the + /// asynchronous WebAssembly to be executed. This closure receives the + /// `Store` via a handle as well as the coroutine ID that's associated + /// with this new coroutine. It's expected that this callback produces a + /// future which represents the execution of the initial WebAssembly + /// callback, handling all canonical ABI translations internally. + /// + /// The second `complete` callback is invoked when the wasm indicates that + /// it's finished executing (the `async_export_done` intrinsic wasm + /// import). This is expected to produce the final result of the function. + /// + /// This function is an `async` function which is expected to be `.await`'d. + /// If this function's future is dropped or cancelled then the coroutine + /// that this executes will also be dropped/cancelled. If a wasm trap + /// happens then that will be returned here and the coroutine will be + /// cancelled. + /// + /// Note that it is possible for wasm to invoke the completion callback and + /// still trap. In situations like that the trap is returned from this + /// function. + pub async fn execute( + &self, + start: impl for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + u32, + ) + -> Pin> + Send + 'a>> + + Send + + 'static, + complete: impl for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + i32, + wasmtime::Memory, + ) + -> Pin> + Send + 'a>> + + Send + + 'static, + ) -> Result + where + U: Send + 'static, + { + // Note that this channel should have at most 2 messages ever sent on it + // so it's easier to deal with an unbounded channel rather than a + // bounded channel. + let (tx, mut rx) = mpsc::unbounded_channel(); + + // Send a request to our "reactor task" which indicates that we'd like + // to start execution of a new WebAssembly coroutine. The start/complete + // callbacks provided here are the implementation of the canonical ABI + // for this particular coroutine. // - // If anything is ready we return the first item that we get. This means - // that if any import has its result ready then we propagate the result - // outwards which will invoke the completion callback for that import's - // execution. If, after running the import completion callback, the - // export completion callback still hasn't been invoked then we'll come - // back here and look for other finished imports. + // Note that failure to send here turns into a trap. This can happen + // when the reactor task is torn down, taking the receiver with it. When + // wasm traps this can happen, and this means that the wasm is no longer + // present for execution so we continue to propagate traps with a new + // synthetic trap here. + self.sender + .send(Message::Execute( + Box::new(start), + Box::new(move |store, ptr, mem| { + Box::pin(async move { + let val = complete(store, ptr, mem).await?; + Ok(Box::new(val) as Box) + }) + }), + tx, + )) + .await + .map_err(|_| Trap::new("wasm reactor task has gone away -- sibling trap?"))?; + + // This is a bit of a tricky dance. Once WebAssembly is requested to be + // executed there are a number of outcomes that can happen here: // - // TODO: this can theoretically exhibit quadratic behavior if the wasm - // calls tons and tons of imports. This should use a more intelligent - // future-polling mechanism to avoid re-polling everything we're not - // interested in every time. - for (i, import) in self.state.imports.iter_mut().enumerate() { - match import.as_mut().poll(cx) { - Poll::Ready(value) => { - drop(self.state.imports.swap_remove(i)); - return Poll::Ready(Ok(value)); - } - Poll::Pending => {} - } + // 1. The WebAssembly coroutine could complete successfully. This means + // that it eventually invokes the completion callback and no traps + // happened. In this case the completion value is sent on the channel + // and then when the wasm is all finished then the sending half of + // the channel is destroyed. + // + // 2. The WebAssembly coroutine could trap before invoking its + // completion callback. In this scenario the first message is a trap + // and there will be no second message because the coroutine is + // destroyed after a trap. + // + // 3. The WebAssembly coroutine could give us a completed value + // successfully, but then afterwards may trap. In this situation the + // first message received is the completed value of the coroutine, + // and the second message will be the trap that occurred. + // + // 4. Finally a the reactor coudl get torn down because of wasm hitting + // a trap (leaving it in an indeterminate state) or a bug in the + // reactor that panicked. + // + // Overall this leads us to two separate `.await` calls. The first + // `.await` receives the first message and "propagates" traps in (4) + // assuming that the reactor is gone due to a wasm trap. This first + // result is `Ok` in (1)/(3), and it's `Err` in the case of (2). + // + // The second `.await` will wait for the full completion of the + // coroutine in (1) but then receive `None`, should immediately receive + // `None` for (2), and will receive a trap with (3). In all situations + // we are guaranteed that after the second message the coroutine is + // deleted and cleaned up. + // + // Note that receiving `Ok` as the second message is not possible + // because the completion callback is invoked at most once and it's only + // invoked if no trap has happened, which means that a successful + // completion callback is guaranteed to be the first message. + // + // TODO: the time that passes between the first `.await` and the second + // `.await` is not exposed with this function's signature. This is + // simply a bland async function that returns the result, but embedders + // may want to process a successful result which later traps. This API + // should probably be redesigned to accommodate this. + let result = rx + .recv() + .await + .ok_or_else(|| Trap::new("wasm reactor task has gone away -- sibling trap?"))?; + match rx.recv().await { + Some(Err(trap)) => Err(trap), + Some(Ok(_)) => unreachable!(), + None => result.map(|e| *e.downcast().ok().unwrap()), } + } - Poll::Pending + pub async fn run_no_coroutine( + &self, + run: impl for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + ) + -> Pin> + Send + 'a>> + + Send + + 'static, + ) -> Result + where + U: Send + 'static, + { + let (tx, mut rx) = mpsc::unbounded_channel(); + self.sender + .send(Message::RunNoCoroutine( + Box::new(move |store| { + Box::pin(async move { + let val = run(store).await?; + Ok(Box::new(val) as Box) + }) + }), + tx, + )) + .await + .ok() + .expect("reactor task should be present"); + rx.recv() + .await + .unwrap() + .map(|e| *e.downcast().ok().unwrap()) } } -fn _assert() { - fn _assert_send(_: &T) {} +mod tls { + use std::any::Any; + use std::cell::Cell; + use std::future::Future; + use std::pin::Pin; + use std::task::{Context, Poll}; + + thread_local!(static CUR: Cell<*const (dyn Any + Send)> = Cell::new(&0)); + + pub async fn scope( + val: &mut (dyn Any + Send + 'static), + future: impl Future, + ) -> T { + struct SetTls<'a, F> { + val: &'a mut (dyn Any + Send + 'static), + future: F, + } + + impl Future for SetTls<'_, F> { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let (val, future) = unsafe { + let inner = self.get_unchecked_mut(); + ( + Pin::new_unchecked(&mut inner.val), + Pin::new_unchecked(&mut inner.future), + ) + }; + + let x: &&mut (dyn Any + Send + 'static) = val.as_ref().get_ref(); + set(&**x, || future.poll(cx)) + } + } + + SetTls { val, future }.await + } + + pub fn set(val: &(dyn Any + Send + 'static), f: impl FnOnce() -> R) -> R { + return CUR.with(|slot| { + let prev = slot.replace(val); + let _reset = Reset(slot, prev); + f() + }); + + struct Reset<'a, T: Copy>(&'a Cell, T); + + impl Drop for Reset<'_, T> { + fn drop(&mut self) { + self.0.set(self.1); + } + } + } - fn _test(x: &mut StoreContextMut<'_, ()>) { - let f = Async::<()>::call_async_export(x, &mut [], &|_| panic!(), |_, _| panic!()); - _assert_send(&f); + pub fn with(f: impl FnOnce(&(dyn Any + Send)) -> R) -> R { + CUR.with(|slot| { + let val = slot.get(); + unsafe { f(&*val) } + }) } } diff --git a/crates/wasmtime/src/lib.rs b/crates/wasmtime/src/lib.rs index 3610de584..9db4d00dc 100644 --- a/crates/wasmtime/src/lib.rs +++ b/crates/wasmtime/src/lib.rs @@ -7,7 +7,7 @@ pub use tracing_lib as tracing; #[doc(hidden)] pub use {anyhow, bitflags, wasmtime}; -pub use futures::{Async, HostFuture}; +pub use futures::HostFuture; mod error; pub mod exports; @@ -35,7 +35,7 @@ unsafe impl Sync for RawMemory {} #[doc(hidden)] pub mod rt { - pub use crate::futures::Async; + pub use crate::futures::{Async, AsyncHandle}; use crate::slab::Slab; use crate::{Endian, Le}; use std::mem; @@ -262,25 +262,60 @@ pub mod rt { } } + use crate::futures::Callback; use std::future::Future; use std::pin::Pin; - /// Helper function to assist with type inference for async bits - pub fn pin_result_future<'a>( - future: impl Future> + Send + 'a, - ) -> Pin> + Send + 'a>> { - Box::pin(future) + pub fn infer_start(callback: F) -> F + where + F: for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + u32, + ) + -> Pin> + Send + 'a>> + + Send + + 'static, + { + callback + } + + pub fn infer_complete(callback: F) -> F + where + F: for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + i32, + Memory, + ) + -> Pin> + Send + 'a>> + + Send + + 'static, + { + callback } - /// Helper function to assist with type inference for async bits - pub fn box_future_callback( + pub fn infer_standalone(callback: F) -> F + where + F: for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + ) + -> Pin> + Send + 'a>> + + Send + + 'static, + { + callback + } + + pub fn box_callback( callback: impl for<'a> FnOnce( &'a mut StoreContextMut<'_, T>, ) -> Pin> + Send + 'a>> + Send + 'static, - ) -> crate::futures::ImportResult { + ) -> Callback { Box::new(callback) } + + #[cfg(feature = "async")] + pub use tokio::sync::mpsc; } diff --git a/crates/wasmtime/src/slab.rs b/crates/wasmtime/src/slab.rs index 29db0c9ba..a18ec52d7 100644 --- a/crates/wasmtime/src/slab.rs +++ b/crates/wasmtime/src/slab.rs @@ -12,6 +12,10 @@ enum Entry { } impl Slab { + pub fn next_id(&self) -> u32 { + self.next as u32 + } + pub fn insert(&mut self, item: T) -> u32 { if self.next == self.storage.len() { self.storage.push(Entry::Empty { @@ -54,6 +58,13 @@ impl Slab { } } } + + pub fn iter(&self) -> impl Iterator { + self.storage.iter().filter_map(|i| match i { + Entry::Full(t) => Some(t), + Entry::Empty { .. } => None, + }) + } } impl Default for Slab { diff --git a/tests/codegen/async_functions.witx b/tests/codegen/async_functions.witx index 53f0e531c..b52d1821f 100644 --- a/tests/codegen/async_functions.witx +++ b/tests/codegen/async_functions.witx @@ -12,5 +12,8 @@ resource Response { body: async function() -> list status: function() -> u32 status_text: function() -> string - } + +resource some_resource + +resource_to_resource: async function(x: some_resource) -> some_resource diff --git a/tests/runtime/async_functions/exports.witx b/tests/runtime/async_functions/exports.witx index 5f5010bb2..78724608f 100644 --- a/tests/runtime/async_functions/exports.witx +++ b/tests/runtime/async_functions/exports.witx @@ -2,3 +2,11 @@ thunk: async function() allocated_bytes: function() -> u32 test_concurrent: async function() + +concurrent_export: async function(idx: u32) + +infinite_loop_async: async function() +infinite_loop: function() + +call_import_then_trap: async function() +call_infinite_import: async function() diff --git a/tests/runtime/async_functions/host.rs b/tests/runtime/async_functions/host.rs index 0f5fef5a5..f44a98c95 100644 --- a/tests/runtime/async_functions/host.rs +++ b/tests/runtime/async_functions/host.rs @@ -1,26 +1,38 @@ -witx_bindgen_wasmtime::import!("./tests/runtime/async_functions/imports.witx"); - use anyhow::Result; -use futures_channel::oneshot::{channel, Receiver, Sender}; -use futures_util::FutureExt; use imports::*; -use wasmtime::{Engine, Linker, Module, Store}; -use witx_bindgen_wasmtime::{Async, HostFuture}; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio::sync::oneshot::{channel, Receiver, Sender}; +use wasmtime::{Config, Engine, Linker, Module, Store, TrapCode}; +use witx_bindgen_wasmtime::HostFuture; + +witx_bindgen_wasmtime::export!({ + paths: ["./tests/runtime/async_functions/imports.witx"], + async: *, +}); #[derive(Default)] pub struct MyImports { - thunk_hit: bool, unblock1: Option>, unblock2: Option>, unblock3: Option>, wait1: Option>, wait2: Option>, wait3: Option>, + + concurrent1: Option>, + concurrent2: Option>, + + iloop_close_on_drop: Option>, + iloop_entered: Option>, + + import_cancelled_signal: Option>, + import_cancelled_entered: Vec>, } +#[witx_bindgen_wasmtime::async_trait] impl Imports for MyImports { fn thunk(&mut self) -> HostFuture<()> { - self.thunk_hit = true; Box::pin(async { async {}.await; }) @@ -28,7 +40,7 @@ impl Imports for MyImports { fn concurrent1(&mut self, a: u32) -> HostFuture { assert_eq!(a, 1); - self.unblock1.take().unwrap().send(()).unwrap(); + self.unblock1.take(); let wait = self.wait1.take().unwrap(); Box::pin(async move { wait.await.unwrap(); @@ -38,7 +50,7 @@ impl Imports for MyImports { fn concurrent2(&mut self, a: u32) -> HostFuture { assert_eq!(a, 2); - self.unblock2.take().unwrap().send(()).unwrap(); + self.unblock2.take(); let wait = self.wait2.take().unwrap(); Box::pin(async move { wait.await.unwrap(); @@ -48,60 +60,116 @@ impl Imports for MyImports { fn concurrent3(&mut self, a: u32) -> HostFuture { assert_eq!(a, 3); - self.unblock3.take().unwrap().send(()).unwrap(); + self.unblock3.take(); let wait = self.wait3.take().unwrap(); Box::pin(async move { wait.await.unwrap(); a + 10 }) } -} -witx_bindgen_wasmtime::export!("./tests/runtime/async_functions/exports.witx"); + fn concurrent_export_helper(&mut self, idx: u32) -> HostFuture<()> { + let rx = if idx == 0 { + self.concurrent1.take().unwrap() + } else { + self.concurrent2.take().unwrap() + }; + Box::pin(async move { + drop(rx.await); + }) + } -fn run(wasm: &str) -> Result<()> { - struct Context { - wasi: wasmtime_wasi::WasiCtx, - imports: MyImports, - async_: Async, - exports: exports::ExportsData, + async fn iloop_entered(&mut self) { + drop(self.iloop_entered.take()); + } + + fn import_to_cancel(&mut self) -> HostFuture<()> { + let signal = self.import_cancelled_signal.take(); + drop(self.import_cancelled_entered.pop()); + Box::pin(async move { + tokio::time::sleep(Duration::new(1_000, 0)).await; + drop(signal); + }) } +} + +witx_bindgen_wasmtime::import!({ + async: *, + paths: ["./tests/runtime/async_functions/exports.witx"], +}); - let engine = Engine::default(); +struct Context { + wasi: wasmtime_wasi::WasiCtx, + imports: MyImports, + exports: exports::ExportsData, +} + +fn run(wasm: &str) -> Result<()> { + let mut config = Config::new(); + config.async_support(true); + config.consume_fuel(true); + let engine = Engine::new(&config)?; let module = Module::from_file(&engine, wasm)?; let mut linker = Linker::::new(&engine); - imports::add_imports_to_linker(&mut linker, |cx| (&mut cx.imports, &mut cx.async_))?; + imports::add_to_linker(&mut linker, |cx| &mut cx.imports)?; wasmtime_wasi::add_to_linker(&mut linker, |cx| &mut cx.wasi)?; + exports::Exports::add_to_linker(&mut linker, |cx| &mut cx.exports)?; - let mut store = Store::new( - &engine, - Context { - wasi: crate::default_wasi(), - imports: MyImports::default(), - async_: Default::default(), - exports: Default::default(), - }, - ); - let (exports, _instance) = - exports::Exports::instantiate(&mut store, &module, &mut linker, |cx| { - (&mut cx.exports, &mut cx.async_) - })?; + let rt = tokio::runtime::Runtime::new()?; + rt.block_on(run_async(&engine, &module, &linker)) +} + +async fn run_async(engine: &Engine, module: &Module, linker: &Linker) -> Result<()> { + let instantiate = |imports| async { + let mut store = Store::new( + &engine, + Context { + wasi: crate::default_wasi(), + imports, + exports: Default::default(), + }, + ); + store.add_fuel(10_000)?; + store.out_of_fuel_async_yield(u64::MAX, 10_000); + let instance = linker.instantiate_async(&mut store, &module).await?; + exports::Exports::new(store, &instance, |cx| &mut cx.exports) + }; - let import = &mut store.data_mut().imports; + let mut import = MyImports::default(); // Initialize various channels which we use as synchronization points to // test the concurrent aspect of async wasm. The first channels here are // used to wait for host functions to get entered by the wasm. - let (a, mut wait1) = channel(); + let (a, wait1) = channel(); import.unblock1 = Some(a); - let (a, mut wait2) = channel(); + let (a, wait2) = channel(); import.unblock2 = Some(a); - let (a, mut wait3) = channel(); + let (a, wait3) = channel(); import.unblock3 = Some(a); + let (tx, mut rx) = mpsc::channel::<()>(10); + let tx2 = tx.clone(); + tokio::spawn(async move { + assert!(wait1.await.is_err()); + drop(tx2); + }); + let tx2 = tx.clone(); + tokio::spawn(async move { + assert!(wait2.await.is_err()); + drop(tx2); + }); + tokio::spawn(async move { + assert!(wait3.await.is_err()); + drop(tx); + }); - // This second set of channels are used to unblock host futures that wasm - // calls, simulating work that returns back to the host and takes some time - // to complete. + let (concurrent1, a) = channel(); + let (concurrent2, b) = channel(); + import.concurrent1 = Some(a); + import.concurrent2 = Some(b); + + // This second set of channels are used to unblock host futures that + // wasm calls, simulating work that returns back to the host and takes + // some time to complete. let (unblock1, b) = channel(); import.wait1 = Some(b); let (unblock2, b) = channel(); @@ -109,40 +177,166 @@ fn run(wasm: &str) -> Result<()> { let (unblock3, b) = channel(); import.wait3 = Some(b); - futures_executor::block_on(async { - exports.thunk(&mut store).await?; - assert!(store.data_mut().imports.thunk_hit); - - let mut future = Box::pin(exports.test_concurrent(&mut store)).fuse(); - - // wait for all three concurrent methods to get entered. Note that we - // poll the `future` while we're here as well to run any callbacks - // inside as necessary, but it shouldn't ever finish. - let mut done = 0; - while done < 3 { - futures_util::select! { - _ = future => unreachable!(), - r = wait1 => { r.unwrap(); done += 1; } - r = wait2 => { r.unwrap(); done += 1; } - r = wait3 => { r.unwrap(); done += 1; } - } + let exports = instantiate(import).await?; + exports.thunk().await?; + + let future = exports.test_concurrent(); + tokio::pin!(future); + + // wait for all three concurrent methods to get entered, where once this + // happens they'll all drop the handles to `rx`, meaning that when + // entered we'll see the `rx` channel get closed. + tokio::select! { + _ = &mut future => unreachable!(), + r = rx.recv() => assert!(r.is_none()), + } + + // Now we can "complete" the async task that each function was waiting + // on. Our original future shouldn't be done until they're all complete. + unblock3.send(()).unwrap(); + unblock2.send(()).unwrap(); + unblock1.send(()).unwrap(); + future.await?; + + // Test concurrent exports can be invoked, here we call wasm + // concurrently twice and complete the second one first, ensuring the + // first one isn't finished, and then we complete the first and assert + // it's done. + let a = exports.concurrent_export(0); + tokio::pin!(a); + let b = exports.concurrent_export(1); + drop(concurrent2); + tokio::select! { + r = &mut a => panic!("got result {:?}", r), + r = b => r.unwrap(), + } + drop(concurrent1); + a.await.unwrap(); + + // Cancelling an infinite loop drops the reactor and the reactor doesn't + // execute forever. This will only work if `tx`, owned by the reactor, will + // get dropped when we cancel execution of the infinite loop. + let (tx, rx) = channel(); + let (tx2, rx2) = channel(); + let mut imports = MyImports::default(); + imports.iloop_close_on_drop = Some(tx); + imports.iloop_entered = Some(tx2); + let exports = instantiate(imports).await?; + { + let iloop = exports.infinite_loop(); + tokio::pin!(iloop); + // execute the iloop long enough to get into wasm and we'll get the + // signal when the `rx2` channel is closed. + tokio::select! { + _ = &mut iloop => unreachable!(), + r = rx2 => assert!(r.is_err()), } + } + assert!(rx.await.is_err()); + drop(exports); - // Now we can "complete" the async task that each function was waiting - // on. Our original future shouldn't be done until they're all complete. - unblock3.send(()).unwrap(); - futures_util::select! { - _ = future => unreachable!(), - default => {} + // Same as above, but an infinite loop in an async exported wasm function + let (tx, rx) = channel(); + let (tx2, rx2) = channel(); + let mut imports = MyImports::default(); + imports.iloop_close_on_drop = Some(tx); + imports.iloop_entered = Some(tx2); + let exports = instantiate(imports).await?; + { + let iloop = exports.infinite_loop_async(); + tokio::pin!(iloop); + // execute the iloop long enough to get into wasm and we'll get the + // signal when the `rx2` channel is closed. + tokio::select! { + _ = &mut iloop => unreachable!(), + r = rx2 => assert!(r.is_err()), } - unblock2.send(()).unwrap(); - futures_util::select! { - _ = future => unreachable!(), - default => {} + } + assert!(rx.await.is_err()); + drop(exports); + + // A trap from WebAssembly should result in cancelling all imported tasks. + // execute forever. This will only work if `tx`, owned by the reactor, will + // get dropped when we cancel execution of the infinite loop. + let (tx, rx) = channel(); + let mut imports = MyImports::default(); + imports.import_cancelled_signal = Some(tx); + let trap = instantiate(imports) + .await? + .call_import_then_trap() + .await + .unwrap_err(); + assert!( + trap.trap_code() == Some(TrapCode::UnreachableCodeReached), + "bad error: {}", + trap + ); + assert!(rx.await.is_err()); + + // Dropping the owned version of the bindings should recursively tear down + // the reactor task since it's got nothing else to do at that point. + let (tx, rx) = channel(); + let mut imports = MyImports::default(); + imports.import_cancelled_signal = Some(tx); + instantiate(imports).await?; + assert!(rx.await.is_err()); + + // Cancelling (dropping) the outer task transitively tears down the reactor + // and cancels imported tasks. + let (tx, rx) = channel(); + let (tx2, rx2) = channel(); + let mut imports = MyImports::default(); + imports.import_cancelled_signal = Some(tx); + imports.import_cancelled_entered.push(tx2); + let exports = instantiate(imports).await?; + { + let f = exports.call_infinite_import(); + tokio::pin!(f); + // execute the wasm long enough to get into it and we'll get the + // signal when the `rx2` channel is closed. + tokio::select! { + _ = &mut f => unreachable!(), + r = rx2 => assert!(r.is_err()), + } + } + assert!(rx.await.is_err()); + drop(exports); + + // With multiple concurrent exports if one of them is cancelled then they + // all get cancelled. + let (tx, rx) = channel(); + let (tx2, rx2) = channel(); + let mut imports = MyImports::default(); + imports.import_cancelled_entered.push(tx); + imports.import_cancelled_entered.push(tx2); + let exports = instantiate(imports).await?; + let a = exports.call_infinite_import(); + let b = exports.call_infinite_import(); + { + tokio::pin!(a); + { + tokio::pin!(b); + // Run this select twice to ensure both futures get into the import within wasm. + tokio::select! { + _ = &mut a => unreachable!(), + _ = &mut b => unreachable!(), + r = rx2 => assert!(r.is_err()), + } + tokio::select! { + _ = &mut a => unreachable!(), + _ = &mut b => unreachable!(), + r = rx => assert!(r.is_err()), + } + // ... `b` is now dropped here } - unblock1.send(()).unwrap(); - future.await?; + let err = a.await.unwrap_err(); + assert!( + err.to_string().contains("wasm reactor task has gone away"), + "bad error: {}", + err + ); + } + drop(exports); - Ok(()) - }) + Ok(()) } diff --git a/tests/runtime/async_functions/host.ts b/tests/runtime/async_functions/host.ts index 4381c7f76..038ebda88 100644 --- a/tests/runtime/async_functions/host.ts +++ b/tests/runtime/async_functions/host.ts @@ -20,6 +20,9 @@ async function run() { const [unblockConcurrent2, resolveUnblockConcurrent2] = promiseChannel(); const [unblockConcurrent3, resolveUnblockConcurrent3] = promiseChannel(); + const [unblockExport1, resolveUnblockExport1] = promiseChannel(); + const [unblockExport2, resolveUnblockExport2] = promiseChannel(); + const imports: Imports = { async thunk() { if (hit) { @@ -60,6 +63,22 @@ async function run() { console.log('concurrent3 returning to wasm'); return 13; }, + + async concurrentExportHelper(n) { + if (n === 0) { + await unblockExport1; + } else { + await unblockExport2; + } + }, + + iloopEntered() { + throw new Error('unsupported'); + }, + + importToCancel() { + throw new Error('unsupported'); + }, }; let instance: WebAssembly.Instance; addImportsToImports(importObj, imports, name => instance.exports[name]); @@ -99,6 +118,13 @@ async function run() { console.log('waiting on host functions'); await concurrentWasm; console.log('concurrent wasm finished'); + + const a = wasm.concurrentExport(0); + const b = wasm.concurrentExport(1); + resolveUnblockExport2(); + await b; + resolveUnblockExport1(); + await a; } async function some_helper() {} diff --git a/tests/runtime/async_functions/imports.witx b/tests/runtime/async_functions/imports.witx index 4cf3657cb..2f15d6513 100644 --- a/tests/runtime/async_functions/imports.witx +++ b/tests/runtime/async_functions/imports.witx @@ -3,3 +3,9 @@ thunk: async function() concurrent1: async function(a: u32) -> u32 concurrent2: async function(a: u32) -> u32 concurrent3: async function(a: u32) -> u32 + +concurrent_export_helper: async function(idx: u32) + +iloop_entered: function() + +import_to_cancel: async function() diff --git a/tests/runtime/async_functions/wasm.rs b/tests/runtime/async_functions/wasm.rs index 273722e36..e43407ff8 100644 --- a/tests/runtime/async_functions/wasm.rs +++ b/tests/runtime/async_functions/wasm.rs @@ -20,4 +20,27 @@ impl exports::Exports for Exports { assert_eq!(futures_util::join!(a2, a3, a1), (12, 13, 11)); } + + async fn concurrent_export(idx: u32) { + imports::concurrent_export_helper(idx).await + } + + async fn infinite_loop_async() { + imports::iloop_entered(); + loop {} + } + + fn infinite_loop() { + imports::iloop_entered(); + loop {} + } + + async fn call_import_then_trap() { + let _f = imports::import_to_cancel(); + std::arch::wasm32::unreachable(); + } + + async fn call_infinite_import() { + imports::import_to_cancel().await; + } } diff --git a/tests/runtime/async_raw/exports.witx b/tests/runtime/async_raw/exports.witx new file mode 100644 index 000000000..c2fce2428 --- /dev/null +++ b/tests/runtime/async_raw/exports.witx @@ -0,0 +1,12 @@ +complete_immediately: async function() +completion_not_called: async function() +complete_twice: async function() +complete_then_trap: async function() +assert_coroutine_id_zero: async function() + +not_async_export_done: function() +not_async_calls_async: function() + +import_callback_null: async function() +import_callback_wrong_type: async function() +import_callback_bad_index: async function() diff --git a/tests/runtime/async_raw/host.rs b/tests/runtime/async_raw/host.rs new file mode 100644 index 000000000..649b07a95 --- /dev/null +++ b/tests/runtime/async_raw/host.rs @@ -0,0 +1,161 @@ +use anyhow::Result; +use wasmtime::{Config, Engine, Linker, Module, Store}; +use witx_bindgen_wasmtime::HostFuture; + +witx_bindgen_wasmtime::export!({ + async: *, + paths: ["./tests/runtime/async_raw/imports.witx"], +}); + +#[derive(Default)] +struct MyImports; + +impl imports::Imports for MyImports { + fn thunk(&mut self) -> HostFuture<()> { + Box::pin(async {}) + } +} + +witx_bindgen_wasmtime::import!({ + async: *, + paths: ["./tests/runtime/async_raw/exports.witx"], +}); + +fn run(wasm: &str) -> Result<()> { + struct Context { + wasi: wasmtime_wasi::WasiCtx, + imports: MyImports, + exports: exports::ExportsData, + } + + let mut config = Config::new(); + config.async_support(true); + let engine = Engine::new(&config)?; + let module = Module::from_file(&engine, wasm)?; + let mut linker = Linker::::new(&engine); + imports::add_to_linker(&mut linker, |cx| &mut cx.imports)?; + wasmtime_wasi::add_to_linker(&mut linker, |cx| &mut cx.wasi)?; + exports::Exports::add_to_linker(&mut linker, |cx| &mut cx.exports)?; + + let rt = tokio::runtime::Runtime::new()?; + rt.block_on(async { + let instantiate = || async { + let mut store = Store::new( + &engine, + Context { + wasi: crate::default_wasi(), + imports: MyImports::default(), + exports: Default::default(), + }, + ); + let instance = linker.instantiate_async(&mut store, &module).await?; + exports::Exports::new(store, &instance, |cx| &mut cx.exports) + }; + + // it's ok to call the completion callback immediately, and the first + // coroutine is always zero. + let exports = instantiate().await?; + exports.complete_immediately().await?; + exports.assert_coroutine_id_zero().await?; + exports.assert_coroutine_id_zero().await?; + + // if the completion callback is never called that's a trap + let err = instantiate() + .await? + .completion_not_called() + .await + .unwrap_err(); + assert!( + err.to_string().contains("completion callback never called"), + "bad error: {}", + err, + ); + + // the completion callback can only be called once + let err = instantiate().await?.complete_twice().await.unwrap_err(); + assert!( + err.to_string().contains("async context not valid"), + "bad error: {}", + err, + ); + + // if the trap happens after the completion callback... something + // happens, for now a trap. + let err = instantiate().await?.complete_then_trap().await.unwrap_err(); + assert!( + err.trap_code() == Some(wasmtime::TrapCode::UnreachableCodeReached), + "bad error: {:?}", + err + ); + + // If a non-async export tries to call the completion callback for + // async exports that's an error. + let err = instantiate() + .await? + .not_async_export_done() + .await + .unwrap_err(); + assert!( + err.to_string().contains("async context not valid"), + "bad error: {}", + err, + ); + + // If a non-async export tries to call an async import that's an error. + let err = instantiate() + .await? + .not_async_calls_async() + .await + .unwrap_err(); + assert!( + err.to_string() + .contains("cannot call async import from non-async export"), + "bad error: {}", + err, + ); + + // The import callback specified cannot be null + let err = instantiate() + .await? + .import_callback_null() + .await + .unwrap_err(); + assert!( + err.to_string().contains("callback was a null function"), + "bad error: {}", + err, + ); + + // The import callback specified must have the right type. + let err = instantiate() + .await? + .import_callback_wrong_type() + .await + .unwrap_err(); + assert!( + err.to_string().contains("type mismatch with parameters"), + "bad error: {}", + err, + ); + + // The import callback specified must point to a valid table index + let err = instantiate() + .await? + .import_callback_bad_index() + .await + .unwrap_err(); + assert!( + err.to_string().contains("invalid function index"), + "bad error: {}", + err, + ); + + // when wasm traps due to one reason or another all future requests to + // execute wasm fail + let exports = instantiate().await?; + assert!(exports.import_callback_null().await.is_err()); + assert!(exports.complete_immediately().await.is_err()); + + Ok(()) + }) +} diff --git a/tests/runtime/async_raw/host.ts b/tests/runtime/async_raw/host.ts new file mode 100644 index 000000000..f759f04fc --- /dev/null +++ b/tests/runtime/async_raw/host.ts @@ -0,0 +1,54 @@ +import { addImportsToImports, Imports } from "./imports.js"; +import { Exports } from "./exports.js"; +import { getWasm, addWasiToImports } from "./helpers.js"; +// @ts-ignore +import * as assert from 'assert'; + +async function run() { + const importObj = {}; + const imports: Imports = { + async thunk() {} + }; + + async function instantiate() { + let instance: WebAssembly.Instance; + addImportsToImports(importObj, imports, name => instance.exports[name]); + const wasi = addWasiToImports(importObj); + + const wasm = new Exports(); + await wasm.instantiate(getWasm(), importObj); + wasi.start(wasm.instance); + instance = wasm.instance; + return wasm; + } + + let wasm = await instantiate(); + await wasm.completeImmediately(); + await wasm.assertCoroutineIdZero(); + await wasm.assertCoroutineIdZero(); + + wasm = await instantiate(); + await assert.rejects(wasm.completionNotCalled(), /blocked coroutine with 0 pending callbacks/); + + wasm = await instantiate(); + await assert.rejects(wasm.completeTwice(), /cannot complete coroutine twice/); + + wasm = await instantiate(); + await assert.rejects(wasm.completeThenTrap(), /unreachable/); + + wasm = await instantiate(); + assert.throws(() => wasm.notAsyncExportDone(), /invalid coroutine index/); + + wasm = await instantiate(); + await assert.rejects(wasm.importCallbackNull(), /table index is a null function/); + + // TODO: this is the wrong error from this, but it's not clear how best to do + // type-checks in JS... + wasm = await instantiate(); + await assert.rejects(wasm.importCallbackWrongType(), /0 pending callbacks/); + + wasm = await instantiate(); + await assert.rejects(wasm.importCallbackBadIndex(), RangeError); +} + +await run() diff --git a/tests/runtime/async_raw/imports.witx b/tests/runtime/async_raw/imports.witx new file mode 100644 index 000000000..c0e1bb5a9 --- /dev/null +++ b/tests/runtime/async_raw/imports.witx @@ -0,0 +1 @@ +thunk: async function() diff --git a/tests/runtime/async_raw/wasm.rs b/tests/runtime/async_raw/wasm.rs new file mode 100644 index 000000000..9140110e3 --- /dev/null +++ b/tests/runtime/async_raw/wasm.rs @@ -0,0 +1,82 @@ +use std::arch::wasm32; + +#[link(wasm_import_module = "canonical_abi")] +extern "C" { + pub fn async_export_done(ctx: i32, ptr: i32); +} + +#[link(wasm_import_module = "imports")] +extern "C" { + pub fn thunk(cb: i32, ptr: i32); +} + +#[no_mangle] +pub extern "C" fn complete_immediately(ctx: i32) { + unsafe { + async_export_done(ctx, 0); + } +} + +#[no_mangle] +pub extern "C" fn completion_not_called(_ctx: i32) {} + +#[no_mangle] +pub extern "C" fn complete_twice(ctx: i32) { + unsafe { + async_export_done(ctx, 0); + async_export_done(ctx, 0); + } +} + +#[no_mangle] +pub extern "C" fn complete_then_trap(ctx: i32) { + unsafe { + async_export_done(ctx, 0); + wasm32::unreachable(); + } +} + +#[no_mangle] +pub extern "C" fn assert_coroutine_id_zero(ctx: i32) { + unsafe { + assert_eq!(ctx, 0); + async_export_done(ctx, 0); + } +} + +#[no_mangle] +pub extern "C" fn not_async_export_done() { + unsafe { + async_export_done(0, 0); + } +} + +#[no_mangle] +pub extern "C" fn not_async_calls_async() { + extern "C" fn callback(_x: i32) {} + unsafe { + thunk(callback as i32, 0); + } +} + +#[no_mangle] +pub extern "C" fn import_callback_null(_cx: i32) { + unsafe { + thunk(0, 0); + } +} + +#[no_mangle] +pub extern "C" fn import_callback_wrong_type(_cx: i32) { + extern "C" fn callback() {} + unsafe { + thunk(callback as i32, 0); + } +} + +#[no_mangle] +pub extern "C" fn import_callback_bad_index(_cx: i32) { + unsafe { + thunk(i32::MAX, 0); + } +} From 62fe3e47f915b38fa7ccc328fbadb26cb886b165 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Thu, 28 Oct 2021 14:50:02 -0700 Subject: [PATCH 5/5] Add new intrinsics for "events" Enables cross-coroutine wakeups primarily. Downside is that the Rust implementation now "bounces" off an import completion callback back onto an event completion callback. --- crates/gen-js/src/lib.rs | 58 ++- crates/gen-wasmtime/src/lib.rs | 22 + crates/rust-wasm/src/futures.rs | 283 +++++++++---- crates/wasmtime/src/futures.rs | 587 +++++++++++++++----------- tests/runtime/async_functions/host.ts | 4 +- tests/runtime/async_raw/host.ts | 4 +- 6 files changed, 602 insertions(+), 356 deletions(-) diff --git a/crates/gen-js/src/lib.rs b/crates/gen-js/src/lib.rs index 44970822f..5d37c69d1 100644 --- a/crates/gen-js/src/lib.rs +++ b/crates/gen-js/src/lib.rs @@ -892,7 +892,17 @@ impl Generator for Js { self.src.js(&format!( " imports.canonical_abi['async_export_done'] = (ctx, ptr) => {{ - {}.async_export_done(ctx, ptr >>> 0) + {0}.async_export_done(ctx, ptr >>> 0) + }}; + imports.canonical_abi['event_new'] = (a, b) => {{ + const callback = {0}.table.get(a); + if (callback === null) {{ + throw new Error('table index is a null function'); + }} + return {0}.event_new(callback, b); + }}; + imports.canonical_abi['event_signal'] = (a, b) => {{ + {0}.event_signal(a, b); }}; ", promises @@ -960,6 +970,15 @@ impl Generator for Js { this._exports = this.instance.exports; "); + if any_async { + let promises = self.intrinsic(Intrinsic::Promises); + // TODO: hardcoding __indirect_function_table + self.src.js(&format!( + "{}.table = this._exports.__indirect_function_table;\n", + promises + )); + } + // Exported resources all get a finalization registry, and we // created them after instantiation so we can pass the raw wasm // export as the destructor callback. @@ -2058,23 +2077,22 @@ impl Bindgen for FunctionBindgen<'_> { }, Instruction::CompletionCallback { .. } => { - // TODO: shouldn't hardcode the function table name, should - // verify the table is present, and should verify the type of - // the function returned. + // TODO: should verify the type of the function + let promises = self.gen.intrinsic(Intrinsic::Promises); self.src.js(&format!( " - const callback = get_export(\"__indirect_function_table\").get({}); - if (callback === null) + const callback = {}.table.get({}); + if (callback === null) {{ throw new Error('table index is a null function'); + }} ", - operands[0] + promises, operands[0] )); results.push("callback".to_string()); } Instruction::ReturnAsyncImport { .. } => { // TODO - self.gen.needs_get_export = true; self.src .js(&format!("{}({});\n", operands[0], operands[1..].join(", "),)); } @@ -2422,6 +2440,7 @@ impl Js { class Promises { constructor() { this.slab = new Slab(); + this.events = new Slab(); this.current = null; } @@ -2497,6 +2516,29 @@ impl Js { this.current = prev; } } + + event_new(callback, data) { + const coroutine = this.slab.get(this.current); + coroutine.pending += 1; + return this.events.insert({ + coroutine: this.current, + callback, + data, + }); + } + + event_signal(idx, arg) { + const { coroutine, callback, data } = this.events.remove(idx); + queueMicrotask(() => { + const wasm = this.slab.maybeGet(coroutine); + if (wasm !== null) { + wasm.pending -= 1; + this.set_current(coroutine, () => { + callback(data, arg); + }); + } + }) + } } export const PROMISES = new Promises(); "), diff --git a/crates/gen-wasmtime/src/lib.rs b/crates/gen-wasmtime/src/lib.rs index 60b85e9db..549bfdbc0 100644 --- a/crates/gen-wasmtime/src/lib.rs +++ b/crates/gen-wasmtime/src/lib.rs @@ -1291,6 +1291,28 @@ impl Generator for Wasmtime { }}) }}, )?; + linker.func_wrap( + \"canonical_abi\", + \"event_new\", + move |mut caller: wasmtime::Caller<'_, T>, cb: u32, data: u32| {{ + witx_bindgen_wasmtime::rt::Async::event_new( + caller, + cb, + data, + ) + }}, + )?; + linker.func_wrap( + \"canonical_abi\", + \"event_signal\", + move |mut caller: wasmtime::Caller<'_, T>, handle: u32, arg: u32| {{ + witx_bindgen_wasmtime::rt::Async::event_signal( + caller, + handle, + arg, + ) + }}, + )?; ", )); } diff --git a/crates/rust-wasm/src/futures.rs b/crates/rust-wasm/src/futures.rs index 286df4ae7..70e4faaae 100644 --- a/crates/rust-wasm/src/futures.rs +++ b/crates/rust-wasm/src/futures.rs @@ -1,10 +1,11 @@ //! Helper library support for `async` witx functions, used for both -use std::cell::RefCell; +use self::event::{Event, Signal}; +use std::cell::{Cell, RefCell}; use std::future::Future; use std::mem; use std::pin::Pin; -use std::rc::{Rc, Weak}; +use std::rc::Rc; use std::sync::Arc; use std::task::*; @@ -19,12 +20,112 @@ pub unsafe extern "C" fn async_export_done(_ctx: i32, _ptr: i32) { panic!("only supported on wasm"); } -struct PollingWaker { - state: RefCell, +/// Runs the `future` provided to completion, polling the future whenever its +/// waker receives a call to `wake`. +pub fn execute(future: Pin>>) { + Task::execute(future) +} + +struct Task { + future: Pin>>, + waker: Arc, +} + +impl Task { + fn execute(future: Pin>>) { + Box::new(Task { + future, + waker: Arc::new(WasmWaker { + state: Cell::new(State::Woken), + }), + }) + .signal() + } +} + +impl Signal for Task { + fn signal(mut self: Box) { + // First, reset our state to `polling` to indicate that we're actively + // polling the future that we own. + let waker = self.waker.clone(); + match waker.state.replace(State::Polling) { + // This shouldn't be possible since if a waiting event is pending + // then we shouldn't be woken up to signal. + State::Waiting(_) => panic!("signaled but event is present"), + + // This also shouldn't be possible since if the previous state were + // polling then we shouldn't be restarting another round of polling. + State::Polling => panic!("poll-in-poll"), + + // This is the expected state, which is to say that we should be + // previously woken with some event having been consumed, which + // left a `Woken` marker here. + State::Woken => {} + } + + // Perform the Rust Dance to poll the future. + let rust_waker = waker.clone().into(); + let mut cx = Context::from_waker(&rust_waker); + match self.future.as_mut().poll(&mut cx) { + // If the future has finished there's nothing else left to do but + // destroy the future, so we do so here through the dtor for `self` + // in an early-return. + Poll::Ready(()) => return, + + // If the future isn't ready then logic below handles the wakeup + // procedure. + Poll::Pending => {} + } + + // Our future isn't ready but we should be scheduled to wait on some + // event from within the future. Configure the state of the waker + // after-the-fact to have an interface-types-provided "event" which, + // when woken, will basically re-invoke this method. + let event = Event::new(self); + match waker.state.replace(State::Waiting(event)) { + // This state shouldn't be possible because we're the only ones + // inserting a `Waiting` state here, so if something else set that + // it's highly unexpected. + State::Waiting(_) => unreachable!(), + + // This is the expected state most of the time where we're replacing + // the `Polling` state that was configured above. This means we've + // switched from polling-to-waiting so we can safely return now and + // wait for our result. + State::Polling => {} + + // This is a slightly tricky state where we received a `wake()` + // while we were polling. In this situation we replace the state + // back to `Woken` and signal the event ourselves. + State::Woken => { + let event = match waker.state.replace(State::Woken) { + State::Waiting(event) => event, + _ => unreachable!(), + }; + event.signal(); + } + } + } +} + +/// This is the internals of the `Waker` that's specific to wasm. +/// +/// For now this is pretty simple where this maintains a state enum where the +/// main interesting state is an "event" that gets a signal to start re-polling +/// the future. This event-based-wakeup has two consequences: +/// +/// * If the `wake()` comes from another Rust coroutine then we'll correctly +/// execute the Rust poll on the original coroutine's context. +/// * If the `wake()` comes from an async import completing then it means the +/// completion callback will do a tiny bit of work to signal the event, and +/// then the real work will happen later when the event's callback is +/// enqueued. +struct WasmWaker { + state: Cell, } enum State { - Waiting(Pin>>), + Waiting(Event), Polling, Woken, } @@ -34,83 +135,33 @@ enum State { // an alternative implementation for threaded WebAssembly when that comes about // to host runtimes off-the-web. #[cfg(not(target_feature = "atomics"))] -unsafe impl Send for PollingWaker {} +unsafe impl Send for WasmWaker {} #[cfg(not(target_feature = "atomics"))] -unsafe impl Sync for PollingWaker {} +unsafe impl Sync for WasmWaker {} -/// Runs the `future` provided to completion, polling the future whenever its -/// waker receives a call to `wake`. -pub fn execute(future: impl Future + 'static) { - let waker = Arc::new(PollingWaker { - state: RefCell::new(State::Waiting(Box::pin(future))), - }); - waker.wake() -} - -impl Wake for PollingWaker { +impl Wake for WasmWaker { fn wake(self: Arc) { - let mut state = self.state.borrow_mut(); - let mut future = match mem::replace(&mut *state, State::Polling) { - // We are the first wake to come in to wake-up this future. This - // means that we need to actually poll the future, so leave the - // `Polling` state in place. - State::Waiting(future) => future, - - // Otherwise the future is either already polling or it was already - // woken while it was being polled, in both instances we reset the - // state back to `Woken` and then we return. This means that the - // future is owned by some previous stack frame and will drive the - // future as necessary. - State::Polling | State::Woken => { - *state = State::Woken; - return; - } - }; - drop(state); - - // Create the futures waker/context from ourselves, used for polling. - let waker = self.clone().into(); - let mut cx = Context::from_waker(&waker); - loop { - match future.as_mut().poll(&mut cx) { - // The future is finished! By returning here we destroy the - // future and release all of its resources. - Poll::Ready(()) => break, - - // The future has work yet-to-do, so continue below. - Poll::Pending => {} - } - - let mut state = self.state.borrow_mut(); - match *state { - // This means that we were not woken while we were polling and - // the state is as it was when we took out the future before. By - // `Pending` being returned at this point we're guaranteed that - // our waker will be woken up at some point in the future, which - // will come look at this future again. This means that we - // simply store our future and return, since this call to `wake` - // is now finished. - State::Polling => { - *state = State::Waiting(future); - break; - } + match self.state.replace(State::Woken) { + // We found a waiting event, yay! Signal that to wake it up and then + // there's nothing much else for us to do. + State::Waiting(event) => event.signal(), - // This means that we received a call to `wake` while we were - // polling. Ideally we'd enqueue some sort of microtask-tick - // here or something like that but for now we just loop around - // and poll again. - State::Woken => {} + // this `wake` happened during the poll of the future itself, which + // is ok and the future will consume our `Woken` status when it's + // done polling. + State::Polling => {} - // This shouldn't be possible since we own the future, and no - // one else should insert another future here. - State::Waiting(_) => unreachable!(), - } + // This is perhaps a concurrent wake where we already woke up the + // main future. That's ok, we're still in the `Woken` state and it's + // still someone else's responsibility to manage wakeups at this + // point. + State::Woken => {} } } } pub struct Oneshot { - inner: Weak>, + inner: Rc>, } pub struct Sender { @@ -130,12 +181,17 @@ enum OneshotState { impl Oneshot { /// Returns a new "oneshot" channel as well as a completion callback. pub fn new() -> (Oneshot, Sender) { + // TODO: this oneshot implementation does not correctly handle "hangups" + // on either the sender or receiver side. This really only works with + // the exact codegen that we have right now and if it's used for + // anything else then this implementation needs to be updated (or this + // should use something off-the-shelf from the ecosystem) let inner = Rc::new(OneshotInner { state: RefCell::new(OneshotState::Start), }); ( Oneshot { - inner: Rc::downgrade(&inner), + inner: inner.clone(), }, Sender { inner }, ) @@ -146,13 +202,7 @@ impl Future for Oneshot { type Output = T; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let inner = match self.inner.upgrade() { - Some(inner) => inner, - // Technically this isn't possible in the initial draft of interface - // types unless there's some serious bug somewhere. - None => panic!("completion callback was canceled"), - }; - let mut state = inner.state.borrow_mut(); + let mut state = self.inner.state.borrow_mut(); match mem::replace(&mut *state, OneshotState::Start) { OneshotState::Done(t) => Poll::Ready(t), OneshotState::Waiting(_) | OneshotState::Start => { @@ -175,12 +225,7 @@ impl Sender { } pub fn send(self, val: T) { - let mut state = self.inner.state.borrow_mut(); - let prev = mem::replace(&mut *state, OneshotState::Done(val)); - // Must `drop` before the `wake` below because waking may induce - // polling which would induce another `borrow_mut` which would - // conflict with this `borrow_mut` otherwise. - drop(state); + let prev = mem::replace(&mut *self.inner.state.borrow_mut(), OneshotState::Done(val)); match prev { // nothing has polled the returned future just yet, so we just @@ -193,16 +238,72 @@ impl Sender { OneshotState::Waiting(waker) => waker.wake(), // Shouldn't be possible, this is the only closure that writes - // `Done` and this can only be invoked once. + // `Done` and this can only be invoked once. Additionally since + // `self` exists we shouldn't be closed yet which is only written in + // `Drop` OneshotState::Done(_) => unreachable!(), } } } -impl Drop for OneshotInner { - fn drop(&mut self) { - if let OneshotState::Waiting(waker) = &*self.state.borrow() { - waker.wake_by_ref(); +mod event { + use std::mem; + + #[cfg(target_arch = "wasm32")] + #[link(wasm_import_module = "canonical_abi")] + extern "C" { + fn event_new(cb: usize, cbdata: usize) -> u32; + fn event_signal(handle: u32, arg: u32); + } + + #[cfg(not(target_arch = "wasm32"))] + unsafe extern "C" fn event_new(_: usize, _: usize) -> u32 { + unreachable!() + } + + #[cfg(not(target_arch = "wasm32"))] + unsafe extern "C" fn event_signal(_: u32, _: u32) { + unreachable!() + } + + pub struct Event(u32); + + pub trait Signal { + fn signal(self: Box); + } + + impl Event { + pub fn new(to_signal: Box) -> Event + where + S: Signal, + { + unsafe { + let to_signal = Box::into_raw(to_signal); + let handle = event_new(signal:: as usize, to_signal as usize); + return Event(handle); + } + + unsafe extern "C" fn signal(data: usize, is_drop: u32) { + let data = Box::from_raw(data as *mut S); + if is_drop == 0 { + data.signal(); + } + } + } + + pub fn signal(self) { + unsafe { + event_signal(self.0, 0); + mem::forget(self); + } + } + } + + impl Drop for Event { + fn drop(&mut self) { + unsafe { + event_signal(self.0, 1); + } } } } diff --git a/crates/wasmtime/src/futures.rs b/crates/wasmtime/src/futures.rs index 936f53098..19a93a402 100644 --- a/crates/wasmtime/src/futures.rs +++ b/crates/wasmtime/src/futures.rs @@ -7,7 +7,9 @@ use std::pin::Pin; use std::sync::{Arc, Weak}; use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedSender}; use tokio::task::JoinHandle; -use wasmtime::{AsContextMut, Caller, Memory, Store, StoreContextMut, Table, Trap}; +use wasmtime::{AsContextMut, Caller, Memory, Store, StoreContextMut, Table, Trap, TypedFunc}; + +const MAX_EVENTS: usize = 1_000; pub struct Async { function_table: Table, @@ -45,6 +47,9 @@ pub struct Async { /// instead of `u32` as slabs do. coroutines: RefCell>, + events: RefCell>, + pending_events: RefCell>, + /// The "currently active" coroutine. /// /// This is used to persist state in the host about what coroutine is @@ -62,6 +67,8 @@ pub struct Async { /// generational index of sorts which prevents accidentally resuing slab /// indices in the `coroutines` array. cur_unique_id: Cell, + + receiver: Receiver>, } /// An "integer" identifier for a coroutine. @@ -86,6 +93,12 @@ enum Message { Cancel(CoroutineId), } +struct Event { + callback: TypedFunc<(u32, u32), ()>, + coroutine: CoroutineId, + data: u32, +} + struct Coroutine { /// A unique ID for this coroutine which is used to ensure that even if this /// coroutine's slab index is reused a `CoroutineId` uniquely points to one @@ -105,9 +118,8 @@ struct Coroutine { /// ASAP via an `abort()` signal on the `JoinHandle`. pending_imports: Slab>, - /// The number of imports that we're waiting on, corresponding to the number - /// of present entries in `pending_imports`. - num_pending_imports: usize, + /// The number of imports or events that we're waiting on. + pending_callbacks: usize, /// A callback to invoke whenever a coroutine's `async_export_done` /// completion callback is invoked. This is used by the host to deserialize @@ -153,13 +165,10 @@ pub type RunStandalone = Box< + Send, >; -impl Async { +impl Async { /// Spawns a new task which will manage async execution of wasm within the /// `store` provided. - pub fn spawn(mut store: Store, function_table: Table) -> AsyncHandle - where - T: Send, - { + pub fn spawn(mut store: Store, function_table: Table) -> AsyncHandle { // This channel is the primary point of communication into the task that // we're going to spawn. This'll be bounded to ensure it doesn't get // overrun, and additionally the sender will be stored in an `Arc` to @@ -174,14 +183,17 @@ impl Async { coroutines: RefCell::new(Coroutines { slab: Slab::default(), }), + events: Default::default(), + pending_events: Default::default(), cur_wasm_coroutine: CoroutineId { slab_index: u32::MAX, unique_id: u64::MAX, }, cur_unique_id: Cell::new(0), + receiver, }; - tokio::spawn(async move { cx.run(&mut store.as_context_mut(), receiver).await }); + tokio::spawn(async move { cx.run(&mut store).await }); AsyncHandle { sender } } @@ -199,7 +211,7 @@ impl Async { .get_mut(&coroutine_id) .ok_or_else(|| Trap::new("cannot call async import from non-async export"))?; let pending_import_id = coroutine.pending_imports.next_id(); - coroutine.num_pending_imports += 1; + coroutine.pending_callbacks += 1; // Note that `tokio::spawn` is used here to allow the `future` for // this import to execute in parallel, not just concurrently. The @@ -231,245 +243,275 @@ impl Async { }) } - async fn run( - &mut self, - store: &mut StoreContextMut<'_, T>, - mut receiver: Receiver>, - ) { - // Infinitely process messages on `receiver` which represent events such as - // requests to invoke an export or completion of an import which results in - // execution of a completion callback. - while let Some(msg) = receiver.recv().await { - let coroutines = self.coroutines.get_mut(); - let (to_execute, coroutine_id) = match msg { - // This message is the start of a new task ("coroutine" in - // interface-types-vernacular) so we allocate a new task in our - // slab and set up its state. - // - // Note that we spawn a "helper" task here to send a message to - // our channel when the `sender` specified here is disconnected. - // That scenario means that this coroutine is cancelled and by - // sending a message into our channel we can start processing - // that. - Message::Execute(run, complete, sender) => { - let unique_id = self.cur_unique_id.get(); - self.cur_unique_id.set(unique_id + 1); - let coroutine_id = coroutines.next_id(unique_id); - let my_sender = self.sender.clone(); - let await_close_sender = sender.clone(); - let cancel_task = tokio::spawn(async move { - await_close_sender.closed().await; - // if the main task is gone one way or another we ignore - // the error here since no one's going to receive it - // anyway and all relevant work should be cancelled. - if let Some(sender) = my_sender.upgrade() { - drop(sender.send(Message::Cancel(coroutine_id)).await); - } - }); - coroutines.insert(Coroutine { - unique_id, - complete: Some(complete), - sender, - pending_imports: Slab::default(), - num_pending_imports: 0, - cancel_task: Some(cancel_task), - }); - (ToExecute::Start(run, coroutine_id.slab_index), coroutine_id) - } + /// Top level run-loop of the task which owns `Async`, typically spawned + /// as a separate task. + async fn run(&mut self, store: &mut Store) { + while self.process_message(store).await { + // ...continue on to the next message... + } + } + + async fn process_message(&mut self, store: &mut Store) -> bool { + let store = &mut store.as_context_mut(); + + // The "highest priority" messages are those of pending events. These + // are processed first before we get to the actual channel queue which + // may block further for events. + if let Some((event, arg)) = self.pending_events.get_mut().pop() { + return self + .execute_coroutine( + event.coroutine, + event.callback.call_async(store, (event.data, arg)), + ) + .await; + } + + // Wait for a message, but if there are no other messages then we're + // done processing messages + let coroutines = self.coroutines.get_mut(); + let msg = match self.receiver.recv().await { + Some(msg) => msg, + None => return false, + }; + match msg { + // This message is the start of a new task ("coroutine" in + // interface-types-vernacular) so we allocate a new task in our + // slab and set up its state. + // + // Note that we spawn a "helper" task here to send a message to + // our channel when the `sender` specified here is disconnected. + // That scenario means that this coroutine is cancelled and by + // sending a message into our channel we can start processing + // that. + Message::Execute(run, complete, sender) => { + let unique_id = self.cur_unique_id.get(); + self.cur_unique_id.set(unique_id + 1); + let coroutine_id = coroutines.next_id(unique_id); + let my_sender = self.sender.clone(); + let await_close_sender = sender.clone(); + let cancel_task = tokio::spawn(async move { + await_close_sender.closed().await; + // if the main task is gone one way or another we ignore + // the error here since no one's going to receive it + // anyway and all relevant work should be cancelled. + if let Some(sender) = my_sender.upgrade() { + drop(sender.send(Message::Cancel(coroutine_id)).await); + } + }); + coroutines.insert(Coroutine { + unique_id, + complete: Some(complete), + sender, + pending_imports: Slab::default(), + pending_callbacks: 1, + cancel_task: Some(cancel_task), + }); + self.execute_coroutine(coroutine_id, run(store, coroutine_id.slab_index)) + .await + } - // This message means that we need to execute `run` specified - // which is a "non blocking"-in-the-coroutine-sense wasm - // function. This is basically "go run that single callback" and - // is currently only used for things like resource destructors. - // These aren't allowed to call blocking functions and a trap is - // generated if they try to call a blocking function (since - // there isn't a coroutine set up). - // - // Note that here we avoid allocating a coroutine entirely since - // this isn't actually a coroutine, which means that any attempt - // to call a blocking function will be met with failure (a - // trap). Additionally note that the actual execution of the - // wasm here is select'd against the closure of the `sender` - // here as well, since if the runtime becomes disinterested in - // the result of this async call we can interrupt and abort the - // wasm. - // - // Finally note that if the wasm completes but we fail to send - // the result of the wasm to the receiver then we ignore the - // error since that was basically a race between wasm exiting - // and the sender being closed. - // - // TODO: should this dropped result/error get logged/processed - // somewhere? - Message::RunNoCoroutine(run, sender) => { - tokio::select! { - r = tls::scope(self, run(store)) => { - let is_trap = r.is_err(); - let _ = sender.send(r); - - // Shut down this reactor if a trap happened because - // the instance is now in an indeterminate state. - if is_trap { - break; - } + // This message means that we need to execute `run` specified + // which is a "non blocking"-in-the-coroutine-sense wasm + // function. This is basically "go run that single callback" and + // is currently only used for things like resource destructors. + // These aren't allowed to call blocking functions and a trap is + // generated if they try to call a blocking function (since + // there isn't a coroutine set up). + // + // Note that here we avoid allocating a coroutine entirely since + // this isn't actually a coroutine, which means that any attempt + // to call a blocking function will be met with failure (a + // trap). Additionally note that the actual execution of the + // wasm here is select'd against the closure of the `sender` + // here as well, since if the runtime becomes disinterested in + // the result of this async call we can interrupt and abort the + // wasm. + // + // Finally note that if the wasm completes but we fail to send + // the result of the wasm to the receiver then we ignore the + // error since that was basically a race between wasm exiting + // and the sender being closed. + // + // TODO: should this dropped result/error get logged/processed + // somewhere? + Message::RunNoCoroutine(run, sender) => { + tokio::select! { + r = tls::scope(self, run(store)) => { + let is_trap = r.is_err(); + let _ = sender.send(r); + + // Shut down this reactor if a trap happened because + // the instance is now in an indeterminate state. + if is_trap { + return false; } - _ = sender.closed() => break, } - continue; + _ = sender.closed() => return false, } + true + } - // This message indicates that an import has completed and - // the completion callback for the wasm must be executed. - // This, plus the serialization of the arguments into wasm - // according to the canonical ABI, is represented by - // `run`. - // - // Note, though, that in some cases we don't actually run - // the completion callback. For example if a previous - // completion callback for this wasm task has failed with a - // trap we don't continue to run completion callbacks for - // the wasm task. This situation is indicated when the - // coroutine is not actually present in our `coroutines` - // list, so we do a lookup here before allowing execution. When - // the coroutine isn't present we simply skip this message which - // will run destructors for any relevant host values. - Message::FinishImport(run, coroutine_id, import_id) => { - let coroutine = match coroutines.get_mut(&coroutine_id) { - Some(c) => c, - None => continue, - }; - coroutine.pending_imports.remove(import_id).unwrap(); - coroutine.num_pending_imports -= 1; - (ToExecute::Callback(run), coroutine_id) - } + // This message indicates that an import has completed and + // the completion callback for the wasm must be executed. + // This, plus the serialization of the arguments into wasm + // according to the canonical ABI, is represented by + // `run`. + // + // Note, though, that in some cases we don't actually run + // the completion callback. For example if a previous + // completion callback for this wasm task has failed with a + // trap we don't continue to run completion callbacks for + // the wasm task. This situation is indicated when the + // coroutine is not actually present in our `coroutines` + // list, so we do a lookup here before allowing execution. When + // the coroutine isn't present we simply skip this message which + // will run destructors for any relevant host values. + Message::FinishImport(run, coroutine_id, import_id) => { + let coroutine = match coroutines.get_mut(&coroutine_id) { + Some(c) => c, + None => return true, + }; + coroutine.pending_imports.remove(import_id).unwrap(); + self.execute_coroutine(coroutine_id, run(&mut store.into())) + .await + } - // This message indicates that the specified coroutine has been - // cancelled, meaning that the sender which would send back the - // result of the coroutine is now a closed channel that we can - // no longer send a message along. Our response to this is to - // remove the coroutine, and its destructor will trigger further - // cancellation if necessary. - // - // Note that this message may race with the actual completion of - // the coroutine so we don't assert that the ID specified here - // is actually in our list. If a coroutine is removed though we - // assume that the wasm is now in an indeterminate state which - // results in aborting this reactor task. If nothing is removed - // then we assume the race was properly resolved and we skip - // this message. - Message::Cancel(coroutine_id) => { - if coroutines.remove(&coroutine_id).is_some() { - break; - } - continue; + // This message indicates that the specified coroutine has been + // cancelled, meaning that the sender which would send back the + // result of the coroutine is now a closed channel that we can + // no longer send a message along. Our response to this is to + // remove the coroutine, and its destructor will trigger further + // cancellation if necessary. + // + // Note that this message may race with the actual completion of + // the coroutine so we don't assert that the ID specified here + // is actually in our list. If a coroutine is removed though we + // assume that the wasm is now in an indeterminate state which + // results in aborting this reactor task. If nothing is removed + // then we assume the race was properly resolved and we skip + // this message. + Message::Cancel(coroutine_id) => { + if coroutines.remove(&coroutine_id).is_some() { + return false; } - }; + true + } + } + } - // Actually execute the WebAssembly callback. The call to - // `to_execute.run` here is what will actually execute WebAssembly - // asynchronously, and note that it's also executed within a - // `tls::scope` to ensure that the `tls::with` function will work - // for the duration of the future. + async fn execute_coroutine( + &mut self, + coroutine_id: CoroutineId, + wasm_execution: impl Future>, + ) -> bool { + // Actually execute the WebAssembly callback. The call to + // `to_execute.run` here is what will actually execute WebAssembly + // asynchronously, and note that it's also executed within a + // `tls::scope` to ensure that the `tls::with` function will work + // for the duration of the future. + // + // Also note, though, that we want to be able to cancel the + // execution of this WebAssembly if the caller becomes disinterested + // in the result. This happens by using the `closed()` method on the + // channel back to the sender, and if that happens we abort wasm + // entirely and abort the whole coroutine by removing it later. + // + // If this wasm operations is aborted then we exit this loop + // entirely and tear down this reactor task. That triggers + // cancellation of all spawned sub-tasks and sibling coroutines, and + // the rationale for this is that we zapped wasm while it was + // executing so it's now in an indeterminate state and not one that + // we can resume. + // + // TODO: this is a `clone()`-per-callback which is probably cheap, + // but this is also a sort of wonky setup so this may wish to change + // in the future. + let coroutine = self.coroutines.get_mut().get_mut(&coroutine_id).unwrap(); + coroutine.pending_callbacks -= 1; + let cancel_signal = coroutine.sender.clone(); + let prev_coroutine_id = mem::replace(&mut self.cur_wasm_coroutine, coroutine_id); + let result = tokio::select! { + r = tls::scope(self, wasm_execution) => r, + _ = cancel_signal.closed() => return false, + }; + self.cur_wasm_coroutine = prev_coroutine_id; + + let coroutines = self.coroutines.get_mut(); + let coroutine = coroutines.get_mut(&coroutine_id).unwrap(); + if let Err(trap) = result { + // Our WebAssembly callback trapped. That means that this + // entire coroutine is now in a failure state. No further + // wasm callbacks will be invoked and the coroutine is + // removed from out internal list to invoke the failure + // callback, informing what trap caused the failure. // - // Also note, though, that we want to be able to cancel the - // execution of this WebAssembly if the caller becomes disinterested - // in the result. This happens by using the `closed()` method on the - // channel back to the sender, and if that happens we abort wasm - // entirely and abort the whole coroutine by removing it later. + // Note that this reopens `coroutine_id.slab_index` to get + // possibly reused, intentionally so, which is why + // `CoroutineId` is a form of generational ID which is + // resilient to this form of reuse. In other words when we + // remove the result here if in the future a pending import + // for this coroutine completes we'll simply discard the + // message. // - // If this wasm operations is aborted then we exit this loop - // entirely and tear down this reactor task. That triggers - // cancellation of all spawned sub-tasks and sibling coroutines, and - // the rationale for this is that we zapped wasm while it was - // executing so it's now in an indeterminate state and not one that - // we can resume. + // Any error in sending the trap along the coroutine's channel + // is ignored since we can race with the coroutine getting + // dropped. // - // TODO: this is a `clone()`-per-callback which is probably cheap, - // but this is also a sort of wonky setup so this may wish to change - // in the future. - let cancel_signal = coroutines.get_mut(&coroutine_id).unwrap().sender.clone(); - let prev_coroutine_id = mem::replace(&mut self.cur_wasm_coroutine, coroutine_id); - let result = tokio::select! { - r = tls::scope(self, to_execute.run(store)) => r, - _ = cancel_signal.closed() => break, - }; - self.cur_wasm_coroutine = prev_coroutine_id; - - let coroutines = self.coroutines.get_mut(); - let coroutine = coroutines.get_mut(&coroutine_id).unwrap(); - if let Err(trap) = result { - // Our WebAssembly callback trapped. That means that this - // entire coroutine is now in a failure state. No further - // wasm callbacks will be invoked and the coroutine is - // removed from out internal list to invoke the failure - // callback, informing what trap caused the failure. - // - // Note that this reopens `coroutine_id.slab_index` to get - // possibly reused, intentionally so, which is why - // `CoroutineId` is a form of generational ID which is - // resilient to this form of reuse. In other words when we - // remove the result here if in the future a pending import - // for this coroutine completes we'll simply discard the - // message. - // - // Any error in sending the trap along the coroutine's channel - // is ignored since we can race with the coroutine getting - // dropped. - // - // TODO: should the trap still be sent somewhere? Is this ok to - // simply ignore? - // - // Finally we exit the reactor in this case because traps - // typically represent fatal conditions for wasm where we can't - // really resume since it may be in an indeterminate state (wasm - // can't handle traps itself), so after we inform the original - // coroutine of the original trap we break out and cancel all - // further execution. - let coroutine = coroutines.remove(&coroutine_id).unwrap(); - let _ = coroutine.sender.send(Err(trap)); - break; - } else if coroutine.num_pending_imports == 0 { - // Our wasm callback succeeded, and there are no pending - // imports for this coroutine. - // - // In this state it means that the coroutine has completed - // since no further work can possibly happen for the - // coroutine. This means that we can safely remove it from - // our internal list. - // - // If the coroutine's completion wasn't ever signaled, - // however, then that indicates a bug in the wasm code - // itself. This bug is translated into a trap which will get - // reported to the caller to inform the original invocation - // of the export that the result of the coroutine never - // actually came about. - // - // Note that like above a failure to send a trap along the - // channel is ignored since we raced with the caller becoming - // disinterested in the result which is fine to happen at any - // time. - // - // TODO: should the trap still be sent somewhere? Is this ok to - // simply ignore? - // - // TODO: should this tear down the reactor as well, despite it - // being a synthetically created trap? - let coroutine = coroutines.remove(&coroutine_id).unwrap(); - if coroutine.complete.is_some() { - let _ = coroutine - .sender - .send(Err(Trap::new("completion callback never called"))); - } - } else { - // Our wasm callback succeeded, and there are pending - // imports for this coroutine. - // - // This means that the coroutine isn't finished yet so we - // simply turn the loop and wait for something else to - // happen. We'll next be executing WebAssembly when one of - // the coroutine's imports finish. + // TODO: should the trap still be sent somewhere? Is this ok to + // simply ignore? + // + // Finally we exit the reactor in this case because traps + // typically represent fatal conditions for wasm where we can't + // really resume since it may be in an indeterminate state (wasm + // can't handle traps itself), so after we inform the original + // coroutine of the original trap we break out and cancel all + // further execution. + let coroutine = coroutines.remove(&coroutine_id).unwrap(); + let _ = coroutine.sender.send(Err(trap)); + return false; + } else if coroutine.pending_callbacks == 0 { + // Our wasm callback succeeded, and there are no pending + // imports for this coroutine. + // + // In this state it means that the coroutine has completed + // since no further work can possibly happen for the + // coroutine. This means that we can safely remove it from + // our internal list. + // + // If the coroutine's completion wasn't ever signaled, + // however, then that indicates a bug in the wasm code + // itself. This bug is translated into a trap which will get + // reported to the caller to inform the original invocation + // of the export that the result of the coroutine never + // actually came about. + // + // Note that like above a failure to send a trap along the + // channel is ignored since we raced with the caller becoming + // disinterested in the result which is fine to happen at any + // time. + // + // TODO: should the trap still be sent somewhere? Is this ok to + // simply ignore? + // + // TODO: should this tear down the reactor as well, despite it + // being a synthetically created trap? + let coroutine = coroutines.remove(&coroutine_id).unwrap(); + if coroutine.complete.is_some() { + let _ = coroutine + .sender + .send(Err(Trap::new("completion callback never called"))); } + } else { + // Our wasm callback succeeded, and there are pending + // imports for this coroutine. + // + // This means that the coroutine isn't finished yet so we + // simply turn the loop and wait for something else to + // happen. We'll next be executing WebAssembly when one of + // the coroutine's imports finish. } + + true } pub async fn async_export_done( @@ -516,6 +558,63 @@ impl Async { }) } + /// Implementation of the `event_new` canonical ABI intrinsic + pub fn event_new(mut caller: Caller<'_, T>, cb: u32, data: u32) -> Result { + Self::with(|cx| { + // First up validate `cb` to ensure it's actually a valid wasm + // callback to have given us. + let callback = cx + .function_table + .get(&mut caller, cb) + .ok_or_else(|| Trap::new("out-of bounds function index"))? + .funcref() + .ok_or_else(|| Trap::new("not a funcref table"))? + .ok_or_else(|| Trap::new("callback cannot be null"))? + .typed(&caller)?; + + // Next record that there's a pending callback because the coroutine + // is now possibly blocked on this event. + cx.coroutines + .borrow_mut() + .get_mut(&cx.cur_wasm_coroutine) + .unwrap() + .pending_callbacks += 1; + + // And here the event data is saved and returned back to wasm as an + // index. + Ok(cx.events.borrow_mut().insert(Event { + callback, + coroutine: cx.cur_wasm_coroutine, + data, + })) + }) + } + + /// Implementation of the `event_signal` canonical ABI intrinsic + pub fn event_signal(_caller: Caller<'_, T>, event: u32, arg: u32) -> Result<(), Trap> { + Self::with(|cx| { + // Validate that `event` is valid for this wasm module and then + // enqueue the event into an intenral list to get processed after + // this wasm is finished executing. + // + // Note that we don't decrement the pending imports count here but + // rather wait for the pending event message to get processed to + // actually decrement the count, lest the coroutine accidentally be + // declared done early. + let event = cx + .events + .borrow_mut() + .remove(event) + .ok_or_else(|| Trap::new("invalid event index"))?; + let mut pending_events = cx.pending_events.borrow_mut(); + if pending_events.len() >= MAX_EVENTS { + return Err(Trap::new("too many events created")); + } + pending_events.push((event, arg)); + Ok(()) + }) + } + // TODO: this is a pretty bad interface to manage the table with... pub fn function_table() -> Table { Self::with(|cx| cx.function_table) @@ -579,20 +678,6 @@ impl Drop for Coroutine { } } -enum ToExecute { - Start(Start, u32), - Callback(Callback), -} - -impl ToExecute { - async fn run(self, store: &mut StoreContextMut<'_, T>) -> Result<(), Trap> { - match self { - ToExecute::Start(cb, val) => cb(store, val).await, - ToExecute::Callback(cb) => cb(store).await, - } - } -} - pub struct AsyncHandle { sender: Arc>>, } diff --git a/tests/runtime/async_functions/host.ts b/tests/runtime/async_functions/host.ts index 038ebda88..880288992 100644 --- a/tests/runtime/async_functions/host.ts +++ b/tests/runtime/async_functions/host.ts @@ -80,14 +80,12 @@ async function run() { throw new Error('unsupported'); }, }; - let instance: WebAssembly.Instance; - addImportsToImports(importObj, imports, name => instance.exports[name]); + addImportsToImports(importObj, imports); const wasi = addWasiToImports(importObj); const wasm = new Exports(); await wasm.instantiate(getWasm(), importObj); wasi.start(wasm.instance); - instance = wasm.instance; const initBytes = wasm.allocatedBytes(); console.log("calling initial async function"); diff --git a/tests/runtime/async_raw/host.ts b/tests/runtime/async_raw/host.ts index f759f04fc..0487114d6 100644 --- a/tests/runtime/async_raw/host.ts +++ b/tests/runtime/async_raw/host.ts @@ -11,14 +11,12 @@ async function run() { }; async function instantiate() { - let instance: WebAssembly.Instance; - addImportsToImports(importObj, imports, name => instance.exports[name]); + addImportsToImports(importObj, imports); const wasi = addWasiToImports(importObj); const wasm = new Exports(); await wasm.instantiate(getWasm(), importObj); wasi.start(wasm.instance); - instance = wasm.instance; return wasm; }