diff --git a/Cargo.lock b/Cargo.lock index d2ecdfad1..3668b99da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1457,6 +1457,29 @@ dependencies = [ "once_cell", ] +[[package]] +name = "tokio" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2c2416fdedca8443ae44b4527de1ea633af61d8f7169ffa6e72c5b53d24efcc" +dependencies = [ + "autocfg", + "num_cpus", + "pin-project-lite", + "tokio-macros", +] + +[[package]] +name = "tokio-macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2dd85aeaba7b68df939bd357c6afb36c87951be9e80bf9c859f2fc3e9fca0fd" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "toml" version = "0.5.8" @@ -2119,6 +2142,7 @@ dependencies = [ "heck", "structopt", "test-helpers", + "tokio", "wasmtime", "wasmtime-wasi", "witx-bindgen-gen-core", @@ -2162,6 +2186,7 @@ dependencies = [ "async-trait", "bitflags", "thiserror", + "tokio", "tracing", "wasmtime", "witx-bindgen-wasmtime-impl", diff --git a/crates/gen-js/src/lib.rs b/crates/gen-js/src/lib.rs index 3cee7e2a9..5d37c69d1 100644 --- a/crates/gen-js/src/lib.rs +++ b/crates/gen-js/src/lib.rs @@ -75,7 +75,6 @@ enum Intrinsic { Utf8EncodedLen, Slab, Promises, - WithCurrentPromise, } impl Intrinsic { @@ -102,7 +101,6 @@ impl Intrinsic { Intrinsic::Utf8EncodedLen => "UTF8_ENCODED_LEN", Intrinsic::Slab => "Slab", Intrinsic::Promises => "PROMISES", - Intrinsic::WithCurrentPromise => "with_current_promise", } } } @@ -603,10 +601,10 @@ impl Generator for Js { self.src.js(&src.js); if func.is_async { - // Note that `catch_closure` here is defined by the `CallInterface` - // instruction. - self.src.js("}, catch_closure);\n"); // `.then` block - self.src.js("});\n"); // `with_current_promise` block. + self.src.js("};\n"); // `const complete = ...` block + let promises = self.intrinsic(Intrinsic::Promises); + self.src.js(&promises); + self.src.js(".spawn_import(promise, complete);\n"); } self.src.js("}"); @@ -894,7 +892,17 @@ impl Generator for Js { self.src.js(&format!( " imports.canonical_abi['async_export_done'] = (ctx, ptr) => {{ - {}.remove(ctx)(ptr >>> 0) + {0}.async_export_done(ctx, ptr >>> 0) + }}; + imports.canonical_abi['event_new'] = (a, b) => {{ + const callback = {0}.table.get(a); + if (callback === null) {{ + throw new Error('table index is a null function'); + }} + return {0}.event_new(callback, b); + }}; + imports.canonical_abi['event_signal'] = (a, b) => {{ + {0}.event_signal(a, b); }}; ", promises @@ -962,6 +970,15 @@ impl Generator for Js { this._exports = this.instance.exports; "); + if any_async { + let promises = self.intrinsic(Intrinsic::Promises); + // TODO: hardcoding __indirect_function_table + self.src.js(&format!( + "{}.table = this._exports.__indirect_function_table;\n", + promises + )); + } + // Exported resources all get a finalization registry, and we // created them after instantiation so we can pass the raw wasm // export as the destructor callback. @@ -1936,20 +1953,10 @@ impl Bindgen for FunctionBindgen<'_> { results: wasm_results, } => { self.bind_results(wasm_results.len(), results); - let promises = self.gen.intrinsic(Intrinsic::Promises); - self.src.js(&format!( - "\ - await new Promise((resolve, reject) => {{ - const promise_ctx = {promises}.insert(val => {{ - if (typeof val !== 'number') - return reject(val); - resolve(\ - ", - promises = promises - )); + self.src.js("await new Promise((resolve, reject) => {\n"); if wasm_results.len() > 0 { - self.src.js("["); + self.src.js("resolve = val => {\nresolve(["); let operands = &["val".to_string()]; let mut results = Vec::new(); for (i, result) in wasm_results.iter().enumerate() { @@ -1965,16 +1972,12 @@ impl Bindgen for FunctionBindgen<'_> { self.load(method, (i * 8) as i32, operands, &mut results); self.src.js(&results.pop().unwrap()); } - self.src.js("]"); + self.src.js("])\n};\n"); // `resolve(...)` } - // Finish the blocks from above - self.src.js(");\n"); // `resolve(...)` - self.src.js("});\n"); // `promises.insert(...)` - - let with = self.gen.intrinsic(Intrinsic::WithCurrentPromise); - self.src.js(&with); - self.src.js("(promise_ctx, _prev => {\n"); + let promises = self.gen.intrinsic(Intrinsic::Promises); + self.src.js(&promises); + self.src.js(".spawn(resolve, reject, promise_ctx => {\n"); self.src.js(&self.src_object); self.src.js("._exports['"); self.src.js(&name); @@ -1984,7 +1987,7 @@ impl Bindgen for FunctionBindgen<'_> { self.src.js(", "); } self.src.js("promise_ctx);\n"); - self.src.js("});\n"); // call to `with` + self.src.js("});\n"); // call to `spawn` self.src.js("});\n"); // `await new Promise(...)` } @@ -2037,16 +2040,10 @@ impl Bindgen for FunctionBindgen<'_> { }; if func.is_async { - let with = self.gen.intrinsic(Intrinsic::WithCurrentPromise); - let promises = self.gen.intrinsic(Intrinsic::Promises); - self.src.js(&with); - self.src.js("(null, cur_promise => {\n"); - self.src.js(&format!( - "const catch_closure = e => {}.remove(cur_promise)(e);\n", - promises - )); + self.src.js("const promise = "); call(self); - self.src.js(".then(e => {\n"); + self.src.js(";\n"); + self.src.js("const complete = e => {\n"); if func.results.len() > 0 { bind_results(self); self.src.js("e;\n"); @@ -2079,31 +2076,25 @@ impl Bindgen for FunctionBindgen<'_> { } }, - Instruction::ReturnAsyncImport { .. } => { - // When we reenter webassembly successfully that means that the - // host's promise resolved without exception. Take the current - // promise index saved as part of `CallInterface` and update the - // `CUR_PROMISE` global with what's currently being executed. - // This'll get reset once the wasm returns again. - // - // Note that the name `cur_promise` used here is introduced in - // the `CallInterface` codegen above in the closure for - // `with_current_promise` which we're using here. - // - // TODO: hardcoding `__indirect_function_table` and no help if - // it's not actually defined. - self.gen.needs_get_export = true; - let with = self.gen.intrinsic(Intrinsic::WithCurrentPromise); + Instruction::CompletionCallback { .. } => { + // TODO: should verify the type of the function + let promises = self.gen.intrinsic(Intrinsic::Promises); self.src.js(&format!( - "\ - {with}(cur_promise, _prev => {{ - get_export(\"__indirect_function_table\").get({})({}); - }}); + " + const callback = {}.table.get({}); + if (callback === null) {{ + throw new Error('table index is a null function'); + }} ", - operands[0], - operands[1..].join(", "), - with = with, + promises, operands[0] )); + results.push("callback".to_string()); + } + + Instruction::ReturnAsyncImport { .. } => { + // TODO + self.src + .js(&format!("{}({});\n", operands[0], operands[1..].join(", "),)); } Instruction::I32Load { offset } => self.load("getInt32", *offset, operands, results), @@ -2390,12 +2381,19 @@ impl Js { } get(idx) { - if (idx >= this.list.length) + const ret = this.maybeGet(idx); + if (ret === null) throw new RangeError('handle index not valid'); + return ret; + } + + maybeGet(idx) { + if (idx >= this.list.length) + return null; const slot = this.list[idx]; if (slot.next === -1) return slot.val; - throw new RangeError('handle index not valid'); + return null; } remove(idx) { @@ -2409,18 +2407,140 @@ impl Js { } "), - Intrinsic::Promises => self.src.js("export const PROMISES = new Slab();\n"), - Intrinsic::WithCurrentPromise => self.src.js(" - let CUR_PROMISE = null; - export function with_current_promise(val, closure) { - const prev = CUR_PROMISE; - CUR_PROMISE = val; - try { - closure(prev); - } finally { - CUR_PROMISE = prev; + Intrinsic::Promises => self.src.js(" + class Coroutine { + constructor(resolve, reject) { + this.resolve = resolve; + this.reject = reject; + this.completion = null; + this.pending = 0; + } + + complete(val) { + if (this.completion === null) { + this.completion = val; + } else { + throw new Error('cannot complete coroutine twice'); + } + } + + check_completion() { + if (this.completion === null) { + if (this.pending === 0) { + throw new Error('blocked coroutine with 0 pending callbacks'); + } + return false; + } else { + this.resolve(this.completion); + return true; + } + } + } + + class Promises { + constructor() { + this.slab = new Slab(); + this.events = new Slab(); + this.current = null; + } + + spawn(resolve, reject, callback) { + const idx = this.slab.insert(new Coroutine(resolve, reject)); + this.set_current(idx, () => callback(idx)); + } + + // called when `promise` is the result of an async host + // call, where `complete` should be called on `.then(..)` + // after the promise is finished. + spawn_import(promise, complete) { + const current = this.current; + if (current === null) { + throw new Error('cannot call async import in sync export'); + } + const coroutine = this.slab.get(current); + coroutine.pending += 1; + promise.then( + e => { + coroutine.pending -= 1; + this.set_current(current, () => complete(e)); + }, + // on error this error is carried over to the + // original coroutine, if it's still present. The + // coroutine may already have failed due to some + // other reason in which case we continue to + // propagate this error in the hopes that someone + // else will log uncaught exceptions or similar. + e => { + const coroutine = this.slab.maybeGet(current); + if (coroutine !== null) { + this.slab.remove(current).reject(e); + } else { + throw e; + } + }, + ) + } + + // implementation of the corresponding wasm intrinsic, + // validates `id` and `val` internally + async_export_done(id, val) { + const coroutine = this.slab.maybeGet(id); + if (coroutine !== null) { + coroutine.complete(val); + } else { + throw new RangeError('invalid coroutine index'); + } + } + + set_current(val, closure) { + // assert that the coroutine is still present + this.slab.get(val); + const prev = this.current; + this.current = val; + try { + // execute some wasm code + closure(); + // if it finished successfully then determine if + // the coroutine has completed, or if there are 0 + // pending callbacks and it's not completed an + // exception is thrown here. + if (this.slab.get(val).check_completion()) { + this.slab.remove(val); + } + } catch (e) { + // any errors immediately result in aborting the + // coroutine and rejection of the coroutine's + // promise with the same error. + this.slab.remove(val).reject(e); + } finally { + this.current = prev; + } + } + + event_new(callback, data) { + const coroutine = this.slab.get(this.current); + coroutine.pending += 1; + return this.events.insert({ + coroutine: this.current, + callback, + data, + }); + } + + event_signal(idx, arg) { + const { coroutine, callback, data } = this.events.remove(idx); + queueMicrotask(() => { + const wasm = this.slab.maybeGet(coroutine); + if (wasm !== null) { + wasm.pending -= 1; + this.set_current(coroutine, () => { + callback(data, arg); + }); + } + }) } } + export const PROMISES = new Promises(); "), } } diff --git a/crates/gen-rust-wasm/src/lib.rs b/crates/gen-rust-wasm/src/lib.rs index 826d20b95..20e857ea2 100644 --- a/crates/gen-rust-wasm/src/lib.rs +++ b/crates/gen-rust-wasm/src/lib.rs @@ -1461,6 +1461,7 @@ impl Bindgen for FunctionBindgen<'_> { operands[0], operands[1] )); } + Instruction::CompletionCallback { .. } => unreachable!(), Instruction::ReturnAsyncImport { .. } => unreachable!(), Instruction::I32Load { offset } => { diff --git a/crates/gen-spidermonkey/src/lib.rs b/crates/gen-spidermonkey/src/lib.rs index 32935abe2..6aed957bc 100644 --- a/crates/gen-spidermonkey/src/lib.rs +++ b/crates/gen-spidermonkey/src/lib.rs @@ -2176,6 +2176,7 @@ impl witx2::abi::Bindgen for Bindgen<'_, '_> { } witx2::abi::Instruction::ReturnAsyncExport { .. } => todo!(), + witx2::abi::Instruction::CompletionCallback { .. } => todo!(), witx2::abi::Instruction::ReturnAsyncImport { .. } => todo!(), witx2::abi::Instruction::Witx { instr: _ } => { diff --git a/crates/gen-wasmtime/Cargo.toml b/crates/gen-wasmtime/Cargo.toml index 8a4660e46..d20519957 100644 --- a/crates/gen-wasmtime/Cargo.toml +++ b/crates/gen-wasmtime/Cargo.toml @@ -20,6 +20,7 @@ test-helpers = { path = '../test-helpers', features = ['witx-bindgen-gen-wasmtim wasmtime = "0.30.0" wasmtime-wasi = "0.30.0" witx-bindgen-wasmtime = { path = '../wasmtime', features = ['tracing', 'async'] } +tokio = { version = "1.0", features = ['rt', 'sync', 'macros', 'rt-multi-thread', 'time'] } [features] old-witx-compat = ['witx-bindgen-gen-core/old-witx-compat'] diff --git a/crates/gen-wasmtime/src/lib.rs b/crates/gen-wasmtime/src/lib.rs index 2863b7641..549bfdbc0 100644 --- a/crates/gen-wasmtime/src/lib.rs +++ b/crates/gen-wasmtime/src/lib.rs @@ -38,6 +38,7 @@ pub struct Wasmtime { trait_name: String, has_preview1_dtor: bool, sizes: SizeAlign, + any_async_func: bool, } enum NeededFunction { @@ -46,7 +47,7 @@ enum NeededFunction { } struct Import { - is_async: bool, + wrap_async: bool, name: String, trait_signature: String, num_wasm_params: usize, @@ -84,6 +85,13 @@ pub struct Opts { pub custom_error: bool, } +// TODO: with the introduction of `async function()` to witx this is no longer a +// good name. The purpose of this configuration is "should the wasmtime function +// be invoked with `call_async`" which is sort of a different form of async. +// Async with wasmtime involves stack switching and with fuel enables +// preemption, but async witx functions don't actually use async host functions +// as-defined-by-wasmtime. All that to say that this probably needs a better +// name. #[derive(Debug, Clone)] pub enum Async { None, @@ -208,6 +216,18 @@ impl Wasmtime { self.needs_custom_error_to_trap = true; FunctionRet::CustomToTrap } + + fn rebind_host(&self, _iface: &Interface) -> Option { + let mut rebind = String::new(); + if self.all_needed_handles.len() > 0 { + rebind.push_str("_tables, "); + } + if rebind != "" { + Some(format!("let (host, {}) = host;\n", rebind)) + } else { + None + } + } } impl RustGenerator for Wasmtime { @@ -216,6 +236,10 @@ impl RustGenerator for Wasmtime { // The default here is that only leaf values can be borrowed because // otherwise lists and such need to be copied into our own memory. TypeMode::LeafBorrowed("'a") + } else if self.any_async_func { + // Once `async` functions are in play then there's a task spawned + // that owns the reactor, and this means + TypeMode::Owned } else { // When we're calling wasm exports, however, there's no need to take // any ownership of anything from the host so everything is borrowed @@ -366,11 +390,15 @@ impl Generator for Wasmtime { self.types.analyze(iface); self.in_import = dir == Direction::Import; self.trait_name = iface.name.to_camel_case(); + self.src.push_str("#[allow(unused_imports)]\n"); + self.src.push_str("#[allow(unused_variables)]\n"); + self.src.push_str("#[allow(unused_mut)]\n"); self.src .push_str(&format!("pub mod {} {{\n", iface.name.to_snake_case())); self.src - .push_str("#[allow(unused_imports)]\nuse witx_bindgen_wasmtime::{wasmtime, anyhow};\n"); + .push_str("use witx_bindgen_wasmtime::{wasmtime, anyhow};\n"); self.sizes.fill(dir, iface); + self.any_async_func = iface.functions.iter().any(|f| f.is_async); } fn type_record( @@ -499,6 +527,15 @@ impl Generator for Wasmtime { let tyname = name.to_camel_case(); self.rustdoc(&iface.resources[ty].docs); self.src.push_str("#[derive(Debug)]\n"); + // TODO: for now in an async environment all of these handles are taken + // by-value in functions whereas in non-async environments everything is + // taken by reference except for destructors. This means that the + // take-by-ownership `drop` function is less meaningful in an async + // environment. This seems like a reasonable-ish way to manage this for + // now but this probably wants a better solution long-term. + if self.any_async_func { + self.src.push_str("#[derive(Clone, Copy)]\n"); + } self.src.push_str(&format!( "pub struct {}(witx_bindgen_wasmtime::rt::ResourceIndex);\n", tyname @@ -573,7 +610,6 @@ impl Generator for Wasmtime { // } fn export(&mut self, iface: &Interface, func: &Function) { - assert!(!func.is_async, "async not supported yet"); let prev = mem::take(&mut self.src); let is_dtor = self.types.is_preview1_dtor_func(func); @@ -621,7 +657,7 @@ impl Generator for Wasmtime { let mut fnsig = FnSig::default(); fnsig.private = true; - fnsig.async_ = self.opts.async_.includes(&func.name); + fnsig.async_ = self.opts.async_.includes(&func.name) && !func.is_async; fnsig.self_arg = Some(self_arg); self.print_docs_and_params( iface, @@ -635,20 +671,21 @@ impl Generator for Wasmtime { ); // The Rust return type may differ from the wasm return type based on // the `custom_error` configuration of this code generator. + self.push_str(" -> "); + if func.is_async { + self.push_str("std::pin::Pin "); - self.print_results(iface, func); - } + self.print_results(iface, func); } FunctionRet::CustomToTrap => { - self.push_str(" -> Result<"); + self.push_str("Result<"); self.print_results(iface, func); self.push_str(", Self::Error>"); } FunctionRet::CustomToError { ok, .. } => { - self.push_str(" -> Result<"); + self.push_str("Result<"); match ok { Some(ty) => self.print_ty(iface, &ty, TypeMode::Owned), None => self.push_str("()"), @@ -656,6 +693,9 @@ impl Generator for Wasmtime { self.push_str(", Self::Error>"); } } + if func.is_async { + self.push_str("> + Send>>"); + } self.in_trait = false; let trait_signature = mem::take(&mut self.src).into(); @@ -679,7 +719,9 @@ impl Generator for Wasmtime { // // If none of that happens, then this is fine to be sync because // everything is sync. - let is_async = if async_intrinsic_called || self.opts.async_.includes(&func.name) { + let finish_async_block = if !func.is_async + && (async_intrinsic_called || self.opts.async_.includes(&func.name)) + { self.src.push_str("Box::new(async move {\n"); true } else { @@ -716,7 +758,7 @@ impl Generator for Wasmtime { if needs_memory || needs_borrow_checker { self.src - .push_str("let memory = &get_memory(&mut caller, \"memory\")?;\n"); + .push_str("let memory = get_memory(&mut caller, \"memory\")?;\n"); self.needs_get_memory = true; } @@ -729,15 +771,19 @@ impl Generator for Wasmtime { } else { self.src.push_str("let host = get(caller.data_mut());\n"); } - - if self.all_needed_handles.len() > 0 { - self.src.push_str("let (host, _tables) = host;\n"); + if let Some(rebind) = self.rebind_host(iface) { + self.src.push_str(&rebind); } self.src.push_str(&String::from(src)); - if is_async { - self.src.push_str("})\n"); + if func.is_async { + self.src.push_str("})?; // finish `spawn_import`\n"); + self.src.push_str("Ok(())\n") + } + + if finish_async_block { + self.src.push_str("}) // end `Box::new(async move { ...`\n"); } self.src.push_str("}"); let closure = mem::replace(&mut self.src, prev).into(); @@ -746,7 +792,7 @@ impl Generator for Wasmtime { .entry(iface.name.to_string()) .or_insert(Vec::new()) .push(Import { - is_async, + wrap_async: finish_async_block, num_wasm_params: sig.params.len(), name: func.name.to_string(), closure, @@ -755,17 +801,26 @@ impl Generator for Wasmtime { } fn import(&mut self, iface: &Interface, func: &Function) { - assert!(!func.is_async, "async not supported yet"); let prev = mem::take(&mut self.src); // If anything is asynchronous on exports then everything must be // asynchronous, Wasmtime can't intermix async and sync calls because // it's unknown whether the wasm module will make an async host call. - let is_async = !self.opts.async_.is_none(); + let is_async = !self.opts.async_.is_none() || func.is_async; let mut sig = FnSig::default(); - sig.async_ = is_async; - sig.self_arg = Some("&self, mut caller: impl wasmtime::AsContextMut".to_string()); - self.print_docs_and_params(iface, func, TypeMode::AllBorrowed("'_"), &sig); + sig.async_ = is_async || self.any_async_func; + if self.any_async_func { + sig.self_arg = Some("&self".to_string()); + } else { + sig.self_arg = + Some("&self, mut caller: impl wasmtime::AsContextMut".to_string()); + } + let mode = if self.any_async_func { + TypeMode::Owned + } else { + TypeMode::AllBorrowed("'_") + }; + self.print_docs_and_params(iface, func, mode, &sig); self.push_str("-> Result<"); self.print_results(iface, func); self.push_str(", wasmtime::Trap> {\n"); @@ -780,6 +835,9 @@ impl Generator for Wasmtime { .map(|(name, _)| to_rust_ident(name).to_string()) .collect(); let mut f = FunctionBindgen::new(self, is_dtor, params); + if f.gen.any_async_func { + f.src.indent(2); + } iface.call( Direction::Export, LiftLower::LowerArgsLiftResults, @@ -793,6 +851,7 @@ impl Generator for Wasmtime { needs_buffer_transaction, closures, needs_functions, + needs_get_state, .. } = f; @@ -802,7 +861,7 @@ impl Generator for Wasmtime { .or_insert_with(Exports::default); for (name, func) in needs_functions { self.src - .push_str(&format!("let func_{0} = &self.{0};\n", name)); + .push_str(&format!("let func_{0} = self.{0};\n", name)); let get = format!( "instance.get_typed_func::<{}, _>(&mut store, \"{}\")?", func.cvt(), @@ -815,7 +874,7 @@ impl Generator for Wasmtime { assert!(!needs_borrow_checker); if needs_memory { - self.src.push_str("let memory = &self.memory;\n"); + self.src.push_str("let memory = self.memory;\n"); exports.fields.insert( "memory".to_string(), ( @@ -824,7 +883,7 @@ impl Generator for Wasmtime { .get_memory(&mut store, \"memory\") .ok_or_else(|| { anyhow::anyhow!(\"`memory` export not a memory\") - })? + })?\ " .to_string(), ), @@ -837,7 +896,54 @@ impl Generator for Wasmtime { .push_str("let mut buffer_transaction = self.buffer_glue.transaction();\n"); } + if needs_get_state { + self.src + .push_str("let get_state = self.get_state.clone();\n"); + } + + if func.is_async { + // If this function itself is async then we start off with an + // initial callback that gets an `async_cx` argument which is the + // integer descriptor for the generated future. + self.src.push_str(&format!( + " + let wasm_func = self.{}; + let start = witx_bindgen_wasmtime::rt::infer_start(move |mut caller, async_cx| {{ + let async_cx = async_cx as i32; + Box::pin(async move {{ + ", + to_rust_ident(&func.name), + )); + } else if self.any_async_func { + // Otherwise if any other function in this interface is async then + // it means all functions are invoked through a reactor task which + // means we need to start a standalone callback to get executed + // on the reactor task. + self.src.push_str(&format!( + " + let wasm_func = self.{}; + let start = witx_bindgen_wasmtime::rt::infer_standalone(move |mut caller| {{ + Box::pin(async move {{ + ", + to_rust_ident(&func.name), + )); + } else { + // And finally with no async functions involved everything is + // simply generated inline. + self.src + .push_str("let mut caller = caller.as_context_mut();\n"); + } + self.src.push_str(&String::from(src)); + + if func.is_async { + self.src + .push_str("self.handle.execute(start, complete).await\n"); + } else if self.any_async_func { + self.src + .push_str("self.handle.run_no_coroutine(start).await\n"); + } + self.src.push_str("}\n"); let func_body = mem::replace(&mut self.src, prev); if !is_dtor { @@ -890,14 +996,21 @@ impl Generator for Wasmtime { self.src.push_str("type "); self.src.push_str(&handle.to_camel_case()); self.src.push_str(": std::fmt::Debug"); - if is_async { + if is_async || self.any_async_func { self.src.push_str(" + Send + Sync"); } + if self.any_async_func { + self.src.push_str(" + 'static"); + } self.src.push_str(";\n"); } } if self.opts.custom_error { - self.src.push_str("type Error;\n"); + self.src.push_str("type Error"); + if self.any_async_func { + self.src.push_str(": Send + 'static"); + } + self.src.push_str(";\n"); if self.needs_custom_error_to_trap { self.src.push_str( "fn error_to_trap(&mut self, err: Self::Error) -> wasmtime::Trap;\n", @@ -946,30 +1059,36 @@ impl Generator for Wasmtime { self.src.push_str("> Default for "); self.src.push_str(&module_camel); self.src.push_str("Tables {\n"); - self.src.push_str("fn default() -> Self { Self {"); + self.src.push_str("fn default() -> Self {\nSelf {\n"); for handle in self.all_needed_handles.iter() { self.src.push_str(&handle.to_snake_case()); - self.src.push_str("_table: Default::default(),"); + self.src.push_str("_table: Default::default(),\n"); } - self.src.push_str("}}}"); + self.src.push_str("}\n}\n}\n"); } } for (module, funcs) in mem::take(&mut self.imports) { let module_camel = module.to_camel_case(); let is_async = !self.opts.async_.is_none(); - self.push_str("\npub fn add_to_linker(linker: &mut wasmtime::Linker"); + self.push_str("\n#[allow(path_statements)]\n"); + self.push_str("pub fn add_to_linker(linker: &mut wasmtime::Linker"); self.push_str(", get: impl Fn(&mut T) -> "); - if self.all_needed_handles.is_empty() { - self.push_str("&mut U"); + + let mut get_rets = vec!["&mut U".to_string()]; + if self.all_needed_handles.len() > 0 { + get_rets.push(format!("&mut {}Tables", module_camel)); + } + if get_rets.len() > 1 { + self.push_str(&format!("({})", get_rets.join(", "))); } else { - self.push_str(&format!("(&mut U, &mut {}Tables)", module_camel)); + self.push_str(&get_rets[0]); } - self.push_str("+ Send + Sync + Copy + 'static) -> anyhow::Result<()> \n"); + self.push_str(" + Send + Sync + Copy + 'static) -> anyhow::Result<()> \n"); self.push_str("where U: "); self.push_str(&module_camel); - if is_async { - self.push_str(", T: Send,"); + if is_async || self.any_async_func { + self.push_str(", T: Send + 'static,"); } self.push_str("\n{\n"); if self.needs_get_memory { @@ -979,7 +1098,7 @@ impl Generator for Wasmtime { self.push_str("use witx_bindgen_wasmtime::rt::get_func;\n"); } for f in funcs { - let method = if f.is_async { + let method = if f.wrap_async { format!("func_wrap{}_async", f.num_wasm_params) } else { String::from("func_wrap") @@ -996,14 +1115,14 @@ impl Generator for Wasmtime { \"canonical_abi\", \"resource_drop_{name}\", move |mut caller: wasmtime::Caller<'_, T>, handle: u32| {{ - let (host, tables) = get(caller.data_mut()); - let handle = tables + let data = get(caller.data_mut()); + let handle = data.1 .{snake}_table .remove(handle) .map_err(|e| {{ wasmtime::Trap::new(format!(\"failed to remove handle: {{}}\", e)) }})?; - host.drop_{snake}(handle); + data.0.drop_{snake}(handle); Ok(()) }} )?;\n", @@ -1031,9 +1150,7 @@ impl Generator for Wasmtime { ", ); self.push_str("#[derive(Default)]\n"); - self.push_str("pub struct "); - self.push_str(&name); - self.push_str("Data {\n"); + self.push_str(&format!("pub struct {}Data {{\n", name)); for r in self.exported_resources.iter() { self.src.push_str(&format!( " @@ -1046,13 +1163,17 @@ impl Generator for Wasmtime { } self.push_str("}\n"); - self.push_str("pub struct "); - self.push_str(&name); - self.push_str(" {\n"); + let get_state_ret = format!("&mut {}Data", name); + self.push_str(&format!("pub struct {} {{\n", name)); self.push_str(&format!( - "get_state: Box &mut {}Data + Send + Sync>,\n", - name + "get_state: std::sync::Arc {} + Send + Sync>,\n", + get_state_ret, )); + if self.any_async_func { + self.push_str(&format!( + "handle: witx_bindgen_wasmtime::rt::AsyncHandle,\n", + )); + } for (name, (ty, _)) in exports.fields.iter() { self.push_str(name); self.push_str(": "); @@ -1063,16 +1184,13 @@ impl Generator for Wasmtime { // self.push_str("buffer_glue: witx_bindgen_wasmtime::imports::BufferGlue,"); // } self.push_str("}\n"); - let bound = if self.opts.async_.is_none() { + let bound = if self.opts.async_.is_none() && !self.any_async_func { "" } else { - ": Send" + ": Send + 'static" }; self.push_str(&format!("impl {} {{\n", bound, name)); - if self.exported_resources.len() == 0 { - self.push_str("#[allow(unused_variables)]\n"); - } self.push_str(&format!( " /// Adds any intrinsics, if necessary for this exported wasm @@ -1083,10 +1201,10 @@ impl Generator for Wasmtime { /// the general store's state. pub fn add_to_linker( linker: &mut wasmtime::Linker, - get_state: impl Fn(&mut T) -> &mut {}Data + Send + Sync + Copy + 'static, + get_state: impl Fn(&mut T) -> {} + Send + Sync + Copy + 'static, ) -> anyhow::Result<()> {{ ", - name, + get_state_ret, )); for r in self.exported_resources.iter() { let (func_wrap, call, wait, prefix, suffix) = if self.opts.async_.is_none() { @@ -1155,6 +1273,49 @@ impl Generator for Wasmtime { suffix = suffix, )); } + if self.any_async_func { + self.src.push_str(&format!( + " + linker.func_wrap2_async( + \"canonical_abi\", + \"async_export_done\", + move |mut caller: wasmtime::Caller<'_, T>, cx: i32, ptr: i32| {{ + Box::new(async move {{ + let memory = witx_bindgen_wasmtime::rt::get_memory(&mut caller, \"memory\")?; + witx_bindgen_wasmtime::rt::Async::async_export_done( + caller, + cx, + ptr, + memory, + ).await + }}) + }}, + )?; + linker.func_wrap( + \"canonical_abi\", + \"event_new\", + move |mut caller: wasmtime::Caller<'_, T>, cb: u32, data: u32| {{ + witx_bindgen_wasmtime::rt::Async::event_new( + caller, + cb, + data, + ) + }}, + )?; + linker.func_wrap( + \"canonical_abi\", + \"event_signal\", + move |mut caller: wasmtime::Caller<'_, T>, handle: u32, arg: u32| {{ + witx_bindgen_wasmtime::rt::Async::event_signal( + caller, + handle, + arg, + ) + }}, + )?; + ", + )); + } // if self.needs_buffer_glue { // self.push_str( // " @@ -1202,36 +1363,44 @@ impl Generator for Wasmtime { } else { ("async ", "_async", ".await") }; - self.push_str(&format!( - " - /// Instantaites the provided `module` using the specified - /// parameters, wrapping up the result in a structure that - /// translates between wasm and the host. - /// - /// The `linker` provided will have intrinsics added to it - /// automatically, so it's not necessary to call - /// `add_to_linker` beforehand. This function will - /// instantiate the `module` otherwise using `linker`, and - /// both an instance of this structure and the underlying - /// `wasmtime::Instance` will be returned. - /// - /// The `get_state` parameter is used to access the - /// auxiliary state necessary for these wasm exports from - /// the general store state `T`. - pub {}fn instantiate( - mut store: impl wasmtime::AsContextMut, - module: &wasmtime::Module, - linker: &mut wasmtime::Linker, - get_state: impl Fn(&mut T) -> &mut {}Data + Send + Sync + Copy + 'static, - ) -> anyhow::Result<(Self, wasmtime::Instance)> {{ - Self::add_to_linker(linker, get_state)?; - let instance = linker.instantiate{}(&mut store, module){}?; - Ok((Self::new(store, &instance,get_state)?, instance)) - }} - ", - async_fn, name, instantiate, wait, - )); + if !self.any_async_func { + self.push_str(&format!( + " + /// Instantiates the provided `module` using the + /// specified parameters, wrapping up the result in a + /// structure that translates between wasm and the + /// host. + /// + /// The `linker` provided will have intrinsics added to + /// it automatically, so it's not necessary to call + /// `add_to_linker` beforehand. This function will + /// instantiate the `module` otherwise using `linker`, + /// and both an instance of this structure and the + /// underlying `wasmtime::Instance` will be returned. + /// + /// The `get_state` parameter is used to access the + /// auxiliary state necessary for these wasm exports + /// from the general store state `T`. + pub {}fn instantiate( + mut store: impl wasmtime::AsContextMut, + module: &wasmtime::Module, + linker: &mut wasmtime::Linker, + get_state: impl Fn(&mut T) -> {} + Send + Sync + Copy + 'static, + ) -> anyhow::Result<(Self, wasmtime::Instance)> {{ + Self::add_to_linker(linker, get_state)?; + let instance = linker.instantiate{}(&mut store, module){}?; + Ok((Self::new(store, &instance,get_state)?, instance)) + }} + ", + async_fn, get_state_ret, instantiate, wait, + )); + } + let store_ty = if self.any_async_func { + "wasmtime::Store" + } else { + "impl wasmtime::AsContextMut" + }; self.push_str(&format!( " /// Low-level creation wrapper for wrapping up the exports @@ -1243,14 +1412,17 @@ impl Generator for Wasmtime { /// returned structure which can be used to interact with /// the wasm module. pub fn new( - mut store: impl wasmtime::AsContextMut, + mut store: {store_ty}, instance: &wasmtime::Instance, - get_state: impl Fn(&mut T) -> &mut {}Data + Send + Sync + Copy + 'static, + get_state: impl Fn(&mut T) -> {} + Send + Sync + Copy + 'static, ) -> anyhow::Result {{ ", - name, + get_state_ret, + store_ty = store_ty, )); - self.push_str("let mut store = store.as_context_mut();\n"); + if !self.any_async_func { + self.push_str("let mut store = store.as_context_mut();\n"); + } assert!(!self.needs_get_func); for (name, (_, get)) in exports.fields.iter() { self.push_str("let "); @@ -1262,16 +1434,26 @@ impl Generator for Wasmtime { for r in self.exported_resources.iter() { self.src.push_str(&format!( " - get_state(store.data_mut()).dtor{} = \ - Some(instance.get_typed_func::(\ - &mut store, \ - \"canonical_abi_drop_{}\", \ - )?);\n + let dtor = instance.get_typed_func::(\ + &mut store, \ + \"canonical_abi_drop_{name}\", \ + )?; + let state = get_state(store.data_mut()); + state.dtor{idx} = Some(dtor); ", - r.index(), - iface.resources[*r].name, + idx = r.index(), + name = iface.resources[*r].name, )); } + if self.any_async_func { + self.push_str( + " + let table = instance.get_table(&mut store, \"__indirect_function_table\") + .ok_or_else(|| wasmtime::Trap::new(\"no exported function table\"))?; + let handle = witx_bindgen_wasmtime::rt::Async::spawn(store, table); + ", + ); + } self.push_str("Ok("); self.push_str(&name); self.push_str("{\n"); @@ -1279,7 +1461,10 @@ impl Generator for Wasmtime { self.push_str(name); self.push_str(",\n"); } - self.push_str("get_state: Box::new(get_state),\n"); + self.push_str("get_state: std::sync::Arc::new(get_state),\n"); + if self.any_async_func { + self.push_str("handle,\n"); + } self.push_str("\n})\n"); self.push_str("}\n"); @@ -1288,12 +1473,13 @@ impl Generator for Wasmtime { } for r in self.exported_resources.iter() { - let (async_fn, call, wait) = if self.opts.async_.is_none() { + let (async_fn, call, wait) = if self.opts.async_.is_none() && !self.any_async_func { ("", "call", "") } else { ("async ", "call_async", ".await") }; - self.src.push_str(&format!( + + self.src.push_str( " /// Drops the host-owned handle to the resource /// specified. @@ -1302,28 +1488,59 @@ impl Generator for Wasmtime { /// destructor for this type. This also may not run /// the destructor if there are still other references /// to this type. - pub {async}fn drop_{name_snake}( - &self, - mut store: impl wasmtime::AsContextMut, - val: {name_camel}, - ) -> Result<(), wasmtime::Trap> {{ - let mut store = store.as_context_mut(); - let data = (self.get_state)(store.data_mut()); - let wasm = match data.resource_slab{idx}.drop(val.0) {{ - Some(val) => val, - None => return Ok(()), - }}; - data.dtor{idx}.unwrap().{call}(&mut store, wasm){wait}?; - Ok(()) - }} ", - name_snake = iface.resources[*r].name.to_snake_case(), - name_camel = iface.resources[*r].name.to_camel_case(), + ); + let body = format!( + " + let state = get_state(store.data_mut()); + let wasm = match state.resource_slab{idx}.drop(val.0) {{ + Some(val) => val, + None => return Ok(()), + }}; + state.dtor{idx}.unwrap().{call}(&mut store, wasm){wait}?; + Ok(()) + ", idx = r.index(), - async = async_fn, call = call, wait = wait, - )); + ); + + if self.any_async_func { + self.src.push_str(&format!( + " + pub async fn drop_{name_snake}( + &self, + val: {name_camel}, + ) -> Result<(), wasmtime::Trap> {{ + let get_state = self.get_state.clone(); + self.handle.run_no_coroutine(move |mut store| Box::pin(async move {{ + {body} + }})).await + }} + ", + name_snake = iface.resources[*r].name.to_snake_case(), + name_camel = iface.resources[*r].name.to_camel_case(), + body = body, + )); + } else { + self.src.push_str(&format!( + " + pub {async}fn drop_{name_snake}( + &self, + mut store: impl wasmtime::AsContextMut, + val: {name_camel}, + ) -> Result<(), wasmtime::Trap> {{ + let mut store = store.as_context_mut(); + let get_state = &self.get_state; + {body} + }} + ", + name_snake = iface.resources[*r].name.to_snake_case(), + name_camel = iface.resources[*r].name.to_camel_case(), + body = body, + async = async_fn, + )); + } } self.push_str("}\n"); @@ -1408,6 +1625,7 @@ struct FunctionBindgen<'a> { needs_borrow_checker: bool, needs_memory: bool, needs_functions: HashMap, + needs_get_state: bool, } impl FunctionBindgen<'_> { @@ -1430,6 +1648,7 @@ impl FunctionBindgen<'_> { needs_functions: HashMap::new(), is_dtor, params, + needs_get_state: false, } } @@ -1452,7 +1671,7 @@ impl FunctionBindgen<'_> { self.push_str( "let (caller_memory, data) = memory.data_and_store_mut(&mut caller);\n", ); - self.push_str("let (_, _tables) = get(data);\n"); + self.push_str("let _tables = &mut get(data).1;\n"); } else { self.push_str("let caller_memory = memory.data_mut(&mut caller);\n"); } @@ -1505,6 +1724,22 @@ impl FunctionBindgen<'_> { mem, operands[1], offset, method, operands[0], extra )); } + + fn bind_results(&mut self, amt: usize, results: &mut Vec) { + if amt == 0 { + return; + } + + let tmp = self.tmp(); + self.push_str("let ("); + for i in 0..amt { + let arg = format!("result{}_{}", tmp, i); + self.push_str(&arg); + self.push_str(","); + results.push(arg); + } + self.push_str(") = "); + } } impl RustFunctionGenerator for FunctionBindgen<'_> { @@ -1686,11 +1921,13 @@ impl Bindgen for FunctionBindgen<'_> { } Instruction::I32FromBorrowedHandle { ty } => { let tmp = self.tmp(); + self.needs_get_state = true; self.push_str(&format!( " let obj{tmp} = {op}; - (self.get_state)(caller.as_context_mut().data_mut()).resource_slab{idx}.clone(obj{tmp}.0)?; - let handle{tmp} = (self.get_state)(caller.as_context_mut().data_mut()).index_slab{idx}.insert(obj{tmp}.0); + let state = get_state(caller.data_mut()); + state.resource_slab{idx}.clone(obj{tmp}.0)?; + let handle{tmp} = state.index_slab{idx}.insert(obj{tmp}.0); ", tmp = tmp, idx = ty.index(), @@ -1701,8 +1938,12 @@ impl Bindgen for FunctionBindgen<'_> { } Instruction::HandleOwnedFromI32 { ty } => { let tmp = self.tmp(); + self.needs_get_state = true; self.push_str(&format!( - "let handle{} = (self.get_state)(caller.as_context_mut().data_mut()).index_slab{}.remove({} as u32)?;\n", + " + let state = get_state(caller.data_mut()); + let handle{} = state.index_slab{}.remove({} as u32)?; + ", tmp, ty.index(), operands[0], @@ -1868,8 +2109,8 @@ impl Bindgen for FunctionBindgen<'_> { " copy_slice( &mut caller, - memory, - func_{}, + &memory, + &func_{}, ptr{tmp}, len{tmp}, {} )? ", @@ -1930,7 +2171,7 @@ impl Bindgen for FunctionBindgen<'_> { )); self.push_str(&format!("let base = {} + (i as i32) * {};\n", result, size)); self.push_str(&body); - self.push_str("}"); + self.push_str("\n}\n"); results.push(result); results.push(len); @@ -2067,19 +2308,13 @@ impl Bindgen for FunctionBindgen<'_> { name, sig, } => { - if sig.results.len() > 0 { - let tmp = self.tmp(); - self.push_str("let ("); - for i in 0..sig.results.len() { - let arg = format!("result{}_{}", tmp, i); - self.push_str(&arg); - self.push_str(","); - results.push(arg); - } - self.push_str(") = "); + self.bind_results(sig.results.len(), results); + if self.gen.any_async_func { + self.push_str("wasm_func"); + } else { + self.push_str("self."); + self.push_str(&to_rust_ident(name)); } - self.push_str("self."); - self.push_str(&to_rust_ident(name)); if self.gen.opts.async_.includes(name) { self.push_str(".call_async("); } else { @@ -2100,7 +2335,61 @@ impl Bindgen for FunctionBindgen<'_> { } Instruction::CallWasmAsyncImport { .. } => unimplemented!(), - Instruction::CallWasmAsyncExport { .. } => unimplemented!(), + + Instruction::CallWasmAsyncExport { + module: _, + name, + params: _, + results: wasm_results, + } => { + self.push_str("wasm_func"); + if self.gen.opts.async_.includes(name) { + self.push_str(".call_async("); + } else { + self.push_str(".call("); + } + self.push_str("&mut caller, ("); + for operand in operands { + self.push_str(operand); + self.push_str(", "); + } + self.push_str("async_cx,"); + self.push_str("))"); + if self.gen.opts.async_.includes(name) { + self.push_str(".await"); + } + self.push_str("?;\n"); + self.push_str("Ok(())\n"); + self.after_call = true; + self.caller_memory_available = false; // invalidated by call + + self.push_str("}) // finish Box::pin\n"); + self.push_str("}); // finish `let start = ...`\n"); + + // TODO: this is somewhat inefficient since it's an `Arc` clone + // that could be unnecessary. It's not clear whether this will + // get closed over in the completion callback below. Generated + // code may need this `get_state` in both the initial and + // completion callback though, and that's why it's cloned here + // too to ensure that there's two values to close over. Should + // figure out a better way to emit this so it's only done if + // necessary. + self.push_str("let get_state = self.get_state.clone();\n"); + + self.push_str( + " + let complete = witx_bindgen_wasmtime::rt::infer_complete(move |mut caller, ptr, memory| { + Box::pin(async move { + ", + ); + + let operands = ["ptr".to_string()]; + for (i, ty) in wasm_results.iter().enumerate() { + let ty = wasm_type(*ty); + let load = self.load((i as i32) * 8, ty, &operands); + results.push(load); + } + } Instruction::CallInterface { module: _, func } => { for (i, operand) in operands.iter().enumerate() { @@ -2134,7 +2423,31 @@ impl Bindgen for FunctionBindgen<'_> { call.push_str(&format!("param{}, ", i)); } call.push_str(")"); - if self.gen.opts.async_.includes(&func.name) { + + // If this is itself an async function then the future is first + // created. The actual await-ing happens inside of a separate + // future we create here and pass to the `_async_cx` which will + // manage execution of the future in connection with the + // original invocation of an async export. + // + // The `future` is `await`'d initially and its results are then + // moved into a completion callback which is processed once the + // store is available again. + if func.is_async { + self.push_str("let future = "); + self.push_str(&call); + self.push_str(";\n"); + self.push_str("witx_bindgen_wasmtime::rt::Async::spawn_import(async move {\n"); + self.push_str("let result = future.await;\n"); + call = format!("result"); + self.push_str("witx_bindgen_wasmtime::rt::box_callback(move |mut caller| {\n"); + self.push_str("Box::pin(async move {\n"); + self.push_str("let host = get(caller.data_mut());\n"); + if let Some(rebind) = self.gen.rebind_host(iface) { + self.push_str(&rebind); + } + self.push_str("drop(&mut *host);\n"); // ignore unused variable + } else if self.gen.opts.async_.includes(&func.name) { call.push_str(".await"); } @@ -2180,24 +2493,87 @@ impl Bindgen for FunctionBindgen<'_> { Instruction::Return { amt, .. } => { let result = match amt { - 0 => format!("Ok(())\n"), - 1 => format!("Ok({})\n", operands[0]), - _ => format!("Ok(({}))\n", operands.join(", ")), + 0 => format!("()"), + 1 => format!("{}", operands[0]), + _ => format!("({})", operands.join(", ")), }; - match self.cleanup.take() { - Some(cleanup) => { - self.push_str("let ret = "); - self.push_str(&result); - self.push_str(";\n"); + if self.gen.any_async_func && !self.gen.in_import { + self.push_str("let ret = "); + self.push_str(&result); + self.push_str(";\n"); + if let Some(cleanup) = self.cleanup.take() { self.push_str(&cleanup); - self.push_str("ret"); } - None => self.push_str(&result), + self.push_str("Ok(ret)\n"); + self.push_str("}) // finish Box::pin\n"); + self.push_str("}); // finish `let complete`\n"); + } else { + let result = format!("Ok({})", result); + match self.cleanup.take() { + Some(cleanup) => { + self.push_str("let ret = "); + self.push_str(&result); + self.push_str(";\n"); + self.push_str(&cleanup); + self.push_str("ret"); + } + None => { + self.push_str(&result); + self.push_str("\n"); + } + } } } Instruction::ReturnAsyncExport { .. } => unimplemented!(), - Instruction::ReturnAsyncImport { .. } => unimplemented!(), + + Instruction::CompletionCallback { params, .. } => { + let mut tys = String::new(); + tys.push_str("i32,"); + for param in params.iter() { + tys.push_str(wasm_type(*param)); + tys.push_str(", "); + } + self.closures.push_str(&format!( + "\ + let completion_callback = + witx_bindgen_wasmtime::rt::Async::::function_table() + .get(&mut caller, {idx} as u32) + .ok_or_else(|| wasmtime::Trap::new(\"invalid function index\"))? + .funcref() + .ok_or_else(|| wasmtime::Trap::new(\"not a funcref table\"))? + .ok_or_else(|| wasmtime::Trap::new(\"callback was a null function\"))? + .typed::<({tys}), (), _>(&caller)?; + ", + idx = operands[0], + tys = tys, + )); + results.push(format!("completion_callback")); + } + + Instruction::ReturnAsyncImport { .. } => { + let mut result = operands[1..].join(", "); + result.push_str(","); + self.push_str(&format!("let ret = ({});\n", result)); + if let Some(cleanup) = self.cleanup.take() { + self.push_str(&cleanup); + } + + self.push_str(&operands[0]); + if self.gen.opts.async_.is_none() { + self.push_str(".call"); + } else { + self.push_str(".call_async"); + } + self.push_str("(&mut caller, ret)"); + if !self.gen.opts.async_.is_none() { + self.push_str(".await"); + } + self.push_str("\n"); + + self.push_str("}) // end Box:pin\n"); + self.push_str("}) // end box_callback\n"); + } Instruction::I32Load { offset } => results.push(self.load(*offset, "i32", operands)), Instruction::I32Load8U { offset } => { diff --git a/crates/gen-wasmtime/tests/codegen.rs b/crates/gen-wasmtime/tests/codegen.rs index 1941d7ea6..cce1256cb 100644 --- a/crates/gen-wasmtime/tests/codegen.rs +++ b/crates/gen-wasmtime/tests/codegen.rs @@ -9,9 +9,6 @@ mod exports { test_helpers::codegen_wasmtime_export!( "*.witx" - // TODO: implement async support - "!async_functions.witx" - // If you want to exclude a specific test you can include it here with // gitignore glob syntax: // @@ -28,9 +25,6 @@ mod imports { test_helpers::codegen_wasmtime_import!( "*.witx" - // TODO: implement async support - "!async_functions.witx" - // TODO: these use push/pull buffer which isn't implemented in the test // generator just yet "!wasi_next.witx" diff --git a/crates/rust-wasm/src/futures.rs b/crates/rust-wasm/src/futures.rs index 286df4ae7..70e4faaae 100644 --- a/crates/rust-wasm/src/futures.rs +++ b/crates/rust-wasm/src/futures.rs @@ -1,10 +1,11 @@ //! Helper library support for `async` witx functions, used for both -use std::cell::RefCell; +use self::event::{Event, Signal}; +use std::cell::{Cell, RefCell}; use std::future::Future; use std::mem; use std::pin::Pin; -use std::rc::{Rc, Weak}; +use std::rc::Rc; use std::sync::Arc; use std::task::*; @@ -19,12 +20,112 @@ pub unsafe extern "C" fn async_export_done(_ctx: i32, _ptr: i32) { panic!("only supported on wasm"); } -struct PollingWaker { - state: RefCell, +/// Runs the `future` provided to completion, polling the future whenever its +/// waker receives a call to `wake`. +pub fn execute(future: Pin>>) { + Task::execute(future) +} + +struct Task { + future: Pin>>, + waker: Arc, +} + +impl Task { + fn execute(future: Pin>>) { + Box::new(Task { + future, + waker: Arc::new(WasmWaker { + state: Cell::new(State::Woken), + }), + }) + .signal() + } +} + +impl Signal for Task { + fn signal(mut self: Box) { + // First, reset our state to `polling` to indicate that we're actively + // polling the future that we own. + let waker = self.waker.clone(); + match waker.state.replace(State::Polling) { + // This shouldn't be possible since if a waiting event is pending + // then we shouldn't be woken up to signal. + State::Waiting(_) => panic!("signaled but event is present"), + + // This also shouldn't be possible since if the previous state were + // polling then we shouldn't be restarting another round of polling. + State::Polling => panic!("poll-in-poll"), + + // This is the expected state, which is to say that we should be + // previously woken with some event having been consumed, which + // left a `Woken` marker here. + State::Woken => {} + } + + // Perform the Rust Dance to poll the future. + let rust_waker = waker.clone().into(); + let mut cx = Context::from_waker(&rust_waker); + match self.future.as_mut().poll(&mut cx) { + // If the future has finished there's nothing else left to do but + // destroy the future, so we do so here through the dtor for `self` + // in an early-return. + Poll::Ready(()) => return, + + // If the future isn't ready then logic below handles the wakeup + // procedure. + Poll::Pending => {} + } + + // Our future isn't ready but we should be scheduled to wait on some + // event from within the future. Configure the state of the waker + // after-the-fact to have an interface-types-provided "event" which, + // when woken, will basically re-invoke this method. + let event = Event::new(self); + match waker.state.replace(State::Waiting(event)) { + // This state shouldn't be possible because we're the only ones + // inserting a `Waiting` state here, so if something else set that + // it's highly unexpected. + State::Waiting(_) => unreachable!(), + + // This is the expected state most of the time where we're replacing + // the `Polling` state that was configured above. This means we've + // switched from polling-to-waiting so we can safely return now and + // wait for our result. + State::Polling => {} + + // This is a slightly tricky state where we received a `wake()` + // while we were polling. In this situation we replace the state + // back to `Woken` and signal the event ourselves. + State::Woken => { + let event = match waker.state.replace(State::Woken) { + State::Waiting(event) => event, + _ => unreachable!(), + }; + event.signal(); + } + } + } +} + +/// This is the internals of the `Waker` that's specific to wasm. +/// +/// For now this is pretty simple where this maintains a state enum where the +/// main interesting state is an "event" that gets a signal to start re-polling +/// the future. This event-based-wakeup has two consequences: +/// +/// * If the `wake()` comes from another Rust coroutine then we'll correctly +/// execute the Rust poll on the original coroutine's context. +/// * If the `wake()` comes from an async import completing then it means the +/// completion callback will do a tiny bit of work to signal the event, and +/// then the real work will happen later when the event's callback is +/// enqueued. +struct WasmWaker { + state: Cell, } enum State { - Waiting(Pin>>), + Waiting(Event), Polling, Woken, } @@ -34,83 +135,33 @@ enum State { // an alternative implementation for threaded WebAssembly when that comes about // to host runtimes off-the-web. #[cfg(not(target_feature = "atomics"))] -unsafe impl Send for PollingWaker {} +unsafe impl Send for WasmWaker {} #[cfg(not(target_feature = "atomics"))] -unsafe impl Sync for PollingWaker {} +unsafe impl Sync for WasmWaker {} -/// Runs the `future` provided to completion, polling the future whenever its -/// waker receives a call to `wake`. -pub fn execute(future: impl Future + 'static) { - let waker = Arc::new(PollingWaker { - state: RefCell::new(State::Waiting(Box::pin(future))), - }); - waker.wake() -} - -impl Wake for PollingWaker { +impl Wake for WasmWaker { fn wake(self: Arc) { - let mut state = self.state.borrow_mut(); - let mut future = match mem::replace(&mut *state, State::Polling) { - // We are the first wake to come in to wake-up this future. This - // means that we need to actually poll the future, so leave the - // `Polling` state in place. - State::Waiting(future) => future, - - // Otherwise the future is either already polling or it was already - // woken while it was being polled, in both instances we reset the - // state back to `Woken` and then we return. This means that the - // future is owned by some previous stack frame and will drive the - // future as necessary. - State::Polling | State::Woken => { - *state = State::Woken; - return; - } - }; - drop(state); - - // Create the futures waker/context from ourselves, used for polling. - let waker = self.clone().into(); - let mut cx = Context::from_waker(&waker); - loop { - match future.as_mut().poll(&mut cx) { - // The future is finished! By returning here we destroy the - // future and release all of its resources. - Poll::Ready(()) => break, - - // The future has work yet-to-do, so continue below. - Poll::Pending => {} - } - - let mut state = self.state.borrow_mut(); - match *state { - // This means that we were not woken while we were polling and - // the state is as it was when we took out the future before. By - // `Pending` being returned at this point we're guaranteed that - // our waker will be woken up at some point in the future, which - // will come look at this future again. This means that we - // simply store our future and return, since this call to `wake` - // is now finished. - State::Polling => { - *state = State::Waiting(future); - break; - } + match self.state.replace(State::Woken) { + // We found a waiting event, yay! Signal that to wake it up and then + // there's nothing much else for us to do. + State::Waiting(event) => event.signal(), - // This means that we received a call to `wake` while we were - // polling. Ideally we'd enqueue some sort of microtask-tick - // here or something like that but for now we just loop around - // and poll again. - State::Woken => {} + // this `wake` happened during the poll of the future itself, which + // is ok and the future will consume our `Woken` status when it's + // done polling. + State::Polling => {} - // This shouldn't be possible since we own the future, and no - // one else should insert another future here. - State::Waiting(_) => unreachable!(), - } + // This is perhaps a concurrent wake where we already woke up the + // main future. That's ok, we're still in the `Woken` state and it's + // still someone else's responsibility to manage wakeups at this + // point. + State::Woken => {} } } } pub struct Oneshot { - inner: Weak>, + inner: Rc>, } pub struct Sender { @@ -130,12 +181,17 @@ enum OneshotState { impl Oneshot { /// Returns a new "oneshot" channel as well as a completion callback. pub fn new() -> (Oneshot, Sender) { + // TODO: this oneshot implementation does not correctly handle "hangups" + // on either the sender or receiver side. This really only works with + // the exact codegen that we have right now and if it's used for + // anything else then this implementation needs to be updated (or this + // should use something off-the-shelf from the ecosystem) let inner = Rc::new(OneshotInner { state: RefCell::new(OneshotState::Start), }); ( Oneshot { - inner: Rc::downgrade(&inner), + inner: inner.clone(), }, Sender { inner }, ) @@ -146,13 +202,7 @@ impl Future for Oneshot { type Output = T; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let inner = match self.inner.upgrade() { - Some(inner) => inner, - // Technically this isn't possible in the initial draft of interface - // types unless there's some serious bug somewhere. - None => panic!("completion callback was canceled"), - }; - let mut state = inner.state.borrow_mut(); + let mut state = self.inner.state.borrow_mut(); match mem::replace(&mut *state, OneshotState::Start) { OneshotState::Done(t) => Poll::Ready(t), OneshotState::Waiting(_) | OneshotState::Start => { @@ -175,12 +225,7 @@ impl Sender { } pub fn send(self, val: T) { - let mut state = self.inner.state.borrow_mut(); - let prev = mem::replace(&mut *state, OneshotState::Done(val)); - // Must `drop` before the `wake` below because waking may induce - // polling which would induce another `borrow_mut` which would - // conflict with this `borrow_mut` otherwise. - drop(state); + let prev = mem::replace(&mut *self.inner.state.borrow_mut(), OneshotState::Done(val)); match prev { // nothing has polled the returned future just yet, so we just @@ -193,16 +238,72 @@ impl Sender { OneshotState::Waiting(waker) => waker.wake(), // Shouldn't be possible, this is the only closure that writes - // `Done` and this can only be invoked once. + // `Done` and this can only be invoked once. Additionally since + // `self` exists we shouldn't be closed yet which is only written in + // `Drop` OneshotState::Done(_) => unreachable!(), } } } -impl Drop for OneshotInner { - fn drop(&mut self) { - if let OneshotState::Waiting(waker) = &*self.state.borrow() { - waker.wake_by_ref(); +mod event { + use std::mem; + + #[cfg(target_arch = "wasm32")] + #[link(wasm_import_module = "canonical_abi")] + extern "C" { + fn event_new(cb: usize, cbdata: usize) -> u32; + fn event_signal(handle: u32, arg: u32); + } + + #[cfg(not(target_arch = "wasm32"))] + unsafe extern "C" fn event_new(_: usize, _: usize) -> u32 { + unreachable!() + } + + #[cfg(not(target_arch = "wasm32"))] + unsafe extern "C" fn event_signal(_: u32, _: u32) { + unreachable!() + } + + pub struct Event(u32); + + pub trait Signal { + fn signal(self: Box); + } + + impl Event { + pub fn new(to_signal: Box) -> Event + where + S: Signal, + { + unsafe { + let to_signal = Box::into_raw(to_signal); + let handle = event_new(signal:: as usize, to_signal as usize); + return Event(handle); + } + + unsafe extern "C" fn signal(data: usize, is_drop: u32) { + let data = Box::from_raw(data as *mut S); + if is_drop == 0 { + data.signal(); + } + } + } + + pub fn signal(self) { + unsafe { + event_signal(self.0, 0); + mem::forget(self); + } + } + } + + impl Drop for Event { + fn drop(&mut self) { + unsafe { + event_signal(self.0, 1); + } } } } diff --git a/crates/test-rust-wasm/Cargo.toml b/crates/test-rust-wasm/Cargo.toml index 32b6402fb..dda16896d 100644 --- a/crates/test-rust-wasm/Cargo.toml +++ b/crates/test-rust-wasm/Cargo.toml @@ -51,3 +51,7 @@ test = false [[bin]] name = "async_functions" test = false + +[[bin]] +name = "async_raw" +test = false diff --git a/crates/test-rust-wasm/src/bin/async_raw.rs b/crates/test-rust-wasm/src/bin/async_raw.rs new file mode 100644 index 000000000..7d47247f5 --- /dev/null +++ b/crates/test-rust-wasm/src/bin/async_raw.rs @@ -0,0 +1,3 @@ +include!("../../../../tests/runtime/async_raw/wasm.rs"); + +fn main() {} diff --git a/crates/wasmtime/Cargo.toml b/crates/wasmtime/Cargo.toml index 81a769efd..6510a2e89 100644 --- a/crates/wasmtime/Cargo.toml +++ b/crates/wasmtime/Cargo.toml @@ -12,6 +12,7 @@ wasmtime = "0.30.0" witx-bindgen-wasmtime-impl = { path = "../wasmtime-impl", version = "0.1" } tracing-lib = { version = "0.1.26", optional = true, package = 'tracing' } async-trait = { version = "0.1.50", optional = true } +tokio = { version = "1.12", features = ['rt', 'sync'], optional = true } [features] # Enables generated code to emit events via the `tracing` crate whenever wasm is @@ -21,7 +22,7 @@ tracing = ['tracing-lib', 'witx-bindgen-wasmtime-impl/tracing'] # Enables async support for generated code, although when enabled this still # needs to be configured through the macro invocation. -async = ['async-trait', 'witx-bindgen-wasmtime-impl/async'] +async = ['async-trait', 'witx-bindgen-wasmtime-impl/async', 'tokio', 'wasmtime/async'] # Enables the ability to parse the old s-expression-based `*.witx` format. old-witx-compat = ['witx-bindgen-wasmtime-impl/old-witx-compat'] diff --git a/crates/wasmtime/src/futures.rs b/crates/wasmtime/src/futures.rs new file mode 100644 index 000000000..19a93a402 --- /dev/null +++ b/crates/wasmtime/src/futures.rs @@ -0,0 +1,907 @@ +use crate::slab::Slab; +use std::any::Any; +use std::cell::{Cell, RefCell}; +use std::future::Future; +use std::mem; +use std::pin::Pin; +use std::sync::{Arc, Weak}; +use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedSender}; +use tokio::task::JoinHandle; +use wasmtime::{AsContextMut, Caller, Memory, Store, StoreContextMut, Table, Trap, TypedFunc}; + +const MAX_EVENTS: usize = 1_000; + +pub struct Async { + function_table: Table, + + /// Channel used to send messages to the main event loop where this + /// `Async` is managed. + /// + /// Note that this is stored in a `Weak` pointer to not hold a strong + /// reference to it because this `Async` is owned by the event loop which + /// is otherwise terminated when the `Sender` is gone, hence if this were + /// a strong reference it would loop forever. + /// + /// The main strong reference to this channel is held by a generated struct + /// that will live in user code on some other task. Other references to this + /// sender, also weak though, will live in each imported async host function + /// when invoked. + sender: Weak>>, + + /// The list of active WebAssembly coroutines that are executing in this + /// event loop. + /// + /// Note that for now the term "coroutine" here is specifically used for the + /// interface-types notion of a coroutine and does not correspond to a + /// literal coroutine/fiber on the host. Interface types coroutines are only + /// implemented right now with the callback ABI, meaning there's no + /// coroutine in the sense of "there's a suspended host stack" somewhere. + /// Instead wasm retains all state necessary for resumption and such. + /// + /// This list of active coroutines will have one-per-export called and when + /// suspended the coroutines here are all guaranteed to have pending imports + /// they're waiting on. + /// + /// Note that internally `Coroutines` is simply a `Slab>` + /// and is only structured this way to have lookups via `&CoroutineId` + /// instead of `u32` as slabs do. + coroutines: RefCell>, + + events: RefCell>, + pending_events: RefCell>, + + /// The "currently active" coroutine. + /// + /// This is used to persist state in the host about what coroutine is + /// currently active so that when an import is called we can automatically + /// assign that import's "thread" of execution to the currently active + /// coroutine, adding it to the right import list. This enables keeping + /// track on the host for what imports are used where and what to cancel + /// whenever one coroutine aborts (if at all). + cur_wasm_coroutine: CoroutineId, + + /// The next unique ID to hand out to a coroutine. + /// + /// This is a monotonically increasing counter which is intended to be + /// unique for all coroutines for the lifetime of a program. This is a + /// generational index of sorts which prevents accidentally resuing slab + /// indices in the `coroutines` array. + cur_unique_id: Cell, + + receiver: Receiver>, +} + +/// An "integer" identifier for a coroutine. +/// +/// This is used to uniquely identify a logical coroutine of WebAssembly +/// execution, and internally contains the slab index it's stored at as well as +/// a unique generational ID. +#[derive(Copy, Clone)] +pub struct CoroutineId { + slab_index: u32, + unique_id: u64, +} + +struct Coroutines { + slab: Slab>, +} + +enum Message { + Execute(Start, Complete, UnboundedSender), + RunNoCoroutine(RunStandalone, UnboundedSender), + FinishImport(Callback, CoroutineId, u32), + Cancel(CoroutineId), +} + +struct Event { + callback: TypedFunc<(u32, u32), ()>, + coroutine: CoroutineId, + data: u32, +} + +struct Coroutine { + /// A unique ID for this coroutine which is used to ensure that even if this + /// coroutine's slab index is reused a `CoroutineId` uniquely points to one + /// logical coroutine. This mostly comes up where when a coroutine exits + /// early due to a trap we need to make sure that even if the slab slot is + /// reused we don't accidentally use some future coroutine for lingering + /// completion callbacks. + unique_id: u64, + + /// A list of spawned tasks corresponding to imported host functions that + /// this coroutine is waiting on. This list is appended to whenever an async + /// host function is invoked and it's removed from when the host function + /// completes (and the message gets back to the main loop). + /// + /// The primary purpose of this list is so that when a coroutine fails (via + /// a trap) that all of the spawned host work for the coroutine can exit + /// ASAP via an `abort()` signal on the `JoinHandle`. + pending_imports: Slab>, + + /// The number of imports or events that we're waiting on. + pending_callbacks: usize, + + /// A callback to invoke whenever a coroutine's `async_export_done` + /// completion callback is invoked. This is used by the host to deserialize + /// the results from WebAssembly (possibly doing things like wasm + /// malloc/free) and then sending the results on a channel. + /// + /// Typically this contains a `Sender` internally within this closure + /// which gets a message once all the wasm arguments have been successfully + /// deserialized. + complete: Option>, + + sender: UnboundedSender, + cancel_task: Option>, +} + +pub type HostFuture = Pin + Send>>; +pub type CoroutineResult = Result, Trap>; +pub type Start = Box< + dyn for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + u32, + ) -> Pin> + Send + 'a>> + + Send, +>; +pub type Callback = Box< + dyn for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + ) -> Pin> + Send + 'a>> + + Send, +>; +pub type Complete = Box< + dyn for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + i32, + wasmtime::Memory, + ) -> Pin + Send + 'a>> + + Send, +>; +pub type RunStandalone = Box< + dyn for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + ) -> Pin + Send + 'a>> + + Send, +>; + +impl Async { + /// Spawns a new task which will manage async execution of wasm within the + /// `store` provided. + pub fn spawn(mut store: Store, function_table: Table) -> AsyncHandle { + // This channel is the primary point of communication into the task that + // we're going to spawn. This'll be bounded to ensure it doesn't get + // overrun, and additionally the sender will be stored in an `Arc` to + // ensure that the returned handle is the only owning handle and the + // intenral weak handle held by `Async` doesn't keep it alive to let + // the task terminate gracefully. + let (sender, receiver) = mpsc::channel(5 /* TODO: should this be configurable? */); + let sender = Arc::new(sender); + let mut cx = Async { + function_table, + sender: Arc::downgrade(&sender), + coroutines: RefCell::new(Coroutines { + slab: Slab::default(), + }), + events: Default::default(), + pending_events: Default::default(), + cur_wasm_coroutine: CoroutineId { + slab_index: u32::MAX, + unique_id: u64::MAX, + }, + cur_unique_id: Cell::new(0), + receiver, + }; + + tokio::spawn(async move { cx.run(&mut store).await }); + AsyncHandle { sender } + } + + pub fn spawn_import( + future: impl Future> + Send + 'static, + ) -> Result<(), Trap> { + Self::with(|cx| { + let sender = cx.sender.clone(); + // Register a new pending import for the currently executing wasm + // coroutine. This will ensure that full completion of this + // coroutine is delayed until this import is resolved. + let coroutine_id = cx.cur_wasm_coroutine; + let mut coroutines = cx.coroutines.borrow_mut(); + let coroutine = coroutines + .get_mut(&coroutine_id) + .ok_or_else(|| Trap::new("cannot call async import from non-async export"))?; + let pending_import_id = coroutine.pending_imports.next_id(); + coroutine.pending_callbacks += 1; + + // Note that `tokio::spawn` is used here to allow the `future` for + // this import to execute in parallel, not just concurrently. The + // result is re-acquired via sending a message on our internal + // channel. + let task = tokio::spawn(async move { + let import_result = future.await; + + // If the main task has exited for some reason then it'll never + // receive our result, but that's ok since we're trying to + // complete a wasm import and if the main task isn't there to + // receive it there's nothing else to do with this result. This + // being an error is theoretically possible but should be rare. + if let Some(sender) = sender.upgrade() { + let send_result = sender + .send(Message::FinishImport( + import_result, + coroutine_id, + pending_import_id, + )) + .await; + drop(send_result); + } + }); + + let id = coroutine.pending_imports.insert(task); + assert_eq!(id, pending_import_id); + Ok(()) + }) + } + + /// Top level run-loop of the task which owns `Async`, typically spawned + /// as a separate task. + async fn run(&mut self, store: &mut Store) { + while self.process_message(store).await { + // ...continue on to the next message... + } + } + + async fn process_message(&mut self, store: &mut Store) -> bool { + let store = &mut store.as_context_mut(); + + // The "highest priority" messages are those of pending events. These + // are processed first before we get to the actual channel queue which + // may block further for events. + if let Some((event, arg)) = self.pending_events.get_mut().pop() { + return self + .execute_coroutine( + event.coroutine, + event.callback.call_async(store, (event.data, arg)), + ) + .await; + } + + // Wait for a message, but if there are no other messages then we're + // done processing messages + let coroutines = self.coroutines.get_mut(); + let msg = match self.receiver.recv().await { + Some(msg) => msg, + None => return false, + }; + match msg { + // This message is the start of a new task ("coroutine" in + // interface-types-vernacular) so we allocate a new task in our + // slab and set up its state. + // + // Note that we spawn a "helper" task here to send a message to + // our channel when the `sender` specified here is disconnected. + // That scenario means that this coroutine is cancelled and by + // sending a message into our channel we can start processing + // that. + Message::Execute(run, complete, sender) => { + let unique_id = self.cur_unique_id.get(); + self.cur_unique_id.set(unique_id + 1); + let coroutine_id = coroutines.next_id(unique_id); + let my_sender = self.sender.clone(); + let await_close_sender = sender.clone(); + let cancel_task = tokio::spawn(async move { + await_close_sender.closed().await; + // if the main task is gone one way or another we ignore + // the error here since no one's going to receive it + // anyway and all relevant work should be cancelled. + if let Some(sender) = my_sender.upgrade() { + drop(sender.send(Message::Cancel(coroutine_id)).await); + } + }); + coroutines.insert(Coroutine { + unique_id, + complete: Some(complete), + sender, + pending_imports: Slab::default(), + pending_callbacks: 1, + cancel_task: Some(cancel_task), + }); + self.execute_coroutine(coroutine_id, run(store, coroutine_id.slab_index)) + .await + } + + // This message means that we need to execute `run` specified + // which is a "non blocking"-in-the-coroutine-sense wasm + // function. This is basically "go run that single callback" and + // is currently only used for things like resource destructors. + // These aren't allowed to call blocking functions and a trap is + // generated if they try to call a blocking function (since + // there isn't a coroutine set up). + // + // Note that here we avoid allocating a coroutine entirely since + // this isn't actually a coroutine, which means that any attempt + // to call a blocking function will be met with failure (a + // trap). Additionally note that the actual execution of the + // wasm here is select'd against the closure of the `sender` + // here as well, since if the runtime becomes disinterested in + // the result of this async call we can interrupt and abort the + // wasm. + // + // Finally note that if the wasm completes but we fail to send + // the result of the wasm to the receiver then we ignore the + // error since that was basically a race between wasm exiting + // and the sender being closed. + // + // TODO: should this dropped result/error get logged/processed + // somewhere? + Message::RunNoCoroutine(run, sender) => { + tokio::select! { + r = tls::scope(self, run(store)) => { + let is_trap = r.is_err(); + let _ = sender.send(r); + + // Shut down this reactor if a trap happened because + // the instance is now in an indeterminate state. + if is_trap { + return false; + } + } + _ = sender.closed() => return false, + } + true + } + + // This message indicates that an import has completed and + // the completion callback for the wasm must be executed. + // This, plus the serialization of the arguments into wasm + // according to the canonical ABI, is represented by + // `run`. + // + // Note, though, that in some cases we don't actually run + // the completion callback. For example if a previous + // completion callback for this wasm task has failed with a + // trap we don't continue to run completion callbacks for + // the wasm task. This situation is indicated when the + // coroutine is not actually present in our `coroutines` + // list, so we do a lookup here before allowing execution. When + // the coroutine isn't present we simply skip this message which + // will run destructors for any relevant host values. + Message::FinishImport(run, coroutine_id, import_id) => { + let coroutine = match coroutines.get_mut(&coroutine_id) { + Some(c) => c, + None => return true, + }; + coroutine.pending_imports.remove(import_id).unwrap(); + self.execute_coroutine(coroutine_id, run(&mut store.into())) + .await + } + + // This message indicates that the specified coroutine has been + // cancelled, meaning that the sender which would send back the + // result of the coroutine is now a closed channel that we can + // no longer send a message along. Our response to this is to + // remove the coroutine, and its destructor will trigger further + // cancellation if necessary. + // + // Note that this message may race with the actual completion of + // the coroutine so we don't assert that the ID specified here + // is actually in our list. If a coroutine is removed though we + // assume that the wasm is now in an indeterminate state which + // results in aborting this reactor task. If nothing is removed + // then we assume the race was properly resolved and we skip + // this message. + Message::Cancel(coroutine_id) => { + if coroutines.remove(&coroutine_id).is_some() { + return false; + } + true + } + } + } + + async fn execute_coroutine( + &mut self, + coroutine_id: CoroutineId, + wasm_execution: impl Future>, + ) -> bool { + // Actually execute the WebAssembly callback. The call to + // `to_execute.run` here is what will actually execute WebAssembly + // asynchronously, and note that it's also executed within a + // `tls::scope` to ensure that the `tls::with` function will work + // for the duration of the future. + // + // Also note, though, that we want to be able to cancel the + // execution of this WebAssembly if the caller becomes disinterested + // in the result. This happens by using the `closed()` method on the + // channel back to the sender, and if that happens we abort wasm + // entirely and abort the whole coroutine by removing it later. + // + // If this wasm operations is aborted then we exit this loop + // entirely and tear down this reactor task. That triggers + // cancellation of all spawned sub-tasks and sibling coroutines, and + // the rationale for this is that we zapped wasm while it was + // executing so it's now in an indeterminate state and not one that + // we can resume. + // + // TODO: this is a `clone()`-per-callback which is probably cheap, + // but this is also a sort of wonky setup so this may wish to change + // in the future. + let coroutine = self.coroutines.get_mut().get_mut(&coroutine_id).unwrap(); + coroutine.pending_callbacks -= 1; + let cancel_signal = coroutine.sender.clone(); + let prev_coroutine_id = mem::replace(&mut self.cur_wasm_coroutine, coroutine_id); + let result = tokio::select! { + r = tls::scope(self, wasm_execution) => r, + _ = cancel_signal.closed() => return false, + }; + self.cur_wasm_coroutine = prev_coroutine_id; + + let coroutines = self.coroutines.get_mut(); + let coroutine = coroutines.get_mut(&coroutine_id).unwrap(); + if let Err(trap) = result { + // Our WebAssembly callback trapped. That means that this + // entire coroutine is now in a failure state. No further + // wasm callbacks will be invoked and the coroutine is + // removed from out internal list to invoke the failure + // callback, informing what trap caused the failure. + // + // Note that this reopens `coroutine_id.slab_index` to get + // possibly reused, intentionally so, which is why + // `CoroutineId` is a form of generational ID which is + // resilient to this form of reuse. In other words when we + // remove the result here if in the future a pending import + // for this coroutine completes we'll simply discard the + // message. + // + // Any error in sending the trap along the coroutine's channel + // is ignored since we can race with the coroutine getting + // dropped. + // + // TODO: should the trap still be sent somewhere? Is this ok to + // simply ignore? + // + // Finally we exit the reactor in this case because traps + // typically represent fatal conditions for wasm where we can't + // really resume since it may be in an indeterminate state (wasm + // can't handle traps itself), so after we inform the original + // coroutine of the original trap we break out and cancel all + // further execution. + let coroutine = coroutines.remove(&coroutine_id).unwrap(); + let _ = coroutine.sender.send(Err(trap)); + return false; + } else if coroutine.pending_callbacks == 0 { + // Our wasm callback succeeded, and there are no pending + // imports for this coroutine. + // + // In this state it means that the coroutine has completed + // since no further work can possibly happen for the + // coroutine. This means that we can safely remove it from + // our internal list. + // + // If the coroutine's completion wasn't ever signaled, + // however, then that indicates a bug in the wasm code + // itself. This bug is translated into a trap which will get + // reported to the caller to inform the original invocation + // of the export that the result of the coroutine never + // actually came about. + // + // Note that like above a failure to send a trap along the + // channel is ignored since we raced with the caller becoming + // disinterested in the result which is fine to happen at any + // time. + // + // TODO: should the trap still be sent somewhere? Is this ok to + // simply ignore? + // + // TODO: should this tear down the reactor as well, despite it + // being a synthetically created trap? + let coroutine = coroutines.remove(&coroutine_id).unwrap(); + if coroutine.complete.is_some() { + let _ = coroutine + .sender + .send(Err(Trap::new("completion callback never called"))); + } + } else { + // Our wasm callback succeeded, and there are pending + // imports for this coroutine. + // + // This means that the coroutine isn't finished yet so we + // simply turn the loop and wait for something else to + // happen. We'll next be executing WebAssembly when one of + // the coroutine's imports finish. + } + + true + } + + pub async fn async_export_done( + mut caller: Caller<'_, T>, + task_id: i32, + ptr: i32, + mem: Memory, + ) -> Result<(), Trap> { + // Extract the completion callback registered in Rust for the `task_id`. + // This will deserialize all of the canonical ABI results specified by + // `ptr`, and presumably send the result on some sort of channel back to the + // task that originally invoked the wasm. + let task_id = task_id as u32; + let complete = Self::with(|cx| { + let mut coroutines = cx.coroutines.borrow_mut(); + let coroutine = coroutines + .slab + .get_mut(task_id) + .ok_or_else(|| Trap::new("async context not valid"))?; + coroutine + .complete + .take() + .ok_or_else(|| Trap::new("async context not valid")) + })?; + + // Note that this is an async-enabled call to allow `call_async` for things + // like fuel in case the completion callback needs to invoke wasm + // asychronously for things like deallocation. + let result = complete(&mut caller.as_context_mut(), ptr, mem).await?; + + // With the final result of the coroutine we send this along the channel + // back to the original task which was waiting for the result. Note that + // this send may fail if we're racing with cancellation of this task, + // and if cancellation happens we translate that to a trap to ensure + // that wasm is cleaned up quickly (as oppose to waiting for the next + // yield point where it should get cleaned up anyway). + Self::with(|cx| { + let mut coroutines = cx.coroutines.borrow_mut(); + let coroutine = coroutines.slab.get_mut(task_id).unwrap(); + coroutine + .sender + .send(Ok(result)) + .map_err(|_| Trap::new("task has been cancelled")) + }) + } + + /// Implementation of the `event_new` canonical ABI intrinsic + pub fn event_new(mut caller: Caller<'_, T>, cb: u32, data: u32) -> Result { + Self::with(|cx| { + // First up validate `cb` to ensure it's actually a valid wasm + // callback to have given us. + let callback = cx + .function_table + .get(&mut caller, cb) + .ok_or_else(|| Trap::new("out-of bounds function index"))? + .funcref() + .ok_or_else(|| Trap::new("not a funcref table"))? + .ok_or_else(|| Trap::new("callback cannot be null"))? + .typed(&caller)?; + + // Next record that there's a pending callback because the coroutine + // is now possibly blocked on this event. + cx.coroutines + .borrow_mut() + .get_mut(&cx.cur_wasm_coroutine) + .unwrap() + .pending_callbacks += 1; + + // And here the event data is saved and returned back to wasm as an + // index. + Ok(cx.events.borrow_mut().insert(Event { + callback, + coroutine: cx.cur_wasm_coroutine, + data, + })) + }) + } + + /// Implementation of the `event_signal` canonical ABI intrinsic + pub fn event_signal(_caller: Caller<'_, T>, event: u32, arg: u32) -> Result<(), Trap> { + Self::with(|cx| { + // Validate that `event` is valid for this wasm module and then + // enqueue the event into an intenral list to get processed after + // this wasm is finished executing. + // + // Note that we don't decrement the pending imports count here but + // rather wait for the pending event message to get processed to + // actually decrement the count, lest the coroutine accidentally be + // declared done early. + let event = cx + .events + .borrow_mut() + .remove(event) + .ok_or_else(|| Trap::new("invalid event index"))?; + let mut pending_events = cx.pending_events.borrow_mut(); + if pending_events.len() >= MAX_EVENTS { + return Err(Trap::new("too many events created")); + } + pending_events.push((event, arg)); + Ok(()) + }) + } + + // TODO: this is a pretty bad interface to manage the table with... + pub fn function_table() -> Table { + Self::with(|cx| cx.function_table) + } + + fn with(f: impl FnOnce(&Async) -> R) -> R { + tls::with(|cx| f(cx.downcast_ref().unwrap())) + } +} + +impl Coroutines { + fn next_id(&self, unique_id: u64) -> CoroutineId { + CoroutineId { + unique_id, + slab_index: self.slab.next_id(), + } + } + + fn insert(&mut self, coroutine: Coroutine) -> CoroutineId { + let unique_id = coroutine.unique_id; + let slab_index = self.slab.insert(coroutine); + CoroutineId { + unique_id, + slab_index, + } + } + + fn get_mut(&mut self, id: &CoroutineId) -> Option<&mut Coroutine> { + let entry = self.slab.get_mut(id.slab_index)?; + if entry.unique_id == id.unique_id { + Some(entry) + } else { + None + } + } + + fn remove(&mut self, id: &CoroutineId) -> Option> { + let entry = self.slab.get_mut(id.slab_index)?; + if entry.unique_id == id.unique_id { + self.slab.remove(id.slab_index) + } else { + None + } + } +} + +impl Drop for Coroutine { + fn drop(&mut self) { + // When a coroutine is removed and dropped from the internal list of + // coroutines then we're no longer interested in any of the results for + // any of the spawned tasks. This means we can proactively cancel + // anything that this coroutine might be waiting on (imported functions) + // plus the task that's used to send a message to the "main loop" on + // cancellation. + if let Some(task) = &self.cancel_task { + task.abort(); + } + for task in self.pending_imports.iter() { + task.abort(); + } + } +} + +pub struct AsyncHandle { + sender: Arc>>, +} + +impl AsyncHandle { + /// Executes a new WebAssembly in the "reactor" that this handle is + /// connected to. + /// + /// This function will execute `start` as the initial callback for the + /// asynchronous WebAssembly to be executed. This closure receives the + /// `Store` via a handle as well as the coroutine ID that's associated + /// with this new coroutine. It's expected that this callback produces a + /// future which represents the execution of the initial WebAssembly + /// callback, handling all canonical ABI translations internally. + /// + /// The second `complete` callback is invoked when the wasm indicates that + /// it's finished executing (the `async_export_done` intrinsic wasm + /// import). This is expected to produce the final result of the function. + /// + /// This function is an `async` function which is expected to be `.await`'d. + /// If this function's future is dropped or cancelled then the coroutine + /// that this executes will also be dropped/cancelled. If a wasm trap + /// happens then that will be returned here and the coroutine will be + /// cancelled. + /// + /// Note that it is possible for wasm to invoke the completion callback and + /// still trap. In situations like that the trap is returned from this + /// function. + pub async fn execute( + &self, + start: impl for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + u32, + ) + -> Pin> + Send + 'a>> + + Send + + 'static, + complete: impl for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + i32, + wasmtime::Memory, + ) + -> Pin> + Send + 'a>> + + Send + + 'static, + ) -> Result + where + U: Send + 'static, + { + // Note that this channel should have at most 2 messages ever sent on it + // so it's easier to deal with an unbounded channel rather than a + // bounded channel. + let (tx, mut rx) = mpsc::unbounded_channel(); + + // Send a request to our "reactor task" which indicates that we'd like + // to start execution of a new WebAssembly coroutine. The start/complete + // callbacks provided here are the implementation of the canonical ABI + // for this particular coroutine. + // + // Note that failure to send here turns into a trap. This can happen + // when the reactor task is torn down, taking the receiver with it. When + // wasm traps this can happen, and this means that the wasm is no longer + // present for execution so we continue to propagate traps with a new + // synthetic trap here. + self.sender + .send(Message::Execute( + Box::new(start), + Box::new(move |store, ptr, mem| { + Box::pin(async move { + let val = complete(store, ptr, mem).await?; + Ok(Box::new(val) as Box) + }) + }), + tx, + )) + .await + .map_err(|_| Trap::new("wasm reactor task has gone away -- sibling trap?"))?; + + // This is a bit of a tricky dance. Once WebAssembly is requested to be + // executed there are a number of outcomes that can happen here: + // + // 1. The WebAssembly coroutine could complete successfully. This means + // that it eventually invokes the completion callback and no traps + // happened. In this case the completion value is sent on the channel + // and then when the wasm is all finished then the sending half of + // the channel is destroyed. + // + // 2. The WebAssembly coroutine could trap before invoking its + // completion callback. In this scenario the first message is a trap + // and there will be no second message because the coroutine is + // destroyed after a trap. + // + // 3. The WebAssembly coroutine could give us a completed value + // successfully, but then afterwards may trap. In this situation the + // first message received is the completed value of the coroutine, + // and the second message will be the trap that occurred. + // + // 4. Finally a the reactor coudl get torn down because of wasm hitting + // a trap (leaving it in an indeterminate state) or a bug in the + // reactor that panicked. + // + // Overall this leads us to two separate `.await` calls. The first + // `.await` receives the first message and "propagates" traps in (4) + // assuming that the reactor is gone due to a wasm trap. This first + // result is `Ok` in (1)/(3), and it's `Err` in the case of (2). + // + // The second `.await` will wait for the full completion of the + // coroutine in (1) but then receive `None`, should immediately receive + // `None` for (2), and will receive a trap with (3). In all situations + // we are guaranteed that after the second message the coroutine is + // deleted and cleaned up. + // + // Note that receiving `Ok` as the second message is not possible + // because the completion callback is invoked at most once and it's only + // invoked if no trap has happened, which means that a successful + // completion callback is guaranteed to be the first message. + // + // TODO: the time that passes between the first `.await` and the second + // `.await` is not exposed with this function's signature. This is + // simply a bland async function that returns the result, but embedders + // may want to process a successful result which later traps. This API + // should probably be redesigned to accommodate this. + let result = rx + .recv() + .await + .ok_or_else(|| Trap::new("wasm reactor task has gone away -- sibling trap?"))?; + match rx.recv().await { + Some(Err(trap)) => Err(trap), + Some(Ok(_)) => unreachable!(), + None => result.map(|e| *e.downcast().ok().unwrap()), + } + } + + pub async fn run_no_coroutine( + &self, + run: impl for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + ) + -> Pin> + Send + 'a>> + + Send + + 'static, + ) -> Result + where + U: Send + 'static, + { + let (tx, mut rx) = mpsc::unbounded_channel(); + self.sender + .send(Message::RunNoCoroutine( + Box::new(move |store| { + Box::pin(async move { + let val = run(store).await?; + Ok(Box::new(val) as Box) + }) + }), + tx, + )) + .await + .ok() + .expect("reactor task should be present"); + rx.recv() + .await + .unwrap() + .map(|e| *e.downcast().ok().unwrap()) + } +} + +mod tls { + use std::any::Any; + use std::cell::Cell; + use std::future::Future; + use std::pin::Pin; + use std::task::{Context, Poll}; + + thread_local!(static CUR: Cell<*const (dyn Any + Send)> = Cell::new(&0)); + + pub async fn scope( + val: &mut (dyn Any + Send + 'static), + future: impl Future, + ) -> T { + struct SetTls<'a, F> { + val: &'a mut (dyn Any + Send + 'static), + future: F, + } + + impl Future for SetTls<'_, F> { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let (val, future) = unsafe { + let inner = self.get_unchecked_mut(); + ( + Pin::new_unchecked(&mut inner.val), + Pin::new_unchecked(&mut inner.future), + ) + }; + + let x: &&mut (dyn Any + Send + 'static) = val.as_ref().get_ref(); + set(&**x, || future.poll(cx)) + } + } + + SetTls { val, future }.await + } + + pub fn set(val: &(dyn Any + Send + 'static), f: impl FnOnce() -> R) -> R { + return CUR.with(|slot| { + let prev = slot.replace(val); + let _reset = Reset(slot, prev); + f() + }); + + struct Reset<'a, T: Copy>(&'a Cell, T); + + impl Drop for Reset<'_, T> { + fn drop(&mut self) { + self.0.set(self.1); + } + } + } + + pub fn with(f: impl FnOnce(&(dyn Any + Send)) -> R) -> R { + CUR.with(|slot| { + let val = slot.get(); + unsafe { f(&*val) } + }) + } +} diff --git a/crates/wasmtime/src/lib.rs b/crates/wasmtime/src/lib.rs index 043ab2297..9db4d00dc 100644 --- a/crates/wasmtime/src/lib.rs +++ b/crates/wasmtime/src/lib.rs @@ -7,8 +7,11 @@ pub use tracing_lib as tracing; #[doc(hidden)] pub use {anyhow, bitflags, wasmtime}; +pub use futures::HostFuture; + mod error; pub mod exports; +mod futures; pub mod imports; mod le; mod region; @@ -32,6 +35,7 @@ unsafe impl Sync for RawMemory {} #[doc(hidden)] pub mod rt { + pub use crate::futures::{Async, AsyncHandle}; use crate::slab::Slab; use crate::{Endian, Le}; use std::mem; @@ -257,4 +261,61 @@ pub mod rt { Some(resource.wasm) } } + + use crate::futures::Callback; + use std::future::Future; + use std::pin::Pin; + + pub fn infer_start(callback: F) -> F + where + F: for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + u32, + ) + -> Pin> + Send + 'a>> + + Send + + 'static, + { + callback + } + + pub fn infer_complete(callback: F) -> F + where + F: for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + i32, + Memory, + ) + -> Pin> + Send + 'a>> + + Send + + 'static, + { + callback + } + + pub fn infer_standalone(callback: F) -> F + where + F: for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + ) + -> Pin> + Send + 'a>> + + Send + + 'static, + { + callback + } + + pub fn box_callback( + callback: impl for<'a> FnOnce( + &'a mut StoreContextMut<'_, T>, + ) + -> Pin> + Send + 'a>> + + Send + + 'static, + ) -> Callback { + Box::new(callback) + } + + #[cfg(feature = "async")] + pub use tokio::sync::mpsc; } diff --git a/crates/wasmtime/src/slab.rs b/crates/wasmtime/src/slab.rs index 29db0c9ba..a18ec52d7 100644 --- a/crates/wasmtime/src/slab.rs +++ b/crates/wasmtime/src/slab.rs @@ -12,6 +12,10 @@ enum Entry { } impl Slab { + pub fn next_id(&self) -> u32 { + self.next as u32 + } + pub fn insert(&mut self, item: T) -> u32 { if self.next == self.storage.len() { self.storage.push(Entry::Empty { @@ -54,6 +58,13 @@ impl Slab { } } } + + pub fn iter(&self) -> impl Iterator { + self.storage.iter().filter_map(|i| match i { + Entry::Full(t) => Some(t), + Entry::Empty { .. } => None, + }) + } } impl Default for Slab { diff --git a/crates/witx2/src/abi.rs b/crates/witx2/src/abi.rs index e4a01eaba..fd009e74e 100644 --- a/crates/witx2/src/abi.rs +++ b/crates/witx2/src/abi.rs @@ -662,6 +662,20 @@ def_instruction! { /// the `async_export_done` intrinsic in the `canonical_abi` module. ReturnAsyncExport { func: &'a Function } : [2] => [0], + /// Validates a completion callback index as provided by wasm. + /// + /// This takes an `i32` argument which was provided by WebAssembly as an + /// index into the function table. This index should be a valid index + /// pointing to a valid function. The function should take the `params` + /// specified plus a leading `i32` parameter. The function should return + /// no values. + /// + /// This instruction should push an expression representing the + /// function, and the expression is later used as the first argument to + /// `ReturnAsyncImport` to actually get invoked in a later async + /// context. + CompletionCallback { func: &'a Function, params: &'a [WasmType] } : [1] => [1], + /// "Returns" from an asynchronous import. /// /// This is only used for host modules at this time, and @@ -1373,6 +1387,21 @@ impl<'a, B: Bindgen> Generator<'a, B> { self.emit(&Instruction::GetArg { nth }); } + // If we're invoking a completion callback then allow codegen to + // front-load validation of the function pointer argument to + // ensure we can continue successfully once we've committed to + // translating all the arguments and calling the host function. + let callback = if func.is_async && self.dir == Direction::Import { + self.emit(&Instruction::GetArg { + nth: sig.params.len() - 2, + }); + let params = sig.retptr.as_ref().unwrap(); + self.emit(&Instruction::CompletionCallback { func, params }); + Some(self.stack.pop().unwrap()) + } else { + None + }; + // Once everything is on the stack we can lift all arguments // one-by-one into their interface-types equivalent. self.lift_all(&func.params); @@ -1393,10 +1422,8 @@ impl<'a, B: Bindgen> Generator<'a, B> { Direction::Import => { assert_eq!(self.stack.len(), tys.len()); let operands = mem::take(&mut self.stack); - // function index to call - self.emit(&Instruction::GetArg { - nth: sig.params.len() - 2, - }); + // wasm function to call + self.stack.extend(callback); // environment for the function self.emit(&Instruction::GetArg { nth: sig.params.len() - 1, diff --git a/tests/codegen/async_functions.witx b/tests/codegen/async_functions.witx index e713eded4..b52d1821f 100644 --- a/tests/codegen/async_functions.witx +++ b/tests/codegen/async_functions.witx @@ -5,3 +5,15 @@ async_results: async function() -> (u32, string, list) resource async_resource { frob: async function() } + +fetch: async function(url: string) -> Response + +resource Response { + body: async function() -> list + status: function() -> u32 + status_text: function() -> string +} + +resource some_resource + +resource_to_resource: async function(x: some_resource) -> some_resource diff --git a/tests/runtime/async_functions/exports.witx b/tests/runtime/async_functions/exports.witx index 5f5010bb2..78724608f 100644 --- a/tests/runtime/async_functions/exports.witx +++ b/tests/runtime/async_functions/exports.witx @@ -2,3 +2,11 @@ thunk: async function() allocated_bytes: function() -> u32 test_concurrent: async function() + +concurrent_export: async function(idx: u32) + +infinite_loop_async: async function() +infinite_loop: function() + +call_import_then_trap: async function() +call_infinite_import: async function() diff --git a/tests/runtime/async_functions/host.rs b/tests/runtime/async_functions/host.rs new file mode 100644 index 000000000..f44a98c95 --- /dev/null +++ b/tests/runtime/async_functions/host.rs @@ -0,0 +1,342 @@ +use anyhow::Result; +use imports::*; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio::sync::oneshot::{channel, Receiver, Sender}; +use wasmtime::{Config, Engine, Linker, Module, Store, TrapCode}; +use witx_bindgen_wasmtime::HostFuture; + +witx_bindgen_wasmtime::export!({ + paths: ["./tests/runtime/async_functions/imports.witx"], + async: *, +}); + +#[derive(Default)] +pub struct MyImports { + unblock1: Option>, + unblock2: Option>, + unblock3: Option>, + wait1: Option>, + wait2: Option>, + wait3: Option>, + + concurrent1: Option>, + concurrent2: Option>, + + iloop_close_on_drop: Option>, + iloop_entered: Option>, + + import_cancelled_signal: Option>, + import_cancelled_entered: Vec>, +} + +#[witx_bindgen_wasmtime::async_trait] +impl Imports for MyImports { + fn thunk(&mut self) -> HostFuture<()> { + Box::pin(async { + async {}.await; + }) + } + + fn concurrent1(&mut self, a: u32) -> HostFuture { + assert_eq!(a, 1); + self.unblock1.take(); + let wait = self.wait1.take().unwrap(); + Box::pin(async move { + wait.await.unwrap(); + a + 10 + }) + } + + fn concurrent2(&mut self, a: u32) -> HostFuture { + assert_eq!(a, 2); + self.unblock2.take(); + let wait = self.wait2.take().unwrap(); + Box::pin(async move { + wait.await.unwrap(); + a + 10 + }) + } + + fn concurrent3(&mut self, a: u32) -> HostFuture { + assert_eq!(a, 3); + self.unblock3.take(); + let wait = self.wait3.take().unwrap(); + Box::pin(async move { + wait.await.unwrap(); + a + 10 + }) + } + + fn concurrent_export_helper(&mut self, idx: u32) -> HostFuture<()> { + let rx = if idx == 0 { + self.concurrent1.take().unwrap() + } else { + self.concurrent2.take().unwrap() + }; + Box::pin(async move { + drop(rx.await); + }) + } + + async fn iloop_entered(&mut self) { + drop(self.iloop_entered.take()); + } + + fn import_to_cancel(&mut self) -> HostFuture<()> { + let signal = self.import_cancelled_signal.take(); + drop(self.import_cancelled_entered.pop()); + Box::pin(async move { + tokio::time::sleep(Duration::new(1_000, 0)).await; + drop(signal); + }) + } +} + +witx_bindgen_wasmtime::import!({ + async: *, + paths: ["./tests/runtime/async_functions/exports.witx"], +}); + +struct Context { + wasi: wasmtime_wasi::WasiCtx, + imports: MyImports, + exports: exports::ExportsData, +} + +fn run(wasm: &str) -> Result<()> { + let mut config = Config::new(); + config.async_support(true); + config.consume_fuel(true); + let engine = Engine::new(&config)?; + let module = Module::from_file(&engine, wasm)?; + let mut linker = Linker::::new(&engine); + imports::add_to_linker(&mut linker, |cx| &mut cx.imports)?; + wasmtime_wasi::add_to_linker(&mut linker, |cx| &mut cx.wasi)?; + exports::Exports::add_to_linker(&mut linker, |cx| &mut cx.exports)?; + + let rt = tokio::runtime::Runtime::new()?; + rt.block_on(run_async(&engine, &module, &linker)) +} + +async fn run_async(engine: &Engine, module: &Module, linker: &Linker) -> Result<()> { + let instantiate = |imports| async { + let mut store = Store::new( + &engine, + Context { + wasi: crate::default_wasi(), + imports, + exports: Default::default(), + }, + ); + store.add_fuel(10_000)?; + store.out_of_fuel_async_yield(u64::MAX, 10_000); + let instance = linker.instantiate_async(&mut store, &module).await?; + exports::Exports::new(store, &instance, |cx| &mut cx.exports) + }; + + let mut import = MyImports::default(); + + // Initialize various channels which we use as synchronization points to + // test the concurrent aspect of async wasm. The first channels here are + // used to wait for host functions to get entered by the wasm. + let (a, wait1) = channel(); + import.unblock1 = Some(a); + let (a, wait2) = channel(); + import.unblock2 = Some(a); + let (a, wait3) = channel(); + import.unblock3 = Some(a); + let (tx, mut rx) = mpsc::channel::<()>(10); + let tx2 = tx.clone(); + tokio::spawn(async move { + assert!(wait1.await.is_err()); + drop(tx2); + }); + let tx2 = tx.clone(); + tokio::spawn(async move { + assert!(wait2.await.is_err()); + drop(tx2); + }); + tokio::spawn(async move { + assert!(wait3.await.is_err()); + drop(tx); + }); + + let (concurrent1, a) = channel(); + let (concurrent2, b) = channel(); + import.concurrent1 = Some(a); + import.concurrent2 = Some(b); + + // This second set of channels are used to unblock host futures that + // wasm calls, simulating work that returns back to the host and takes + // some time to complete. + let (unblock1, b) = channel(); + import.wait1 = Some(b); + let (unblock2, b) = channel(); + import.wait2 = Some(b); + let (unblock3, b) = channel(); + import.wait3 = Some(b); + + let exports = instantiate(import).await?; + exports.thunk().await?; + + let future = exports.test_concurrent(); + tokio::pin!(future); + + // wait for all three concurrent methods to get entered, where once this + // happens they'll all drop the handles to `rx`, meaning that when + // entered we'll see the `rx` channel get closed. + tokio::select! { + _ = &mut future => unreachable!(), + r = rx.recv() => assert!(r.is_none()), + } + + // Now we can "complete" the async task that each function was waiting + // on. Our original future shouldn't be done until they're all complete. + unblock3.send(()).unwrap(); + unblock2.send(()).unwrap(); + unblock1.send(()).unwrap(); + future.await?; + + // Test concurrent exports can be invoked, here we call wasm + // concurrently twice and complete the second one first, ensuring the + // first one isn't finished, and then we complete the first and assert + // it's done. + let a = exports.concurrent_export(0); + tokio::pin!(a); + let b = exports.concurrent_export(1); + drop(concurrent2); + tokio::select! { + r = &mut a => panic!("got result {:?}", r), + r = b => r.unwrap(), + } + drop(concurrent1); + a.await.unwrap(); + + // Cancelling an infinite loop drops the reactor and the reactor doesn't + // execute forever. This will only work if `tx`, owned by the reactor, will + // get dropped when we cancel execution of the infinite loop. + let (tx, rx) = channel(); + let (tx2, rx2) = channel(); + let mut imports = MyImports::default(); + imports.iloop_close_on_drop = Some(tx); + imports.iloop_entered = Some(tx2); + let exports = instantiate(imports).await?; + { + let iloop = exports.infinite_loop(); + tokio::pin!(iloop); + // execute the iloop long enough to get into wasm and we'll get the + // signal when the `rx2` channel is closed. + tokio::select! { + _ = &mut iloop => unreachable!(), + r = rx2 => assert!(r.is_err()), + } + } + assert!(rx.await.is_err()); + drop(exports); + + // Same as above, but an infinite loop in an async exported wasm function + let (tx, rx) = channel(); + let (tx2, rx2) = channel(); + let mut imports = MyImports::default(); + imports.iloop_close_on_drop = Some(tx); + imports.iloop_entered = Some(tx2); + let exports = instantiate(imports).await?; + { + let iloop = exports.infinite_loop_async(); + tokio::pin!(iloop); + // execute the iloop long enough to get into wasm and we'll get the + // signal when the `rx2` channel is closed. + tokio::select! { + _ = &mut iloop => unreachable!(), + r = rx2 => assert!(r.is_err()), + } + } + assert!(rx.await.is_err()); + drop(exports); + + // A trap from WebAssembly should result in cancelling all imported tasks. + // execute forever. This will only work if `tx`, owned by the reactor, will + // get dropped when we cancel execution of the infinite loop. + let (tx, rx) = channel(); + let mut imports = MyImports::default(); + imports.import_cancelled_signal = Some(tx); + let trap = instantiate(imports) + .await? + .call_import_then_trap() + .await + .unwrap_err(); + assert!( + trap.trap_code() == Some(TrapCode::UnreachableCodeReached), + "bad error: {}", + trap + ); + assert!(rx.await.is_err()); + + // Dropping the owned version of the bindings should recursively tear down + // the reactor task since it's got nothing else to do at that point. + let (tx, rx) = channel(); + let mut imports = MyImports::default(); + imports.import_cancelled_signal = Some(tx); + instantiate(imports).await?; + assert!(rx.await.is_err()); + + // Cancelling (dropping) the outer task transitively tears down the reactor + // and cancels imported tasks. + let (tx, rx) = channel(); + let (tx2, rx2) = channel(); + let mut imports = MyImports::default(); + imports.import_cancelled_signal = Some(tx); + imports.import_cancelled_entered.push(tx2); + let exports = instantiate(imports).await?; + { + let f = exports.call_infinite_import(); + tokio::pin!(f); + // execute the wasm long enough to get into it and we'll get the + // signal when the `rx2` channel is closed. + tokio::select! { + _ = &mut f => unreachable!(), + r = rx2 => assert!(r.is_err()), + } + } + assert!(rx.await.is_err()); + drop(exports); + + // With multiple concurrent exports if one of them is cancelled then they + // all get cancelled. + let (tx, rx) = channel(); + let (tx2, rx2) = channel(); + let mut imports = MyImports::default(); + imports.import_cancelled_entered.push(tx); + imports.import_cancelled_entered.push(tx2); + let exports = instantiate(imports).await?; + let a = exports.call_infinite_import(); + let b = exports.call_infinite_import(); + { + tokio::pin!(a); + { + tokio::pin!(b); + // Run this select twice to ensure both futures get into the import within wasm. + tokio::select! { + _ = &mut a => unreachable!(), + _ = &mut b => unreachable!(), + r = rx2 => assert!(r.is_err()), + } + tokio::select! { + _ = &mut a => unreachable!(), + _ = &mut b => unreachable!(), + r = rx => assert!(r.is_err()), + } + // ... `b` is now dropped here + } + let err = a.await.unwrap_err(); + assert!( + err.to_string().contains("wasm reactor task has gone away"), + "bad error: {}", + err + ); + } + drop(exports); + + Ok(()) +} diff --git a/tests/runtime/async_functions/host.ts b/tests/runtime/async_functions/host.ts index 4381c7f76..880288992 100644 --- a/tests/runtime/async_functions/host.ts +++ b/tests/runtime/async_functions/host.ts @@ -20,6 +20,9 @@ async function run() { const [unblockConcurrent2, resolveUnblockConcurrent2] = promiseChannel(); const [unblockConcurrent3, resolveUnblockConcurrent3] = promiseChannel(); + const [unblockExport1, resolveUnblockExport1] = promiseChannel(); + const [unblockExport2, resolveUnblockExport2] = promiseChannel(); + const imports: Imports = { async thunk() { if (hit) { @@ -60,15 +63,29 @@ async function run() { console.log('concurrent3 returning to wasm'); return 13; }, + + async concurrentExportHelper(n) { + if (n === 0) { + await unblockExport1; + } else { + await unblockExport2; + } + }, + + iloopEntered() { + throw new Error('unsupported'); + }, + + importToCancel() { + throw new Error('unsupported'); + }, }; - let instance: WebAssembly.Instance; - addImportsToImports(importObj, imports, name => instance.exports[name]); + addImportsToImports(importObj, imports); const wasi = addWasiToImports(importObj); const wasm = new Exports(); await wasm.instantiate(getWasm(), importObj); wasi.start(wasm.instance); - instance = wasm.instance; const initBytes = wasm.allocatedBytes(); console.log("calling initial async function"); @@ -99,6 +116,13 @@ async function run() { console.log('waiting on host functions'); await concurrentWasm; console.log('concurrent wasm finished'); + + const a = wasm.concurrentExport(0); + const b = wasm.concurrentExport(1); + resolveUnblockExport2(); + await b; + resolveUnblockExport1(); + await a; } async function some_helper() {} diff --git a/tests/runtime/async_functions/imports.witx b/tests/runtime/async_functions/imports.witx index 4cf3657cb..2f15d6513 100644 --- a/tests/runtime/async_functions/imports.witx +++ b/tests/runtime/async_functions/imports.witx @@ -3,3 +3,9 @@ thunk: async function() concurrent1: async function(a: u32) -> u32 concurrent2: async function(a: u32) -> u32 concurrent3: async function(a: u32) -> u32 + +concurrent_export_helper: async function(idx: u32) + +iloop_entered: function() + +import_to_cancel: async function() diff --git a/tests/runtime/async_functions/wasm.rs b/tests/runtime/async_functions/wasm.rs index 273722e36..e43407ff8 100644 --- a/tests/runtime/async_functions/wasm.rs +++ b/tests/runtime/async_functions/wasm.rs @@ -20,4 +20,27 @@ impl exports::Exports for Exports { assert_eq!(futures_util::join!(a2, a3, a1), (12, 13, 11)); } + + async fn concurrent_export(idx: u32) { + imports::concurrent_export_helper(idx).await + } + + async fn infinite_loop_async() { + imports::iloop_entered(); + loop {} + } + + fn infinite_loop() { + imports::iloop_entered(); + loop {} + } + + async fn call_import_then_trap() { + let _f = imports::import_to_cancel(); + std::arch::wasm32::unreachable(); + } + + async fn call_infinite_import() { + imports::import_to_cancel().await; + } } diff --git a/tests/runtime/async_raw/exports.witx b/tests/runtime/async_raw/exports.witx new file mode 100644 index 000000000..c2fce2428 --- /dev/null +++ b/tests/runtime/async_raw/exports.witx @@ -0,0 +1,12 @@ +complete_immediately: async function() +completion_not_called: async function() +complete_twice: async function() +complete_then_trap: async function() +assert_coroutine_id_zero: async function() + +not_async_export_done: function() +not_async_calls_async: function() + +import_callback_null: async function() +import_callback_wrong_type: async function() +import_callback_bad_index: async function() diff --git a/tests/runtime/async_raw/host.rs b/tests/runtime/async_raw/host.rs new file mode 100644 index 000000000..649b07a95 --- /dev/null +++ b/tests/runtime/async_raw/host.rs @@ -0,0 +1,161 @@ +use anyhow::Result; +use wasmtime::{Config, Engine, Linker, Module, Store}; +use witx_bindgen_wasmtime::HostFuture; + +witx_bindgen_wasmtime::export!({ + async: *, + paths: ["./tests/runtime/async_raw/imports.witx"], +}); + +#[derive(Default)] +struct MyImports; + +impl imports::Imports for MyImports { + fn thunk(&mut self) -> HostFuture<()> { + Box::pin(async {}) + } +} + +witx_bindgen_wasmtime::import!({ + async: *, + paths: ["./tests/runtime/async_raw/exports.witx"], +}); + +fn run(wasm: &str) -> Result<()> { + struct Context { + wasi: wasmtime_wasi::WasiCtx, + imports: MyImports, + exports: exports::ExportsData, + } + + let mut config = Config::new(); + config.async_support(true); + let engine = Engine::new(&config)?; + let module = Module::from_file(&engine, wasm)?; + let mut linker = Linker::::new(&engine); + imports::add_to_linker(&mut linker, |cx| &mut cx.imports)?; + wasmtime_wasi::add_to_linker(&mut linker, |cx| &mut cx.wasi)?; + exports::Exports::add_to_linker(&mut linker, |cx| &mut cx.exports)?; + + let rt = tokio::runtime::Runtime::new()?; + rt.block_on(async { + let instantiate = || async { + let mut store = Store::new( + &engine, + Context { + wasi: crate::default_wasi(), + imports: MyImports::default(), + exports: Default::default(), + }, + ); + let instance = linker.instantiate_async(&mut store, &module).await?; + exports::Exports::new(store, &instance, |cx| &mut cx.exports) + }; + + // it's ok to call the completion callback immediately, and the first + // coroutine is always zero. + let exports = instantiate().await?; + exports.complete_immediately().await?; + exports.assert_coroutine_id_zero().await?; + exports.assert_coroutine_id_zero().await?; + + // if the completion callback is never called that's a trap + let err = instantiate() + .await? + .completion_not_called() + .await + .unwrap_err(); + assert!( + err.to_string().contains("completion callback never called"), + "bad error: {}", + err, + ); + + // the completion callback can only be called once + let err = instantiate().await?.complete_twice().await.unwrap_err(); + assert!( + err.to_string().contains("async context not valid"), + "bad error: {}", + err, + ); + + // if the trap happens after the completion callback... something + // happens, for now a trap. + let err = instantiate().await?.complete_then_trap().await.unwrap_err(); + assert!( + err.trap_code() == Some(wasmtime::TrapCode::UnreachableCodeReached), + "bad error: {:?}", + err + ); + + // If a non-async export tries to call the completion callback for + // async exports that's an error. + let err = instantiate() + .await? + .not_async_export_done() + .await + .unwrap_err(); + assert!( + err.to_string().contains("async context not valid"), + "bad error: {}", + err, + ); + + // If a non-async export tries to call an async import that's an error. + let err = instantiate() + .await? + .not_async_calls_async() + .await + .unwrap_err(); + assert!( + err.to_string() + .contains("cannot call async import from non-async export"), + "bad error: {}", + err, + ); + + // The import callback specified cannot be null + let err = instantiate() + .await? + .import_callback_null() + .await + .unwrap_err(); + assert!( + err.to_string().contains("callback was a null function"), + "bad error: {}", + err, + ); + + // The import callback specified must have the right type. + let err = instantiate() + .await? + .import_callback_wrong_type() + .await + .unwrap_err(); + assert!( + err.to_string().contains("type mismatch with parameters"), + "bad error: {}", + err, + ); + + // The import callback specified must point to a valid table index + let err = instantiate() + .await? + .import_callback_bad_index() + .await + .unwrap_err(); + assert!( + err.to_string().contains("invalid function index"), + "bad error: {}", + err, + ); + + // when wasm traps due to one reason or another all future requests to + // execute wasm fail + let exports = instantiate().await?; + assert!(exports.import_callback_null().await.is_err()); + assert!(exports.complete_immediately().await.is_err()); + + Ok(()) + }) +} diff --git a/tests/runtime/async_raw/host.ts b/tests/runtime/async_raw/host.ts new file mode 100644 index 000000000..0487114d6 --- /dev/null +++ b/tests/runtime/async_raw/host.ts @@ -0,0 +1,52 @@ +import { addImportsToImports, Imports } from "./imports.js"; +import { Exports } from "./exports.js"; +import { getWasm, addWasiToImports } from "./helpers.js"; +// @ts-ignore +import * as assert from 'assert'; + +async function run() { + const importObj = {}; + const imports: Imports = { + async thunk() {} + }; + + async function instantiate() { + addImportsToImports(importObj, imports); + const wasi = addWasiToImports(importObj); + + const wasm = new Exports(); + await wasm.instantiate(getWasm(), importObj); + wasi.start(wasm.instance); + return wasm; + } + + let wasm = await instantiate(); + await wasm.completeImmediately(); + await wasm.assertCoroutineIdZero(); + await wasm.assertCoroutineIdZero(); + + wasm = await instantiate(); + await assert.rejects(wasm.completionNotCalled(), /blocked coroutine with 0 pending callbacks/); + + wasm = await instantiate(); + await assert.rejects(wasm.completeTwice(), /cannot complete coroutine twice/); + + wasm = await instantiate(); + await assert.rejects(wasm.completeThenTrap(), /unreachable/); + + wasm = await instantiate(); + assert.throws(() => wasm.notAsyncExportDone(), /invalid coroutine index/); + + wasm = await instantiate(); + await assert.rejects(wasm.importCallbackNull(), /table index is a null function/); + + // TODO: this is the wrong error from this, but it's not clear how best to do + // type-checks in JS... + wasm = await instantiate(); + await assert.rejects(wasm.importCallbackWrongType(), /0 pending callbacks/); + + wasm = await instantiate(); + await assert.rejects(wasm.importCallbackBadIndex(), RangeError); +} + +await run() diff --git a/tests/runtime/async_raw/imports.witx b/tests/runtime/async_raw/imports.witx new file mode 100644 index 000000000..c0e1bb5a9 --- /dev/null +++ b/tests/runtime/async_raw/imports.witx @@ -0,0 +1 @@ +thunk: async function() diff --git a/tests/runtime/async_raw/wasm.rs b/tests/runtime/async_raw/wasm.rs new file mode 100644 index 000000000..9140110e3 --- /dev/null +++ b/tests/runtime/async_raw/wasm.rs @@ -0,0 +1,82 @@ +use std::arch::wasm32; + +#[link(wasm_import_module = "canonical_abi")] +extern "C" { + pub fn async_export_done(ctx: i32, ptr: i32); +} + +#[link(wasm_import_module = "imports")] +extern "C" { + pub fn thunk(cb: i32, ptr: i32); +} + +#[no_mangle] +pub extern "C" fn complete_immediately(ctx: i32) { + unsafe { + async_export_done(ctx, 0); + } +} + +#[no_mangle] +pub extern "C" fn completion_not_called(_ctx: i32) {} + +#[no_mangle] +pub extern "C" fn complete_twice(ctx: i32) { + unsafe { + async_export_done(ctx, 0); + async_export_done(ctx, 0); + } +} + +#[no_mangle] +pub extern "C" fn complete_then_trap(ctx: i32) { + unsafe { + async_export_done(ctx, 0); + wasm32::unreachable(); + } +} + +#[no_mangle] +pub extern "C" fn assert_coroutine_id_zero(ctx: i32) { + unsafe { + assert_eq!(ctx, 0); + async_export_done(ctx, 0); + } +} + +#[no_mangle] +pub extern "C" fn not_async_export_done() { + unsafe { + async_export_done(0, 0); + } +} + +#[no_mangle] +pub extern "C" fn not_async_calls_async() { + extern "C" fn callback(_x: i32) {} + unsafe { + thunk(callback as i32, 0); + } +} + +#[no_mangle] +pub extern "C" fn import_callback_null(_cx: i32) { + unsafe { + thunk(0, 0); + } +} + +#[no_mangle] +pub extern "C" fn import_callback_wrong_type(_cx: i32) { + extern "C" fn callback() {} + unsafe { + thunk(callback as i32, 0); + } +} + +#[no_mangle] +pub extern "C" fn import_callback_bad_index(_cx: i32) { + unsafe { + thunk(i32::MAX, 0); + } +}