diff --git a/crates/gen-c/src/lib.rs b/crates/gen-c/src/lib.rs index 5b26b891c..d90d75a28 100644 --- a/crates/gen-c/src/lib.rs +++ b/crates/gen-c/src/lib.rs @@ -488,16 +488,12 @@ impl C { TypeDefKind::Variant(v) => { self.src.c("switch ((int32_t) ptr->tag) {\n"); for (i, case) in v.cases.iter().enumerate() { - let case_ty = match &case.ty { - Some(ty) => ty, - None => continue, - }; - if !self.owns_anything(iface, case_ty) { + if !self.owns_anything(iface, &case.ty) { continue; } self.src.c(&format!("case {}: {{\n", i)); let expr = format!("&ptr->val.{}", case.name.to_snake_case()); - self.free(iface, case_ty, &expr); + self.free(iface, &case.ty, &expr); self.src.c("break;\n"); self.src.c("}\n"); } @@ -554,11 +550,7 @@ impl C { TypeDefKind::Flags(_) => false, TypeDefKind::Enum(_) => false, TypeDefKind::List(_) => true, - TypeDefKind::Variant(v) => v - .cases - .iter() - .filter_map(|c| c.ty.as_ref()) - .any(|t| self.owns_anything(iface, t)), + TypeDefKind::Variant(v) => v.cases.iter().any(|c| self.owns_anything(iface, &c.ty)), TypeDefKind::Union(v) => v .cases .iter() @@ -793,16 +785,17 @@ impl Generator for C { self.docs(docs); self.names.insert(&name.to_snake_case()).unwrap(); self.src.h("typedef struct {\n"); - self.src.h(int_repr(variant.tag)); + self.src.h(int_repr(variant.tag())); self.src.h(" tag;\n"); self.src.h("union {\n"); for case in variant.cases.iter() { - if let Some(ty) = &case.ty { - self.print_ty(iface, ty); - self.src.h(" "); - self.src.h(&case.name.to_snake_case()); - self.src.h(";\n"); + if self.is_empty_type(iface, &case.ty) { + continue; } + self.print_ty(iface, &case.ty); + self.src.h(" "); + self.src.h(&case.name.to_snake_case()); + self.src.h(";\n"); } self.src.h("} val;\n"); self.src.h("} "); @@ -1637,17 +1630,15 @@ impl Bindgen for FunctionBindgen<'_> { variant.cases.iter().zip(blocks).zip(payloads).enumerate() { self.src.push_str(&format!("case {}: {{\n", i)); - if let Some(ty) = &case.ty { - if !self.gen.is_empty_type(iface, ty) { - let ty = self.gen.type_string(iface, ty); - self.src.push_str(&format!( - "const {} *{} = &({}).val", - ty, payload, operands[0], - )); - self.src.push_str("."); - self.src.push_str(&case.name.to_snake_case()); - self.src.push_str(";\n"); - } + if !self.gen.is_empty_type(iface, &case.ty) { + let ty = self.gen.type_string(iface, &case.ty); + self.src.push_str(&format!( + "const {} *{} = &({}).val", + ty, payload, operands[0], + )); + self.src.push_str("."); + self.src.push_str(&case.name.to_snake_case()); + self.src.push_str(";\n"); } self.src.push_str(&block); @@ -1677,15 +1668,13 @@ impl Bindgen for FunctionBindgen<'_> { { self.src.push_str(&format!("case {}: {{\n", i)); self.src.push_str(&block); + assert!(block_results.len() == 1); - if case.ty.is_some() { - assert!(block_results.len() == 1); + if !self.gen.is_empty_type(iface, &case.ty) { let mut dst = format!("{}.val", result); dst.push_str("."); dst.push_str(&case.name.to_snake_case()); self.store_op(&block_results[0], &dst); - } else { - assert!(block_results.is_empty()); } self.src.push_str("break;\n}\n"); } @@ -1810,9 +1799,10 @@ impl Bindgen for FunctionBindgen<'_> { Instruction::OptionLift { payload, ty, .. } => { let (some, some_results) = self.blocks.pop().unwrap(); let (none, none_results) = self.blocks.pop().unwrap(); - assert!(none_results.is_empty()); + assert!(none_results.len() == 1); assert!(some_results.len() == 1); let some_result = &some_results[0]; + assert_eq!(none_results[0], "INVALID"); let ty = self.gen.type_string(iface, &Type::Id(*ty)); let result = self.locals.tmp("option"); diff --git a/crates/gen-core/src/lib.rs b/crates/gen-core/src/lib.rs index a2a186956..76ba77d83 100644 --- a/crates/gen-core/src/lib.rs +++ b/crates/gen-core/src/lib.rs @@ -223,9 +223,7 @@ impl Types { TypeDefKind::Enum(_) => {} TypeDefKind::Variant(v) => { for case in v.cases.iter() { - if let Some(ty) = &case.ty { - info |= self.type_info(iface, ty); - } + info |= self.type_info(iface, &case.ty); } } TypeDefKind::List(ty) => { @@ -279,9 +277,7 @@ impl Types { TypeDefKind::Enum(_) => {} TypeDefKind::Variant(v) => { for case in v.cases.iter() { - if let Some(ty) = &case.ty { - self.set_param_result_ty(iface, ty, param, result) - } + self.set_param_result_ty(iface, &case.ty, param, result) } } TypeDefKind::List(ty) | TypeDefKind::Type(ty) | TypeDefKind::Option(ty) => { diff --git a/crates/gen-js/src/lib.rs b/crates/gen-js/src/lib.rs index fb608f001..cbd0d2cc9 100644 --- a/crates/gen-js/src/lib.rs +++ b/crates/gen-js/src/lib.rs @@ -445,9 +445,9 @@ impl Generator for Js { self.src.ts("tag: \""); self.src.ts(&case.name); self.src.ts("\",\n"); - if let Some(ty) = &case.ty { + if case.ty != Type::Unit { self.src.ts("val: "); - self.print_ty(iface, ty); + self.print_ty(iface, &case.ty); self.src.ts(",\n"); } self.src.ts("}\n"); @@ -1595,7 +1595,7 @@ impl Bindgen for FunctionBindgen<'_> { for (case, (block, block_results)) in variant.cases.iter().zip(blocks) { self.src .js(&format!("case \"{}\": {{\n", case.name.as_str())); - if case.ty.is_some() { + if case.ty != Type::Unit { self.src.js(&format!("const e = variant{}.val;\n", tmp)); } self.src.js(&block); @@ -1633,11 +1633,11 @@ impl Bindgen for FunctionBindgen<'_> { self.src.js(&format!("variant{} = {{\n", tmp)); self.src.js(&format!("tag: \"{}\",\n", case.name.as_str())); - if case.ty.is_some() { - assert!(block_results.len() == 1); + assert!(block_results.len() == 1); + if case.ty != Type::Unit { self.src.js(&format!("val: {},\n", block_results[0])); } else { - assert!(block_results.is_empty()); + assert_eq!(block_results[0], "undefined"); } self.src.js("};\n"); self.src.js("break;\n}\n"); @@ -1789,9 +1789,10 @@ impl Bindgen for FunctionBindgen<'_> { Instruction::OptionLift { payload, .. } => { let (some, some_results) = self.blocks.pop().unwrap(); let (none, none_results) = self.blocks.pop().unwrap(); - assert!(none_results.is_empty()); + assert!(none_results.len() == 1); assert!(some_results.len() == 1); let some_result = &some_results[0]; + assert_eq!(none_results[0], "undefined"); let tmp = self.tmp(); diff --git a/crates/gen-markdown/src/lib.rs b/crates/gen-markdown/src/lib.rs index 769eb69e7..74ab4367a 100644 --- a/crates/gen-markdown/src/lib.rs +++ b/crates/gen-markdown/src/lib.rs @@ -262,10 +262,8 @@ impl Generator for Markdown { format!("{}::{}", name, case.name), format!("#{}.{}", name.to_snake_case(), case.name.to_snake_case()), ); - if let Some(ty) = &case.ty { - self.src.push_str(": "); - self.print_ty(iface, ty, false); - } + self.src.push_str(": "); + self.print_ty(iface, &case.ty, false); self.src.indent(1); self.src.push_str("\n\n"); self.docs(&case.docs); diff --git a/crates/gen-rust-wasm/src/lib.rs b/crates/gen-rust-wasm/src/lib.rs index dbc0f8760..ca434d524 100644 --- a/crates/gen-rust-wasm/src/lib.rs +++ b/crates/gen-rust-wasm/src/lib.rs @@ -1018,10 +1018,11 @@ impl Bindgen for FunctionBindgen<'_> { for (case, block) in variant.cases.iter().zip(blocks) { let case_name = case.name.to_camel_case(); self.push_str(&format!("{name}::{case_name}")); - if case.ty.is_some() { - self.push_str("(e)"); + if case.ty == Type::Unit { + self.push_str(&format!(" => {{\nlet e = ();\n{block}\n}}\n")); + } else { + self.push_str(&format!("(e) => {block},\n")); } - self.push_str(&format!(" => {block},\n")); } self.push_str("};\n"); } @@ -1029,7 +1030,7 @@ impl Bindgen for FunctionBindgen<'_> { // In unchecked mode when this type is a named enum then we know we // defined the type so we can transmute directly into it. Instruction::VariantLift { name, variant, .. } - if variant.cases.iter().all(|c| c.ty.is_none()) && unchecked => + if variant.cases.iter().all(|c| c.ty == Type::Unit) && unchecked => { self.blocks.drain(self.blocks.len() - variant.cases.len()..); let mut result = format!("core::mem::transmute::<_, "); @@ -1037,7 +1038,7 @@ impl Bindgen for FunctionBindgen<'_> { result.push_str(">("); result.push_str(&operands[0]); result.push_str(" as "); - result.push_str(int_repr(variant.tag)); + result.push_str(int_repr(variant.tag())); result.push_str(")"); results.push(result); } @@ -1056,7 +1057,7 @@ impl Bindgen for FunctionBindgen<'_> { } else { i.to_string() }; - let block = if case.ty.is_some() { + let block = if case.ty != Type::Unit { format!("({block})") } else { String::new() @@ -1124,8 +1125,8 @@ impl Bindgen for FunctionBindgen<'_> { let operand = &operands[0]; self.push_str(&format!( "match {operand} {{ - Some(e) => {{ {some} }}, - None => {{ {none} }}, + Some(e) => {some}, + None => {{\nlet e = ();\n{none}\n}}, }};" )); } diff --git a/crates/gen-rust/src/lib.rs b/crates/gen-rust/src/lib.rs index ae1361fa1..7d7899158 100644 --- a/crates/gen-rust/src/lib.rs +++ b/crates/gen-rust/src/lib.rs @@ -457,7 +457,7 @@ pub trait RustGenerator { variant .cases .iter() - .map(|c| (c.name.to_camel_case(), &c.docs, c.ty.as_ref())), + .map(|c| (c.name.to_camel_case(), &c.docs, &c.ty)), docs, ); } @@ -473,7 +473,7 @@ pub trait RustGenerator { .cases .iter() .enumerate() - .map(|(i, c)| (format!("V{i}"), &c.docs, Some(&c.ty))), + .map(|(i, c)| (format!("V{i}"), &c.docs, &c.ty)), docs, ); } @@ -482,7 +482,7 @@ pub trait RustGenerator { &mut self, iface: &Interface, id: TypeId, - cases: impl IntoIterator)> + Clone, + cases: impl IntoIterator + Clone, docs: &Docs, ) where Self: Sized, @@ -504,9 +504,9 @@ pub trait RustGenerator { for (case_name, docs, payload) in cases.clone() { self.rustdoc(docs); self.push_str(&case_name); - if let Some(ty) = payload { + if *payload != Type::Unit { self.push_str("("); - self.print_ty(iface, ty, mode); + self.print_ty(iface, payload, mode); self.push_str(")") } self.push_str(",\n"); @@ -530,7 +530,7 @@ pub trait RustGenerator { id: TypeId, mode: TypeMode, name: &str, - cases: impl IntoIterator)>, + cases: impl IntoIterator, ) where Self: Sized, { @@ -548,12 +548,12 @@ pub trait RustGenerator { self.push_str(name); self.push_str("::"); self.push_str(&case_name); - if payload.is_some() { + if *payload != Type::Unit { self.push_str("(e)"); } self.push_str(" => {\n"); self.push_str(&format!("f.debug_tuple(\"{}::{}\")", name, case_name)); - if payload.is_some() { + if *payload != Type::Unit { self.push_str(".field(e)"); } self.push_str(".finish()\n"); @@ -689,7 +689,10 @@ pub trait RustGenerator { id, TypeMode::Owned, &name, - enum_.cases.iter().map(|c| (c.name.to_camel_case(), None)), + enum_ + .cases + .iter() + .map(|c| (c.name.to_camel_case(), &Type::Unit)), ) } } diff --git a/crates/gen-wasmtime-py/src/lib.rs b/crates/gen-wasmtime-py/src/lib.rs index ad86e353d..e41df1843 100644 --- a/crates/gen-wasmtime-py/src/lib.rs +++ b/crates/gen-wasmtime-py/src/lib.rs @@ -650,14 +650,9 @@ impl Generator for WasmtimePy { let name = format!("{}{}", name.to_camel_case(), case.name.to_camel_case()); self.src.push_str(&format!("class {}:\n", name)); self.indent(); - match &case.ty { - Some(ty) => { - self.src.push_str("value: "); - self.print_ty(iface, ty); - self.src.push_str("\n"); - } - None => self.src.push_str("pass\n"), - } + self.src.push_str("value: "); + self.print_ty(iface, &case.ty); + self.src.push_str("\n"); self.deindent(); self.src.push_str("\n"); cases.push(name); @@ -1658,12 +1653,8 @@ impl Bindgen for FunctionBindgen<'_> { case.name.to_camel_case() )); self.src.indent(2); - - if case.ty.is_some() { - self.src - .push_str(&format!("{} = {}.value\n", payload, operands[0])); - } - + self.src + .push_str(&format!("{} = {}.value\n", payload, operands[0])); self.src.push_str(&block); for (i, result) in block_results.iter().enumerate() { @@ -1713,12 +1704,8 @@ impl Bindgen for FunctionBindgen<'_> { name.to_camel_case(), case.name.to_camel_case() )); - if case.ty.is_some() { - assert!(block_results.len() == 1); - self.src.push_str(&block_results[0]); - } else { - assert!(block_results.is_empty()); - } + assert!(block_results.len() == 1); + self.src.push_str(&block_results[0]); self.src.push_str(")\n"); self.src.deindent(2); } @@ -1848,9 +1835,10 @@ impl Bindgen for FunctionBindgen<'_> { Instruction::OptionLift { ty, .. } => { let (some, some_results) = self.blocks.pop().unwrap(); let (none, none_results) = self.blocks.pop().unwrap(); - assert!(none_results.is_empty()); + assert!(none_results.len() == 1); assert!(some_results.len() == 1); let some_result = &some_results[0]; + assert_eq!(none_results[0], "None"); let result = self.locals.tmp("option"); self.src.push_str(&format!( diff --git a/crates/gen-wasmtime/src/lib.rs b/crates/gen-wasmtime/src/lib.rs index c6c0e2ace..d799fc650 100644 --- a/crates/gen-wasmtime/src/lib.rs +++ b/crates/gen-wasmtime/src/lib.rs @@ -1634,10 +1634,11 @@ impl Bindgen for FunctionBindgen<'_> { for (case, block) in variant.cases.iter().zip(blocks) { let case_name = case.name.to_camel_case(); self.push_str(&format!("{name}::{case_name}")); - if case.ty.is_some() { - self.push_str("(e)"); + if case.ty == Type::Unit { + self.push_str(&format!(" => {{\nlet e = ();\n{block}\n}}\n")); + } else { + self.push_str(&format!("(e) => {block},\n")); } - self.push_str(&format!(" => {block},\n")); } self.push_str("};\n"); } @@ -1651,7 +1652,7 @@ impl Bindgen for FunctionBindgen<'_> { let mut result = format!("match {op0} {{\n"); let name = self.typename_lift(iface, *ty); for (i, (case, block)) in variant.cases.iter().zip(blocks).enumerate() { - let block = if case.ty.is_some() { + let block = if case.ty != Type::Unit { format!("({block})") } else { String::new() @@ -1711,8 +1712,8 @@ impl Bindgen for FunctionBindgen<'_> { let operand = &operands[0]; self.push_str(&format!( "match {operand} {{ - Some(e) => {{ {some} }}, - None => {{ {none} }}, + Some(e) => {some}, + None => {{\nlet e = ();\n{none}\n}}, }};" )); } diff --git a/crates/parser/src/abi.rs b/crates/parser/src/abi.rs index 0d4e9d8a7..431ee2837 100644 --- a/crates/parser/src/abi.rs +++ b/crates/parser/src/abi.rs @@ -1005,25 +1005,25 @@ impl Interface { } TypeDefKind::Variant(v) => { - result.push(v.tag.into()); - self.push_wasm_variants(variant, v.cases.iter().map(|c| c.ty.as_ref()), result); + result.push(v.tag().into()); + self.push_wasm_variants(variant, v.cases.iter().map(|c| &c.ty), result); } TypeDefKind::Enum(e) => result.push(e.tag().into()), TypeDefKind::Option(t) => { result.push(WasmType::I32); - self.push_wasm_variants(variant, [None, Some(t)], result); + self.push_wasm_variants(variant, [&Type::Unit, t], result); } TypeDefKind::Expected(e) => { result.push(WasmType::I32); - self.push_wasm_variants(variant, [Some(&e.ok), Some(&e.err)], result); + self.push_wasm_variants(variant, [&e.ok, &e.err], result); } TypeDefKind::Union(u) => { result.push(WasmType::I32); - self.push_wasm_variants(variant, u.cases.iter().map(|c| Some(&c.ty)), result); + self.push_wasm_variants(variant, u.cases.iter().map(|c| &c.ty), result); } }, } @@ -1032,7 +1032,7 @@ impl Interface { fn push_wasm_variants<'a>( &self, variant: AbiVariant, - tys: impl IntoIterator>, + tys: impl IntoIterator, result: &mut Vec, ) { let mut temp = Vec::new(); @@ -1044,11 +1044,7 @@ impl Interface { // "unification" so we can handle things like `Result` where that turns into `[i32 i32]` where the second // `i32` might be the `f32` bitcasted. - for case in tys { - let ty = match case { - Some(ty) => ty, - None => continue, - }; + for ty in tys { self.push_wasm(variant, ty, &mut temp); for (i, ty) in temp.drain(..).enumerate() { @@ -1546,8 +1542,7 @@ impl<'a, B: Bindgen> Generator<'a, B> { } TypeDefKind::Variant(v) => { - let results = - self.lower_variant_arms(ty, v.cases.iter().map(|c| c.ty.as_ref())); + let results = self.lower_variant_arms(ty, v.cases.iter().map(|c| &c.ty)); self.emit(&VariantLower { variant: v, ty: id, @@ -1563,7 +1558,7 @@ impl<'a, B: Bindgen> Generator<'a, B> { }); } TypeDefKind::Option(t) => { - let results = self.lower_variant_arms(ty, [None, Some(t)]); + let results = self.lower_variant_arms(ty, [&Type::Unit, t]); self.emit(&OptionLower { payload: t, ty: id, @@ -1571,7 +1566,7 @@ impl<'a, B: Bindgen> Generator<'a, B> { }); } TypeDefKind::Expected(e) => { - let results = self.lower_variant_arms(ty, [Some(&e.ok), Some(&e.err)]); + let results = self.lower_variant_arms(ty, [&e.ok, &e.err]); self.emit(&ExpectedLower { expected: e, ty: id, @@ -1579,8 +1574,7 @@ impl<'a, B: Bindgen> Generator<'a, B> { }); } TypeDefKind::Union(union) => { - let results = - self.lower_variant_arms(ty, union.cases.iter().map(|c| Some(&c.ty))); + let results = self.lower_variant_arms(ty, union.cases.iter().map(|c| &c.ty)); self.emit(&UnionLower { union, ty: id, @@ -1595,43 +1589,41 @@ impl<'a, B: Bindgen> Generator<'a, B> { fn lower_variant_arms<'b>( &mut self, ty: &Type, - cases: impl IntoIterator>, + cases: impl IntoIterator, ) -> Vec { use Instruction::*; let mut results = Vec::new(); let mut temp = Vec::new(); let mut casts = Vec::new(); self.iface.push_wasm(self.variant, ty, &mut results); - for (i, case) in cases.into_iter().enumerate() { + for (i, ty) in cases.into_iter().enumerate() { self.push_block(); self.emit(&VariantPayloadName); let payload_name = self.stack.pop().unwrap(); self.emit(&I32Const { val: i as i32 }); let mut pushed = 1; - if let Some(ty) = case { - // Using the payload of this block we lower the type to - // raw wasm values. - self.stack.push(payload_name.clone()); - self.lower(ty); - - // Determine the types of all the wasm values we just - // pushed, and record how many. If we pushed too few - // then we'll need to push some zeros after this. - temp.truncate(0); - self.iface.push_wasm(self.variant, ty, &mut temp); - pushed += temp.len(); - - // For all the types pushed we may need to insert some - // bitcasts. This will go through and cast everything - // to the right type to ensure all blocks produce the - // same set of results. - casts.truncate(0); - for (actual, expected) in temp.iter().zip(&results[1..]) { - casts.push(cast(*actual, *expected)); - } - if casts.iter().any(|c| *c != Bitcast::None) { - self.emit(&Bitcasts { casts: &casts }); - } + // Using the payload of this block we lower the type to + // raw wasm values. + self.stack.push(payload_name.clone()); + self.lower(ty); + + // Determine the types of all the wasm values we just + // pushed, and record how many. If we pushed too few + // then we'll need to push some zeros after this. + temp.truncate(0); + self.iface.push_wasm(self.variant, ty, &mut temp); + pushed += temp.len(); + + // For all the types pushed we may need to insert some + // bitcasts. This will go through and cast everything + // to the right type to ensure all blocks produce the + // same set of results. + casts.truncate(0); + for (actual, expected) in temp.iter().zip(&results[1..]) { + casts.push(cast(*actual, *expected)); + } + if casts.iter().any(|c| *c != Bitcast::None) { + self.emit(&Bitcasts { casts: &casts }); } // If we haven't pushed enough items in this block to match @@ -1760,7 +1752,7 @@ impl<'a, B: Bindgen> Generator<'a, B> { } TypeDefKind::Variant(v) => { - self.lift_variant_arms(ty, v.cases.iter().map(|c| c.ty.as_ref())); + self.lift_variant_arms(ty, v.cases.iter().map(|c| &c.ty)); self.emit(&VariantLift { variant: v, ty: id, @@ -1777,12 +1769,12 @@ impl<'a, B: Bindgen> Generator<'a, B> { } TypeDefKind::Option(t) => { - self.lift_variant_arms(ty, [None, Some(t)]); + self.lift_variant_arms(ty, [&Type::Unit, t]); self.emit(&OptionLift { payload: t, ty: id }); } TypeDefKind::Expected(e) => { - self.lift_variant_arms(ty, [Some(&e.ok), Some(&e.err)]); + self.lift_variant_arms(ty, [&e.ok, &e.err]); self.emit(&ExpectedLift { expected: e, ty: id, @@ -1790,7 +1782,7 @@ impl<'a, B: Bindgen> Generator<'a, B> { } TypeDefKind::Union(union) => { - self.lift_variant_arms(ty, union.cases.iter().map(|c| Some(&c.ty))); + self.lift_variant_arms(ty, union.cases.iter().map(|c| &c.ty)); self.emit(&UnionLift { union, ty: id, @@ -1801,11 +1793,7 @@ impl<'a, B: Bindgen> Generator<'a, B> { } } - fn lift_variant_arms<'b>( - &mut self, - ty: &Type, - cases: impl IntoIterator>, - ) { + fn lift_variant_arms<'b>(&mut self, ty: &Type, cases: impl IntoIterator) { let mut params = Vec::new(); let mut temp = Vec::new(); let mut casts = Vec::new(); @@ -1814,30 +1802,28 @@ impl<'a, B: Bindgen> Generator<'a, B> { .stack .drain(self.stack.len() + 1 - params.len()..) .collect::>(); - for case in cases { + for ty in cases { self.push_block(); - if let Some(ty) = case { - // Push only the values we need for this variant onto - // the stack. - temp.truncate(0); - self.iface.push_wasm(self.variant, ty, &mut temp); - self.stack - .extend(block_inputs[..temp.len()].iter().cloned()); - - // Cast all the types we have on the stack to the actual - // types needed for this variant, if necessary. - casts.truncate(0); - for (actual, expected) in temp.iter().zip(¶ms[1..]) { - casts.push(cast(*expected, *actual)); - } - if casts.iter().any(|c| *c != Bitcast::None) { - self.emit(&Instruction::Bitcasts { casts: &casts }); - } - - // Then recursively lift this variant's payload. - self.lift(ty); + // Push only the values we need for this variant onto + // the stack. + temp.truncate(0); + self.iface.push_wasm(self.variant, ty, &mut temp); + self.stack + .extend(block_inputs[..temp.len()].iter().cloned()); + + // Cast all the types we have on the stack to the actual + // types needed for this variant, if necessary. + casts.truncate(0); + for (actual, expected) in temp.iter().zip(¶ms[1..]) { + casts.push(cast(*expected, *actual)); } - self.finish_block(case.is_some() as usize); + if casts.iter().any(|c| *c != Bitcast::None) { + self.emit(&Instruction::Bitcasts { casts: &casts }); + } + + // Then recursively lift this variant's payload. + self.lift(ty); + self.finish_block(1); } } @@ -1923,8 +1909,8 @@ impl<'a, B: Bindgen> Generator<'a, B> { self.write_variant_arms_to_memory( offset, addr, - v.tag, - v.cases.iter().map(|c| c.ty.as_ref()), + v.tag(), + v.cases.iter().map(|c| &c.ty), ); self.emit(&VariantLower { variant: v, @@ -1935,7 +1921,7 @@ impl<'a, B: Bindgen> Generator<'a, B> { } TypeDefKind::Option(t) => { - self.write_variant_arms_to_memory(offset, addr, Int::U8, [None, Some(t)]); + self.write_variant_arms_to_memory(offset, addr, Int::U8, [&Type::Unit, t]); self.emit(&OptionLower { payload: t, ty: id, @@ -1944,12 +1930,7 @@ impl<'a, B: Bindgen> Generator<'a, B> { } TypeDefKind::Expected(e) => { - self.write_variant_arms_to_memory( - offset, - addr, - Int::U8, - [Some(&e.ok), Some(&e.err)], - ); + self.write_variant_arms_to_memory(offset, addr, Int::U8, [&e.ok, &e.err]); self.emit(&ExpectedLower { expected: e, ty: id, @@ -1968,7 +1949,7 @@ impl<'a, B: Bindgen> Generator<'a, B> { offset, addr, union.tag(), - union.cases.iter().map(|c| Some(&c.ty)), + union.cases.iter().map(|c| &c.ty), ); self.emit(&UnionLower { union, @@ -1986,21 +1967,19 @@ impl<'a, B: Bindgen> Generator<'a, B> { offset: i32, addr: B::Operand, tag: Int, - cases: impl IntoIterator> + Clone, + cases: impl IntoIterator + Clone, ) { let payload_offset = offset + (self.bindgen.sizes().payload_offset(tag, cases.clone()) as i32); - for (i, case) in cases.into_iter().enumerate() { + for (i, ty) in cases.into_iter().enumerate() { self.push_block(); self.emit(&Instruction::VariantPayloadName); let payload_name = self.stack.pop().unwrap(); self.emit(&Instruction::I32Const { val: i as i32 }); self.stack.push(addr.clone()); self.store_intrepr(offset, tag); - if let Some(ty) = case { - self.stack.push(payload_name.clone()); - self.write_to_memory(ty, addr.clone(), payload_offset); - } + self.stack.push(payload_name.clone()); + self.write_to_memory(ty, addr.clone(), payload_offset); self.finish_block(0); } } @@ -2113,8 +2092,8 @@ impl<'a, B: Bindgen> Generator<'a, B> { self.read_variant_arms_from_memory( offset, addr, - variant.tag, - variant.cases.iter().map(|c| c.ty.as_ref()), + variant.tag(), + variant.cases.iter().map(|c| &c.ty), ); self.emit(&VariantLift { variant, @@ -2124,17 +2103,12 @@ impl<'a, B: Bindgen> Generator<'a, B> { } TypeDefKind::Option(t) => { - self.read_variant_arms_from_memory(offset, addr, Int::U8, [None, Some(t)]); + self.read_variant_arms_from_memory(offset, addr, Int::U8, [&Type::Unit, t]); self.emit(&OptionLift { payload: t, ty: id }); } TypeDefKind::Expected(e) => { - self.read_variant_arms_from_memory( - offset, - addr, - Int::U8, - [Some(&e.ok), Some(&e.err)], - ); + self.read_variant_arms_from_memory(offset, addr, Int::U8, [&e.ok, &e.err]); self.emit(&ExpectedLift { expected: e, ty: id, @@ -2152,7 +2126,7 @@ impl<'a, B: Bindgen> Generator<'a, B> { offset, addr, union.tag(), - union.cases.iter().map(|c| Some(&c.ty)), + union.cases.iter().map(|c| &c.ty), ); self.emit(&UnionLift { union, @@ -2169,18 +2143,16 @@ impl<'a, B: Bindgen> Generator<'a, B> { offset: i32, addr: B::Operand, tag: Int, - cases: impl IntoIterator> + Clone, + cases: impl IntoIterator + Clone, ) { self.stack.push(addr.clone()); self.load_intrepr(offset, tag); let payload_offset = offset + (self.bindgen.sizes().payload_offset(tag, cases.clone()) as i32); - for case in cases { + for ty in cases { self.push_block(); - if let Some(ty) = case { - self.read_from_memory(ty, addr.clone(), payload_offset); - } - self.finish_block(case.is_some() as usize); + self.read_from_memory(ty, addr.clone(), payload_offset); + self.finish_block(1); } } diff --git a/crates/parser/src/ast.rs b/crates/parser/src/ast.rs index 4390676c5..a688160df 100644 --- a/crates/parser/src/ast.rs +++ b/crates/parser/src/ast.rs @@ -120,7 +120,6 @@ struct Flag<'a> { } struct Variant<'a> { - tag: Option>>, span: Span, cases: Vec>, } @@ -309,7 +308,6 @@ impl<'a> TypeDef<'a> { tokens.expect(Token::Variant)?; let name = parse_id(tokens)?; let ty = Type::Variant(Variant { - tag: None, span: name.span, cases: parse_list( tokens, diff --git a/crates/parser/src/ast/resolve.rs b/crates/parser/src/ast/resolve.rs index 990e84b88..24ea457d5 100644 --- a/crates/parser/src/ast/resolve.rs +++ b/crates/parser/src/ast/resolve.rs @@ -19,7 +19,7 @@ pub struct Resolver { #[derive(PartialEq, Eq, Hash)] enum Key { - Variant(Vec<(String, Option)>), + Variant(Vec<(String, Type)>), Record(Vec<(String, Type)>), Flags(Vec), Tuple(Vec), @@ -224,10 +224,9 @@ impl Resolver { .map(|case| Case { docs: case.docs.clone(), name: case.name.clone(), - ty: case.ty.map(|t| self.copy_type(dep_name, dep, t)), + ty: self.copy_type(dep_name, dep, case.ty), }) .collect(), - tag: v.tag, }), TypeDefKind::Enum(e) => TypeDefKind::Enum(Enum { cases: e.cases.clone(), @@ -431,19 +430,13 @@ impl Resolver { docs: self.docs(&case.docs), name: case.name.name.to_string(), ty: match &case.ty { - Some(ty) => Some(self.resolve_type(ty)?), - None => None, + Some(ty) => self.resolve_type(ty)?, + None => Type::Unit, }, }) }) .collect::>>()?; - TypeDefKind::Variant(Variant { - tag: match &variant.tag { - Some(ty) => self.get_variant_tag(ty), - None => Variant::infer_tag(cases.len()), - }, - cases, - }) + TypeDefKind::Variant(Variant { cases }) } super::Type::Enum(e) => { if e.cases.is_empty() { @@ -493,32 +486,6 @@ impl Resolver { }) } - fn get_variant_tag(&self, tag: &super::Type) -> Int { - match tag { - super::Type::U8 => Int::U8, - super::Type::U16 => Int::U16, - super::Type::U32 => Int::U32, - super::Type::U64 => Int::U64, - super::Type::Name(name) => { - let ty = self.type_lookup[&*name.name]; - self.get_variant_tag_id(ty) - } - _ => panic!("unknown explicit variant tag"), - } - } - - fn get_variant_tag_id(&self, ty: TypeId) -> Int { - match &self.types[ty].kind { - TypeDefKind::Type(Type::U8) => Int::U8, - TypeDefKind::Type(Type::U16) => Int::U16, - TypeDefKind::Type(Type::U32) => Int::U32, - TypeDefKind::Type(Type::U64) => Int::U64, - TypeDefKind::Type(Type::Id(id)) => self.get_variant_tag_id(*id), - TypeDefKind::Variant(v) => v.tag, - _ => panic!("unknown typedef"), - } - } - fn resolve_type(&mut self, ty: &super::Type<'_>) -> Result { let kind = self.resolve_type_def(ty)?; Ok(self.anon_type_def(TypeDef { @@ -700,7 +667,7 @@ impl Resolver { } TypeDefKind::Variant(v) => { for case in v.cases.iter() { - if let Some(Type::Id(id)) = case.ty { + if let Type::Id(id) = case.ty { self.validate_type_not_recursive(span, id, visiting, valid)?; } } diff --git a/crates/parser/src/lib.rs b/crates/parser/src/lib.rs index 3f6cd9bc1..494f7e101 100644 --- a/crates/parser/src/lib.rs +++ b/crates/parser/src/lib.rs @@ -143,25 +143,21 @@ pub struct Tuple { #[derive(Debug)] pub struct Variant { pub cases: Vec, - /// The bit representation of the width of this variant's tag when the - /// variant is stored in memory. - pub tag: Int, } #[derive(Debug)] pub struct Case { pub docs: Docs, pub name: String, - pub ty: Option, + pub ty: Type, } impl Variant { - pub fn infer_tag(cases: usize) -> Int { - match cases { + pub fn tag(&self) -> Int { + match self.cases.len() { n if n <= u8::max_value() as usize => Int::U8, n if n <= u16::max_value() as usize => Int::U16, n if n <= u32::max_value() as usize => Int::U32, - n if n <= u64::max_value() as usize => Int::U64, _ => panic!("too many cases to fit in a repr"), } } @@ -411,9 +407,7 @@ impl Interface { } TypeDefKind::Variant(v) => { for v in v.cases.iter() { - if let Some(ty) = &v.ty { - self.topo_visit_ty(ty, list, visited); - } + self.topo_visit_ty(&v.ty, list, visited); } } TypeDefKind::Option(ty) => self.topo_visit_ty(ty, list, visited), diff --git a/crates/parser/src/sizealign.rs b/crates/parser/src/sizealign.rs index 79e6ef817..b3651e159 100644 --- a/crates/parser/src/sizealign.rs +++ b/crates/parser/src/sizealign.rs @@ -25,11 +25,11 @@ impl SizeAlign { FlagsRepr::U16 => (2, 2), FlagsRepr::U32(n) => (n * 4, 4), }, - TypeDefKind::Variant(v) => self.variant(v.tag, v.cases.iter().map(|c| c.ty.as_ref())), + TypeDefKind::Variant(v) => self.variant(v.tag(), v.cases.iter().map(|c| &c.ty)), TypeDefKind::Enum(e) => self.variant(e.tag(), []), - TypeDefKind::Option(t) => self.variant(Int::U8, [None, Some(t)]), - TypeDefKind::Expected(e) => self.variant(Int::U8, [Some(&e.ok), Some(&e.err)]), - TypeDefKind::Union(u) => self.variant(u.tag(), u.cases.iter().map(|c| Some(&c.ty))), + TypeDefKind::Option(t) => self.variant(Int::U8, [&Type::Unit, t]), + TypeDefKind::Expected(e) => self.variant(Int::U8, [&e.ok, &e.err]), + TypeDefKind::Union(u) => self.variant(u.tag(), u.cases.iter().map(|c| &c.ty)), } } @@ -68,16 +68,10 @@ impl SizeAlign { .collect() } - pub fn payload_offset<'a>( - &self, - tag: Int, - cases: impl IntoIterator>, - ) -> usize { + pub fn payload_offset<'a>(&self, tag: Int, cases: impl IntoIterator) -> usize { let mut max_align = 1; - for c in cases { - if let Some(ty) = c { - max_align = max_align.max(self.align(ty)); - } + for ty in cases { + max_align = max_align.max(self.align(ty)); } let tag_size = int_size_align(tag).0; align_to(tag_size, max_align) @@ -95,21 +89,15 @@ impl SizeAlign { (align_to(size, align), align) } - fn variant<'a>( - &self, - tag: Int, - types: impl IntoIterator>, - ) -> (usize, usize) { + fn variant<'a>(&self, tag: Int, types: impl IntoIterator) -> (usize, usize) { let (discrim_size, discrim_align) = int_size_align(tag); let mut size = discrim_size; let mut align = discrim_align; for ty in types { - if let Some(ty) = ty { - let case_size = self.size(ty); - let case_align = self.align(ty); - align = align.max(case_align); - size = size.max(align_to(discrim_size, case_align) + case_size); - } + let case_size = self.size(ty); + let case_align = self.align(ty); + align = align.max(case_align); + size = size.max(align_to(discrim_size, case_align) + case_size); } (size, align) } diff --git a/crates/parser/tests/all.rs b/crates/parser/tests/all.rs index 54f9c7af8..4f051257f 100644 --- a/crates/parser/tests/all.rs +++ b/crates/parser/tests/all.rs @@ -188,30 +188,15 @@ fn to_json(i: &Interface) -> String { #[serde(rename_all = "kebab-case")] enum Type { Primitive(String), - Record { - fields: Vec<(String, String)>, - }, - Flags { - flags: Vec, - }, - Enum { - cases: Vec, - }, - Variant { - cases: Vec<(String, Option)>, - }, - Tuple { - types: Vec, - }, + Record { fields: Vec<(String, String)> }, + Flags { flags: Vec }, + Enum { cases: Vec }, + Variant { cases: Vec<(String, String)> }, + Tuple { types: Vec }, Option(String), - Expected { - ok: String, - err: String, - }, + Expected { ok: String, err: String }, List(String), - Union { - cases: Vec, - }, + Union { cases: Vec }, } #[derive(Serialize)] @@ -298,7 +283,7 @@ fn to_json(i: &Interface) -> String { cases: v .cases .iter() - .map(|f| (f.name.clone(), f.ty.as_ref().map(translate_type))) + .map(|f| (f.name.clone(), translate_type(&f.ty))) .collect(), }, TypeDefKind::Option(t) => Type::Option(translate_type(t)), diff --git a/crates/parser/tests/ui/types.wit.result b/crates/parser/tests/ui/types.wit.result index ee9108937..7acc6aa81 100644 --- a/crates/parser/tests/ui/types.wit.result +++ b/crates/parser/tests/ui/types.wit.result @@ -269,7 +269,7 @@ "cases": [ [ "a", - null + "unit" ] ] } @@ -281,11 +281,11 @@ "cases": [ [ "a", - null + "unit" ], [ "b", - null + "unit" ] ] } @@ -297,11 +297,11 @@ "cases": [ [ "a", - null + "unit" ], [ "b", - null + "unit" ] ] } @@ -313,7 +313,7 @@ "cases": [ [ "a", - null + "unit" ], [ "b", @@ -329,7 +329,7 @@ "cases": [ [ "a", - null + "unit" ], [ "b", diff --git a/crates/wasmlink/src/adapter/call.rs b/crates/wasmlink/src/adapter/call.rs index 19feeab92..d4ba03163 100644 --- a/crates/wasmlink/src/adapter/call.rs +++ b/crates/wasmlink/src/adapter/call.rs @@ -488,8 +488,8 @@ impl<'a> CallAdapter<'a> { TypeDefKind::Variant(v) => Self::push_variant_operands( interface, sizes, - v.tag, - v.cases.iter().map(|c| c.ty.as_ref()), + v.tag(), + v.cases.iter().map(|c| &c.ty), params, mode, locals_count, @@ -499,7 +499,7 @@ impl<'a> CallAdapter<'a> { interface, sizes, u.tag(), - u.cases.iter().map(|c| Some(&c.ty)), + u.cases.iter().map(|c| &c.ty), params, mode, locals_count, @@ -509,7 +509,7 @@ impl<'a> CallAdapter<'a> { interface, sizes, Int::U8, - [None, Some(t)], + [&Type::Unit, t], params, mode, locals_count, @@ -519,7 +519,7 @@ impl<'a> CallAdapter<'a> { interface, sizes, Int::U8, - [Some(&e.ok), Some(&e.err)], + [&e.ok, &e.err], params, mode, locals_count, @@ -570,7 +570,7 @@ impl<'a> CallAdapter<'a> { interface: &'a WitInterface, sizes: &SizeAlign, tag: Int, - all_cases: impl IntoIterator>, + all_cases: impl IntoIterator, params: &mut T, mode: PushMode, locals_count: &mut u32, @@ -581,27 +581,25 @@ impl<'a> CallAdapter<'a> { let discriminant = params.next().unwrap(); let mut count = 0; let mut cases = Vec::new(); - for (i, c) in all_cases.into_iter().enumerate() { - if let Some(ty) = c { - let mut iter = params.clone(); - let mut operands = Vec::new(); + for (i, ty) in all_cases.into_iter().enumerate() { + let mut iter = params.clone(); + let mut operands = Vec::new(); - Self::push_operands( - interface, - sizes, - ty, - &mut iter, - mode, - locals_count, - &mut operands, - ); - - if !operands.is_empty() { - cases.push((i as u32, operands)); - } + Self::push_operands( + interface, + sizes, + ty, + &mut iter, + mode, + locals_count, + &mut operands, + ); - count = std::cmp::max(count, params.len() - iter.len()); + if !operands.is_empty() { + cases.push((i as u32, operands)); } + + count = std::cmp::max(count, params.len() - iter.len()); } if !cases.is_empty() { @@ -708,8 +706,8 @@ impl<'a> CallAdapter<'a> { interface, sizes, offset, - v.tag, - v.cases.iter().map(|c| c.ty.as_ref()), + v.tag(), + v.cases.iter().map(|c| &c.ty), mode, locals_count, operands, @@ -719,7 +717,7 @@ impl<'a> CallAdapter<'a> { sizes, offset, u.tag(), - u.cases.iter().map(|c| Some(&c.ty)), + u.cases.iter().map(|c| &c.ty), mode, locals_count, operands, @@ -729,7 +727,7 @@ impl<'a> CallAdapter<'a> { sizes, offset, Int::U8, - [None, Some(t)], + [&Type::Unit, t], mode, locals_count, operands, @@ -739,7 +737,7 @@ impl<'a> CallAdapter<'a> { sizes, offset, Int::U8, - [Some(&e.ok), Some(&e.err)], + [&e.ok, &e.err], mode, locals_count, operands, @@ -778,7 +776,7 @@ impl<'a> CallAdapter<'a> { sizes: &SizeAlign, offset: u32, tag: Int, - all_cases: impl IntoIterator> + Clone, + all_cases: impl IntoIterator + Clone, mode: PushMode, locals_count: &mut u32, operands: &mut Vec>, @@ -786,21 +784,19 @@ impl<'a> CallAdapter<'a> { let payload_offset = sizes.payload_offset(tag, all_cases.clone()) as u32; let mut cases = Vec::new(); - for (i, c) in all_cases.into_iter().enumerate() { - if let Some(ty) = c { - let mut operands = Vec::new(); - Self::push_element_operands( - interface, - sizes, - ty, - offset + payload_offset, - mode, - locals_count, - &mut operands, - ); - if !operands.is_empty() { - cases.push((i as u32, operands)); - } + for (i, ty) in all_cases.into_iter().enumerate() { + let mut operands = Vec::new(); + Self::push_element_operands( + interface, + sizes, + ty, + offset + payload_offset, + mode, + locals_count, + &mut operands, + ); + if !operands.is_empty() { + cases.push((i as u32, operands)); } } diff --git a/crates/wasmlink/src/module.rs b/crates/wasmlink/src/module.rs index 346ebdc45..86468ca0e 100644 --- a/crates/wasmlink/src/module.rs +++ b/crates/wasmlink/src/module.rs @@ -51,11 +51,7 @@ fn has_list(interface: &WitInterface, ty: &WitType) -> bool { TypeDefKind::Flags(_) => false, TypeDefKind::Record(r) => r.fields.iter().any(|f| has_list(interface, &f.ty)), TypeDefKind::Tuple(t) => t.types.iter().any(|ty| has_list(interface, ty)), - TypeDefKind::Variant(v) => v.cases.iter().any(|c| { - c.ty.as_ref() - .map(|t| has_list(interface, t)) - .unwrap_or(false) - }), + TypeDefKind::Variant(v) => v.cases.iter().any(|c| has_list(interface, &c.ty)), TypeDefKind::Union(v) => v.cases.iter().any(|c| has_list(interface, &c.ty)), TypeDefKind::Option(t) => has_list(interface, t), TypeDefKind::Expected(e) => has_list(interface, &e.ok) || has_list(interface, &e.err), diff --git a/crates/wit-component/src/decoding.rs b/crates/wit-component/src/decoding.rs index 23f6c4540..35cc52676 100644 --- a/crates/wit-component/src/decoding.rs +++ b/crates/wit-component/src/decoding.rs @@ -378,8 +378,6 @@ impl<'a> InterfaceDecoder<'a> { let variant_name = variant_name.ok_or_else(|| anyhow!("interface has an unnamed variant type"))?; - let cases_len = cases.len(); - let variant = Variant { cases: cases .map(|(name, case)| { @@ -393,16 +391,10 @@ impl<'a> InterfaceDecoder<'a> { Ok(Case { docs: Docs::default(), name: name.to_string(), - ty: match case.ty { - types::InterfaceTypeRef::Primitive(PrimitiveInterfaceType::Unit) => { - None - } - _ => Some(self.decode_type(&case.ty)?), - }, + ty: self.decode_type(&case.ty)?, }) }) .collect::>()?, - tag: Variant::infer_tag(cases_len), }; Ok(Type::Id(self.alloc_type( diff --git a/crates/wit-component/src/encoding.rs b/crates/wit-component/src/encoding.rs index 83aad1347..e326f7b0f 100644 --- a/crates/wit-component/src/encoding.rs +++ b/crates/wit-component/src/encoding.rs @@ -124,16 +124,13 @@ impl PartialEq for TypeDefKey<'_> { v1.cases.iter().zip(v2.cases.iter()).all(|(c1, c2)| { c1.name == c2.name - && c1 - .ty - .map(|ty| TypeKey { - interface: self.interface, - ty, - }) - .eq(&c2.ty.map(|ty| TypeKey { - interface: other.interface, - ty, - })) + && TypeKey { + interface: self.interface, + ty: c1.ty, + } == TypeKey { + interface: other.interface, + ty: c2.ty, + } }) } (TypeDefKind::Union(v1), TypeDefKind::Union(v2)) => { @@ -229,10 +226,10 @@ impl Hash for TypeDefKey<'_> { state.write_u8(3); for c in &v.cases { c.name.hash(state); - c.ty.map(|ty| TypeKey { + TypeKey { interface: self.interface, - ty, - }) + ty: c.ty, + } .hash(state); } } @@ -573,10 +570,7 @@ impl<'a> TypeEncoder<'a> { .map(|c| { Ok(( c.name.as_str(), - match c.ty.as_ref() { - Some(ty) => self.encode_type(interface, instance, ty)?, - None => InterfaceTypeRef::Primitive(PrimitiveInterfaceType::Unit), - }, + self.encode_type(interface, instance, &c.ty)?, None, // TODO: support defaulting case values in the future )) }) @@ -712,7 +706,7 @@ impl RequiredOptions { Self::for_type(interface, &e.ok) | Self::for_type(interface, &e.err) } TypeDefKind::Variant(v) => { - Self::for_types(interface, v.cases.iter().filter_map(|c| c.ty.as_ref())) + Self::for_types(interface, v.cases.iter().map(|c| &c.ty)) } TypeDefKind::Union(v) => Self::for_types(interface, v.cases.iter().map(|c| &c.ty)), TypeDefKind::Enum(_) => Self::None, diff --git a/crates/wit-component/src/printing.rs b/crates/wit-component/src/printing.rs index ca902666f..9eed3b863 100644 --- a/crates/wit-component/src/printing.rs +++ b/crates/wit-component/src/printing.rs @@ -267,9 +267,7 @@ impl InterfacePrinter { variant: &Variant, ) -> Result<()> { for case in variant.cases.iter() { - if let Some(ty) = &case.ty { - self.declare_type(interface, ty)?; - } + self.declare_type(interface, &case.ty)?; } let name = match name { @@ -279,9 +277,9 @@ impl InterfacePrinter { writeln!(&mut self.output, "variant {} {{", name)?; for case in &variant.cases { write!(&mut self.output, " {}", case.name)?; - if let Some(ty) = &case.ty { + if case.ty != Type::Unit { self.output.push('('); - self.print_type_name(interface, ty)?; + self.print_type_name(interface, &case.ty)?; self.output.push(')'); } self.output.push_str(",\n");