From 92ca41ef9f822c3b10bc98568e36ad710e249e88 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Mon, 2 May 2022 08:39:50 -0700 Subject: [PATCH 1/3] Split out the `Option` and `Expected` types from `Variant` This commit is like prior PRs to split out specializations of types into their own AST type to avoid conflicting with the main type (in this case `variant`). I originally thought these two types would be relatively simple but this is probably one of the more complicated transitions, as evidenced by the lines changed here. The main churn was that variants already have a significant amount of code to support them and this is in some places "duplicating" code for option/expected and in other cases splitting what was already an if/else. Overall I think that the generated code gets a little better since it's clear when something is and `option` vs `expected` now rather than trying to have everything shoehorned into one. Notably the C code generator now generates descriptive fields like `bool is_some` or `bool is_err` instead of a bland `uint8_t tag` with some comments about how to use it. --- crates/gen-c/src/lib.rs | 428 +++++++++++++---- crates/gen-core/src/lib.rs | 34 +- crates/gen-js/src/lib.rs | 448 +++++++++++++----- crates/gen-markdown/src/lib.rs | 68 ++- crates/gen-rust-wasm/src/lib.rs | 99 +++- crates/gen-rust/src/lib.rs | 101 ++-- crates/gen-spidermonkey/src/lib.rs | 32 +- crates/gen-wasmtime-py/src/lib.rs | 405 ++++++++++------ crates/gen-wasmtime/src/lib.rs | 94 +++- crates/gen-wasmtime/tests/codegen.rs | 2 +- crates/parser/src/abi.rs | 433 +++++++++++------ crates/parser/src/ast.rs | 57 +-- crates/parser/src/ast/resolve.rs | 28 ++ crates/parser/src/lib.rs | 19 +- crates/parser/src/sizealign.rs | 23 +- crates/parser/tests/all.rs | 10 + crates/parser/tests/ui/functions.wit.result | 27 +- crates/parser/tests/ui/types.wit | 6 +- crates/parser/tests/ui/types.wit.result | 82 +--- crates/parser/tests/ui/wasi-clock.wit.result | 14 +- crates/parser/tests/ui/wasi-http.wit.result | 27 +- crates/test-helpers/src/lib.rs | 26 +- .../modules/crates/variants/variants.wit | 12 +- crates/wasmlink/src/adapter/call.rs | 217 ++++++--- crates/wasmlink/src/module.rs | 2 + crates/wasmlink/tests/run.rs | 1 + crates/wasmlink/tests/variants.wit | 12 +- crates/wit-component/src/decoding.rs | 54 +-- crates/wit-component/src/encoding.rs | 81 +++- crates/wit-component/src/printing.rs | 98 ++-- .../tests/interfaces/variants/variants.wat | 133 +++--- .../tests/interfaces/variants/variants.wit | 6 +- crates/wit-component/tests/roundtrip.rs | 8 +- tests/codegen/variants.wit | 14 +- tests/runtime/flavorful/exports.wit | 4 +- tests/runtime/flavorful/host.ts | 2 +- tests/runtime/flavorful/imports.wit | 6 +- tests/runtime/flavorful/wasm.c | 28 +- tests/runtime/handles/wasm.c | 10 +- tests/runtime/variants/exports.wit | 2 +- tests/runtime/variants/host.ts | 4 +- tests/runtime/variants/imports.wit | 4 +- tests/runtime/variants/wasm.c | 39 +- 43 files changed, 2099 insertions(+), 1101 deletions(-) diff --git a/crates/gen-c/src/lib.rs b/crates/gen-c/src/lib.rs index 52bde1061..023103263 100644 --- a/crates/gen-c/src/lib.rs +++ b/crates/gen-c/src/lib.rs @@ -166,6 +166,8 @@ impl C { Type::Id(id) => match &iface.types[*id].kind { TypeDefKind::Type(t) => self.is_arg_by_pointer(iface, t), TypeDefKind::Variant(_) => true, + TypeDefKind::Option(_) => true, + TypeDefKind::Expected(_) => true, TypeDefKind::Enum(_) => false, TypeDefKind::Flags(_) => false, TypeDefKind::Tuple(_) | TypeDefKind::Record(_) | TypeDefKind::List(_) => true, @@ -235,26 +237,22 @@ impl C { } Type::Id(id) => { let ty = &iface.types[*id]; - if let Some(name) = &ty.name { - self.print_namespace(iface); - self.src.h(&name.to_snake_case()); - self.src.h("_t"); - return; - } - match &ty.kind { - TypeDefKind::Type(t) => self.print_ty(iface, t), - TypeDefKind::Flags(_) - | TypeDefKind::Tuple(_) - | TypeDefKind::Enum(_) - | TypeDefKind::Record(_) - | TypeDefKind::List(_) - | TypeDefKind::Variant(_) => { - self.public_anonymous_types.insert(*id); - self.private_anonymous_types.remove(id); + match &ty.name { + Some(name) => { self.print_namespace(iface); - self.print_ty_name(iface, &Type::Id(*id)); + self.src.h(&name.to_snake_case()); self.src.h("_t"); } + None => match &ty.kind { + TypeDefKind::Type(t) => self.print_ty(iface, t), + _ => { + self.public_anonymous_types.insert(*id); + self.private_anonymous_types.remove(id); + self.print_namespace(iface); + self.print_ty_name(iface, &Type::Id(*id)); + self.src.h("_t"); + } + }, } } } @@ -284,7 +282,10 @@ impl C { } match &ty.kind { TypeDefKind::Type(t) => self.print_ty_name(iface, t), - TypeDefKind::Record(_) | TypeDefKind::Flags(_) | TypeDefKind::Enum(_) => { + TypeDefKind::Record(_) + | TypeDefKind::Flags(_) + | TypeDefKind::Enum(_) + | TypeDefKind::Variant(_) => { unimplemented!() } TypeDefKind::Tuple(t) => { @@ -295,24 +296,15 @@ impl C { self.print_ty_name(iface, ty); } } - TypeDefKind::Variant(v) => { - if let Some(ty) = v.as_option() { - self.src.h("option_"); - self.print_ty_name(iface, ty); - } else if let Some((ok, err)) = v.as_expected() { - self.src.h("expected_"); - match ok { - Some(t) => self.print_ty_name(iface, t), - None => self.src.h("void"), - } - self.src.h("_"); - match err { - Some(t) => self.print_ty_name(iface, t), - None => self.src.h("void"), - } - } else { - unimplemented!(); - } + TypeDefKind::Option(ty) => { + self.src.h("option_"); + self.print_ty_name(iface, ty); + } + TypeDefKind::Expected(e) => { + self.src.h("expected_"); + self.print_ty_name(iface, &e.ok); + self.src.h("_"); + self.print_ty_name(iface, &e.err); } TypeDefKind::List(t) => { self.src.h("list_"); @@ -331,7 +323,8 @@ impl C { TypeDefKind::Type(_) | TypeDefKind::Flags(_) | TypeDefKind::Record(_) - | TypeDefKind::Enum(_) => { + | TypeDefKind::Enum(_) + | TypeDefKind::Variant(_) => { unreachable!() } TypeDefKind::Tuple(t) => { @@ -343,35 +336,30 @@ impl C { } self.src.h("}"); } - TypeDefKind::Variant(v) => { - if let Some(t) = v.as_option() { - self.src.h("struct {\n"); - self.src.h("\ - // `true` if `val` is present, `false` otherwise - bool tag; - "); + TypeDefKind::Option(t) => { + self.src.h("struct {\n"); + self.src.h("bool is_some;\n"); + if !self.is_empty_type(iface, t) { self.print_ty(iface, t); self.src.h(" val;\n"); - self.src.h("}"); - } else if let Some((ok, err)) = v.as_expected() { - self.src.h("struct { - // 0 if `val` is `ok`, 1 otherwise - uint8_t tag; - union { - "); - if let Some(ok) = ok { - self.print_ty(iface, ok); - self.src.h(" ok;\n"); - } - if let Some(err) = err { - self.print_ty(iface, err); - self.src.h(" err;\n"); - } - self.src.h("} val;\n"); - self.src.h("}"); - } else { - unimplemented!(); } + self.src.h("}"); + } + TypeDefKind::Expected(e) => { + self.src.h("struct { + bool is_err; + union { + "); + if !self.is_empty_type(iface, &e.ok) { + self.print_ty(iface, &e.ok); + self.src.h(" ok;\n"); + } + if !self.is_empty_type(iface, &e.err) { + self.print_ty(iface, &e.err); + self.src.h(" err;\n"); + } + self.src.h("} val;\n"); + self.src.h("}"); } TypeDefKind::List(t) => { self.src.h("struct {\n"); @@ -392,6 +380,7 @@ impl C { fn is_empty_type(&self, iface: &Interface, ty: &Type) -> bool { let id = match ty { Type::Id(id) => *id, + Type::Unit => return true, _ => return false, }; match &iface.types[id].kind { @@ -515,6 +504,24 @@ impl C { } self.src.c("}\n"); } + + TypeDefKind::Option(t) => { + self.src.c("if (ptr->is_some) {\n"); + self.free(iface, t, "&ptr->val"); + self.src.c("}\n"); + } + + TypeDefKind::Expected(e) => { + self.src.c("if (!ptr->is_err) {\n"); + if self.owns_anything(iface, &e.ok) { + self.free(iface, &e.ok, "&ptr->val.ok"); + } + if self.owns_anything(iface, &e.err) { + self.src.c("} else {\n"); + self.free(iface, &e.err, "&ptr->val.err"); + } + self.src.c("}\n"); + } } self.src.c("}\n"); } @@ -538,6 +545,10 @@ impl C { .iter() .filter_map(|c| c.ty.as_ref()) .any(|t| self.owns_anything(iface, t)), + TypeDefKind::Option(t) => self.owns_anything(iface, t), + TypeDefKind::Expected(e) => { + self.owns_anything(iface, &e.ok) || self.owns_anything(iface, &e.err) + } } } @@ -607,44 +618,44 @@ impl Return { self.scalar = Some(Scalar::Type(*orig_ty)); } - TypeDefKind::Variant(r) => { - // Unpack optional returns where a boolean discriminant is - // returned and then the actual type returned is returned - // through a return pointer. - if let Some(ty) = r.as_option() { - self.scalar = Some(Scalar::OptionBool(*ty)); - self.retptrs.push(*ty); - return; - } - - // Unpack `expected` returns where `E` looks like an enum - // so we can return that in the scalar return and have `T` get - // returned through the normal returns. - if let Some((ok, err)) = r.as_expected() { - if let Some(Type::Id(err)) = err { - if let TypeDefKind::Enum(e) = &iface.types[*err].kind { - self.scalar = Some(Scalar::ExpectedEnum { - err: *err, - max_err: e.cases.len(), - }); - if let Some(ok) = ok { - self.splat_tuples(iface, ok, ok); - } - return; - } + // Unpack optional returns where a boolean discriminant is + // returned and then the actual type returned is returned + // through a return pointer. + TypeDefKind::Option(ty) => { + self.scalar = Some(Scalar::OptionBool(*ty)); + self.retptrs.push(*ty); + } + + // Unpack `expected` returns where `E` looks like an enum + // so we can return that in the scalar return and have `T` get + // returned through the normal returns. + TypeDefKind::Expected(e) => { + if let Type::Id(err) = e.err { + if let TypeDefKind::Enum(enum_) = &iface.types[err].kind { + self.scalar = Some(Scalar::ExpectedEnum { + err, + max_err: enum_.cases.len(), + }); + self.splat_tuples(iface, &e.ok, &e.ok); + return; } } - // If all that failed then just return the variant via a normal + // otherwise just return the variant via a normal // return pointer self.retptrs.push(*orig_ty); } + + TypeDefKind::Variant(_) => { + self.retptrs.push(*orig_ty); + } } } fn splat_tuples(&mut self, iface: &Interface, ty: &Type, orig_ty: &Type) { let id = match ty { Type::Id(id) => *id, + Type::Unit => return, _ => { self.retptrs.push(*orig_ty); return; @@ -802,6 +813,64 @@ impl Generator for C { .insert(id, mem::replace(&mut self.src.header, prev)); } + fn type_option( + &mut self, + iface: &Interface, + id: TypeId, + name: &str, + payload: &Type, + docs: &Docs, + ) { + let prev = mem::take(&mut self.src.header); + self.docs(docs); + self.names.insert(&name.to_snake_case()).unwrap(); + self.src.h("typedef struct {\n"); + self.src.h("bool is_some;\n"); + if !self.is_empty_type(iface, payload) { + self.print_ty(iface, payload); + self.src.h(" val;\n"); + } + self.src.h("} "); + self.print_namespace(iface); + self.src.h(&name.to_snake_case()); + self.src.h("_t;\n"); + + self.types + .insert(id, mem::replace(&mut self.src.header, prev)); + } + + fn type_expected( + &mut self, + iface: &Interface, + id: TypeId, + name: &str, + expected: &Expected, + docs: &Docs, + ) { + let prev = mem::take(&mut self.src.header); + self.docs(docs); + self.names.insert(&name.to_snake_case()).unwrap(); + self.src.h("typedef struct {\n"); + self.src.h("bool is_err;\n"); + self.src.h("union {\n"); + if !self.is_empty_type(iface, &expected.ok) { + self.print_ty(iface, &expected.ok); + self.src.h(" ok;\n"); + } + if !self.is_empty_type(iface, &expected.err) { + self.print_ty(iface, &expected.err); + self.src.h(" err;\n"); + } + self.src.h("} val;\n"); + self.src.h("} "); + self.print_namespace(iface); + self.src.h(&name.to_snake_case()); + self.src.h("_t;\n"); + + self.types + .insert(id, mem::replace(&mut self.src.header, prev)); + } + fn type_enum(&mut self, iface: &Interface, id: TypeId, name: &str, enum_: &Enum, docs: &Docs) { let prev = mem::take(&mut self.src.header); self.docs(docs); @@ -1589,6 +1658,177 @@ impl Bindgen for FunctionBindgen<'_> { results.push(result); } + Instruction::OptionLower { + results: result_types, + payload, + .. + } => { + let (mut some, some_results) = self.blocks.pop().unwrap(); + let (mut none, none_results) = self.blocks.pop().unwrap(); + let some_payload = self.payloads.pop().unwrap(); + let _none_payload = self.payloads.pop().unwrap(); + + let mut variant_results = Vec::new(); + for (i, ty) in result_types.iter().enumerate() { + let name = self.locals.tmp("option"); + results.push(name.clone()); + self.src.push_str(wasm_type(*ty)); + self.src.push_str(" "); + self.src.push_str(&name); + self.src.push_str(";\n"); + let some_result = &some_results[i]; + some.push_str(&format!("{name} = {some_result};\n")); + let none_result = &none_results[i]; + none.push_str(&format!("{name} = {none_result};\n")); + variant_results.push(name); + } + + let op0 = &operands[0]; + let ty = self.gen.type_string(iface, payload); + let bind_some = if self.gen.is_empty_type(iface, payload) { + String::new() + } else { + format!("const {ty} *{some_payload} = &({op0}).val;") + }; + self.src.push_str(&format!( + " + if (({op0}).is_some) {{ + {bind_some} + {some} + }} else {{ + {none} + }} + " + )); + } + + Instruction::OptionLift { payload, ty, .. } => { + let (some, some_results) = self.blocks.pop().unwrap(); + let (none, none_results) = self.blocks.pop().unwrap(); + assert!(none_results.is_empty()); + assert!(some_results.len() == 1); + let some_result = &some_results[0]; + + let ty = self.gen.type_string(iface, &Type::Id(*ty)); + let result = self.locals.tmp("option"); + self.src.push_str(&format!("{ty} {result};\n")); + let op0 = &operands[0]; + let set_some = if self.gen.is_empty_type(iface, payload) { + String::new() + } else { + format!("{result}.val = {some_result};") + }; + self.src.push_str(&format!( + "switch ({op0}) {{ + case 0: {{ + {result}.is_some = false; + {none} + break; + }} + case 1: {{ + {result}.is_some = true; + {some} + {set_some} + break; + }} + }}" + )); + results.push(result); + } + + Instruction::ExpectedLower { + results: result_types, + expected, + .. + } => { + let (mut err, err_results) = self.blocks.pop().unwrap(); + let (mut ok, ok_results) = self.blocks.pop().unwrap(); + let err_payload = self.payloads.pop().unwrap(); + let ok_payload = self.payloads.pop().unwrap(); + + let mut variant_results = Vec::new(); + for (i, ty) in result_types.iter().enumerate() { + let name = self.locals.tmp("expected"); + results.push(name.clone()); + self.src.push_str(wasm_type(*ty)); + self.src.push_str(" "); + self.src.push_str(&name); + self.src.push_str(";\n"); + let ok_result = &ok_results[i]; + ok.push_str(&format!("{name} = {ok_result};\n")); + let err_result = &err_results[i]; + err.push_str(&format!("{name} = {err_result};\n")); + variant_results.push(name); + } + + let op0 = &operands[0]; + let ok_ty = self.gen.type_string(iface, &expected.ok); + let err_ty = self.gen.type_string(iface, &expected.err); + let bind_ok = if self.gen.is_empty_type(iface, &expected.ok) { + String::new() + } else { + format!("const {ok_ty} *{ok_payload} = &({op0}).val.ok;") + }; + let bind_err = if self.gen.is_empty_type(iface, &expected.err) { + String::new() + } else { + format!("const {err_ty} *{err_payload} = &({op0}).val.err;") + }; + self.src.push_str(&format!( + " + if (({op0}).is_err) {{ + {bind_err} + {err} + }} else {{ + {bind_ok} + {ok} + }} + " + )); + } + + Instruction::ExpectedLift { expected, ty, .. } => { + let (err, err_results) = self.blocks.pop().unwrap(); + assert!(err_results.len() == 1); + let err_result = &err_results[0]; + let (ok, ok_results) = self.blocks.pop().unwrap(); + assert!(ok_results.len() == 1); + let ok_result = &ok_results[0]; + + let result = self.locals.tmp("expected"); + let set_ok = if self.gen.is_empty_type(iface, &expected.ok) { + String::new() + } else { + format!("{result}.val.ok = {ok_result};") + }; + let set_err = if self.gen.is_empty_type(iface, &expected.err) { + String::new() + } else { + format!("{result}.val.err = {err_result};") + }; + + let ty = self.gen.type_string(iface, &Type::Id(*ty)); + self.src.push_str(&format!("{ty} {result};\n")); + let op0 = &operands[0]; + self.src.push_str(&format!( + "switch ({op0}) {{ + case 0: {{ + {result}.is_err = false; + {ok} + {set_ok} + break; + }} + case 1: {{ + {result}.is_err = true; + {err} + {set_err} + break; + }} + }}" + )); + results.push(result); + } + Instruction::EnumLower { .. } => results.push(format!("(int32_t) {}", operands[0])), Instruction::EnumLift { .. } => results.push(operands.pop().unwrap()), @@ -1721,7 +1961,7 @@ impl Bindgen for FunctionBindgen<'_> { self.src.push_str(&format!( " {ty} {ret}; - {ret}.tag = {tag}; + {ret}.is_some = {tag}; {ret}.val = {val}; ", ty = option_ty, @@ -1756,10 +1996,10 @@ impl Bindgen for FunctionBindgen<'_> { " {ty} {ret}; if ({tag} <= {max}) {{ - {ret}.tag = 1; + {ret}.is_err = true; {ret}.val.err = {tag}; }} else {{ - {ret}.tag = 0; + {ret}.is_err = false; {set_ok} }} ", @@ -1804,7 +2044,7 @@ impl Bindgen for FunctionBindgen<'_> { self.store_in_retptrs(&[format!("{}.val", variant)]); self.src.push_str("return "); self.src.push_str(&variant); - self.src.push_str(".tag;\n"); + self.src.push_str(".is_some;\n"); } Some(Scalar::ExpectedEnum { .. }) => { assert_eq!(operands.len(), 1); @@ -1813,7 +2053,7 @@ impl Bindgen for FunctionBindgen<'_> { self.store_in_retptrs(&[format!("{}.val.ok", variant)]); } self.src - .push_str(&format!("return {}.tag ? {0}.val.err : -1;\n", variant,)); + .push_str(&format!("return {}.is_err ? {0}.val.err : -1;\n", variant)); } }, Instruction::Return { amt, .. } => { diff --git a/crates/gen-core/src/lib.rs b/crates/gen-core/src/lib.rs index 57eb5207c..fb412365b 100644 --- a/crates/gen-core/src/lib.rs +++ b/crates/gen-core/src/lib.rs @@ -65,6 +65,22 @@ pub trait Generator { variant: &Variant, docs: &Docs, ); + fn type_option( + &mut self, + iface: &Interface, + id: TypeId, + name: &str, + payload: &Type, + docs: &Docs, + ); + fn type_expected( + &mut self, + iface: &Interface, + id: TypeId, + name: &str, + expected: &Expected, + docs: &Docs, + ); fn type_enum(&mut self, iface: &Interface, id: TypeId, name: &str, enum_: &Enum, docs: &Docs); fn type_resource(&mut self, iface: &Interface, ty: ResourceId); fn type_alias(&mut self, iface: &Interface, id: TypeId, name: &str, ty: &Type, docs: &Docs); @@ -97,6 +113,8 @@ pub trait Generator { TypeDefKind::Variant(variant) => { self.type_variant(iface, id, name, variant, &ty.docs) } + TypeDefKind::Option(t) => self.type_option(iface, id, name, t, &ty.docs), + TypeDefKind::Expected(e) => self.type_expected(iface, id, name, e, &ty.docs), TypeDefKind::List(t) => self.type_list(iface, id, name, t, &ty.docs), TypeDefKind::Type(t) => self.type_alias(iface, id, name, t, &ty.docs), } @@ -215,6 +233,13 @@ impl Types { TypeDefKind::Type(ty) => { info = self.type_info(iface, ty); } + TypeDefKind::Option(ty) => { + info = self.type_info(iface, ty); + } + TypeDefKind::Expected(e) => { + info = self.type_info(iface, &e.ok); + info |= self.type_info(iface, &e.err); + } } self.type_info.insert(ty, info); return info; @@ -252,8 +277,13 @@ impl Types { } } } - TypeDefKind::List(ty) => self.set_param_result_ty(iface, ty, param, result), - TypeDefKind::Type(ty) => self.set_param_result_ty(iface, ty, param, result), + TypeDefKind::List(ty) | TypeDefKind::Type(ty) | TypeDefKind::Option(ty) => { + self.set_param_result_ty(iface, ty, param, result) + } + TypeDefKind::Expected(e) => { + self.set_param_result_ty(iface, &e.ok, param, result); + self.set_param_result_ty(iface, &e.err, param, result); + } } } diff --git a/crates/gen-js/src/lib.rs b/crates/gen-js/src/lib.rs index c553f5724..dc2bbf984 100644 --- a/crates/gen-js/src/lib.rs +++ b/crates/gen-js/src/lib.rs @@ -173,32 +173,26 @@ impl Js { TypeDefKind::Record(_) => panic!("anonymous record"), TypeDefKind::Flags(_) => panic!("anonymous flags"), TypeDefKind::Enum(_) => panic!("anonymous enum"), - TypeDefKind::Variant(v) => { - if self.is_nullable_option(iface, v) { - self.print_ty(iface, v.cases[1].ty.as_ref().unwrap()); - self.src.ts(" | null"); - } else if let Some(t) = v.as_option() { + TypeDefKind::Option(t) => { + if self.maybe_null(iface, t) { self.needs_ty_option = true; self.src.ts("Option<"); self.print_ty(iface, t); self.src.ts(">"); - } else if let Some((ok, err)) = v.as_expected() { - self.needs_ty_result = true; - self.src.ts("Result<"); - match ok { - Some(ok) => self.print_ty(iface, ok), - None => self.src.ts("undefined"), - } - self.src.ts(", "); - match err { - Some(err) => self.print_ty(iface, err), - None => self.src.ts("undefined"), - } - self.src.ts(">"); } else { - panic!("anonymous variant"); + self.print_ty(iface, t); + self.src.ts(" | null"); } } + TypeDefKind::Expected(e) => { + self.needs_ty_result = true; + self.src.ts("Result<"); + self.print_ty(iface, &e.ok); + self.src.ts(", "); + self.print_ty(iface, &e.err); + self.src.ts(">"); + } + TypeDefKind::Variant(_) => panic!("anonymous variant"), TypeDefKind::List(v) => self.print_list(iface, v), } } @@ -305,15 +299,46 @@ impl Js { return i.name().to_string(); } - pub fn get_nullable_option<'a>(&self, iface: &'a Interface, ty: &Type) -> Option<&'a Type> { - iface.get_variant(ty).and_then(|v| v.as_option()) + /// Returns whether `null` is a valid value of type `ty` + fn maybe_null(&self, iface: &Interface, ty: &Type) -> bool { + self.as_nullable(iface, ty).is_some() } - pub fn is_nullable_option(&self, iface: &Interface, variant: &Variant) -> bool { - variant.as_option().map_or(false, |ty| { - self.get_nullable_option(iface, ty) - .map_or(true, |ty| self.get_nullable_option(iface, ty).is_none()) - }) + /// Tests whether `ty` can be represented with `null`, and if it can then + /// the "other type" is returned. If `Some` is returned that means that `ty` + /// is `null | `. If `None` is returned that means that `null` can't + /// be used to represent `ty`. + fn as_nullable<'a>(&self, iface: &'a Interface, ty: &'a Type) -> Option<&'a Type> { + let id = match ty { + Type::Id(id) => *id, + _ => return None, + }; + match &iface.types[id].kind { + // If `ty` points to an `option`, then `ty` can be represented + // with `null` if `t` itself can't be represented with null. For + // example `option>` can't be represented with `null` + // since that's ambiguous if it's `none` or `some(none)`. + // + // Note, oddly enough, that `option>>` can be + // represented as `null` since: + // + // * `null` => `none` + // * `{ tag: "none" }` => `some(none)` + // * `{ tag: "some", val: null }` => `some(some(none))` + // * `{ tag: "some", val: 1 }` => `some(some(some(1)))` + // + // It's doubtful anyone would actually rely on that though due to + // how confusing it is. + TypeDefKind::Option(t) => { + if !self.maybe_null(iface, t) { + Some(t) + } else { + None + } + } + TypeDefKind::Type(t) => self.as_nullable(iface, t), + _ => None, + } } } @@ -338,7 +363,7 @@ impl Generator for Js { for field in record.fields.iter() { self.docs(&field.docs); let (option_str, ty) = self - .get_nullable_option(iface, &field.ty) + .as_nullable(iface, &field.ty) .map_or(("", &field.ty), |ty| ("?", ty)); self.src .ts(&format!("{}{}: ", field.name.to_mixed_case(), option_str)); @@ -400,39 +425,73 @@ impl Generator for Js { docs: &Docs, ) { self.docs(docs); - if self.is_nullable_option(iface, variant) { - self.src - .ts(&format!("export type {} = ", name.to_camel_case())); - self.print_ty(iface, variant.cases[1].ty.as_ref().unwrap()); - self.src.ts("| null;\n"); - } else { - self.src - .ts(&format!("export type {} = ", name.to_camel_case())); - for (i, case) in variant.cases.iter().enumerate() { - if i > 0 { - self.src.ts(" | "); - } - self.src - .ts(&format!("{}_{}", name, case.name).to_camel_case()); + self.src + .ts(&format!("export type {} = ", name.to_camel_case())); + for (i, case) in variant.cases.iter().enumerate() { + if i > 0 { + self.src.ts(" | "); } - self.src.ts(";\n"); - for case in variant.cases.iter() { - self.docs(&case.docs); - self.src.ts(&format!( - "export interface {} {{\n", - format!("{}_{}", name, case.name).to_camel_case() - )); - self.src.ts("tag: \""); - self.src.ts(&case.name); - self.src.ts("\",\n"); - if let Some(ty) = &case.ty { - self.src.ts("val: "); - self.print_ty(iface, ty); - self.src.ts(",\n"); - } - self.src.ts("}\n"); + self.src + .ts(&format!("{}_{}", name, case.name).to_camel_case()); + } + self.src.ts(";\n"); + for case in variant.cases.iter() { + self.docs(&case.docs); + self.src.ts(&format!( + "export interface {} {{\n", + format!("{}_{}", name, case.name).to_camel_case() + )); + self.src.ts("tag: \""); + self.src.ts(&case.name); + self.src.ts("\",\n"); + if let Some(ty) = &case.ty { + self.src.ts("val: "); + self.print_ty(iface, ty); + self.src.ts(",\n"); } + self.src.ts("}\n"); + } + } + + fn type_option( + &mut self, + iface: &Interface, + _id: TypeId, + name: &str, + payload: &Type, + docs: &Docs, + ) { + self.docs(docs); + let name = name.to_camel_case(); + self.src.ts(&format!("export type {name} = ")); + if self.maybe_null(iface, payload) { + self.needs_ty_option = true; + self.src.ts("Option<"); + self.print_ty(iface, payload); + self.src.ts(">"); + } else { + self.print_ty(iface, payload); + self.src.ts(" | null"); } + self.src.ts(";\n"); + } + + fn type_expected( + &mut self, + iface: &Interface, + _id: TypeId, + name: &str, + expected: &Expected, + docs: &Docs, + ) { + self.docs(docs); + let name = name.to_camel_case(); + self.needs_ty_result = true; + self.src.ts(&format!("export type {name} = Result<")); + self.print_ty(iface, &expected.ok); + self.src.ts(", "); + self.print_ty(iface, &expected.err); + self.src.ts(">;\n"); } fn type_enum( @@ -1308,11 +1367,9 @@ impl Bindgen for FunctionBindgen<'_> { } } - Instruction::UnitLower => { - assert_eq!(operands, &[""]); - } + Instruction::UnitLower => {} Instruction::UnitLift => { - results.push("".to_string()); + results.push("undefined".to_string()); } Instruction::BoolFromI32 => { @@ -1502,30 +1559,15 @@ impl Bindgen for FunctionBindgen<'_> { results.push(format!("variant{}_{}", tmp, i)); } - let expr_to_match = if self.gen.is_nullable_option(iface, variant) { - format!("variant{}", tmp) - } else { - format!("variant{}.tag", tmp) - }; + let expr_to_match = format!("variant{}.tag", tmp); self.src.js(&format!("switch ({}) {{\n", expr_to_match)); - let mut use_default = true; for (case, (block, block_results)) in variant.cases.iter().zip(blocks) { - if self.gen.is_nullable_option(iface, variant) { - if case.ty.is_none() { - self.src.js("case null: {\n"); - } else { - self.src.js("default: {\n"); - self.src.js(&format!("const e = variant{};\n", tmp)); - use_default = false; - } - } else { - self.src - .js(&format!("case \"{}\": {{\n", case.name.as_str())); - if case.ty.is_some() { - self.src.js(&format!("const e = variant{}.val;\n", tmp)); - } - }; + self.src + .js(&format!("case \"{}\": {{\n", case.name.as_str())); + if case.ty.is_some() { + self.src.js(&format!("const e = variant{}.val;\n", tmp)); + } self.src.js(&block); for (i, result) in block_results.iter().enumerate() { @@ -1534,23 +1576,12 @@ impl Bindgen for FunctionBindgen<'_> { } self.src.js("break;\n}\n"); } - if use_default { - let variant_name = name.map(|s| s.to_camel_case()); - let variant_name = variant_name.as_deref().unwrap_or_else(|| { - if variant.as_expected().is_some() { - "expected" - } else if variant.as_option().is_some() { - "option" - } else { - unimplemented!() - } - }); - self.src.js("default:\n"); - self.src.js(&format!( - "throw new RangeError(\"invalid variant specified for {}\");\n", - variant_name - )); - } + let variant_name = name.to_camel_case(); + self.src.js("default:\n"); + self.src.js(&format!( + "throw new RangeError(\"invalid variant specified for {}\");\n", + variant_name + )); self.src.js("}\n"); } @@ -1570,38 +1601,18 @@ impl Bindgen for FunctionBindgen<'_> { self.src.js(&format!("case {}: {{\n", i)); self.src.js(&block); - if self.gen.is_nullable_option(iface, variant) { - if case.ty.is_none() { - assert!(block_results.is_empty()); - self.src.js(&format!("variant{} = null;\n", tmp)); - } else { - assert!(block_results.len() == 1); - self.src - .js(&format!("variant{} = {};\n", tmp, block_results[0])); - } + self.src.js(&format!("variant{} = {{\n", tmp)); + self.src.js(&format!("tag: \"{}\",\n", case.name.as_str())); + if case.ty.is_some() { + assert!(block_results.len() == 1); + self.src.js(&format!("val: {},\n", block_results[0])); } else { - self.src.js(&format!("variant{} = {{\n", tmp)); - self.src.js(&format!("tag: \"{}\",\n", case.name.as_str())); - if case.ty.is_some() { - assert!(block_results.len() == 1); - self.src.js(&format!("val: {},\n", block_results[0])); - } else { - assert!(block_results.is_empty()); - } - self.src.js("};\n"); + assert!(block_results.is_empty()); } + self.src.js("};\n"); self.src.js("break;\n}\n"); } - let variant_name = name.map(|s| s.to_camel_case()); - let variant_name = variant_name.as_deref().unwrap_or_else(|| { - if variant.as_expected().is_some() { - "expected" - } else if variant.as_option().is_some() { - "option" - } else { - unimplemented!() - } - }); + let variant_name = name.to_camel_case(); self.src.js("default:\n"); self.src.js(&format!( "throw new RangeError(\"invalid variant discriminant for {}\");\n", @@ -1611,6 +1622,189 @@ impl Bindgen for FunctionBindgen<'_> { results.push(format!("variant{}", tmp)); } + Instruction::OptionLower { + payload, + results: result_types, + .. + } => { + let (mut some, some_results) = self.blocks.pop().unwrap(); + let (mut none, none_results) = self.blocks.pop().unwrap(); + + let tmp = self.tmp(); + self.src + .js(&format!("const variant{tmp} = {};\n", operands[0])); + + for i in 0..result_types.len() { + self.src.js(&format!("let variant{tmp}_{i};\n")); + results.push(format!("variant{tmp}_{i}")); + + let some_result = &some_results[i]; + let none_result = &none_results[i]; + some.push_str(&format!("variant{tmp}_{i} = {some_result};\n")); + none.push_str(&format!("variant{tmp}_{i} = {none_result};\n")); + } + + if self.gen.maybe_null(iface, payload) { + self.src.js(&format!( + " + switch (variant{tmp}.tag) {{ + case \"none\": {{ + {none} + break; + }} + case \"some\": {{ + const e = variant{tmp}.val; + {some} + break; + }} + default: {{ + throw new RangeError(\"invalid variant specified for option\"); + }} + }} + " + )); + } else { + self.src.js(&format!( + " + switch (variant{tmp}) {{ + case null: {{ + {none} + break; + }} + default: {{ + const e = variant{tmp}; + {some} + break; + }} + }} + " + )); + } + } + + Instruction::OptionLift { payload, .. } => { + let (some, some_results) = self.blocks.pop().unwrap(); + let (none, none_results) = self.blocks.pop().unwrap(); + assert!(none_results.is_empty()); + assert!(some_results.len() == 1); + let some_result = &some_results[0]; + + let tmp = self.tmp(); + + self.src.js(&format!("let variant{tmp};\n")); + self.src.js(&format!("switch ({}) {{\n", operands[0])); + + if self.gen.maybe_null(iface, payload) { + self.src.js(&format!( + " + case 0: {{ + {none} + variant{tmp} = {{ tag: \"none\" }}; + break; + }} + case 1: {{ + {some} + variant{tmp} = {{ tag: \"some\", val: {some_result} }}; + break; + }} + ", + )); + } else { + self.src.js(&format!( + " + case 0: {{ + {none} + variant{tmp} = null; + break; + }} + case 1: {{ + {some} + variant{tmp} = {some_result}; + break; + }} + ", + )); + } + self.src.js(" + default: + throw new RangeError(\"invalid variant discriminant for option\"); + "); + self.src.js("}\n"); + results.push(format!("variant{tmp}")); + } + + Instruction::ExpectedLower { + results: result_types, + .. + } => { + let (mut err, err_results) = self.blocks.pop().unwrap(); + let (mut ok, ok_results) = self.blocks.pop().unwrap(); + + let tmp = self.tmp(); + self.src + .js(&format!("const variant{tmp} = {};\n", operands[0])); + + for i in 0..result_types.len() { + self.src.js(&format!("let variant{tmp}_{i};\n")); + results.push(format!("variant{tmp}_{i}")); + + let ok_result = &ok_results[i]; + let err_result = &err_results[i]; + ok.push_str(&format!("variant{tmp}_{i} = {ok_result};\n")); + err.push_str(&format!("variant{tmp}_{i} = {err_result};\n")); + } + + self.src.js(&format!( + " + switch (variant{tmp}.tag) {{ + case \"ok\": {{ + const e = variant{tmp}.val; + {ok} + break; + }} + case \"err\": {{ + const e = variant{tmp}.val; + {err} + break; + }} + default: {{ + throw new RangeError(\"invalid variant specified for expected\"); + }} + }} + " + )); + } + + Instruction::ExpectedLift { .. } => { + let (err, err_results) = self.blocks.pop().unwrap(); + let (ok, ok_results) = self.blocks.pop().unwrap(); + let err_result = &err_results[0]; + let ok_result = &ok_results[0]; + let tmp = self.tmp(); + let op0 = &operands[0]; + self.src.js(&format!( + " + let variant{tmp}; + switch ({op0}) {{ + case 0: {{ + {ok} + variant{tmp} = {{ tag: \"ok\", val: {ok_result} }}; + break; + }} + case 1: {{ + {err} + variant{tmp} = {{ tag: \"err\", val: {err_result} }}; + break; + }} + default: {{ + throw new RangeError(\"invalid variant discriminant for expected\"); + }} + }} + ", + )); + results.push(format!("variant{tmp}")); + } + Instruction::EnumLower { name, .. } => { let tmp = self.tmp(); self.src diff --git a/crates/gen-markdown/src/lib.rs b/crates/gen-markdown/src/lib.rs index 00c75f617..2672acfd6 100644 --- a/crates/gen-markdown/src/lib.rs +++ b/crates/gen-markdown/src/lib.rs @@ -78,29 +78,23 @@ impl Markdown { } self.src.push_str(")"); } - TypeDefKind::Record(_) | TypeDefKind::Flags(_) | TypeDefKind::Enum(_) => { + TypeDefKind::Record(_) + | TypeDefKind::Flags(_) + | TypeDefKind::Enum(_) + | TypeDefKind::Variant(_) => { unreachable!() } - TypeDefKind::Variant(v) => { - if let Some(t) = v.as_option() { - self.src.push_str("option<"); - self.print_ty(iface, t, false); - self.src.push_str(">"); - } else if let Some((ok, err)) = v.as_expected() { - self.src.push_str("expected<"); - match ok { - Some(t) => self.print_ty(iface, t, false), - None => self.src.push_str("_"), - } - self.src.push_str(", "); - match err { - Some(t) => self.print_ty(iface, t, false), - None => self.src.push_str("_"), - } - self.src.push_str(">"); - } else { - unreachable!() - } + TypeDefKind::Option(t) => { + self.src.push_str("option<"); + self.print_ty(iface, t, false); + self.src.push_str(">"); + } + TypeDefKind::Expected(e) => { + self.src.push_str("expected<"); + self.print_ty(iface, &e.ok, false); + self.src.push_str(", "); + self.print_ty(iface, &e.err, false); + self.src.push_str(">"); } TypeDefKind::List(t) => { self.src.push_str("list<"); @@ -303,6 +297,38 @@ impl Generator for Markdown { } } + fn type_option( + &mut self, + iface: &Interface, + id: TypeId, + name: &str, + payload: &Type, + docs: &Docs, + ) { + self.print_type_header(name); + self.src.push_str("option<"); + self.print_ty(iface, payload, false); + self.src.push_str(">"); + self.print_type_info(id, docs); + } + + fn type_expected( + &mut self, + iface: &Interface, + id: TypeId, + name: &str, + expected: &Expected, + docs: &Docs, + ) { + self.print_type_header(name); + self.src.push_str("expected<"); + self.print_ty(iface, &expected.ok, false); + self.src.push_str(", "); + self.print_ty(iface, &expected.err, false); + self.src.push_str(">"); + self.print_type_info(id, docs); + } + fn type_resource(&mut self, iface: &Interface, ty: ResourceId) { drop((iface, ty)); } diff --git a/crates/gen-rust-wasm/src/lib.rs b/crates/gen-rust-wasm/src/lib.rs index c90bcfe04..3865fefec 100644 --- a/crates/gen-rust-wasm/src/lib.rs +++ b/crates/gen-rust-wasm/src/lib.rs @@ -236,6 +236,28 @@ impl Generator for RustWasm { self.print_typedef_variant(iface, id, variant, docs); } + fn type_option( + &mut self, + iface: &Interface, + id: TypeId, + _name: &str, + payload: &Type, + docs: &Docs, + ) { + self.print_typedef_option(iface, id, payload, docs); + } + + fn type_expected( + &mut self, + iface: &Interface, + id: TypeId, + _name: &str, + expected: &Expected, + docs: &Docs, + ) { + self.print_typedef_expected(iface, id, expected, docs); + } + fn type_enum( &mut self, _iface: &Interface, @@ -998,11 +1020,9 @@ impl Bindgen for FunctionBindgen<'_> { // In unchecked mode when this type is a named enum then we know we // defined the type so we can transmute directly into it. - Instruction::VariantLift { - name: Some(name), - variant, - .. - } if variant.cases.iter().all(|c| c.ty.is_none()) && unchecked => { + Instruction::VariantLift { name, variant, .. } + if variant.cases.iter().all(|c| c.ty.is_none()) && unchecked => + { self.blocks.drain(self.blocks.len() - variant.cases.len()..); let mut result = format!("core::mem::transmute::<_, "); result.push_str(&name.to_camel_case()); @@ -1039,6 +1059,75 @@ impl Bindgen for FunctionBindgen<'_> { results.push(result); } + Instruction::OptionLower { + results: result_types, + .. + } => { + let some = self.blocks.pop().unwrap(); + let none = self.blocks.pop().unwrap(); + self.let_results(result_types.len(), results); + let operand = &operands[0]; + self.push_str(&format!( + "match {operand} {{ + Some(e) => {{ {some} }}, + None => {{ {none} }}, + }};" + )); + } + + Instruction::OptionLift { .. } => { + let some = self.blocks.pop().unwrap(); + let none = self.blocks.pop().unwrap(); + assert_eq!(none, "()"); + let operand = &operands[0]; + let invalid = if unchecked { + "std::hint::unreachable_unchecked()" + } else { + "panic!(\"invalid enum discriminant\")" + }; + results.push(format!( + "match {operand} {{ + 0 => None, + 1 => Some({some}), + _ => {invalid}, + }}" + )); + } + + Instruction::ExpectedLower { + results: result_types, + .. + } => { + let err = self.blocks.pop().unwrap(); + let ok = self.blocks.pop().unwrap(); + self.let_results(result_types.len(), results); + let operand = &operands[0]; + self.push_str(&format!( + "match {operand} {{ + Ok(e) => {{ {ok} }}, + Err(e) => {{ {err} }}, + }};" + )); + } + + Instruction::ExpectedLift { .. } => { + let err = self.blocks.pop().unwrap(); + let ok = self.blocks.pop().unwrap(); + let operand = &operands[0]; + let invalid = if unchecked { + "std::hint::unreachable_unchecked()" + } else { + "panic!(\"invalid enum discriminant\")" + }; + results.push(format!( + "match {operand} {{ + 0 => Ok({ok}), + 1 => Err({err}), + _ => {invalid}, + }}" + )); + } + Instruction::EnumLower { enum_, name, .. } => { let mut result = format!("match {} {{\n", operands[0]); let name = name.to_camel_case(); diff --git a/crates/gen-rust/src/lib.rs b/crates/gen-rust/src/lib.rs index 0880d21ce..0c5fc27a8 100644 --- a/crates/gen-rust/src/lib.rs +++ b/crates/gen-rust/src/lib.rs @@ -227,6 +227,8 @@ pub trait RustGenerator { match ty { TypeDefKind::Variant(_) | TypeDefKind::Record(_) + | TypeDefKind::Option(_) + | TypeDefKind::Expected(_) | TypeDefKind::List(_) | TypeDefKind::Flags(_) | TypeDefKind::Enum(_) @@ -242,31 +244,21 @@ pub trait RustGenerator { match &ty.kind { TypeDefKind::List(t) => self.print_list(iface, t, mode), - // Variants can be printed natively if they're `Option` or - // `Result`, otherwise they must be named for now. - TypeDefKind::Variant(v) => match v.as_expected() { - Some((ok, err)) => { - self.push_str("Result<"); - match ok { - Some(ty) => self.print_ty(iface, ty, mode), - None => self.push_str("()"), - } - self.push_str(","); - match err { - Some(ty) => self.print_ty(iface, ty, mode), - None => self.push_str("()"), - } - self.push_str(">"); - } - None => match v.as_option() { - Some(ty) => { - self.push_str("Option<"); - self.print_ty(iface, ty, mode); - self.push_str(">"); - } - None => panic!("unsupported anonymous variant"), - }, - }, + TypeDefKind::Option(t) => { + self.push_str("Option<"); + self.print_ty(iface, t, mode); + self.push_str(">"); + } + + TypeDefKind::Expected(e) => { + self.push_str("Result<"); + self.print_ty(iface, &e.ok, mode); + self.push_str(","); + self.print_ty(iface, &e.err, mode); + self.push_str(">"); + } + + TypeDefKind::Variant(_) => panic!("unsupported anonymous variant"), // Tuple-like records are mapped directly to Rust tuples of // types. Note the trailing comma after each member to @@ -458,29 +450,6 @@ pub trait RustGenerator { for (name, mode) in self.modes_of(iface, id) { self.rustdoc(docs); let lt = self.lifetime_for(&info, mode); - if let Some(ty) = variant.as_option() { - self.push_str(&format!("pub type {}", name)); - self.print_generics(&info, lt, true); - self.push_str("= Option<"); - self.print_ty(iface, ty, mode); - self.push_str(">;\n"); - continue; - } else if let Some((ok, err)) = variant.as_expected() { - self.push_str(&format!("pub type {}", name)); - self.print_generics(&info, lt, true); - self.push_str("= Result<"); - match ok { - Some(ty) => self.print_ty(iface, ty, mode), - None => self.push_str("()"), - } - self.push_str(","); - match err { - Some(ty) => self.print_ty(iface, ty, mode), - None => self.push_str("()"), - } - self.push_str(">;\n"); - continue; - } if !info.owns_data() { self.push_str("#[derive(Clone, Copy)]\n"); } else if !info.has_handle { @@ -534,6 +503,42 @@ pub trait RustGenerator { } } + fn print_typedef_option(&mut self, iface: &Interface, id: TypeId, payload: &Type, docs: &Docs) { + let info = self.info(id); + + for (name, mode) in self.modes_of(iface, id) { + self.rustdoc(docs); + let lt = self.lifetime_for(&info, mode); + self.push_str(&format!("pub type {}", name)); + self.print_generics(&info, lt, true); + self.push_str("= Option<"); + self.print_ty(iface, payload, mode); + self.push_str(">;\n"); + } + } + + fn print_typedef_expected( + &mut self, + iface: &Interface, + id: TypeId, + expected: &Expected, + docs: &Docs, + ) { + let info = self.info(id); + + for (name, mode) in self.modes_of(iface, id) { + self.rustdoc(docs); + let lt = self.lifetime_for(&info, mode); + self.push_str(&format!("pub type {}", name)); + self.print_generics(&info, lt, true); + self.push_str("= Result<"); + self.print_ty(iface, &expected.ok, mode); + self.push_str(","); + self.print_ty(iface, &expected.err, mode); + self.push_str(">;\n"); + } + } + fn print_typedef_enum(&mut self, name: &str, enum_: &Enum, docs: &Docs) { // TODO: should this perhaps be an attribute in the wit file? let is_error = name.contains("errno"); diff --git a/crates/gen-spidermonkey/src/lib.rs b/crates/gen-spidermonkey/src/lib.rs index 3f8494d96..d4a6e847e 100644 --- a/crates/gen-spidermonkey/src/lib.rs +++ b/crates/gen-spidermonkey/src/lib.rs @@ -16,8 +16,8 @@ use wasm_encoder::Instruction; use wit_bindgen_gen_core::{ wit_parser::{ abi::{self, AbiVariant, WasmSignature, WasmType}, - Docs, Enum, Flags, Function, Interface, Record, ResourceId, SizeAlign, Tuple, Type, TypeId, - Variant, + Docs, Enum, Expected, Flags, Function, Interface, Record, ResourceId, SizeAlign, Tuple, + Type, TypeId, Variant, }, Direction, Files, Generator, }; @@ -988,6 +988,30 @@ impl Generator for SpiderMonkeyWasm<'_> { todo!() } + fn type_option( + &mut self, + iface: &Interface, + id: TypeId, + name: &str, + payload: &Type, + docs: &Docs, + ) { + let _ = (iface, id, name, payload, docs); + todo!() + } + + fn type_expected( + &mut self, + iface: &Interface, + id: TypeId, + name: &str, + expected: &Expected, + docs: &Docs, + ) { + let _ = (iface, id, name, expected, docs); + todo!() + } + fn type_enum(&mut self, iface: &Interface, id: TypeId, name: &str, enum_: &Enum, docs: &Docs) { let _ = (iface, id, name, enum_, docs); todo!() @@ -1931,6 +1955,10 @@ impl abi::Bindgen for Bindgen<'_, '_> { name: _, ty: _, } => todo!(), + abi::Instruction::OptionLower { .. } => todo!(), + abi::Instruction::OptionLift { .. } => todo!(), + abi::Instruction::ExpectedLower { .. } => todo!(), + abi::Instruction::ExpectedLift { .. } => todo!(), abi::Instruction::EnumLower { enum_: _, name: _, diff --git a/crates/gen-wasmtime-py/src/lib.rs b/crates/gen-wasmtime-py/src/lib.rs index 362a154ca..d4b3c93e0 100644 --- a/crates/gen-wasmtime-py/src/lib.rs +++ b/crates/gen-wasmtime-py/src/lib.rs @@ -411,31 +411,25 @@ impl WasmtimePy { match &ty.kind { TypeDefKind::Type(t) => self.print_ty(iface, t), TypeDefKind::Tuple(t) => self.print_tuple(iface, t), - TypeDefKind::Record(_) | TypeDefKind::Flags(_) | TypeDefKind::Enum(_) => { + TypeDefKind::Record(_) + | TypeDefKind::Flags(_) + | TypeDefKind::Enum(_) + | TypeDefKind::Variant(_) => { unreachable!() } - TypeDefKind::Variant(v) => { - if let Some(t) = v.as_option() { - self.pyimport("typing", "Optional"); - self.src.push_str("Optional["); - self.print_ty(iface, t); - self.src.push_str("]"); - } else if let Some((ok, err)) = v.as_expected() { - self.needs_expected = true; - self.src.push_str("Expected["); - match ok { - Some(t) => self.print_ty(iface, t), - None => self.src.push_str("None"), - } - self.src.push_str(", "); - match err { - Some(t) => self.print_ty(iface, t), - None => self.src.push_str("None"), - } - self.src.push_str("]"); - } else { - unreachable!() - } + TypeDefKind::Option(t) => { + self.pyimport("typing", "Optional"); + self.src.push_str("Optional["); + self.print_ty(iface, t); + self.src.push_str("]"); + } + TypeDefKind::Expected(e) => { + self.needs_expected = true; + self.src.push_str("Expected["); + self.print_ty(iface, &e.ok); + self.src.push_str(", "); + self.print_ty(iface, &e.err); + self.src.push_str("]"); } TypeDefKind::List(t) => self.print_list(iface, t), } @@ -647,58 +641,70 @@ impl Generator for WasmtimePy { docs: &Docs, ) { self.docs(docs); - if let Some(t) = variant.as_option() { - self.pyimport("typing", "Optional"); - self.src - .push_str(&format!("{} = Optional[", name.to_camel_case())); - self.print_ty(iface, t); - self.src.push_str("]\n"); - } else if let Some((ok, err)) = variant.as_expected() { - self.needs_expected = true; - self.src - .push_str(&format!("{} = Expected[", name.to_camel_case())); - match ok { - Some(t) => self.print_ty(iface, t), - None => self.src.push_str("None"), - } - self.src.push_str(", "); - match err { - Some(t) => self.print_ty(iface, t), - None => self.src.push_str("None"), - } - self.src.push_str("]\n"); - } else { - self.pyimport("dataclasses", "dataclass"); - let mut cases = Vec::new(); - for case in variant.cases.iter() { - self.docs(&case.docs); - self.src.push_str("@dataclass\n"); - let name = format!("{}{}", name.to_camel_case(), case.name.to_camel_case()); - self.src.push_str(&format!("class {}:\n", name)); - self.indent(); - match &case.ty { - Some(ty) => { - self.src.push_str("value: "); - self.print_ty(iface, ty); - self.src.push_str("\n"); - } - None => self.src.push_str("pass\n"), + self.pyimport("dataclasses", "dataclass"); + let mut cases = Vec::new(); + for case in variant.cases.iter() { + self.docs(&case.docs); + self.src.push_str("@dataclass\n"); + let name = format!("{}{}", name.to_camel_case(), case.name.to_camel_case()); + self.src.push_str(&format!("class {}:\n", name)); + self.indent(); + match &case.ty { + Some(ty) => { + self.src.push_str("value: "); + self.print_ty(iface, ty); + self.src.push_str("\n"); } - self.deindent(); - self.src.push_str("\n"); - cases.push(name); + None => self.src.push_str("pass\n"), } - - self.pyimport("typing", "Union"); - self.src.push_str(&format!( - "{} = Union[{}]\n", - name.to_camel_case(), - cases.join(", "), - )); + self.deindent(); + self.src.push_str("\n"); + cases.push(name); } + + self.pyimport("typing", "Union"); + self.src.push_str(&format!( + "{} = Union[{}]\n", + name.to_camel_case(), + cases.join(", "), + )); self.src.push_str("\n"); } + fn type_option( + &mut self, + iface: &Interface, + _id: TypeId, + name: &str, + payload: &Type, + docs: &Docs, + ) { + self.docs(docs); + self.pyimport("typing", "Optional"); + self.src + .push_str(&format!("{} = Optional[", name.to_camel_case())); + self.print_ty(iface, payload); + self.src.push_str("]\n\n"); + } + + fn type_expected( + &mut self, + iface: &Interface, + _id: TypeId, + name: &str, + expected: &Expected, + docs: &Docs, + ) { + self.docs(docs); + self.needs_expected = true; + self.src + .push_str(&format!("{} = Expected[", name.to_camel_case())); + self.print_ty(iface, &expected.ok); + self.src.push_str(", "); + self.print_ty(iface, &expected.err); + self.src.push_str("]\n\n"); + } + fn type_enum( &mut self, _iface: &Interface, @@ -1450,11 +1456,9 @@ impl Bindgen for FunctionBindgen<'_> { } } - Instruction::UnitLower => { - assert_eq!(operands, &["".to_string()]); - } + Instruction::UnitLower => {} Instruction::UnitLift => { - results.push("".to_string()); + results.push("None".to_string()); } Instruction::BoolFromI32 => { let op = self.locals.tmp("operand"); @@ -1605,45 +1609,26 @@ impl Bindgen for FunctionBindgen<'_> { results.push(self.locals.tmp("variant")); } - let mut needs_else = true; for (i, ((case, (block, block_results)), payload)) in variant.cases.iter().zip(blocks).zip(payloads).enumerate() { if i == 0 { self.src.push_str("if "); - } else if i == 1 && variant.as_option().is_some() { - needs_else = false; - self.src.push_str("else:\n"); } else { self.src.push_str("elif "); } - if variant.as_option().is_some() { - if i == 0 { - self.src.push_str(&format!("{} is None:\n", operands[0])); - } - } else { - self.src.push_str(&format!( - "isinstance({}, {}{}):\n", - operands[0], - if variant.as_expected().is_some() { - String::new() - } else { - name.unwrap().to_camel_case() - }, - case.name.to_camel_case() - )); - } + self.src.push_str(&format!( + "isinstance({}, {}{}):\n", + operands[0], + name.to_camel_case(), + case.name.to_camel_case() + )); self.src.indent(2); if case.ty.is_some() { - if variant.as_option().is_some() { - self.src - .push_str(&format!("{} = {}\n", payload, operands[0])); - } else { - self.src - .push_str(&format!("{} = {}.value\n", payload, operands[0])); - } + self.src + .push_str(&format!("{} = {}.value\n", payload, operands[0])); } self.src.push_str(&block); @@ -1653,25 +1638,14 @@ impl Bindgen for FunctionBindgen<'_> { } self.src.deindent(2); } - if needs_else { - let variant_name = name.map(|s| s.to_camel_case()); - let variant_name = variant_name.as_deref().unwrap_or_else(|| { - if variant.as_expected().is_some() { - "expected" - } else if variant.as_option().is_some() { - "option" - } else { - unimplemented!() - } - }); - self.src.push_str("else:\n"); - self.src.indent(2); - self.src.push_str(&format!( - "raise TypeError(\"invalid variant specified for {}\")\n", - variant_name - )); - self.src.deindent(2); - } + let variant_name = name.to_camel_case(); + self.src.push_str("else:\n"); + self.src.indent(2); + self.src.push_str(&format!( + "raise TypeError(\"invalid variant specified for {}\")\n", + variant_name + )); + self.src.deindent(2); } Instruction::VariantLift { @@ -1700,51 +1674,24 @@ impl Bindgen for FunctionBindgen<'_> { self.src.indent(2); self.src.push_str(&block); - if variant.as_option().is_some() { - if case.ty.is_none() { - assert!(block_results.is_empty()); - self.src.push_str(&format!("{} = None\n", result)); - } else { - assert!(block_results.len() == 1); - self.src - .push_str(&format!("{} = {}\n", result, block_results[0])); - } + self.src.push_str(&format!( + "{} = {}{}(", + result, + name.to_camel_case(), + case.name.to_camel_case() + )); + if case.ty.is_some() { + assert!(block_results.len() == 1); + self.src.push_str(&block_results[0]); } else { - self.src.push_str(&format!( - "{} = {}{}(", - result, - if variant.as_expected().is_some() { - String::new() - } else { - name.unwrap().to_camel_case() - }, - case.name.to_camel_case() - )); - if case.ty.is_some() { - assert!(block_results.len() == 1); - self.src.push_str(&block_results[0]); - } else { - assert!(block_results.is_empty()); - if variant.as_expected().is_some() { - self.src.push_str("None"); - } - } - self.src.push_str(")\n"); + assert!(block_results.is_empty()); } + self.src.push_str(")\n"); self.src.deindent(2); } self.src.push_str("else:\n"); self.src.indent(2); - let variant_name = name.map(|s| s.to_camel_case()); - let variant_name = variant_name.as_deref().unwrap_or_else(|| { - if variant.as_expected().is_some() { - "expected" - } else if variant.as_option().is_some() { - "option" - } else { - unimplemented!() - } - }); + let variant_name = name.to_camel_case(); self.src.push_str(&format!( "raise TypeError(\"invalid variant discriminant for {}\")\n", variant_name @@ -1753,6 +1700,148 @@ impl Bindgen for FunctionBindgen<'_> { results.push(result); } + Instruction::OptionLower { + results: result_types, + .. + } => { + let (some, some_results) = self.blocks.pop().unwrap(); + let (none, none_results) = self.blocks.pop().unwrap(); + let some_payload = self.payloads.pop().unwrap(); + let _none_payload = self.payloads.pop().unwrap(); + + for _ in 0..result_types.len() { + results.push(self.locals.tmp("variant")); + } + + let op0 = &operands[0]; + self.src.push_str(&format!("if {op0} is None:\n")); + + self.src.indent(2); + self.src.push_str(&none); + for (dst, result) in results.iter().zip(&none_results) { + self.src.push_str(&format!("{dst} = {result}\n")); + } + self.src.deindent(2); + self.src.push_str("else:\n"); + self.src.indent(2); + self.src.push_str(&format!("{some_payload} = {op0}\n")); + self.src.push_str(&some); + for (dst, result) in results.iter().zip(&some_results) { + self.src.push_str(&format!("{dst} = {result}\n")); + } + self.src.deindent(2); + } + + Instruction::OptionLift { ty, .. } => { + let (some, some_results) = self.blocks.pop().unwrap(); + let (none, none_results) = self.blocks.pop().unwrap(); + assert!(none_results.is_empty()); + assert!(some_results.len() == 1); + let some_result = &some_results[0]; + + let result = self.locals.tmp("option"); + self.src.push_str(&format!( + "{result}: {}\n", + self.gen.type_string(iface, &Type::Id(*ty)), + )); + + let op0 = &operands[0]; + self.src.push_str(&format!("if {op0} == 0:\n")); + self.src.indent(2); + self.src.push_str(&none); + self.src.push_str(&format!("{result} = None\n")); + self.src.deindent(2); + self.src.push_str(&format!("elif {op0} == 1:\n")); + self.src.indent(2); + self.src.push_str(&some); + self.src.push_str(&format!("{result} = {some_result}\n")); + self.src.deindent(2); + + self.src.push_str("else:\n"); + self.src.indent(2); + self.src + .push_str("raise TypeError(\"invalid variant discriminant for option\")\n"); + self.src.deindent(2); + + results.push(result); + } + + Instruction::ExpectedLower { + results: result_types, + .. + } => { + let (err, err_results) = self.blocks.pop().unwrap(); + let (ok, ok_results) = self.blocks.pop().unwrap(); + let err_payload = self.payloads.pop().unwrap(); + let ok_payload = self.payloads.pop().unwrap(); + + for _ in 0..result_types.len() { + results.push(self.locals.tmp("variant")); + } + + let op0 = &operands[0]; + self.src.push_str(&format!("if isinstance({op0}, Ok):\n")); + + self.src.indent(2); + self.src.push_str(&format!("{ok_payload} = {op0}.value\n")); + self.src.push_str(&ok); + for (dst, result) in results.iter().zip(&ok_results) { + self.src.push_str(&format!("{dst} = {result}\n")); + } + self.src.deindent(2); + self.src + .push_str(&format!("elif isinstance({op0}, Err):\n")); + self.src.indent(2); + self.src.push_str(&format!("{err_payload} = {op0}.value\n")); + self.src.push_str(&err); + for (dst, result) in results.iter().zip(&err_results) { + self.src.push_str(&format!("{dst} = {result}\n")); + } + self.src.deindent(2); + self.src.push_str("else:\n"); + self.src.indent(2); + self.src.push_str(&format!( + "raise TypeError(\"invalid variant specified for expected\")\n", + )); + self.src.deindent(2); + } + + Instruction::ExpectedLift { ty, .. } => { + let (err, err_results) = self.blocks.pop().unwrap(); + let (ok, ok_results) = self.blocks.pop().unwrap(); + assert!(err_results.len() == 1); + let err_result = &err_results[0]; + assert!(ok_results.len() == 1); + let ok_result = &ok_results[0]; + + let result = self.locals.tmp("expected"); + self.src.push_str(&format!( + "{result}: {}\n", + self.gen.type_string(iface, &Type::Id(*ty)), + )); + + let op0 = &operands[0]; + self.src.push_str(&format!("if {op0} == 0:\n")); + self.src.indent(2); + self.src.push_str(&ok); + self.src.push_str(&format!("{result} = Ok({ok_result})\n")); + self.src.deindent(2); + self.src.push_str(&format!("elif {op0} == 1:\n")); + self.src.indent(2); + self.src.push_str(&err); + self.src + .push_str(&format!("{result} = Err({err_result})\n")); + self.src.deindent(2); + + self.src.push_str("else:\n"); + self.src.indent(2); + self.src + .push_str("raise TypeError(\"invalid variant discriminant for expected\")\n"); + self.src.deindent(2); + + results.push(result); + } + Instruction::EnumLower { .. } => results.push(format!("({}).value", operands[0])), Instruction::EnumLift { name, .. } => { diff --git a/crates/gen-wasmtime/src/lib.rs b/crates/gen-wasmtime/src/lib.rs index b809fef84..a15fb6b5b 100644 --- a/crates/gen-wasmtime/src/lib.rs +++ b/crates/gen-wasmtime/src/lib.rs @@ -441,6 +441,28 @@ impl Generator for Wasmtime { self.print_typedef_variant(iface, id, variant, docs); } + fn type_option( + &mut self, + iface: &Interface, + id: TypeId, + _name: &str, + payload: &Type, + docs: &Docs, + ) { + self.print_typedef_option(iface, id, payload, docs); + } + + fn type_expected( + &mut self, + iface: &Interface, + id: TypeId, + _name: &str, + expected: &Expected, + docs: &Docs, + ) { + self.print_typedef_expected(iface, id, expected, docs); + } + fn type_enum( &mut self, _iface: &Interface, @@ -1631,16 +1653,7 @@ impl Bindgen for FunctionBindgen<'_> { self.variant_lift_case(iface, *ty, variant, case, &block, &mut result); result.push_str(",\n"); } - let variant_name = name.map(|s| s.to_camel_case()); - let variant_name = variant_name.as_deref().unwrap_or_else(|| { - if variant.as_expected().is_some() { - "Result" - } else if variant.as_option().is_some() { - "Option" - } else { - unimplemented!() - } - }); + let variant_name = name.to_camel_case(); result.push_str("_ => return Err(invalid_variant(\""); result.push_str(&variant_name); result.push_str("\")),\n"); @@ -1649,6 +1662,67 @@ impl Bindgen for FunctionBindgen<'_> { self.gen.needs_invalid_variant = true; } + Instruction::OptionLower { + results: result_types, + .. + } => { + let some = self.blocks.pop().unwrap(); + let none = self.blocks.pop().unwrap(); + self.let_results(result_types.len(), results); + let operand = &operands[0]; + self.push_str(&format!( + "match {operand} {{ + Some(e) => {{ {some} }}, + None => {{ {none} }}, + }};" + )); + } + + Instruction::OptionLift { .. } => { + let some = self.blocks.pop().unwrap(); + let none = self.blocks.pop().unwrap(); + assert_eq!(none, "()"); + let operand = &operands[0]; + results.push(format!( + "match {operand} {{ + 0 => None, + 1 => Some({some}), + _ => return Err(invalid_variant(\"option\")), + }}" + )); + self.gen.needs_invalid_variant = true; + } + + Instruction::ExpectedLower { + results: result_types, + .. + } => { + let err = self.blocks.pop().unwrap(); + let ok = self.blocks.pop().unwrap(); + self.let_results(result_types.len(), results); + let operand = &operands[0]; + self.push_str(&format!( + "match {operand} {{ + Ok(e) => {{ {ok} }}, + Err(e) => {{ {err} }}, + }};" + )); + } + + Instruction::ExpectedLift { .. } => { + let err = self.blocks.pop().unwrap(); + let ok = self.blocks.pop().unwrap(); + let operand = &operands[0]; + results.push(format!( + "match {operand} {{ + 0 => Ok({ok}), + 1 => Err({err}), + _ => return Err(invalid_variant(\"expected\")), + }}" + )); + self.gen.needs_invalid_variant = true; + } + Instruction::EnumLower { .. } => { results.push(format!("{} as i32", operands[0])); } diff --git a/crates/gen-wasmtime/tests/codegen.rs b/crates/gen-wasmtime/tests/codegen.rs index 4bdc1e6e0..54d05c57d 100644 --- a/crates/gen-wasmtime/tests/codegen.rs +++ b/crates/gen-wasmtime/tests/codegen.rs @@ -96,7 +96,7 @@ mod custom_errors { wit_bindgen_wasmtime::export!({ src["x"]: " foo: function() - bar: function() -> expected<_, u32> + bar: function() -> expected enum errno { bad1, bad2, diff --git a/crates/parser/src/abi.rs b/crates/parser/src/abi.rs index 1ad6c228a..0aeb6c80c 100644 --- a/crates/parser/src/abi.rs +++ b/crates/parser/src/abi.rs @@ -1,7 +1,7 @@ use crate::sizealign::align_to; use crate::{ - Enum, Flags, FlagsRepr, Function, Int, Interface, Record, ResourceId, Tuple, Type, TypeDefKind, - TypeId, Variant, + Enum, Expected, Flags, FlagsRepr, Function, Int, Interface, Record, ResourceId, Tuple, Type, + TypeDefKind, TypeId, Variant, }; use std::mem; @@ -554,7 +554,7 @@ def_instruction! { /// from the stack to produce `nresults` of items. VariantLower { variant: &'a Variant, - name: Option<&'a str>, + name: &'a str, ty: TypeId, results: &'a [WasmType], } : [1] => [results.len()], @@ -564,7 +564,7 @@ def_instruction! { /// from the stack to produce a final variant. VariantLift { variant: &'a Variant, - name: Option<&'a str>, + name: &'a str, ty: TypeId, } : [1] => [1], @@ -582,6 +582,32 @@ def_instruction! { ty: TypeId, } : [1] => [1], + /// TODO + OptionLower { + payload: &'a Type, + ty: TypeId, + results: &'a [WasmType], + } : [1] => [results.len()], + + /// TODO + OptionLift { + payload: &'a Type, + ty: TypeId, + } : [1] => [1], + + /// TODO + ExpectedLower { + expected: &'a Expected, + ty: TypeId, + results: &'a [WasmType], + } : [1] => [results.len()], + + /// TODO + ExpectedLift { + expected: &'a Expected, + ty: TypeId, + } : [1] => [1], + // calling/control flow /// Represents a call to a raw WebAssembly API. The module/name are @@ -957,36 +983,55 @@ impl Interface { TypeDefKind::Variant(v) => { result.push(v.tag.into()); - let start = result.len(); - let mut temp = Vec::new(); + self.push_wasm_variants(variant, v.cases.iter().map(|c| c.ty.as_ref()), result); + } - // Push each case's type onto a temporary vector, and then - // merge that vector into our final list starting at - // `start`. Note that this requires some degree of - // "unification" so we can handle things like `Result` where that turns into `[i32 i32]` where the second - // `i32` might be the `f32` bitcasted. - for case in v.cases.iter() { - let ty = match &case.ty { - Some(ty) => ty, - None => continue, - }; - self.push_wasm(variant, ty, &mut temp); + TypeDefKind::Enum(e) => result.push(e.tag().into()), - for (i, ty) in temp.drain(..).enumerate() { - match result.get_mut(start + i) { - Some(prev) => *prev = join(*prev, ty), - None => result.push(ty), - } - } - } + TypeDefKind::Option(t) => { + result.push(WasmType::I32); + self.push_wasm_variants(variant, [None, Some(t)], result); } - TypeDefKind::Enum(e) => result.push(e.tag().into()), + TypeDefKind::Expected(e) => { + result.push(WasmType::I32); + self.push_wasm_variants(variant, [Some(&e.ok), Some(&e.err)], result); + } }, } } + fn push_wasm_variants<'a>( + &self, + variant: AbiVariant, + tys: impl IntoIterator>, + result: &mut Vec, + ) { + let mut temp = Vec::new(); + let start = result.len(); + + // Push each case's type onto a temporary vector, and then + // merge that vector into our final list starting at + // `start`. Note that this requires some degree of + // "unification" so we can handle things like `Result` where that turns into `[i32 i32]` where the second + // `i32` might be the `f32` bitcasted. + for case in tys { + let ty = match case { + Some(ty) => ty, + None => continue, + }; + self.push_wasm(variant, ty, &mut temp); + + for (i, ty) in temp.drain(..).enumerate() { + match result.get_mut(start + i) { + Some(prev) => *prev = join(*prev, ty), + None => result.push(ty), + } + } + } + } + /// Generates an abstract sequence of instructions which represents this /// function being adapted as an imported function. /// @@ -1473,57 +1518,13 @@ impl<'a, B: Bindgen> Generator<'a, B> { } TypeDefKind::Variant(v) => { - let mut results = Vec::new(); - let mut temp = Vec::new(); - let mut casts = Vec::new(); - self.iface.push_wasm(self.variant, ty, &mut results); - for (i, case) in v.cases.iter().enumerate() { - self.push_block(); - self.emit(&VariantPayloadName); - let payload_name = self.stack.pop().unwrap(); - self.emit(&I32Const { val: i as i32 }); - let mut pushed = 1; - if let Some(ty) = &case.ty { - // Using the payload of this block we lower the type to - // raw wasm values. - self.stack.push(payload_name.clone()); - self.lower(ty); - - // Determine the types of all the wasm values we just - // pushed, and record how many. If we pushed too few - // then we'll need to push some zeros after this. - temp.truncate(0); - self.iface.push_wasm(self.variant, ty, &mut temp); - pushed += temp.len(); - - // For all the types pushed we may need to insert some - // bitcasts. This will go through and cast everything - // to the right type to ensure all blocks produce the - // same set of results. - casts.truncate(0); - for (actual, expected) in temp.iter().zip(&results[1..]) { - casts.push(cast(*actual, *expected)); - } - if casts.iter().any(|c| *c != Bitcast::None) { - self.emit(&Bitcasts { casts: &casts }); - } - } - - // If we haven't pushed enough items in this block to match - // what other variants are pushing then we need to push - // some zeros. - if pushed < results.len() { - self.emit(&ConstZero { - tys: &results[pushed..], - }); - } - self.finish_block(results.len()); - } + let results = + self.lower_variant_arms(ty, v.cases.iter().map(|c| c.ty.as_ref())); self.emit(&VariantLower { variant: v, ty: id, results: &results, - name: self.iface.types[id].name.as_deref(), + name: self.iface.types[id].name.as_deref().unwrap(), }); } TypeDefKind::Enum(enum_) => { @@ -1533,10 +1534,81 @@ impl<'a, B: Bindgen> Generator<'a, B> { name: self.iface.types[id].name.as_deref().unwrap(), }); } + TypeDefKind::Option(t) => { + let results = self.lower_variant_arms(ty, [None, Some(t)]); + self.emit(&OptionLower { + payload: t, + ty: id, + results: &results, + }); + } + TypeDefKind::Expected(e) => { + let results = self.lower_variant_arms(ty, [Some(&e.ok), Some(&e.err)]); + self.emit(&ExpectedLower { + expected: e, + ty: id, + results: &results, + }); + } }, } } + fn lower_variant_arms<'b>( + &mut self, + ty: &Type, + cases: impl IntoIterator>, + ) -> Vec { + use Instruction::*; + let mut results = Vec::new(); + let mut temp = Vec::new(); + let mut casts = Vec::new(); + self.iface.push_wasm(self.variant, ty, &mut results); + for (i, case) in cases.into_iter().enumerate() { + self.push_block(); + self.emit(&VariantPayloadName); + let payload_name = self.stack.pop().unwrap(); + self.emit(&I32Const { val: i as i32 }); + let mut pushed = 1; + if let Some(ty) = case { + // Using the payload of this block we lower the type to + // raw wasm values. + self.stack.push(payload_name.clone()); + self.lower(ty); + + // Determine the types of all the wasm values we just + // pushed, and record how many. If we pushed too few + // then we'll need to push some zeros after this. + temp.truncate(0); + self.iface.push_wasm(self.variant, ty, &mut temp); + pushed += temp.len(); + + // For all the types pushed we may need to insert some + // bitcasts. This will go through and cast everything + // to the right type to ensure all blocks produce the + // same set of results. + casts.truncate(0); + for (actual, expected) in temp.iter().zip(&results[1..]) { + casts.push(cast(*actual, *expected)); + } + if casts.iter().any(|c| *c != Bitcast::None) { + self.emit(&Bitcasts { casts: &casts }); + } + } + + // If we haven't pushed enough items in this block to match + // what other variants are pushing then we need to push + // some zeros. + if pushed < results.len() { + self.emit(&ConstZero { + tys: &results[pushed..], + }); + } + self.finish_block(results.len()); + } + results + } + fn list_realloc(&self) -> Option<&'static str> { // Lowering parameters calling a wasm import means // we don't need to pass ownership, but we pass @@ -1650,43 +1722,11 @@ impl<'a, B: Bindgen> Generator<'a, B> { } TypeDefKind::Variant(v) => { - let mut params = Vec::new(); - let mut temp = Vec::new(); - let mut casts = Vec::new(); - self.iface.push_wasm(self.variant, ty, &mut params); - let block_inputs = self - .stack - .drain(self.stack.len() + 1 - params.len()..) - .collect::>(); - for case in v.cases.iter() { - self.push_block(); - if let Some(ty) = &case.ty { - // Push only the values we need for this variant onto - // the stack. - temp.truncate(0); - self.iface.push_wasm(self.variant, ty, &mut temp); - self.stack - .extend(block_inputs[..temp.len()].iter().cloned()); - - // Cast all the types we have on the stack to the actual - // types needed for this variant, if necessary. - casts.truncate(0); - for (actual, expected) in temp.iter().zip(¶ms[1..]) { - casts.push(cast(*expected, *actual)); - } - if casts.iter().any(|c| *c != Bitcast::None) { - self.emit(&Bitcasts { casts: &casts }); - } - - // Then recursively lift this variant's payload. - self.lift(ty); - } - self.finish_block(case.ty.is_some() as usize); - } + self.lift_variant_arms(ty, v.cases.iter().map(|c| c.ty.as_ref())); self.emit(&VariantLift { variant: v, ty: id, - name: self.iface.types[id].name.as_deref(), + name: self.iface.types[id].name.as_deref().unwrap(), }); } @@ -1697,10 +1737,63 @@ impl<'a, B: Bindgen> Generator<'a, B> { name: self.iface.types[id].name.as_deref().unwrap(), }); } + + TypeDefKind::Option(t) => { + self.lift_variant_arms(ty, [None, Some(t)]); + self.emit(&OptionLift { payload: t, ty: id }); + } + + TypeDefKind::Expected(e) => { + self.lift_variant_arms(ty, [Some(&e.ok), Some(&e.err)]); + self.emit(&ExpectedLift { + expected: e, + ty: id, + }); + } }, } } + fn lift_variant_arms<'b>( + &mut self, + ty: &Type, + cases: impl IntoIterator>, + ) { + let mut params = Vec::new(); + let mut temp = Vec::new(); + let mut casts = Vec::new(); + self.iface.push_wasm(self.variant, ty, &mut params); + let block_inputs = self + .stack + .drain(self.stack.len() + 1 - params.len()..) + .collect::>(); + for case in cases { + self.push_block(); + if let Some(ty) = case { + // Push only the values we need for this variant onto + // the stack. + temp.truncate(0); + self.iface.push_wasm(self.variant, ty, &mut temp); + self.stack + .extend(block_inputs[..temp.len()].iter().cloned()); + + // Cast all the types we have on the stack to the actual + // types needed for this variant, if necessary. + casts.truncate(0); + for (actual, expected) in temp.iter().zip(¶ms[1..]) { + casts.push(cast(*expected, *actual)); + } + if casts.iter().any(|c| *c != Bitcast::None) { + self.emit(&Instruction::Bitcasts { casts: &casts }); + } + + // Then recursively lift this variant's payload. + self.lift(ty); + } + self.finish_block(case.is_some() as usize); + } + } + fn list_free(&self) -> Option<&'static str> { // Lifting the arguments of a defined import means that, if // possible, the caller still retains ownership and we don't @@ -1780,25 +1873,40 @@ impl<'a, B: Bindgen> Generator<'a, B> { // payload we write the payload after the discriminant, aligned up // to the type's alignment. TypeDefKind::Variant(v) => { - let payload_offset = offset + (self.bindgen.sizes().payload_offset(v) as i32); - for (i, case) in v.cases.iter().enumerate() { - self.push_block(); - self.emit(&VariantPayloadName); - let payload_name = self.stack.pop().unwrap(); - self.emit(&I32Const { val: i as i32 }); - self.stack.push(addr.clone()); - self.store_intrepr(offset, v.tag); - if let Some(ty) = &case.ty { - self.stack.push(payload_name.clone()); - self.write_to_memory(ty, addr.clone(), payload_offset); - } - self.finish_block(0); - } + self.write_variant_arms_to_memory( + offset, + addr, + v.tag, + v.cases.iter().map(|c| c.ty.as_ref()), + ); self.emit(&VariantLower { variant: v, ty: id, results: &[], - name: self.iface.types[id].name.as_deref(), + name: self.iface.types[id].name.as_deref().unwrap(), + }); + } + + TypeDefKind::Option(t) => { + self.write_variant_arms_to_memory(offset, addr, Int::U8, [None, Some(t)]); + self.emit(&OptionLower { + payload: t, + ty: id, + results: &[], + }); + } + + TypeDefKind::Expected(e) => { + self.write_variant_arms_to_memory( + offset, + addr, + Int::U8, + [Some(&e.ok), Some(&e.err)], + ); + self.emit(&ExpectedLower { + expected: e, + ty: id, + results: &[], }); } @@ -1811,6 +1919,30 @@ impl<'a, B: Bindgen> Generator<'a, B> { } } + fn write_variant_arms_to_memory<'b>( + &mut self, + offset: i32, + addr: B::Operand, + tag: Int, + cases: impl IntoIterator> + Clone, + ) { + let payload_offset = + offset + (self.bindgen.sizes().payload_offset(tag, cases.clone()) as i32); + for (i, case) in cases.into_iter().enumerate() { + self.push_block(); + self.emit(&Instruction::VariantPayloadName); + let payload_name = self.stack.pop().unwrap(); + self.emit(&Instruction::I32Const { val: i as i32 }); + self.stack.push(addr.clone()); + self.store_intrepr(offset, tag); + if let Some(ty) = case { + self.stack.push(payload_name.clone()); + self.write_to_memory(ty, addr.clone(), payload_offset); + } + self.finish_block(0); + } + } + fn write_list_to_memory(&mut self, ty: &Type, addr: B::Operand, offset: i32) { // After lowering the list there's two i32 values on the stack // which we write into memory, writing the pointer into the low address @@ -1916,21 +2048,34 @@ impl<'a, B: Bindgen> Generator<'a, B> { // individual block is pretty simple and just reads the payload type // from the corresponding offset if one is available. TypeDefKind::Variant(variant) => { - self.stack.push(addr.clone()); - self.load_intrepr(offset, variant.tag); - let payload_offset = - offset + (self.bindgen.sizes().payload_offset(variant) as i32); - for case in variant.cases.iter() { - self.push_block(); - if let Some(ty) = &case.ty { - self.read_from_memory(ty, addr.clone(), payload_offset); - } - self.finish_block(case.ty.is_some() as usize); - } + self.read_variant_arms_to_memory( + offset, + addr, + variant.tag, + variant.cases.iter().map(|c| c.ty.as_ref()), + ); self.emit(&VariantLift { variant, ty: id, - name: self.iface.types[id].name.as_deref(), + name: self.iface.types[id].name.as_deref().unwrap(), + }); + } + + TypeDefKind::Option(t) => { + self.read_variant_arms_to_memory(offset, addr, Int::U8, [None, Some(t)]); + self.emit(&OptionLift { payload: t, ty: id }); + } + + TypeDefKind::Expected(e) => { + self.read_variant_arms_to_memory( + offset, + addr, + Int::U8, + [Some(&e.ok), Some(&e.err)], + ); + self.emit(&ExpectedLift { + expected: e, + ty: id, }); } @@ -1943,6 +2088,26 @@ impl<'a, B: Bindgen> Generator<'a, B> { } } + fn read_variant_arms_to_memory<'b>( + &mut self, + offset: i32, + addr: B::Operand, + tag: Int, + cases: impl IntoIterator> + Clone, + ) { + self.stack.push(addr.clone()); + self.load_intrepr(offset, tag); + let payload_offset = + offset + (self.bindgen.sizes().payload_offset(tag, cases.clone()) as i32); + for case in cases { + self.push_block(); + if let Some(ty) = case { + self.read_from_memory(ty, addr.clone(), payload_offset); + } + self.finish_block(case.is_some() as usize); + } + } + fn read_list_from_memory(&mut self, ty: &Type, addr: B::Operand, offset: i32) { // Read the pointer/len and then perform the standard lifting // proceses. diff --git a/crates/parser/src/ast.rs b/crates/parser/src/ast.rs index 3f1499847..51ac2445d 100644 --- a/crates/parser/src/ast.rs +++ b/crates/parser/src/ast.rs @@ -95,6 +95,8 @@ enum Type<'a> { Variant(Variant<'a>), Tuple(Vec>), Enum(Enum<'a>), + Option(Box>), + Expected(Expected<'a>), } struct Record<'a> { @@ -138,6 +140,11 @@ struct EnumCase<'a> { name: Id<'a>, } +struct Expected<'a> { + ok: Box>, + err: Box>, +} + pub struct Value<'a> { docs: Docs<'a>, name: Id<'a>, @@ -488,59 +495,21 @@ impl<'a> Type<'a> { } // option - Some((span, Token::Option_)) => { + Some((_span, Token::Option_)) => { tokens.expect(Token::LessThan)?; let ty = Type::parse(tokens)?; tokens.expect(Token::GreaterThan)?; - Ok(Type::Variant(Variant { - tag: None, - span, - cases: vec![ - Case { - docs: Docs::default(), - name: "none".into(), - ty: None, - }, - Case { - docs: Docs::default(), - name: "some".into(), - ty: Some(ty), - }, - ], - })) + Ok(Type::Option(Box::new(ty))) } // expected - Some((span, Token::Expected)) => { + Some((_span, Token::Expected)) => { tokens.expect(Token::LessThan)?; - let ok = if tokens.eat(Token::Underscore)? { - None - } else { - Some(Type::parse(tokens)?) - }; + let ok = Box::new(Type::parse(tokens)?); tokens.expect(Token::Comma)?; - let err = if tokens.eat(Token::Underscore)? { - None - } else { - Some(Type::parse(tokens)?) - }; + let err = Box::new(Type::parse(tokens)?); tokens.expect(Token::GreaterThan)?; - Ok(Type::Variant(Variant { - tag: None, - span, - cases: vec![ - Case { - docs: Docs::default(), - name: "ok".into(), - ty: ok, - }, - Case { - docs: Docs::default(), - name: "err".into(), - ty: err, - }, - ], - })) + Ok(Type::Expected(Expected { ok, err })) } // `foo` diff --git a/crates/parser/src/ast/resolve.rs b/crates/parser/src/ast/resolve.rs index 0af4c9831..308e8642e 100644 --- a/crates/parser/src/ast/resolve.rs +++ b/crates/parser/src/ast/resolve.rs @@ -25,6 +25,8 @@ enum Key { Tuple(Vec), Enum(Vec), List(Type), + Option(Type), + Expected(Type, Type), } impl Resolver { @@ -230,6 +232,11 @@ impl Resolver { cases: e.cases.clone(), }), TypeDefKind::List(t) => TypeDefKind::List(self.copy_type(dep_name, dep, *t)), + TypeDefKind::Option(t) => TypeDefKind::Option(self.copy_type(dep_name, dep, *t)), + TypeDefKind::Expected(e) => TypeDefKind::Expected(Expected { + ok: self.copy_type(dep_name, dep, e.ok), + err: self.copy_type(dep_name, dep, e.err), + }), }, }; let id = self.types.alloc(ty); @@ -447,6 +454,11 @@ impl Resolver { .collect::>>()?; TypeDefKind::Enum(Enum { cases }) } + super::Type::Option(ty) => TypeDefKind::Option(self.resolve_type(ty)?), + super::Type::Expected(e) => TypeDefKind::Expected(Expected { + ok: self.resolve_type(&e.ok)?, + err: self.resolve_type(&e.err)?, + }), }) } @@ -509,6 +521,8 @@ impl Resolver { Key::Enum(r.cases.iter().map(|f| f.name.clone()).collect::>()) } TypeDefKind::List(ty) => Key::List(*ty), + TypeDefKind::Option(t) => Key::Option(*t), + TypeDefKind::Expected(e) => Key::Expected(e.ok, e.err), }; let types = &mut self.types; let id = self @@ -674,6 +688,20 @@ impl Resolver { } } + TypeDefKind::Option(t) => { + if let Type::Id(id) = *t { + self.validate_type_not_recursive(span, id, visiting, valid)? + } + } + TypeDefKind::Expected(e) => { + if let Type::Id(id) = e.ok { + self.validate_type_not_recursive(span, id, visiting, valid)? + } + if let Type::Id(id) = e.err { + self.validate_type_not_recursive(span, id, visiting, valid)? + } + } + TypeDefKind::Flags(_) | TypeDefKind::List(_) | TypeDefKind::Type(_) diff --git a/crates/parser/src/lib.rs b/crates/parser/src/lib.rs index 54c6ea0d8..1843e0057 100644 --- a/crates/parser/src/lib.rs +++ b/crates/parser/src/lib.rs @@ -50,6 +50,8 @@ pub enum TypeDefKind { Tuple(Tuple), Variant(Variant), Enum(Enum), + Option(Type), + Expected(Expected), List(Type), Type(Type), } @@ -219,6 +221,12 @@ impl Enum { } } +#[derive(Debug)] +pub struct Expected { + pub ok: Type, + pub err: Type, +} + #[derive(Clone, Default, Debug)] pub struct Docs { pub contents: Option, @@ -418,6 +426,11 @@ impl Interface { } } } + TypeDefKind::Option(ty) => self.topo_visit_ty(ty, list, visited), + TypeDefKind::Expected(e) => { + self.topo_visit_ty(&e.ok, list, visited); + self.topo_visit_ty(&e.err, list, visited); + } } list.push(id); } @@ -445,7 +458,11 @@ impl Interface { Type::Bool | Type::Char | Type::Handle(_) | Type::String => false, Type::Id(id) => match &self.types[*id].kind { - TypeDefKind::List(_) | TypeDefKind::Variant(_) | TypeDefKind::Enum(_) => false, + TypeDefKind::List(_) + | TypeDefKind::Variant(_) + | TypeDefKind::Enum(_) + | TypeDefKind::Option(_) + | TypeDefKind::Expected(_) => false, TypeDefKind::Type(t) => self.all_bits_valid(t), TypeDefKind::Record(r) => r.fields.iter().all(|f| self.all_bits_valid(&f.ty)), TypeDefKind::Tuple(t) => t.types.iter().all(|t| self.all_bits_valid(t)), diff --git a/crates/parser/src/sizealign.rs b/crates/parser/src/sizealign.rs index 46db20a41..eca0e7809 100644 --- a/crates/parser/src/sizealign.rs +++ b/crates/parser/src/sizealign.rs @@ -1,4 +1,4 @@ -use crate::{FlagsRepr, Int, Interface, Type, TypeDef, TypeDefKind, Variant}; +use crate::{FlagsRepr, Int, Interface, Type, TypeDef, TypeDefKind}; #[derive(Default)] pub struct SizeAlign { @@ -40,6 +40,15 @@ impl SizeAlign { (size, align) } TypeDefKind::Enum(e) => int_size_align(e.tag()), + TypeDefKind::Option(t) => { + let align = self.align(t); + (align_to(1, align) + self.size(t), align) + } + TypeDefKind::Expected(e) => { + let align = self.align(&e.ok).max(self.align(&e.err)); + let size = self.size(&e.ok).max(self.size(&e.err)); + (align_to(1, align) + size, align) + } } } @@ -78,14 +87,18 @@ impl SizeAlign { .collect() } - pub fn payload_offset(&self, variant: &Variant) -> usize { + pub fn payload_offset<'a>( + &self, + tag: Int, + cases: impl IntoIterator>, + ) -> usize { let mut max_align = 1; - for c in variant.cases.iter() { - if let Some(ty) = &c.ty { + for c in cases { + if let Some(ty) = c { max_align = max_align.max(self.align(ty)); } } - let tag_size = int_size_align(variant.tag).0; + let tag_size = int_size_align(tag).0; align_to(tag_size, max_align) } diff --git a/crates/parser/tests/all.rs b/crates/parser/tests/all.rs index 0246a396f..8e1a93cd6 100644 --- a/crates/parser/tests/all.rs +++ b/crates/parser/tests/all.rs @@ -203,6 +203,11 @@ fn to_json(i: &Interface) -> String { Tuple { types: Vec, }, + Option(String), + Expected { + ok: String, + err: String, + }, List(String), } @@ -293,6 +298,11 @@ fn to_json(i: &Interface) -> String { .map(|f| (f.name.clone(), f.ty.as_ref().map(translate_type))) .collect(), }, + TypeDefKind::Option(t) => Type::Option(translate_type(t)), + TypeDefKind::Expected(e) => Type::Expected { + ok: translate_type(&e.ok), + err: translate_type(&e.err), + }, TypeDefKind::List(ty) => Type::List(translate_type(ty)), } } diff --git a/crates/parser/tests/ui/functions.wit.result b/crates/parser/tests/ui/functions.wit.result index f48ba3d83..c2c5e86e1 100644 --- a/crates/parser/tests/ui/functions.wit.result +++ b/crates/parser/tests/ui/functions.wit.result @@ -11,32 +11,13 @@ }, { "idx": 1, - "variant": { - "cases": [ - [ - "none", - null - ], - [ - "some", - "u32" - ] - ] - } + "option": "u32" }, { "idx": 2, - "variant": { - "cases": [ - [ - "ok", - "u32" - ], - [ - "err", - "float32" - ] - ] + "expected": { + "ok": "u32", + "err": "float32" } } ], diff --git a/crates/parser/tests/ui/types.wit b/crates/parser/tests/ui/types.wit index b28ceb875..2a933bb9f 100644 --- a/crates/parser/tests/ui/types.wit +++ b/crates/parser/tests/ui/types.wit @@ -15,9 +15,9 @@ type t12 = list type t13 = string type t14 = option type t15 = expected -type t16 = expected<_, u32> -type t17 = expected -type t18 = expected<_, _> +type t16 = expected +type t17 = expected +type t18 = expected type t19 = handle x record t20 {} record t21 { a: u32 } diff --git a/crates/parser/tests/ui/types.wit.result b/crates/parser/tests/ui/types.wit.result index 2213e764d..19149ecb7 100644 --- a/crates/parser/tests/ui/types.wit.result +++ b/crates/parser/tests/ui/types.wit.result @@ -77,81 +77,38 @@ { "idx": 14, "name": "t14", - "variant": { - "cases": [ - [ - "none", - null - ], - [ - "some", - "u32" - ] - ] - } + "option": "u32" }, { "idx": 15, "name": "t15", - "variant": { - "cases": [ - [ - "ok", - "u32" - ], - [ - "err", - "u32" - ] - ] + "expected": { + "ok": "u32", + "err": "u32" } }, { "idx": 16, "name": "t16", - "variant": { - "cases": [ - [ - "ok", - null - ], - [ - "err", - "u32" - ] - ] + "expected": { + "ok": "unit", + "err": "u32" } }, { "idx": 17, "name": "t17", - "variant": { - "cases": [ - [ - "ok", - "u32" - ], - [ - "err", - null - ] - ] + "expected": { + "ok": "u32", + "err": "unit" } }, { "idx": 18, "name": "t18", - "variant": { - "cases": [ - [ - "ok", - null - ], - [ - "err", - null - ] - ] + "expected": { + "ok": "unit", + "err": "unit" } }, { @@ -488,18 +445,7 @@ }, { "idx": 51, - "variant": { - "cases": [ - [ - "none", - null - ], - [ - "some", - "u32" - ] - ] - } + "option": "u32" }, { "idx": 52, diff --git a/crates/parser/tests/ui/wasi-clock.wit.result b/crates/parser/tests/ui/wasi-clock.wit.result index 3cbfe5895..ee1e14ce6 100644 --- a/crates/parser/tests/ui/wasi-clock.wit.result +++ b/crates/parser/tests/ui/wasi-clock.wit.result @@ -105,17 +105,9 @@ }, { "idx": 3, - "variant": { - "cases": [ - [ - "ok", - "type-1" - ], - [ - "err", - "type-2" - ] - ] + "expected": { + "ok": "type-1", + "err": "type-2" } } ], diff --git a/crates/parser/tests/ui/wasi-http.wit.result b/crates/parser/tests/ui/wasi-http.wit.result index 7142b2a35..037e009ea 100644 --- a/crates/parser/tests/ui/wasi-http.wit.result +++ b/crates/parser/tests/ui/wasi-http.wit.result @@ -58,18 +58,7 @@ }, { "idx": 6, - "variant": { - "cases": [ - [ - "none", - null - ], - [ - "some", - "string" - ] - ] - } + "option": "string" }, { "idx": 7, @@ -77,17 +66,9 @@ }, { "idx": 8, - "variant": { - "cases": [ - [ - "ok", - "u64" - ], - [ - "err", - "type-4" - ] - ] + "expected": { + "ok": "u64", + "err": "type-4" } } ], diff --git a/crates/test-helpers/src/lib.rs b/crates/test-helpers/src/lib.rs index a10556863..a8305b2aa 100644 --- a/crates/test-helpers/src/lib.rs +++ b/crates/test-helpers/src/lib.rs @@ -182,27 +182,19 @@ pub fn codegen_rust_wasm_export(input: TokenStream) -> TokenStream { TypeDefKind::Flags(_) => panic!("unknown flags"), TypeDefKind::Enum(_) => panic!("unknown enum"), TypeDefKind::Record(_) => panic!("unknown record"), + TypeDefKind::Variant(_) => panic!("unknown variant"), TypeDefKind::Tuple(t) => { let fields = t.types.iter().map(|ty| quote_ty(param, iface, ty)); quote::quote! { (#(#fields,)*) } } - TypeDefKind::Variant(v) => { - if let Some(ty) = v.as_option() { - let ty = quote_ty(param, iface, ty); - quote::quote! { Option<#ty> } - } else if let Some((ok, err)) = v.as_expected() { - let ok = match ok { - Some(ok) => quote_ty(param, iface, ok), - None => quote::quote! { () }, - }; - let err = match err { - Some(err) => quote_ty(param, iface, err), - None => quote::quote! { () }, - }; - quote::quote! { Result<#ok, #err> } - } else { - panic!("unknown variant"); - } + TypeDefKind::Option(ty) => { + let ty = quote_ty(param, iface, ty); + quote::quote! { Option<#ty> } + } + TypeDefKind::Expected(e) => { + let ok = quote_ty(param, iface, &e.ok); + let err = quote_ty(param, iface, &e.err); + quote::quote! { Result<#ok, #err> } } } } diff --git a/crates/test-modules/modules/crates/variants/variants.wit b/crates/test-modules/modules/crates/variants/variants.wit index 340270575..525bc1a4c 100644 --- a/crates/test-modules/modules/crates/variants/variants.wit +++ b/crates/test-modules/modules/crates/variants/variants.wit @@ -97,17 +97,17 @@ casts: function( > expected-arg: function( - a: expected<_, _>, - b: expected<_, e1>, - c: expected, + a: expected, + b: expected, + c: expected, d: expected, tuple<>>, e: expected, f: expected>, ) expected-result: function() -> tuple< - expected<_, _>, - expected<_, e1>, - expected, + expected, + expected, + expected, expected, tuple<>>, expected, expected>, diff --git a/crates/wasmlink/src/adapter/call.rs b/crates/wasmlink/src/adapter/call.rs index ebafaec3b..9ab1fda42 100644 --- a/crates/wasmlink/src/adapter/call.rs +++ b/crates/wasmlink/src/adapter/call.rs @@ -485,44 +485,36 @@ impl<'a> CallAdapter<'a> { TypeDefKind::Enum(_) => { params.next().unwrap(); } - TypeDefKind::Variant(v) => { - let discriminant = params.next().unwrap(); - let mut count = 0; - let mut cases = Vec::new(); - for (i, c) in v.cases.iter().enumerate() { - if let Some(ty) = &c.ty { - let mut iter = params.clone(); - let mut operands = Vec::new(); - - Self::push_operands( - interface, - sizes, - ty, - &mut iter, - mode, - locals_count, - &mut operands, - ); - - if !operands.is_empty() { - cases.push((i as u32, operands)); - } - - count = std::cmp::max(count, params.len() - iter.len()); - } - } - - if !cases.is_empty() { - operands.push(Operand::Variant { - discriminant: (mode.create_value_ref(discriminant), v.tag.into()), - cases, - }); - } - - for _ in 0..count { - params.next().unwrap(); - } - } + TypeDefKind::Variant(v) => Self::push_variant_operands( + interface, + sizes, + v.tag, + v.cases.iter().map(|c| c.ty.as_ref()), + params, + mode, + locals_count, + operands, + ), + TypeDefKind::Option(t) => Self::push_variant_operands( + interface, + sizes, + Int::U8, + [None, Some(t)], + params, + mode, + locals_count, + operands, + ), + TypeDefKind::Expected(e) => Self::push_variant_operands( + interface, + sizes, + Int::U8, + [Some(&e.ok), Some(&e.err)], + params, + mode, + locals_count, + operands, + ), }, Type::String => { let addr = params.next().unwrap(); @@ -557,12 +549,63 @@ impl<'a> CallAdapter<'a> { name: interface.resources[*id].name.as_str(), }); } + Type::Unit => {} _ => { params.next().unwrap(); } } } + fn push_variant_operands<'b, T>( + interface: &'a WitInterface, + sizes: &SizeAlign, + tag: Int, + all_cases: impl IntoIterator>, + params: &mut T, + mode: PushMode, + locals_count: &mut u32, + operands: &mut Vec>, + ) where + T: ExactSizeIterator + Clone, + { + let discriminant = params.next().unwrap(); + let mut count = 0; + let mut cases = Vec::new(); + for (i, c) in all_cases.into_iter().enumerate() { + if let Some(ty) = c { + let mut iter = params.clone(); + let mut operands = Vec::new(); + + Self::push_operands( + interface, + sizes, + ty, + &mut iter, + mode, + locals_count, + &mut operands, + ); + + if !operands.is_empty() { + cases.push((i as u32, operands)); + } + + count = std::cmp::max(count, params.len() - iter.len()); + } + } + + if !cases.is_empty() { + operands.push(Operand::Variant { + discriminant: (mode.create_value_ref(discriminant), tag.into()), + cases, + }); + } + + for _ in 0..count { + params.next().unwrap(); + } + } + fn push_element_operands( interface: &'a WitInterface, sizes: &SizeAlign, @@ -651,35 +694,36 @@ impl<'a> CallAdapter<'a> { } } TypeDefKind::Enum(_) => {} - TypeDefKind::Variant(v) => { - let payload_offset = sizes.payload_offset(v) as u32; - - let mut cases = Vec::new(); - for (i, c) in v.cases.iter().enumerate() { - if let Some(ty) = &c.ty { - let mut operands = Vec::new(); - Self::push_element_operands( - interface, - sizes, - ty, - offset + payload_offset, - mode, - locals_count, - &mut operands, - ); - if !operands.is_empty() { - cases.push((i as u32, operands)); - } - } - } - - if !cases.is_empty() { - operands.push(Operand::Variant { - discriminant: (ValueRef::ElementOffset(offset), v.tag.into()), - cases, - }); - } - } + TypeDefKind::Variant(v) => Self::push_variant_element_operands( + interface, + sizes, + offset, + v.tag, + v.cases.iter().map(|c| c.ty.as_ref()), + mode, + locals_count, + operands, + ), + TypeDefKind::Option(t) => Self::push_variant_element_operands( + interface, + sizes, + offset, + Int::U8, + [None, Some(t)], + mode, + locals_count, + operands, + ), + TypeDefKind::Expected(e) => Self::push_variant_element_operands( + interface, + sizes, + offset, + Int::U8, + [Some(&e.ok), Some(&e.err)], + mode, + locals_count, + operands, + ), }, Type::String => { // Every list copied needs a source and destination local @@ -709,6 +753,45 @@ impl<'a> CallAdapter<'a> { } } + fn push_variant_element_operands<'b>( + interface: &'a WitInterface, + sizes: &SizeAlign, + offset: u32, + tag: Int, + all_cases: impl IntoIterator> + Clone, + mode: PushMode, + locals_count: &mut u32, + operands: &mut Vec>, + ) { + let payload_offset = sizes.payload_offset(tag, all_cases.clone()) as u32; + + let mut cases = Vec::new(); + for (i, c) in all_cases.into_iter().enumerate() { + if let Some(ty) = c { + let mut operands = Vec::new(); + Self::push_element_operands( + interface, + sizes, + ty, + offset + payload_offset, + mode, + locals_count, + &mut operands, + ); + if !operands.is_empty() { + cases.push((i as u32, operands)); + } + } + } + + if !cases.is_empty() { + operands.push(Operand::Variant { + discriminant: (ValueRef::ElementOffset(offset), tag.into()), + cases, + }); + } + } + fn emit_store_from_base( &self, function: &mut wasm_encoder::Function, diff --git a/crates/wasmlink/src/module.rs b/crates/wasmlink/src/module.rs index b0142d523..c256a0b7a 100644 --- a/crates/wasmlink/src/module.rs +++ b/crates/wasmlink/src/module.rs @@ -56,6 +56,8 @@ fn has_list(interface: &WitInterface, ty: &WitType) -> bool { .map(|t| has_list(interface, t)) .unwrap_or(false) }), + TypeDefKind::Option(t) => has_list(interface, t), + TypeDefKind::Expected(e) => has_list(interface, &e.ok) || has_list(interface, &e.err), TypeDefKind::Enum(_) => false, }, _ => false, diff --git a/crates/wasmlink/tests/run.rs b/crates/wasmlink/tests/run.rs index 4d63678d7..a226c8fef 100644 --- a/crates/wasmlink/tests/run.rs +++ b/crates/wasmlink/tests/run.rs @@ -28,6 +28,7 @@ fn wasmlink_file_tests() -> Result<()> { let entry = entry?; let path = entry.path(); + println!("{:?}", path); match ( path.file_stem().and_then(OsStr::to_str), diff --git a/crates/wasmlink/tests/variants.wit b/crates/wasmlink/tests/variants.wit index 340270575..525bc1a4c 100644 --- a/crates/wasmlink/tests/variants.wit +++ b/crates/wasmlink/tests/variants.wit @@ -97,17 +97,17 @@ casts: function( > expected-arg: function( - a: expected<_, _>, - b: expected<_, e1>, - c: expected, + a: expected, + b: expected, + c: expected, d: expected, tuple<>>, e: expected, f: expected>, ) expected-result: function() -> tuple< - expected<_, _>, - expected<_, e1>, - expected, + expected, + expected, + expected, expected, tuple<>>, expected, expected>, diff --git a/crates/wit-component/src/decoding.rs b/crates/wit-component/src/decoding.rs index f4e3d5d8d..b2818b922 100644 --- a/crates/wit-component/src/decoding.rs +++ b/crates/wit-component/src/decoding.rs @@ -514,26 +514,11 @@ impl<'a> InterfaceDecoder<'a> { fn decode_option( &mut self, name: Option, - ty: &types::InterfaceTypeRef, + payload: &types::InterfaceTypeRef, ) -> Result { - let variant = Variant { - cases: vec![ - Case { - docs: Docs::default(), - name: "none".to_string(), - ty: None, - }, - Case { - docs: Docs::default(), - name: "some".to_string(), - ty: Some(self.decode_type(ty)?), - }, - ], - tag: Variant::infer_tag(2), - }; - + let payload = self.decode_type(payload)?; Ok(Type::Id( - self.alloc_type(name, TypeDefKind::Variant(variant)), + self.alloc_type(name, TypeDefKind::Option(payload)), )) } @@ -541,33 +526,14 @@ impl<'a> InterfaceDecoder<'a> { &mut self, name: Option, ok: &types::InterfaceTypeRef, - error: &types::InterfaceTypeRef, + err: &types::InterfaceTypeRef, ) -> Result { - let variant = Variant { - cases: vec![ - Case { - docs: Docs::default(), - name: "ok".to_string(), - ty: match ok { - types::InterfaceTypeRef::Primitive(PrimitiveInterfaceType::Unit) => None, - _ => Some(self.decode_type(ok)?), - }, - }, - Case { - docs: Docs::default(), - name: "err".to_string(), - ty: match error { - types::InterfaceTypeRef::Primitive(PrimitiveInterfaceType::Unit) => None, - _ => Some(self.decode_type(error)?), - }, - }, - ], - tag: Variant::infer_tag(2), - }; - - Ok(Type::Id( - self.alloc_type(name, TypeDefKind::Variant(variant)), - )) + let ok = self.decode_type(ok)?; + let err = self.decode_type(err)?; + Ok(Type::Id(self.alloc_type( + name, + TypeDefKind::Expected(Expected { ok, err }), + ))) } fn alloc_type(&mut self, name: Option, kind: TypeDefKind) -> TypeId { diff --git a/crates/wit-component/src/encoding.rs b/crates/wit-component/src/encoding.rs index 88fc9af40..61d8a4e3c 100644 --- a/crates/wit-component/src/encoding.rs +++ b/crates/wit-component/src/encoding.rs @@ -13,8 +13,8 @@ use wasm_encoder::*; use wasmparser::{Validator, WasmFeatures}; use wit_parser::{ abi::{AbiVariant, WasmSignature, WasmType}, - Enum, Flags, Function, FunctionKind, Interface, Record, Tuple, Type, TypeDef, TypeDefKind, - Variant, + Enum, Expected, Flags, Function, FunctionKind, Interface, Record, Tuple, Type, TypeDef, + TypeDefKind, Variant, }; const INDIRECT_TABLE_NAME: &str = "$imports"; @@ -227,6 +227,27 @@ impl Hash for TypeDefKey<'_> { } .hash(state); } + TypeDefKind::Option(ty) => { + state.write_u8(7); + TypeKey { + interface: self.interface, + ty: *ty, + } + .hash(state); + } + TypeDefKind::Expected(e) => { + state.write_u8(8); + TypeKey { + interface: self.interface, + ty: e.ok, + } + .hash(state); + TypeKey { + interface: self.interface, + ty: e.err, + } + .hash(state); + } } } } @@ -407,6 +428,8 @@ impl<'a> TypeEncoder<'a> { TypeDefKind::Tuple(t) => self.encode_tuple(interface, instance, t)?, TypeDefKind::Flags(r) => self.encode_flags(r)?, TypeDefKind::Variant(v) => self.encode_variant(interface, instance, v)?, + TypeDefKind::Option(t) => self.encode_option(interface, instance, t)?, + TypeDefKind::Expected(e) => self.encode_expected(interface, instance, e)?, TypeDefKind::Enum(e) => self.encode_enum(e)?, TypeDefKind::List(ty) => { let ty = self.encode_type(interface, instance, ty)?; @@ -501,29 +524,6 @@ impl<'a> TypeEncoder<'a> { instance: &mut Option>, variant: &Variant, ) -> Result { - if let Some(ty) = variant.as_option() { - let ty = self.encode_type(interface, instance, ty)?; - let index = self.types.len(); - let encoder = self.types.interface_type(); - encoder.option(ty); - return Ok(InterfaceTypeRef::Type(index)); - } - - if let Some((ok, error)) = variant.as_expected() { - let ok = ok - .map(|ty| self.encode_type(interface, instance, ty)) - .transpose()? - .unwrap_or(InterfaceTypeRef::Primitive(PrimitiveInterfaceType::Unit)); - let error = error - .map(|ty| self.encode_type(interface, instance, ty)) - .transpose()? - .unwrap_or(InterfaceTypeRef::Primitive(PrimitiveInterfaceType::Unit)); - let index = self.types.len(); - let encoder = self.types.interface_type(); - encoder.expected(ok, error); - return Ok(InterfaceTypeRef::Type(index)); - } - if variant.is_union() { let tys = variant .cases @@ -558,6 +558,33 @@ impl<'a> TypeEncoder<'a> { Ok(InterfaceTypeRef::Type(index)) } + fn encode_option( + &mut self, + interface: &'a Interface, + instance: &mut Option>, + payload: &Type, + ) -> Result { + let ty = self.encode_type(interface, instance, payload)?; + let index = self.types.len(); + let encoder = self.types.interface_type(); + encoder.option(ty); + Ok(InterfaceTypeRef::Type(index)) + } + + fn encode_expected( + &mut self, + interface: &'a Interface, + instance: &mut Option>, + expected: &Expected, + ) -> Result { + let ok = self.encode_type(interface, instance, &expected.ok)?; + let error = self.encode_type(interface, instance, &expected.err)?; + let index = self.types.len(); + let encoder = self.types.interface_type(); + encoder.expected(ok, error); + Ok(InterfaceTypeRef::Type(index)) + } + fn encode_enum(&mut self, enum_: &Enum) -> Result { let index = self.types.len(); let encoder = self.types.interface_type(); @@ -632,6 +659,10 @@ impl RequiredOptions { } TypeDefKind::Tuple(t) => Self::for_types(interface, t.types.iter()), TypeDefKind::Flags(_) => Self::None, + TypeDefKind::Option(t) => Self::for_type(interface, t), + TypeDefKind::Expected(e) => { + Self::for_type(interface, &e.ok) | Self::for_type(interface, &e.err) + } TypeDefKind::Variant(v) => { Self::for_types(interface, v.cases.iter().filter_map(|c| c.ty.as_ref())) } diff --git a/crates/wit-component/src/printing.rs b/crates/wit-component/src/printing.rs index 28aafbe97..83f22d07f 100644 --- a/crates/wit-component/src/printing.rs +++ b/crates/wit-component/src/printing.rs @@ -1,7 +1,9 @@ use anyhow::{bail, Result}; use std::collections::HashSet; use std::fmt::Write; -use wit_parser::{Enum, Flags, Interface, Record, Tuple, Type, TypeDefKind, TypeId, Variant}; +use wit_parser::{ + Enum, Expected, Flags, Interface, Record, Tuple, Type, TypeDefKind, TypeId, Variant, +}; /// A utility for printing WebAssembly interface definitions to a string. #[derive(Default)] @@ -72,6 +74,12 @@ impl InterfacePrinter { TypeDefKind::Tuple(t) => { self.print_tuple_type(interface, t)?; } + TypeDefKind::Option(t) => { + self.print_option_type(interface, t)?; + } + TypeDefKind::Expected(t) => { + self.print_expected_type(interface, t)?; + } TypeDefKind::Record(_) => { bail!("interface has an unnamed record type"); } @@ -81,8 +89,8 @@ impl InterfacePrinter { TypeDefKind::Enum(_) => { bail!("interface has unnamed enum type") } - TypeDefKind::Variant(v) => { - self.print_variant_type(interface, v)?; + TypeDefKind::Variant(_) => { + bail!("interface has unnamed variant type") } TypeDefKind::List(ty) => { self.output.push_str("list<"); @@ -112,30 +120,20 @@ impl InterfacePrinter { Ok(()) } - fn print_variant_type(&mut self, interface: &Interface, variant: &Variant) -> Result<()> { - if let Some(ty) = variant.as_option() { - self.output.push_str("option<"); - self.print_type_name(interface, ty)?; - self.output.push('>'); - return Ok(()); - } - - if let Some((ok, err)) = variant.as_expected() { - self.output.push_str("expected<"); - match ok { - Some(ty) => self.print_type_name(interface, ty)?, - None => self.output.push('_'), - } - self.output.push_str(", "); - match err { - Some(ty) => self.print_type_name(interface, ty)?, - None => self.output.push('_'), - } - self.output.push('>'); - return Ok(()); - } + fn print_option_type(&mut self, interface: &Interface, payload: &Type) -> Result<()> { + self.output.push_str("option<"); + self.print_type_name(interface, payload)?; + self.output.push('>'); + Ok(()) + } - bail!("interface has an unnamed variant type"); + fn print_expected_type(&mut self, interface: &Interface, expected: &Expected) -> Result<()> { + self.output.push_str("expected<"); + self.print_type_name(interface, &expected.ok)?; + self.output.push_str(", "); + self.print_type_name(interface, &expected.err)?; + self.output.push('>'); + Ok(()) } fn declare_type(&mut self, interface: &Interface, ty: &Type) -> Result<()> { @@ -172,6 +170,12 @@ impl InterfacePrinter { TypeDefKind::Variant(v) => { self.declare_variant(interface, ty.name.as_deref(), v)? } + TypeDefKind::Option(t) => { + self.declare_option(interface, ty.name.as_deref(), t)? + } + TypeDefKind::Expected(e) => { + self.declare_expected(interface, ty.name.as_deref(), e)? + } TypeDefKind::Enum(e) => self.declare_enum(ty.name.as_deref(), e)?, TypeDefKind::List(inner) => { self.declare_list(interface, ty.name.as_deref(), inner)? @@ -262,15 +266,6 @@ impl InterfacePrinter { } } - if variant.as_option().is_some() || variant.as_expected().is_some() { - if let Some(name) = name { - write!(&mut self.output, "type {} = ", name)?; - self.print_variant_type(interface, variant)?; - self.output.push_str("\n\n"); - } - return Ok(()); - } - match name { Some(name) => { if variant.is_union() { @@ -301,6 +296,39 @@ impl InterfacePrinter { } } + fn declare_option( + &mut self, + interface: &Interface, + name: Option<&str>, + payload: &Type, + ) -> Result<()> { + self.declare_type(interface, payload)?; + + if let Some(name) = name { + write!(&mut self.output, "type {} = ", name)?; + self.print_option_type(interface, payload)?; + self.output.push_str("\n\n"); + } + Ok(()) + } + + fn declare_expected( + &mut self, + interface: &Interface, + name: Option<&str>, + expected: &Expected, + ) -> Result<()> { + self.declare_type(interface, &expected.ok)?; + self.declare_type(interface, &expected.err)?; + + if let Some(name) = name { + write!(&mut self.output, "type {} = ", name)?; + self.print_expected_type(interface, expected)?; + self.output.push_str("\n\n"); + } + Ok(()) + } + fn declare_enum(&mut self, name: Option<&str>, enum_: &Enum) -> Result<()> { let name = match name { Some(name) => name, diff --git a/crates/wit-component/tests/interfaces/variants/variants.wat b/crates/wit-component/tests/interfaces/variants/variants.wat index caba3815c..9f256be2e 100644 --- a/crates/wit-component/tests/interfaces/variants/variants.wat +++ b/crates/wit-component/tests/interfaces/variants/variants.wat @@ -18,46 +18,61 @@ (type (;16;) (option (type 0))) (type (;17;) (option float32)) (type (;18;) (option (type 3))) - (type (;19;) (option (type 12))) - (type (;20;) (func (param "a" (type 12)) (param "b" (type 14)) (param "c" (type 15)) (param "d" (type 16)) (param "e" (type 17)) (param "f" (type 18)) (param "g" (type 19)))) - (type (;21;) (tuple (type 12) (type 14) (type 15) (type 16) (type 17) (type 18) (type 19))) - (type (;22;) (func (result (type 21)))) - (type (;23;) (variant (case "a" s32) (case "b" float32))) - (type (;24;) (variant (case "a" float64) (case "b" float32))) - (type (;25;) (variant (case "a" float64) (case "b" u64))) - (type (;26;) (variant (case "a" u32) (case "b" s64))) - (type (;27;) (variant (case "a" float32) (case "b" s64))) - (type (;28;) (tuple float32 u32)) - (type (;29;) (tuple u32 u32)) - (type (;30;) (variant (case "a" (type 28)) (case "b" (type 29)))) - (type (;31;) (tuple (type 23) (type 24) (type 25) (type 26) (type 27) (type 30))) - (type (;32;) (func (param "a" (type 23)) (param "b" (type 24)) (param "c" (type 25)) (param "d" (type 26)) (param "e" (type 27)) (param "f" (type 30)) (result (type 31)))) - (type (;33;) (expected unit unit)) - (type (;34;) (expected unit (type 0))) - (type (;35;) (expected (type 0) unit)) - (type (;36;) (expected (type 13) (type 13))) - (type (;37;) (expected u32 (type 7))) - (type (;38;) (list u8)) - (type (;39;) (expected string (type 38))) - (type (;40;) (func (param "a" (type 33)) (param "b" (type 34)) (param "c" (type 35)) (param "d" (type 36)) (param "e" (type 37)) (param "f" (type 39)))) - (type (;41;) (tuple (type 33) (type 34) (type 35) (type 36) (type 37) (type 39))) - (type (;42;) (func (result (type 41)))) - (type (;43;) (enum "bad1" "bad2")) - (type (;44;) (expected s32 (type 43))) - (type (;45;) (func (result (type 44)))) - (type (;46;) (expected unit (type 43))) - (type (;47;) (func (result (type 46)))) - (type (;48;) (expected (type 43) (type 43))) - (type (;49;) (func (result (type 48)))) - (type (;50;) (tuple s32 u32)) - (type (;51;) (expected (type 50) (type 43))) - (type (;52;) (func (result (type 51)))) - (type (;53;) (option s32)) - (type (;54;) (func (result (type 53)))) - (type (;55;) (option (type 43))) - (type (;56;) (func (result (type 55)))) - (type (;57;) (expected u32 s32)) - (type (;58;) (func (result (type 57)))) + (type (;19;) (option bool)) + (type (;20;) (option (type 19))) + (type (;21;) (func (param "a" (type 12)) (param "b" (type 14)) (param "c" (type 15)) (param "d" (type 16)) (param "e" (type 17)) (param "f" (type 18)) (param "g" (type 20)))) + (type (;22;) (option bool)) + (type (;23;) (option (type 13))) + (type (;24;) (option u32)) + (type (;25;) (option (type 0))) + (type (;26;) (option float32)) + (type (;27;) (option (type 3))) + (type (;28;) (option bool)) + (type (;29;) (option (type 28))) + (type (;30;) (tuple (type 22) (type 23) (type 24) (type 25) (type 26) (type 27) (type 29))) + (type (;31;) (func (result (type 30)))) + (type (;32;) (variant (case "a" s32) (case "b" float32))) + (type (;33;) (variant (case "a" float64) (case "b" float32))) + (type (;34;) (variant (case "a" float64) (case "b" u64))) + (type (;35;) (variant (case "a" u32) (case "b" s64))) + (type (;36;) (variant (case "a" float32) (case "b" s64))) + (type (;37;) (tuple float32 u32)) + (type (;38;) (tuple u32 u32)) + (type (;39;) (variant (case "a" (type 37)) (case "b" (type 38)))) + (type (;40;) (tuple (type 32) (type 33) (type 34) (type 35) (type 36) (type 39))) + (type (;41;) (func (param "a" (type 32)) (param "b" (type 33)) (param "c" (type 34)) (param "d" (type 35)) (param "e" (type 36)) (param "f" (type 39)) (result (type 40)))) + (type (;42;) (expected unit unit)) + (type (;43;) (expected unit (type 0))) + (type (;44;) (expected (type 0) unit)) + (type (;45;) (expected (type 13) (type 13))) + (type (;46;) (expected u32 (type 7))) + (type (;47;) (list u8)) + (type (;48;) (expected string (type 47))) + (type (;49;) (func (param "a" (type 42)) (param "b" (type 43)) (param "c" (type 44)) (param "d" (type 45)) (param "e" (type 46)) (param "f" (type 48)))) + (type (;50;) (expected unit unit)) + (type (;51;) (expected unit (type 0))) + (type (;52;) (expected (type 0) unit)) + (type (;53;) (expected (type 13) (type 13))) + (type (;54;) (expected u32 (type 7))) + (type (;55;) (expected string (type 47))) + (type (;56;) (tuple (type 50) (type 51) (type 52) (type 53) (type 54) (type 55))) + (type (;57;) (func (result (type 56)))) + (type (;58;) (enum "bad1" "bad2")) + (type (;59;) (expected s32 (type 58))) + (type (;60;) (func (result (type 59)))) + (type (;61;) (expected unit (type 58))) + (type (;62;) (func (result (type 61)))) + (type (;63;) (expected (type 58) (type 58))) + (type (;64;) (func (result (type 63)))) + (type (;65;) (tuple s32 u32)) + (type (;66;) (expected (type 65) (type 58))) + (type (;67;) (func (result (type 66)))) + (type (;68;) (option s32)) + (type (;69;) (func (result (type 68)))) + (type (;70;) (option (type 58))) + (type (;71;) (func (result (type 70)))) + (type (;72;) (expected u32 s32)) + (type (;73;) (func (result (type 72)))) (export "e1" (type 0)) (export "e1-arg" (type 1)) (export "e1-result" (type 2)) @@ -70,23 +85,23 @@ (export "v1-result" (type 9)) (export "bool-arg" (type 10)) (export "bool-result" (type 11)) - (export "option-arg" (type 20)) - (export "option-result" (type 22)) - (export "casts1" (type 23)) - (export "casts2" (type 24)) - (export "casts3" (type 25)) - (export "casts4" (type 26)) - (export "casts5" (type 27)) - (export "casts6" (type 30)) - (export "casts" (type 32)) - (export "expected-arg" (type 40)) - (export "expected-result" (type 42)) - (export "my-errno" (type 43)) - (export "return-expected-sugar" (type 45)) - (export "return-expected-sugar2" (type 47)) - (export "return-expected-sugar3" (type 49)) - (export "return-expected-sugar4" (type 52)) - (export "return-option-sugar" (type 54)) - (export "return-option-sugar2" (type 56)) - (export "expected-simple" (type 58)) + (export "option-arg" (type 21)) + (export "option-result" (type 31)) + (export "casts1" (type 32)) + (export "casts2" (type 33)) + (export "casts3" (type 34)) + (export "casts4" (type 35)) + (export "casts5" (type 36)) + (export "casts6" (type 39)) + (export "casts" (type 41)) + (export "expected-arg" (type 49)) + (export "expected-result" (type 57)) + (export "my-errno" (type 58)) + (export "return-expected-sugar" (type 60)) + (export "return-expected-sugar2" (type 62)) + (export "return-expected-sugar3" (type 64)) + (export "return-expected-sugar4" (type 67)) + (export "return-option-sugar" (type 69)) + (export "return-option-sugar2" (type 71)) + (export "expected-simple" (type 73)) ) \ No newline at end of file diff --git a/crates/wit-component/tests/interfaces/variants/variants.wit b/crates/wit-component/tests/interfaces/variants/variants.wit index 461bd1cc4..eeaa8cc59 100644 --- a/crates/wit-component/tests/interfaces/variants/variants.wit +++ b/crates/wit-component/tests/interfaces/variants/variants.wit @@ -77,13 +77,13 @@ option-result: function() -> tuple, option>, option, o casts: function(a: casts1, b: casts2, c: casts3, d: casts4, e: casts5, f: casts6) -> tuple -expected-arg: function(a: expected<_, _>, b: expected<_, e1>, c: expected, d: expected, tuple<>>, e: expected, f: expected>) +expected-arg: function(a: expected, b: expected, c: expected, d: expected, tuple<>>, e: expected, f: expected>) -expected-result: function() -> tuple, expected<_, e1>, expected, expected, tuple<>>, expected, expected>> +expected-result: function() -> tuple, expected, expected, expected, tuple<>>, expected, expected>> return-expected-sugar: function() -> expected -return-expected-sugar2: function() -> expected<_, my-errno> +return-expected-sugar2: function() -> expected return-expected-sugar3: function() -> expected diff --git a/crates/wit-component/tests/roundtrip.rs b/crates/wit-component/tests/roundtrip.rs index 2bda797af..f00ad7183 100644 --- a/crates/wit-component/tests/roundtrip.rs +++ b/crates/wit-component/tests/roundtrip.rs @@ -27,7 +27,7 @@ fn roundtrip_interfaces() -> Result<()> { let test_case = path.file_stem().unwrap().to_str().unwrap(); let wit_path = path.join(test_case).with_extension("wit"); - let interface = Interface::parse_file(&wit_path)?; + let interface = Interface::parse_file(&wit_path).context("failed to parse `wit` file")?; let encoder = ComponentEncoder::default() .interface(&interface) @@ -41,10 +41,12 @@ fn roundtrip_interfaces() -> Result<()> { ) })?; - let interface = decode_interface_component(&bytes)?; + let interface = decode_interface_component(&bytes).context("failed to decode bytes")?; let mut printer = InterfacePrinter::default(); - let output = printer.print(&interface)?; + let output = printer + .print(&interface) + .context("failed to print interface")?; if std::env::var_os("BLESS").is_some() { fs::write(&wit_path, output)?; diff --git a/tests/codegen/variants.wit b/tests/codegen/variants.wit index f1ba28ccf..7857e9216 100644 --- a/tests/codegen/variants.wit +++ b/tests/codegen/variants.wit @@ -97,17 +97,17 @@ casts: function( > expected-arg: function( - a: expected<_, _>, - b: expected<_, e1>, - c: expected, + a: expected, + b: expected, + c: expected, d: expected, tuple<>>, e: expected, f: expected>, ) expected-result: function() -> tuple< - expected<_, _>, - expected<_, e1>, - expected, + expected, + expected, + expected, expected, tuple<>>, expected, expected>, @@ -119,7 +119,7 @@ enum my-errno { } return-expected-sugar: function() -> expected -return-expected-sugar2: function() -> expected<_, my-errno> +return-expected-sugar2: function() -> expected return-expected-sugar3: function() -> expected return-expected-sugar4: function() -> expected, my-errno> return-option-sugar: function() -> option diff --git a/tests/runtime/flavorful/exports.wit b/tests/runtime/flavorful/exports.wit index 65efcaa91..29ccb1fc2 100644 --- a/tests/runtime/flavorful/exports.wit +++ b/tests/runtime/flavorful/exports.wit @@ -12,7 +12,7 @@ list-in-record3: function(a: list-in-record3) -> list-in-record3 list-in-record4: function(a: list-in-alias) -> list-in-alias type list-in-variant1-v1 = option -type list-in-variant1-v2 = expected<_, string> +type list-in-variant1-v2 = expected union list-in-variant1-v3 { string, float32 } list-in-variant1: function(a: list-in-variant1-v1, b: list-in-variant1-v2, c: list-in-variant1-v3) @@ -23,7 +23,7 @@ type list-in-variant3 = option list-in-variant3: function(a: list-in-variant3) -> list-in-variant3 enum my-errno { success, a, b } -errno-result: function() -> expected<_, my-errno> +errno-result: function() -> expected type list-typedef = string type list-typedef2 = list diff --git a/tests/runtime/flavorful/host.ts b/tests/runtime/flavorful/host.ts index 02c464f8c..ebde61d42 100644 --- a/tests/runtime/flavorful/host.ts +++ b/tests/runtime/flavorful/host.ts @@ -38,7 +38,7 @@ async function run() { listOfVariants(bools, results, enums) { assert.deepStrictEqual(bools, [true, false]); - assert.deepStrictEqual(results, [{ tag: 'ok' }, { tag: 'err' }]); + assert.deepStrictEqual(results, [{ tag: 'ok', val: undefined }, { tag: 'err', val: undefined }]); assert.deepStrictEqual(enums, [MyErrno.Success, MyErrno.A]); return [ [false, true], diff --git a/tests/runtime/flavorful/imports.wit b/tests/runtime/flavorful/imports.wit index a31bb5b23..b517b94bf 100644 --- a/tests/runtime/flavorful/imports.wit +++ b/tests/runtime/flavorful/imports.wit @@ -10,7 +10,7 @@ list-in-record3: function(a: list-in-record3) -> list-in-record3 list-in-record4: function(a: list-in-alias) -> list-in-alias type list-in-variant1-v1 = option -type list-in-variant1-v2 = expected<_, string> +type list-in-variant1-v2 = expected union list-in-variant1-v3 { string, float32 } list-in-variant1: function(a: list-in-variant1-v1, b: list-in-variant1-v2, c: list-in-variant1-v3) @@ -21,11 +21,11 @@ type list-in-variant3 = option list-in-variant3: function(a: list-in-variant3) -> list-in-variant3 enum my-errno { success, a, b } -errno-result: function() -> expected<_, my-errno> +errno-result: function() -> expected type list-typedef = string type list-typedef2 = list type list-typedef3 = list list-typedefs: function(a: list-typedef, c: list-typedef3) -> tuple -list-of-variants: function(a: list, b: list>, c: list) -> tuple, list>, list> +list-of-variants: function(a: list, b: list>, c: list) -> tuple, list>, list> diff --git a/tests/runtime/flavorful/wasm.c b/tests/runtime/flavorful/wasm.c index deefddedd..ad1fa695c 100644 --- a/tests/runtime/flavorful/wasm.c +++ b/tests/runtime/flavorful/wasm.c @@ -36,9 +36,9 @@ void exports_test_imports() { imports_list_in_variant1_v1_t a; imports_list_in_variant1_v2_t b; imports_list_in_variant1_v3_t c; - a.tag = IMPORTS_LIST_IN_VARIANT1_V1_SOME; + a.is_some = true; imports_string_set(&a.val, "foo"); - b.tag = IMPORTS_LIST_IN_VARIANT1_V2_ERR; + b.is_err = true; imports_string_set(&b.val.err, "bar"); c.tag = IMPORTS_LIST_IN_VARIANT1_V3_0; imports_string_set(&c.val.f0, "baz"); @@ -54,7 +54,7 @@ void exports_test_imports() { { imports_list_in_variant3_t a; - a.tag = IMPORTS_LIST_IN_VARIANT3_SOME; + a.is_some = true; imports_string_set(&a.val, "input3"); imports_string_t b; assert(imports_list_in_variant3(&a, &b)); @@ -90,10 +90,10 @@ void exports_test_imports() { a.ptr = a_val; a.len = 2; - imports_list_expected_void_void_t b; - imports_expected_void_void_t b_val[2]; - b_val[0].tag = 0; - b_val[1].tag = 1; + imports_list_expected_unit_unit_t b; + imports_expected_unit_unit_t b_val[2]; + b_val[0].is_err = false; + b_val[1].is_err = true; b.ptr = b_val; b.len = 2; @@ -105,7 +105,7 @@ void exports_test_imports() { c.len = 2; imports_list_bool_t d; - imports_list_expected_void_void_t e; + imports_list_expected_unit_unit_t e; imports_list_my_errno_t f; imports_list_of_variants(&a, &b, &c, &d, &e, &f); @@ -114,15 +114,15 @@ void exports_test_imports() { assert(d.ptr[1] == true); assert(e.len == 2); - assert(e.ptr[0].tag == 1); - assert(e.ptr[1].tag == 0); + assert(e.ptr[0].is_err == true); + assert(e.ptr[1].is_err == false); assert(f.len == 2); assert(f.ptr[0] == IMPORTS_MY_ERRNO_A); assert(f.ptr[1] == IMPORTS_MY_ERRNO_B); imports_list_bool_free(&d); - imports_list_expected_void_void_free(&e); + imports_list_expected_unit_unit_free(&e); imports_list_my_errno_free(&f); } } @@ -149,11 +149,11 @@ void exports_list_in_record4(exports_list_in_alias_t *a, exports_list_in_alias_t } void exports_list_in_variant1(exports_list_in_variant1_v1_t *a, exports_list_in_variant1_v2_t *b, exports_list_in_variant1_v3_t *c) { - assert(a->tag == EXPORTS_LIST_IN_VARIANT1_V1_SOME); + assert(a->is_some); assert(memcmp(a->val.ptr, "foo", a->val.len) == 0); exports_list_in_variant1_v1_free(a); - assert(b->tag == EXPORTS_LIST_IN_VARIANT1_V2_ERR); + assert(b->is_err); assert(memcmp(b->val.err.ptr, "bar", b->val.err.len) == 0); exports_list_in_variant1_v2_free(b); @@ -168,7 +168,7 @@ bool exports_list_in_variant2(exports_string_t *ret0) { } bool exports_list_in_variant3(exports_list_in_variant3_t *a, exports_string_t *ret0) { - assert(a->tag == EXPORTS_LIST_IN_VARIANT3_SOME); + assert(a->is_some); assert(memcmp(a->val.ptr, "input3", a->val.len) == 0); exports_list_in_variant3_free(a); exports_string_dup(ret0, "output3"); diff --git a/tests/runtime/handles/wasm.c b/tests/runtime/handles/wasm.c index 49ed2674c..63aaa1693 100644 --- a/tests/runtime/handles/wasm.c +++ b/tests/runtime/handles/wasm.c @@ -38,16 +38,16 @@ void exports_test_imports() { } { imports_host_state_param_option_t a; - a.tag = 1; + a.is_some = true; a.val = d; imports_host_state2_param_option(&a); } { imports_host_state_param_result_t a; - a.tag = 0; + a.is_err = false; a.val.ok = d; imports_host_state2_param_result(&a); - a.tag = 1; + a.is_err = true; a.val.err = 2; imports_host_state2_param_result(&a); } @@ -95,7 +95,7 @@ void exports_test_imports() { { imports_host_state_result_result_t a; imports_host_state2_result_result(&a); - assert(a.tag == 0); + assert(!a.is_err); imports_host_state2_free(&a.val.ok); } { @@ -194,7 +194,7 @@ bool exports_wasm_state2_result_option(exports_wasm_state2_t *ret0) { } void exports_wasm_state2_result_result(exports_wasm_state_result_result_t *ret0) { - ret0->tag = EXPORTS_WASM_STATE_RESULT_RESULT_OK; + ret0->is_err = false; ret0->val.ok = exports_wasm_state2_new((void*) 555); } diff --git a/tests/runtime/variants/exports.wit b/tests/runtime/variants/exports.wit index cf2eee779..019db9ca2 100644 --- a/tests/runtime/variants/exports.wit +++ b/tests/runtime/variants/exports.wit @@ -26,5 +26,5 @@ variant-zeros: function(a: zeros) -> zeros type option-typedef = option type bool-typedef = bool -type result-typedef = expected +type result-typedef = expected variant-typedefs: function(a: option-typedef, b: bool-typedef, c: result-typedef) diff --git a/tests/runtime/variants/host.ts b/tests/runtime/variants/host.ts index 66b981997..89f63bdc8 100644 --- a/tests/runtime/variants/host.ts +++ b/tests/runtime/variants/host.ts @@ -23,7 +23,7 @@ async function run() { variantTypedefs(x, y, z) {}, variantEnums(a, b, c) { assert.deepStrictEqual(a, true); - assert.deepStrictEqual(b, { tag: 'ok' }); + assert.deepStrictEqual(b, { tag: 'ok', val: undefined }); assert.deepStrictEqual(c, MyErrno.Success); return [ false, @@ -107,7 +107,7 @@ async function run() { assert.deepStrictEqual(a4, { tag: 'a', val: 4 }); } - wasm.variantTypedefs(null, false, { tag: 'err' }); + wasm.variantTypedefs(null, false, { tag: 'err', val: undefined }); } await run() diff --git a/tests/runtime/variants/imports.wit b/tests/runtime/variants/imports.wit index b817e62bd..7ddab93b3 100644 --- a/tests/runtime/variants/imports.wit +++ b/tests/runtime/variants/imports.wit @@ -24,8 +24,8 @@ variant-zeros: function(a: zeros) -> zeros type option-typedef = option type bool-typedef = bool -type result-typedef = expected +type result-typedef = expected variant-typedefs: function(a: option-typedef, b: bool-typedef, c: result-typedef) enum my-errno { success, a, b } -variant-enums: function(a: bool, b: expected<_, _>, c: my-errno) -> tuple, my-errno> +variant-enums: function(a: bool, b: expected, c: my-errno) -> tuple, my-errno> diff --git a/tests/runtime/variants/wasm.c b/tests/runtime/variants/wasm.c index 184ff9981..cf371e4b4 100644 --- a/tests/runtime/variants/wasm.c +++ b/tests/runtime/variants/wasm.c @@ -6,13 +6,13 @@ void exports_test_imports() { { imports_option_float32_t a; uint8_t r; - a.tag = 1; + a.is_some = true; a.val = 1; assert(imports_roundtrip_option(&a, &r) && r == 1); assert(r == 1); - a.tag = 0; + a.is_some = false; assert(!imports_roundtrip_option(&a, &r)); - a.tag = 2; + a.is_some = true; a.val = 2; assert(imports_roundtrip_option(&a, &r) && r == 2); } @@ -22,21 +22,21 @@ void exports_test_imports() { imports_expected_u32_float32_t a; imports_expected_float64_u8_t b; - a.tag = 0; + a.is_err = false; a.val.ok = 2; imports_roundtrip_result(&a, &b); - assert(b.tag == 0); + assert(!b.is_err); assert(b.val.ok == 2.0); a.val.ok = 4; imports_roundtrip_result(&a, &b); - assert(b.tag == 0); + assert(!b.is_err); assert(b.val.ok == 4); - a.tag = 1; + a.is_err = true; a.val.err = 5.3; imports_roundtrip_result(&a, &b); - assert(b.tag == 1); + assert(b.is_err); assert(b.val.err == 5); } @@ -144,37 +144,38 @@ void exports_test_imports() { { imports_option_typedef_t a; - a.tag = 0; + a.is_some = false; bool b = false; imports_result_typedef_t c; - c.tag = 1; + c.is_err = true; imports_variant_typedefs(&a, b, &c); } { bool a; - imports_expected_void_void_t b; + imports_expected_unit_unit_t b; imports_my_errno_t c; - imports_variant_enums(true, 0, IMPORTS_MY_ERRNO_SUCCESS, &a, &b, &c); + b.is_err = false; + imports_variant_enums(true, &b, IMPORTS_MY_ERRNO_SUCCESS, &a, &b, &c); assert(a == false); - assert(b.tag == 1); + assert(b.is_err); assert(c == IMPORTS_MY_ERRNO_A); } } bool exports_roundtrip_option(exports_option_float32_t *a, uint8_t *ret0) { - if (a->tag) { + if (a->is_some) { *ret0 = a->val; } - return a->tag; + return a->is_some; } void exports_roundtrip_result(exports_expected_u32_float32_t *a, exports_expected_float64_u8_t *ret0) { - ret0->tag = a->tag; - if (a->tag == 0) { - ret0->val.ok = a->val.ok; - } else { + ret0->is_err = a->is_err; + if (a->is_err) { ret0->val.err = a->val.err; + } else { + ret0->val.ok = a->val.ok; } } From 62f123d4a815903c73a4761539d54d8f90721143 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Wed, 4 May 2022 11:00:58 -0700 Subject: [PATCH 2/3] Remove `Variant::as_{option,expected}` ... as these are separate variants now. --- crates/gen-c/src/lib.rs | 40 ++++++++----------------- crates/gen-rust-wasm/src/lib.rs | 2 +- crates/gen-rust/src/lib.rs | 52 +++++++-------------------------- crates/gen-wasmtime/src/lib.rs | 27 +++++++---------- crates/parser/src/lib.rs | 26 ----------------- 5 files changed, 34 insertions(+), 113 deletions(-) diff --git a/crates/gen-c/src/lib.rs b/crates/gen-c/src/lib.rs index 023103263..2a00b3e41 100644 --- a/crates/gen-c/src/lib.rs +++ b/crates/gen-c/src/lib.rs @@ -493,11 +493,7 @@ impl C { continue; } self.src.c(&format!("case {}: {{\n", i)); - let expr = if v.as_option().is_some() { - String::from("&ptr->val") - } else { - format!("&ptr->val.{}", case_field_name(case)) - }; + let expr = format!("&ptr->val.{}", case_field_name(case)); self.free(iface, case_ty, &expr); self.src.c("break;\n"); self.src.c("}\n"); @@ -777,24 +773,16 @@ impl Generator for C { self.src.h("typedef struct {\n"); self.src.h(int_repr(variant.tag)); self.src.h(" tag;\n"); - match variant.as_option() { - Some(ty) => { + self.src.h("union {\n"); + for case in variant.cases.iter() { + if let Some(ty) = &case.ty { self.print_ty(iface, ty); - self.src.h(" val;\n"); - } - None => { - self.src.h("union {\n"); - for case in variant.cases.iter() { - if let Some(ty) = &case.ty { - self.print_ty(iface, ty); - self.src.h(" "); - self.src.h(&case_field_name(case)); - self.src.h(";\n"); - } - } - self.src.h("} val;\n"); + self.src.h(" "); + self.src.h(&case_field_name(case)); + self.src.h(";\n"); } } + self.src.h("} val;\n"); self.src.h("} "); self.print_namespace(iface); self.src.h(&name.to_snake_case()); @@ -1605,10 +1593,8 @@ impl Bindgen for FunctionBindgen<'_> { "const {} *{} = &({}).val", ty, payload, operands[0], )); - if !variant.as_option().is_some() { - self.src.push_str("."); - self.src.push_str(&case_field_name(case)); - } + self.src.push_str("."); + self.src.push_str(&case_field_name(case)); self.src.push_str(";\n"); } } @@ -1644,10 +1630,8 @@ impl Bindgen for FunctionBindgen<'_> { if case.ty.is_some() { assert!(block_results.len() == 1); let mut dst = format!("{}.val", result); - if !variant.as_option().is_some() { - dst.push_str("."); - dst.push_str(&case_field_name(case)); - } + dst.push_str("."); + dst.push_str(&case_field_name(case)); self.store_op(&block_results[0], &dst); } else { assert!(block_results.is_empty()); diff --git a/crates/gen-rust-wasm/src/lib.rs b/crates/gen-rust-wasm/src/lib.rs index 3865fefec..ee0f1243b 100644 --- a/crates/gen-rust-wasm/src/lib.rs +++ b/crates/gen-rust-wasm/src/lib.rs @@ -1049,7 +1049,7 @@ impl Bindgen for FunctionBindgen<'_> { result.push_str(&i.to_string()); } result.push_str(" => "); - self.variant_lift_case(iface, *ty, variant, case, &block, &mut result); + self.variant_lift_case(iface, *ty, case, &block, &mut result); result.push_str(",\n"); } if !unchecked { diff --git a/crates/gen-rust/src/lib.rs b/crates/gen-rust/src/lib.rs index 0c5fc27a8..a7e07fe0c 100644 --- a/crates/gen-rust/src/lib.rs +++ b/crates/gen-rust/src/lib.rs @@ -846,32 +846,17 @@ pub trait RustFunctionGenerator { results: &mut Vec, blocks: Vec, ) { - let has_name = iface.types[id].name.is_some(); self.let_results(nresults, results); self.push_str("match "); self.push_str(operand); self.push_str("{\n"); for (case, block) in ty.cases.iter().zip(blocks) { - if ty.as_expected().is_some() { - self.push_str(&case.name.to_camel_case()); - self.push_str("("); - self.push_str(if case.ty.is_some() { "e" } else { "()" }); - self.push_str(")"); - } else if ty.as_option().is_some() { - self.push_str(&case.name.to_camel_case()); - if case.ty.is_some() { - self.push_str("(e)"); - } - } else if has_name { - let name = self.typename_lower(iface, id); - self.push_str(&name); - self.push_str("::"); - self.push_str(&case_name(&case.name)); - if case.ty.is_some() { - self.push_str("(e)"); - } - } else { - unimplemented!() + let name = self.typename_lower(iface, id); + self.push_str(&name); + self.push_str("::"); + self.push_str(&case_name(&case.name)); + if case.ty.is_some() { + self.push_str("(e)"); } self.push_str(" => { "); self.push_str(&block); @@ -884,34 +869,17 @@ pub trait RustFunctionGenerator { &mut self, iface: &Interface, id: TypeId, - ty: &Variant, case: &Case, block: &str, result: &mut String, ) { - if ty.as_expected().is_some() { - result.push_str(&case.name.to_camel_case()); + result.push_str(&self.typename_lift(iface, id)); + result.push_str("::"); + result.push_str(&case_name(&case.name)); + if case.ty.is_some() { result.push_str("("); result.push_str(block); result.push_str(")"); - } else if ty.as_option().is_some() { - result.push_str(&case.name.to_camel_case()); - if case.ty.is_some() { - result.push_str("("); - result.push_str(block); - result.push_str(")"); - } - } else if iface.types[id].name.is_some() { - result.push_str(&self.typename_lift(iface, id)); - result.push_str("::"); - result.push_str(&case_name(&case.name)); - if case.ty.is_some() { - result.push_str("("); - result.push_str(block); - result.push_str(")"); - } - } else { - unimplemented!() } } } diff --git a/crates/gen-wasmtime/src/lib.rs b/crates/gen-wasmtime/src/lib.rs index a15fb6b5b..cd962ee8d 100644 --- a/crates/gen-wasmtime/src/lib.rs +++ b/crates/gen-wasmtime/src/lib.rs @@ -144,7 +144,7 @@ enum FunctionRet { /// The function returns a `Result` in both wasm and in Rust, but the /// Rust error type is a custom error and must be converted to `err`. The /// `ok` variant payload is provided here too. - CustomToError { ok: Option, err: String }, + CustomToError { ok: Type, err: String }, } impl Wasmtime { @@ -200,16 +200,14 @@ impl Wasmtime { } if let Type::Id(id) = &f.result { - if let TypeDefKind::Variant(v) = &iface.types[*id].kind { - if let Some((ok, Some(err))) = v.as_expected() { - if let Type::Id(err) = err { - if let Some(name) = &iface.types[*err].name { - self.needs_custom_error_to_types.insert(name.clone()); - return FunctionRet::CustomToError { - ok: ok.cloned(), - err: name.to_string(), - }; - } + if let TypeDefKind::Expected(e) = &iface.types[*id].kind { + if let Type::Id(err) = e.err { + if let Some(name) = &iface.types[err].name { + self.needs_custom_error_to_types.insert(name.clone()); + return FunctionRet::CustomToError { + ok: e.ok, + err: name.to_string(), + }; } } } @@ -571,10 +569,7 @@ impl Generator for Wasmtime { } FunctionRet::CustomToError { ok, .. } => { self.push_str(" -> Result<"); - match ok { - Some(ty) => self.print_ty(iface, &ty, TypeMode::Owned), - None => self.push_str("()"), - } + self.print_ty(iface, &ok, TypeMode::Owned); self.push_str(", Self::Error>"); } } @@ -1650,7 +1645,7 @@ impl Bindgen for FunctionBindgen<'_> { for (i, (case, block)) in variant.cases.iter().zip(blocks).enumerate() { result.push_str(&i.to_string()); result.push_str(" => "); - self.variant_lift_case(iface, *ty, variant, case, &block, &mut result); + self.variant_lift_case(iface, *ty, case, &block, &mut result); result.push_str(",\n"); } let variant_name = name.to_camel_case(); diff --git a/crates/parser/src/lib.rs b/crates/parser/src/lib.rs index 1843e0057..18870f21a 100644 --- a/crates/parser/src/lib.rs +++ b/crates/parser/src/lib.rs @@ -171,32 +171,6 @@ impl Variant { .enumerate() .all(|(i, c)| c.name.parse().ok() == Some(i) && c.ty.is_some()) } - - pub fn as_option(&self) -> Option<&Type> { - if self.cases.len() != 2 { - return None; - } - if self.cases[0].name != "none" || self.cases[0].ty.is_some() { - return None; - } - if self.cases[1].name != "some" { - return None; - } - self.cases[1].ty.as_ref() - } - - pub fn as_expected(&self) -> Option<(Option<&Type>, Option<&Type>)> { - if self.cases.len() != 2 { - return None; - } - if self.cases[0].name != "ok" { - return None; - } - if self.cases[1].name != "err" { - return None; - } - Some((self.cases[0].ty.as_ref(), self.cases[1].ty.as_ref())) - } } #[derive(Debug)] From f1e9edd0579fee44cbdd72662295e2a776a8f308 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Thu, 5 May 2022 08:12:30 -0700 Subject: [PATCH 3/3] Review comments --- crates/gen-c/src/lib.rs | 6 +- crates/parser/src/abi.rs | 24 ++-- crates/wasmlink/tests/run.rs | 2 +- crates/wit-component/src/encoding.rs | 18 ++- .../tests/interfaces/variants/variants.wat | 133 ++++++++---------- 5 files changed, 94 insertions(+), 89 deletions(-) diff --git a/crates/gen-c/src/lib.rs b/crates/gen-c/src/lib.rs index 2a00b3e41..82c27cbb6 100644 --- a/crates/gen-c/src/lib.rs +++ b/crates/gen-c/src/lib.rs @@ -1567,7 +1567,7 @@ impl Bindgen for FunctionBindgen<'_> { .drain(self.payloads.len() - variant.cases.len()..) .collect::>(); - let mut variant_results = Vec::new(); + let mut variant_results = Vec::with_capacity(result_types.len()); for ty in result_types.iter() { let name = self.locals.tmp("variant"); results.push(name.clone()); @@ -1652,7 +1652,6 @@ impl Bindgen for FunctionBindgen<'_> { let some_payload = self.payloads.pop().unwrap(); let _none_payload = self.payloads.pop().unwrap(); - let mut variant_results = Vec::new(); for (i, ty) in result_types.iter().enumerate() { let name = self.locals.tmp("option"); results.push(name.clone()); @@ -1664,7 +1663,6 @@ impl Bindgen for FunctionBindgen<'_> { some.push_str(&format!("{name} = {some_result};\n")); let none_result = &none_results[i]; none.push_str(&format!("{name} = {none_result};\n")); - variant_results.push(name); } let op0 = &operands[0]; @@ -1730,7 +1728,6 @@ impl Bindgen for FunctionBindgen<'_> { let err_payload = self.payloads.pop().unwrap(); let ok_payload = self.payloads.pop().unwrap(); - let mut variant_results = Vec::new(); for (i, ty) in result_types.iter().enumerate() { let name = self.locals.tmp("expected"); results.push(name.clone()); @@ -1742,7 +1739,6 @@ impl Bindgen for FunctionBindgen<'_> { ok.push_str(&format!("{name} = {ok_result};\n")); let err_result = &err_results[i]; err.push_str(&format!("{name} = {err_result};\n")); - variant_results.push(name); } let op0 = &operands[0]; diff --git a/crates/parser/src/abi.rs b/crates/parser/src/abi.rs index 0aeb6c80c..11daa5b16 100644 --- a/crates/parser/src/abi.rs +++ b/crates/parser/src/abi.rs @@ -582,27 +582,35 @@ def_instruction! { ty: TypeId, } : [1] => [1], - /// TODO + /// Specialization of `VariantLower` for specifically `option` types, + /// otherwise behaves the same as `VariantLower` (e.g. two blocks for + /// the two cases. OptionLower { payload: &'a Type, ty: TypeId, results: &'a [WasmType], } : [1] => [results.len()], - /// TODO + /// Specialization of `VariantLift` for specifically the `option` + /// type. Otherwise behaves the same as the `VariantLift` instruction + /// with two blocks for the lift. OptionLift { payload: &'a Type, ty: TypeId, } : [1] => [1], - /// TODO + /// Specialization of `VariantLower` for specifically `expected` + /// types, otherwise behaves the same as `VariantLower` (e.g. two blocks + /// for the two cases. ExpectedLower { expected: &'a Expected, ty: TypeId, results: &'a [WasmType], } : [1] => [results.len()], - /// TODO + /// Specialization of `VariantLift` for specifically the `expected` type. Otherwise behaves the same as the `VariantLift` + /// instruction with two blocks for the lift. ExpectedLift { expected: &'a Expected, ty: TypeId, @@ -2048,7 +2056,7 @@ impl<'a, B: Bindgen> Generator<'a, B> { // individual block is pretty simple and just reads the payload type // from the corresponding offset if one is available. TypeDefKind::Variant(variant) => { - self.read_variant_arms_to_memory( + self.read_variant_arms_from_memory( offset, addr, variant.tag, @@ -2062,12 +2070,12 @@ impl<'a, B: Bindgen> Generator<'a, B> { } TypeDefKind::Option(t) => { - self.read_variant_arms_to_memory(offset, addr, Int::U8, [None, Some(t)]); + self.read_variant_arms_from_memory(offset, addr, Int::U8, [None, Some(t)]); self.emit(&OptionLift { payload: t, ty: id }); } TypeDefKind::Expected(e) => { - self.read_variant_arms_to_memory( + self.read_variant_arms_from_memory( offset, addr, Int::U8, @@ -2088,7 +2096,7 @@ impl<'a, B: Bindgen> Generator<'a, B> { } } - fn read_variant_arms_to_memory<'b>( + fn read_variant_arms_from_memory<'b>( &mut self, offset: i32, addr: B::Operand, diff --git a/crates/wasmlink/tests/run.rs b/crates/wasmlink/tests/run.rs index a226c8fef..b9718a817 100644 --- a/crates/wasmlink/tests/run.rs +++ b/crates/wasmlink/tests/run.rs @@ -28,7 +28,7 @@ fn wasmlink_file_tests() -> Result<()> { let entry = entry?; let path = entry.path(); - println!("{:?}", path); + println!("running test {:?}", path); match ( path.file_stem().and_then(OsStr::to_str), diff --git a/crates/wit-component/src/encoding.rs b/crates/wit-component/src/encoding.rs index 61d8a4e3c..6575c0515 100644 --- a/crates/wit-component/src/encoding.rs +++ b/crates/wit-component/src/encoding.rs @@ -147,7 +147,8 @@ impl PartialEq for TypeDefKey<'_> { .all(|(c1, c2)| c1.name == c2.name) } (TypeDefKind::List(t1), TypeDefKind::List(t2)) - | (TypeDefKind::Type(t1), TypeDefKind::Type(t2)) => TypeKey { + | (TypeDefKind::Type(t1), TypeDefKind::Type(t2)) + | (TypeDefKind::Option(t1), TypeDefKind::Option(t2)) => TypeKey { interface: self.interface, ty: *t1, } @@ -155,6 +156,21 @@ impl PartialEq for TypeDefKey<'_> { interface: other.interface, ty: *t2, }), + (TypeDefKind::Expected(e1), TypeDefKind::Expected(e2)) => { + TypeKey { + interface: self.interface, + ty: e1.ok, + } == TypeKey { + interface: other.interface, + ty: e2.ok, + } && TypeKey { + interface: self.interface, + ty: e1.err, + } == TypeKey { + interface: other.interface, + ty: e2.err, + } + } _ => false, } } diff --git a/crates/wit-component/tests/interfaces/variants/variants.wat b/crates/wit-component/tests/interfaces/variants/variants.wat index 9f256be2e..caba3815c 100644 --- a/crates/wit-component/tests/interfaces/variants/variants.wat +++ b/crates/wit-component/tests/interfaces/variants/variants.wat @@ -18,61 +18,46 @@ (type (;16;) (option (type 0))) (type (;17;) (option float32)) (type (;18;) (option (type 3))) - (type (;19;) (option bool)) - (type (;20;) (option (type 19))) - (type (;21;) (func (param "a" (type 12)) (param "b" (type 14)) (param "c" (type 15)) (param "d" (type 16)) (param "e" (type 17)) (param "f" (type 18)) (param "g" (type 20)))) - (type (;22;) (option bool)) - (type (;23;) (option (type 13))) - (type (;24;) (option u32)) - (type (;25;) (option (type 0))) - (type (;26;) (option float32)) - (type (;27;) (option (type 3))) - (type (;28;) (option bool)) - (type (;29;) (option (type 28))) - (type (;30;) (tuple (type 22) (type 23) (type 24) (type 25) (type 26) (type 27) (type 29))) - (type (;31;) (func (result (type 30)))) - (type (;32;) (variant (case "a" s32) (case "b" float32))) - (type (;33;) (variant (case "a" float64) (case "b" float32))) - (type (;34;) (variant (case "a" float64) (case "b" u64))) - (type (;35;) (variant (case "a" u32) (case "b" s64))) - (type (;36;) (variant (case "a" float32) (case "b" s64))) - (type (;37;) (tuple float32 u32)) - (type (;38;) (tuple u32 u32)) - (type (;39;) (variant (case "a" (type 37)) (case "b" (type 38)))) - (type (;40;) (tuple (type 32) (type 33) (type 34) (type 35) (type 36) (type 39))) - (type (;41;) (func (param "a" (type 32)) (param "b" (type 33)) (param "c" (type 34)) (param "d" (type 35)) (param "e" (type 36)) (param "f" (type 39)) (result (type 40)))) - (type (;42;) (expected unit unit)) - (type (;43;) (expected unit (type 0))) - (type (;44;) (expected (type 0) unit)) - (type (;45;) (expected (type 13) (type 13))) - (type (;46;) (expected u32 (type 7))) - (type (;47;) (list u8)) - (type (;48;) (expected string (type 47))) - (type (;49;) (func (param "a" (type 42)) (param "b" (type 43)) (param "c" (type 44)) (param "d" (type 45)) (param "e" (type 46)) (param "f" (type 48)))) - (type (;50;) (expected unit unit)) - (type (;51;) (expected unit (type 0))) - (type (;52;) (expected (type 0) unit)) - (type (;53;) (expected (type 13) (type 13))) - (type (;54;) (expected u32 (type 7))) - (type (;55;) (expected string (type 47))) - (type (;56;) (tuple (type 50) (type 51) (type 52) (type 53) (type 54) (type 55))) - (type (;57;) (func (result (type 56)))) - (type (;58;) (enum "bad1" "bad2")) - (type (;59;) (expected s32 (type 58))) - (type (;60;) (func (result (type 59)))) - (type (;61;) (expected unit (type 58))) - (type (;62;) (func (result (type 61)))) - (type (;63;) (expected (type 58) (type 58))) - (type (;64;) (func (result (type 63)))) - (type (;65;) (tuple s32 u32)) - (type (;66;) (expected (type 65) (type 58))) - (type (;67;) (func (result (type 66)))) - (type (;68;) (option s32)) - (type (;69;) (func (result (type 68)))) - (type (;70;) (option (type 58))) - (type (;71;) (func (result (type 70)))) - (type (;72;) (expected u32 s32)) - (type (;73;) (func (result (type 72)))) + (type (;19;) (option (type 12))) + (type (;20;) (func (param "a" (type 12)) (param "b" (type 14)) (param "c" (type 15)) (param "d" (type 16)) (param "e" (type 17)) (param "f" (type 18)) (param "g" (type 19)))) + (type (;21;) (tuple (type 12) (type 14) (type 15) (type 16) (type 17) (type 18) (type 19))) + (type (;22;) (func (result (type 21)))) + (type (;23;) (variant (case "a" s32) (case "b" float32))) + (type (;24;) (variant (case "a" float64) (case "b" float32))) + (type (;25;) (variant (case "a" float64) (case "b" u64))) + (type (;26;) (variant (case "a" u32) (case "b" s64))) + (type (;27;) (variant (case "a" float32) (case "b" s64))) + (type (;28;) (tuple float32 u32)) + (type (;29;) (tuple u32 u32)) + (type (;30;) (variant (case "a" (type 28)) (case "b" (type 29)))) + (type (;31;) (tuple (type 23) (type 24) (type 25) (type 26) (type 27) (type 30))) + (type (;32;) (func (param "a" (type 23)) (param "b" (type 24)) (param "c" (type 25)) (param "d" (type 26)) (param "e" (type 27)) (param "f" (type 30)) (result (type 31)))) + (type (;33;) (expected unit unit)) + (type (;34;) (expected unit (type 0))) + (type (;35;) (expected (type 0) unit)) + (type (;36;) (expected (type 13) (type 13))) + (type (;37;) (expected u32 (type 7))) + (type (;38;) (list u8)) + (type (;39;) (expected string (type 38))) + (type (;40;) (func (param "a" (type 33)) (param "b" (type 34)) (param "c" (type 35)) (param "d" (type 36)) (param "e" (type 37)) (param "f" (type 39)))) + (type (;41;) (tuple (type 33) (type 34) (type 35) (type 36) (type 37) (type 39))) + (type (;42;) (func (result (type 41)))) + (type (;43;) (enum "bad1" "bad2")) + (type (;44;) (expected s32 (type 43))) + (type (;45;) (func (result (type 44)))) + (type (;46;) (expected unit (type 43))) + (type (;47;) (func (result (type 46)))) + (type (;48;) (expected (type 43) (type 43))) + (type (;49;) (func (result (type 48)))) + (type (;50;) (tuple s32 u32)) + (type (;51;) (expected (type 50) (type 43))) + (type (;52;) (func (result (type 51)))) + (type (;53;) (option s32)) + (type (;54;) (func (result (type 53)))) + (type (;55;) (option (type 43))) + (type (;56;) (func (result (type 55)))) + (type (;57;) (expected u32 s32)) + (type (;58;) (func (result (type 57)))) (export "e1" (type 0)) (export "e1-arg" (type 1)) (export "e1-result" (type 2)) @@ -85,23 +70,23 @@ (export "v1-result" (type 9)) (export "bool-arg" (type 10)) (export "bool-result" (type 11)) - (export "option-arg" (type 21)) - (export "option-result" (type 31)) - (export "casts1" (type 32)) - (export "casts2" (type 33)) - (export "casts3" (type 34)) - (export "casts4" (type 35)) - (export "casts5" (type 36)) - (export "casts6" (type 39)) - (export "casts" (type 41)) - (export "expected-arg" (type 49)) - (export "expected-result" (type 57)) - (export "my-errno" (type 58)) - (export "return-expected-sugar" (type 60)) - (export "return-expected-sugar2" (type 62)) - (export "return-expected-sugar3" (type 64)) - (export "return-expected-sugar4" (type 67)) - (export "return-option-sugar" (type 69)) - (export "return-option-sugar2" (type 71)) - (export "expected-simple" (type 73)) + (export "option-arg" (type 20)) + (export "option-result" (type 22)) + (export "casts1" (type 23)) + (export "casts2" (type 24)) + (export "casts3" (type 25)) + (export "casts4" (type 26)) + (export "casts5" (type 27)) + (export "casts6" (type 30)) + (export "casts" (type 32)) + (export "expected-arg" (type 40)) + (export "expected-result" (type 42)) + (export "my-errno" (type 43)) + (export "return-expected-sugar" (type 45)) + (export "return-expected-sugar2" (type 47)) + (export "return-expected-sugar3" (type 49)) + (export "return-expected-sugar4" (type 52)) + (export "return-option-sugar" (type 54)) + (export "return-option-sugar2" (type 56)) + (export "expected-simple" (type 58)) ) \ No newline at end of file