From 34164d31e5dd5c485bf2b705bfe4a65eaa54f8f4 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Thu, 28 Apr 2022 12:25:57 -0700 Subject: [PATCH 1/4] Split out the `flags` type from `record` This commit creates a dedicated `flags` type which is distinct from the `record` type rather than the previous inferred-from-the-structure logic. This also additionally changes the canonical ABI of the `flags` type where 64-bit flags are now passed as two `i32` values instead of one `i64`. This ended up changing a significant amount of the logic internally in each code generator, notably around the new lift/lower behavior. Along the way I tried to refactor code to support 64+ flags in a few more places. While some support may be there, though, this is untested and will need a full-fledged feature in the future. --- crates/gen-c/src/lib.rs | 141 ++++++++++++++-------- crates/gen-core/src/lib.rs | 4 + crates/gen-js/src/lib.rs | 145 ++++++++++++++++------- crates/gen-markdown/src/lib.rs | 38 +++++- crates/gen-rust-wasm/src/lib.rs | 112 +++++++----------- crates/gen-rust/src/lib.rs | 37 ++++++ crates/gen-spidermonkey/src/lib.rs | 28 +++-- crates/gen-wasmtime-py/src/lib.rs | 74 ++++++++---- crates/gen-wasmtime/src/lib.rs | 111 +++++++++-------- crates/parser/src/abi.rs | 171 +++++++++++---------------- crates/parser/src/ast.rs | 25 ++-- crates/parser/src/ast/resolve.rs | 28 +++-- crates/parser/src/lib.rs | 68 +++++++---- crates/parser/src/sizealign.rs | 21 ++-- crates/parser/tests/all.rs | 6 + crates/test-helpers/src/lib.rs | 1 + crates/wasmlink/src/adapter/call.rs | 17 +-- crates/wasmlink/src/module.rs | 1 + crates/wasmlink/tests/flags.wat | 2 +- crates/wasmtime/src/lib.rs | 15 ++- crates/wit-component/src/decoding.rs | 14 +-- crates/wit-component/src/encoding.rs | 30 +++-- crates/wit-component/src/printing.rs | 34 +++--- tests/runtime/many_arguments/host.rs | 4 +- 24 files changed, 656 insertions(+), 471 deletions(-) diff --git a/crates/gen-c/src/lib.rs b/crates/gen-c/src/lib.rs index 067badd3b..02898c0e2 100644 --- a/crates/gen-c/src/lib.rs +++ b/crates/gen-c/src/lib.rs @@ -170,7 +170,7 @@ impl C { Type::Id(id) => match &iface.types[*id].kind { TypeDefKind::Type(t) => self.is_arg_by_pointer(iface, t), TypeDefKind::Variant(v) => !v.is_enum(), - TypeDefKind::Record(r) if r.is_flags() => false, + TypeDefKind::Flags(_) => false, TypeDefKind::Record(_) | TypeDefKind::List(_) => true, }, Type::String => true, @@ -246,14 +246,10 @@ impl C { } match &ty.kind { TypeDefKind::Type(t) => self.print_ty(iface, t), - TypeDefKind::Variant(_) => { - self.public_anonymous_types.insert(*id); - self.private_anonymous_types.remove(id); - self.print_namespace(iface); - self.print_ty_name(iface, &Type::Id(*id)); - self.src.h("_t"); - } - TypeDefKind::Record(_) | TypeDefKind::List(_) => { + TypeDefKind::Flags(_) + | TypeDefKind::Record(_) + | TypeDefKind::List(_) + | TypeDefKind::Variant(_) => { self.public_anonymous_types.insert(*id); self.private_anonymous_types.remove(id); self.print_namespace(iface); @@ -289,6 +285,7 @@ impl C { } match &ty.kind { TypeDefKind::Type(t) => self.print_ty_name(iface, t), + TypeDefKind::Flags(_) => unimplemented!(), TypeDefKind::Record(r) => { assert!(r.is_tuple()); self.src.h("tuple"); @@ -331,7 +328,7 @@ impl C { self.src.h("typedef "); let kind = &iface.types[ty].kind; match kind { - TypeDefKind::Type(_) => { + TypeDefKind::Type(_) | TypeDefKind::Flags(_) => { unreachable!() } TypeDefKind::Record(r) => { @@ -459,6 +456,8 @@ impl C { match &iface.types[id].kind { TypeDefKind::Type(t) => self.free(iface, t, "ptr"), + TypeDefKind::Flags(_) => {} + TypeDefKind::Record(r) => { for field in r.fields.iter() { if !self.owns_anything(iface, &field.ty) { @@ -525,6 +524,7 @@ impl C { match &iface.types[id].kind { TypeDefKind::Type(t) => self.owns_anything(iface, t), TypeDefKind::Record(r) => r.fields.iter().any(|t| self.owns_anything(iface, &t.ty)), + TypeDefKind::Flags(_) => false, TypeDefKind::List(_) => true, TypeDefKind::Variant(v) => v .cases @@ -580,7 +580,7 @@ impl Return { TypeDefKind::Type(t) => self.return_single(iface, t, orig_ty), // Flags are returned as their bare values - TypeDefKind::Record(r) if r.is_flags() => { + TypeDefKind::Flags(_) => { self.scalar = Some(Scalar::Type(*orig_ty)); } @@ -669,41 +669,52 @@ impl Generator for C { let prev = mem::take(&mut self.src.header); self.docs(docs); self.names.insert(&name.to_snake_case()).unwrap(); - if record.is_flags() { - self.src.h("typedef "); - let repr = iface - .flags_repr(record) - .expect("unsupported number of flags"); - self.src.h(int_repr(repr)); + self.src.h("typedef struct {\n"); + for field in record.fields.iter() { + self.print_ty(iface, &field.ty); self.src.h(" "); - self.print_namespace(iface); - self.src.h(&name.to_snake_case()); - self.src.h("_t;\n"); - - for (i, field) in record.fields.iter().enumerate() { - self.src.h(&format!( - "#define {}_{}_{} (1 << {})\n", - iface.name.to_shouty_snake_case(), - name.to_shouty_snake_case(), - field.name.to_shouty_snake_case(), - i, - )); + if record.is_tuple() { + self.src.h("f"); } - } else { - self.src.h("typedef struct {\n"); - for field in record.fields.iter() { - self.print_ty(iface, &field.ty); - self.src.h(" "); - if record.is_tuple() { - self.src.h("f"); - } - self.src.h(&field.name.to_snake_case()); - self.src.h(";\n"); - } - self.src.h("} "); - self.print_namespace(iface); - self.src.h(&name.to_snake_case()); - self.src.h("_t;\n"); + self.src.h(&field.name.to_snake_case()); + self.src.h(";\n"); + } + self.src.h("} "); + self.print_namespace(iface); + self.src.h(&name.to_snake_case()); + self.src.h("_t;\n"); + + self.types + .insert(id, mem::replace(&mut self.src.header, prev)); + } + + fn type_flags( + &mut self, + iface: &Interface, + id: TypeId, + name: &str, + flags: &Flags, + docs: &Docs, + ) { + let prev = mem::take(&mut self.src.header); + self.docs(docs); + self.names.insert(&name.to_snake_case()).unwrap(); + self.src.h("typedef "); + let repr = flags_repr(flags); + self.src.h(int_repr(repr)); + self.src.h(" "); + self.print_namespace(iface); + self.src.h(&name.to_snake_case()); + self.src.h("_t;\n"); + + for (i, flag) in flags.flags.iter().enumerate() { + self.src.h(&format!( + "#define {}_{}_{} (1 << {})\n", + iface.name.to_shouty_snake_case(), + name.to_shouty_snake_case(), + flag.name.to_shouty_snake_case(), + i, + )); } self.types @@ -1398,15 +1409,31 @@ impl Bindgen for FunctionBindgen<'_> { } // TODO: checked - Instruction::FlagsLower { record, .. } | Instruction::FlagsLift { record, .. } => { - match record.num_i32s() { - 0 | 1 => results.push(operands.pop().unwrap()), - _ => panic!("unsupported bitflags"), + Instruction::FlagsLower { flags, ty, .. } => match flags_repr(flags) { + Int::U8 | Int::U16 | Int::U32 => { + results.push(operands.pop().unwrap()); } - } - Instruction::FlagsLower64 { .. } | Instruction::FlagsLift64 { .. } => { - results.push(operands.pop().unwrap()); - } + Int::U64 => { + let name = self.gen.type_string(iface, &Type::Id(*ty)); + let tmp = self.locals.tmp("flags"); + self.src + .push_str(&format!("{name} {tmp} = {};\n", operands[0])); + results.push(format!("{tmp} & 0xffffffff")); + results.push(format!("({tmp} >> 32) & 0xffffffff")); + } + }, + + Instruction::FlagsLift { flags, ty, .. } => match flags_repr(flags) { + Int::U8 | Int::U16 | Int::U32 => { + results.push(operands.pop().unwrap()); + } + Int::U64 => { + let name = self.gen.type_string(iface, &Type::Id(*ty)); + let op0 = &operands[0]; + let op1 = &operands[1]; + results.push(format!("(({name}) ({op0})) | ((({name}) ({op1})) << 32)")); + } + }, Instruction::VariantPayloadName => { let name = self.locals.tmp("payload"); @@ -1824,3 +1851,13 @@ fn case_field_name(case: &Case) -> String { case.name.to_snake_case() } } + +fn flags_repr(f: &Flags) -> Int { + match f.repr() { + FlagsRepr::U8 => Int::U8, + FlagsRepr::U16 => Int::U16, + FlagsRepr::U32(1) => Int::U32, + FlagsRepr::U32(2) => Int::U64, + repr => panic!("unimplemented flags {:?}", repr), + } +} diff --git a/crates/gen-core/src/lib.rs b/crates/gen-core/src/lib.rs index df4940f96..337cfd168 100644 --- a/crates/gen-core/src/lib.rs +++ b/crates/gen-core/src/lib.rs @@ -55,6 +55,7 @@ pub trait Generator { record: &Record, docs: &Docs, ); + fn type_flags(&mut self, iface: &Interface, id: TypeId, name: &str, flags: &Flags, docs: &Docs); fn type_variant( &mut self, iface: &Interface, @@ -88,6 +89,7 @@ pub trait Generator { }; match &ty.kind { TypeDefKind::Record(record) => self.type_record(iface, id, name, record, &ty.docs), + TypeDefKind::Flags(flags) => self.type_flags(iface, id, name, flags, &ty.docs), TypeDefKind::Variant(variant) => { self.type_variant(iface, id, name, variant, &ty.docs) } @@ -188,6 +190,7 @@ impl Types { info |= self.type_info(iface, &field.ty); } } + TypeDefKind::Flags(_) => {} TypeDefKind::Variant(v) => { for case in v.cases.iter() { if let Some(ty) = &case.ty { @@ -225,6 +228,7 @@ impl Types { self.set_param_result_ty(iface, &field.ty, param, result) } } + TypeDefKind::Flags(_) => {} TypeDefKind::Variant(v) => { for case in v.cases.iter() { if let Some(ty) = &case.ty { diff --git a/crates/gen-js/src/lib.rs b/crates/gen-js/src/lib.rs index fe813fd3b..d5e8d92bf 100644 --- a/crates/gen-js/src/lib.rs +++ b/crates/gen-js/src/lib.rs @@ -171,6 +171,7 @@ impl Js { TypeDefKind::Type(t) => self.print_ty(iface, t), TypeDefKind::Record(r) if r.is_tuple() => self.print_tuple(iface, r), TypeDefKind::Record(_) => panic!("anonymous record"), + TypeDefKind::Flags(_) => panic!("anonymous flags"), TypeDefKind::Variant(v) => { if self.is_nullable_option(iface, v) { self.print_ty(iface, v.cases[1].ty.as_ref().unwrap()); @@ -336,37 +337,6 @@ impl Generator for Js { .ts(&format!("export type {} = ", name.to_camel_case())); self.print_tuple(iface, record); self.src.ts(";\n"); - } else if record.is_flags() { - let repr = iface - .flags_repr(record) - .expect("unsupported number of flags"); - let suffix = if repr == Int::U64 { - self.src - .ts(&format!("export type {} = bigint;\n", name.to_camel_case())); - "n" - } else { - self.src - .ts(&format!("export type {} = number;\n", name.to_camel_case())); - "" - }; - let name = name.to_shouty_snake_case(); - for (i, field) in record.fields.iter().enumerate() { - let field = field.name.to_shouty_snake_case(); - self.src.js(&format!( - "export const {}_{} = {}{};\n", - name, - field, - 1u64 << i, - suffix, - )); - self.src.ts(&format!( - "export const {}_{} = {}{};\n", - name, - field, - 1u64 << i, - suffix, - )); - } } else { self.src .ts(&format!("export interface {} {{\n", name.to_camel_case())); @@ -384,6 +354,34 @@ impl Generator for Js { } } + fn type_flags( + &mut self, + _iface: &Interface, + _id: TypeId, + name: &str, + flags: &Flags, + docs: &Docs, + ) { + self.docs(docs); + let repr = js_flags_repr(flags); + let ty = repr.ty(); + let suffix = repr.suffix(); + self.src + .ts(&format!("export type {} = {ty};\n", name.to_camel_case())); + let name = name.to_shouty_snake_case(); + for (i, flag) in flags.flags.iter().enumerate() { + let flag = flag.name.to_shouty_snake_case(); + self.src.js(&format!( + "export const {name}_{flag} = {}{suffix};\n", + 1u128 << i, + )); + self.src.ts(&format!( + "export const {name}_{flag} = {}{suffix};\n", + 1u128 << i, + )); + } + } + fn type_variant( &mut self, iface: &Interface, @@ -1412,20 +1410,56 @@ impl Bindgen for FunctionBindgen<'_> { } } - Instruction::FlagsLower { record, .. } | Instruction::FlagsLift { record, .. } => { - match record.num_i32s() { - 0 | 1 => { - let validate = self.gen.intrinsic(Intrinsic::ValidateFlags); - let mask = (1u64 << record.fields.len()) - 1; - results.push(format!("{}({}, {})", validate, operands[0], mask)); + Instruction::FlagsLower { flags, .. } => { + let repr = js_flags_repr(flags); + let validate = match repr { + JsFlagsRepr::Number => self.gen.intrinsic(Intrinsic::ValidateFlags), + JsFlagsRepr::Bigint => self.gen.intrinsic(Intrinsic::ValidateFlags64), + }; + let op0 = &operands[0]; + let len = flags.flags.len(); + let n = repr.suffix(); + let tmp = self.tmp(); + let mask = (1u128 << len) - 1; + self.src.js(&format!( + "const flags{tmp} = {validate}({op0}, {mask}{n});\n" + )); + match repr { + JsFlagsRepr::Number => { + results.push(format!("flags{}", tmp)); + } + JsFlagsRepr::Bigint => { + for i in 0..flags.repr().count() { + let i = 32 * i; + results.push(format!("Number((flags{tmp} >> {i}n) & 0xffffffffn)",)); + } } - _ => panic!("unsupported bitflags"), } } - Instruction::FlagsLower64 { record, .. } | Instruction::FlagsLift64 { record, .. } => { - let validate = self.gen.intrinsic(Intrinsic::ValidateFlags64); - let mask = (1u128 << record.fields.len()) - 1; - results.push(format!("{}({}, {}n)", validate, operands[0], mask)); + + Instruction::FlagsLift { flags, .. } => { + let repr = js_flags_repr(flags); + let n = repr.suffix(); + let tmp = self.tmp(); + let operand = match repr { + JsFlagsRepr::Number => operands[0].clone(), + JsFlagsRepr::Bigint => { + self.src.js(&format!("let flags{tmp} = 0n;\n")); + for (i, op) in operands.iter().enumerate() { + let i = 32 * i; + self.src + .js(&format!("flags{tmp} |= BigInt({op}) << {i}n;\n",)); + } + format!("flags{tmp}") + } + }; + let validate = match repr { + JsFlagsRepr::Number => self.gen.intrinsic(Intrinsic::ValidateFlags), + JsFlagsRepr::Bigint => self.gen.intrinsic(Intrinsic::ValidateFlags64), + }; + let len = flags.flags.len(); + let mask = (1u128 << len) - 1; + results.push(format!("{validate}({operand}, {mask}{n})")); } Instruction::VariantPayloadName => results.push("e".to_string()), @@ -2238,3 +2272,30 @@ impl Source { self.ts.push_str(s); } } + +enum JsFlagsRepr { + Number, + Bigint, +} + +impl JsFlagsRepr { + fn ty(&self) -> &'static str { + match self { + JsFlagsRepr::Number => "number", + JsFlagsRepr::Bigint => "bigint", + } + } + fn suffix(&self) -> &'static str { + match self { + JsFlagsRepr::Number => "", + JsFlagsRepr::Bigint => "n", + } + } +} + +fn js_flags_repr(f: &Flags) -> JsFlagsRepr { + match f.repr() { + FlagsRepr::U8 | FlagsRepr::U16 | FlagsRepr::U32(1) => JsFlagsRepr::Number, + FlagsRepr::U32(_) => JsFlagsRepr::Bigint, + } +} diff --git a/crates/gen-markdown/src/lib.rs b/crates/gen-markdown/src/lib.rs index 87c7a646c..b71da58ed 100644 --- a/crates/gen-markdown/src/lib.rs +++ b/crates/gen-markdown/src/lib.rs @@ -79,6 +79,7 @@ impl Markdown { } self.src.push_str(")"); } + TypeDefKind::Flags(_) => unreachable!(), TypeDefKind::Variant(v) => { if let Some(t) = v.as_option() { self.src.push_str("option<"); @@ -162,7 +163,7 @@ impl Generator for Markdown { self.src.push_str("record\n\n"); self.print_type_info(id, docs); self.src.push_str("\n### Record Fields\n\n"); - for (i, field) in record.fields.iter().enumerate() { + for field in record.fields.iter() { self.src.push_str(&format!( "- [`{name}`](#{r}.{f}): ", r = name.to_snake_case(), @@ -178,9 +179,38 @@ impl Generator for Markdown { self.src.push_str("\n\n"); self.docs(&field.docs); self.src.deindent(1); - if record.is_flags() { - self.src.push_str(&format!("Bit: {}\n", i)); - } + self.src.push_str("\n"); + } + } + + fn type_flags( + &mut self, + _iface: &Interface, + id: TypeId, + name: &str, + flags: &Flags, + docs: &Docs, + ) { + self.print_type_header(name); + self.src.push_str("record\n\n"); + self.print_type_info(id, docs); + self.src.push_str("\n### Record Fields\n\n"); + for (i, flag) in flags.flags.iter().enumerate() { + self.src.push_str(&format!( + "- [`{name}`](#{r}.{f}): ", + r = name.to_snake_case(), + f = flag.name.to_snake_case(), + name = flag.name, + )); + self.hrefs.insert( + format!("{}::{}", name, flag.name), + format!("#{}.{}", name.to_snake_case(), flag.name.to_snake_case()), + ); + self.src.indent(1); + self.src.push_str("\n\n"); + self.docs(&flag.docs); + self.src.deindent(1); + self.src.push_str(&format!("Bit: {}\n", i)); self.src.push_str("\n"); } } diff --git a/crates/gen-rust-wasm/src/lib.rs b/crates/gen-rust-wasm/src/lib.rs index 4b5caf4f6..407cda398 100644 --- a/crates/gen-rust-wasm/src/lib.rs +++ b/crates/gen-rust-wasm/src/lib.rs @@ -8,7 +8,7 @@ use wit_bindgen_gen_core::wit_parser::abi::{ }; use wit_bindgen_gen_core::{wit_parser::*, Direction, Files, Generator, Source, TypeInfo, Types}; use wit_bindgen_gen_rust::{ - int_repr, wasm_type, FnSig, RustFunctionGenerator, RustGenerator, TypeMode, + int_repr, wasm_type, FnSig, RustFlagsRepr, RustFunctionGenerator, RustGenerator, TypeMode, }; #[derive(Default)] @@ -165,72 +165,53 @@ impl Generator for RustWasm { &mut self, iface: &Interface, id: TypeId, - name: &str, + _name: &str, record: &Record, docs: &Docs, ) { - if record.is_flags() { - self.src - .push_str("wit_bindgen_rust::bitflags::bitflags! {\n"); - self.rustdoc(docs); - let repr = iface - .flags_repr(record) - .expect("unsupported number of flags"); - self.src.push_str(&format!( - "pub struct {}: {} {{\n", - name.to_camel_case(), - int_repr(repr) - )); - for (i, field) in record.fields.iter().enumerate() { - self.rustdoc(&field.docs); - self.src.push_str(&format!( - "const {} = 1 << {};\n", - field.name.to_shouty_snake_case(), - i, - )); - } - self.src.push_str("}\n"); - self.src.push_str("}\n"); + self.print_typedef_record(iface, id, record, docs); + } - // Add a `from_bits_preserve` method. - self.src - .push_str(&format!("impl {} {{\n", name.to_camel_case())); - self.src.push_str(&format!( - " /// Convert from a raw integer, preserving any unknown bits. See\n" - )); - self.src.push_str(&format!(" /// \n")); - self.src.push_str(&format!( - " pub fn from_bits_preserve(bits: {}) -> Self {{\n", - int_repr(repr) - )); - self.src.push_str(&format!(" Self {{ bits }}\n")); - self.src.push_str(&format!(" }}\n")); - self.src.push_str(&format!("}}\n")); - - // Add a `AsI64` etc. method. - let as_trait = match repr { - Int::U8 | Int::U16 | Int::U32 => "i32", - Int::U64 => "i64", - }; - self.src.push_str(&format!( - "impl wit_bindgen_rust::rt::As{} for {} {{\n", - as_trait.to_camel_case(), - name.to_camel_case() - )); - self.src.push_str(&format!(" #[inline]")); + fn type_flags( + &mut self, + _iface: &Interface, + _id: TypeId, + name: &str, + flags: &Flags, + docs: &Docs, + ) { + self.src + .push_str("wit_bindgen_rust::bitflags::bitflags! {\n"); + self.rustdoc(docs); + let repr = RustFlagsRepr::new(flags); + self.src + .push_str(&format!("pub struct {}: {repr} {{\n", name.to_camel_case(),)); + for (i, flag) in flags.flags.iter().enumerate() { + self.rustdoc(&flag.docs); self.src.push_str(&format!( - " fn as_{}(self) -> {} {{\n", - as_trait, as_trait + "const {} = 1 << {};\n", + flag.name.to_shouty_snake_case(), + i, )); - self.src - .push_str(&format!(" self.bits() as {}\n", as_trait)); - self.src.push_str(&format!(" }}")); - self.src.push_str(&format!("}}\n")); - - return; } + self.src.push_str("}\n"); + self.src.push_str("}\n"); - self.print_typedef_record(iface, id, record, docs); + // Add a `from_bits_preserve` method. + self.src + .push_str(&format!("impl {} {{\n", name.to_camel_case())); + self.src.push_str(&format!( + " /// Convert from a raw integer, preserving any unknown bits. See\n" + )); + self.src.push_str(&format!( + " /// \n" + )); + self.src.push_str(&format!( + " pub fn from_bits_preserve(bits: {repr}) -> Self {{\n", + )); + self.src.push_str(&format!(" Self {{ bits }}\n")); + self.src.push_str(&format!(" }}\n")); + self.src.push_str(&format!("}}\n")); } fn type_variant( @@ -934,23 +915,20 @@ impl Bindgen for FunctionBindgen<'_> { )); } - Instruction::FlagsLower { record, .. } => { + Instruction::FlagsLower { flags, .. } => { let tmp = self.tmp(); self.push_str(&format!("let flags{} = {};\n", tmp, operands[0])); - for i in 0..record.num_i32s() { + for i in 0..flags.repr().count() { results.push(format!("(flags{}.bits() >> {}) as i32", tmp, i * 32)); } } - Instruction::FlagsLower64 { .. } => { - let s = operands.pop().unwrap(); - results.push(format!("wit_bindgen_rust::rt::as_i64({})", s)); - } - Instruction::FlagsLift { name, .. } | Instruction::FlagsLift64 { name, .. } => { + Instruction::FlagsLift { name, flags, .. } => { + let repr = RustFlagsRepr::new(flags); let name = name.to_camel_case(); let mut result = format!("{}::empty()", name); for (i, op) in operands.iter().enumerate() { result.push_str(&format!( - " | {}::from_bits_preserve((({} as u32) << {}) as _)", + " | {}::from_bits_preserve((({} as {repr}) << {}) as _)", name, op, i * 32 diff --git a/crates/gen-rust/src/lib.rs b/crates/gen-rust/src/lib.rs index 170dee605..ac10272e0 100644 --- a/crates/gen-rust/src/lib.rs +++ b/crates/gen-rust/src/lib.rs @@ -1,4 +1,5 @@ use heck::*; +use std::fmt; use wit_bindgen_gen_core::wit_parser::abi::{Bitcast, LiftLower, WasmType}; use wit_bindgen_gen_core::{wit_parser::*, TypeInfo, Types}; @@ -276,6 +277,9 @@ pub trait RustGenerator { TypeDefKind::Record(_) => { panic!("unsupported anonymous type reference: record") } + TypeDefKind::Flags(_) => { + panic!("unsupported anonymous type reference: flags") + } TypeDefKind::Type(t) => self.print_ty(iface, t, mode), } @@ -942,3 +946,36 @@ pub fn bitcast(casts: &[Bitcast], operands: &[String], results: &mut Vec }); } } + +pub enum RustFlagsRepr { + U8, + U16, + U32, + U64, + U128, +} + +impl RustFlagsRepr { + pub fn new(f: &Flags) -> RustFlagsRepr { + match f.repr() { + FlagsRepr::U8 => RustFlagsRepr::U8, + FlagsRepr::U16 => RustFlagsRepr::U16, + FlagsRepr::U32(1) => RustFlagsRepr::U32, + FlagsRepr::U32(2) => RustFlagsRepr::U64, + FlagsRepr::U32(3 | 4) => RustFlagsRepr::U128, + FlagsRepr::U32(n) => panic!("unsupported number of flags: {}", n * 32), + } + } +} + +impl fmt::Display for RustFlagsRepr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + RustFlagsRepr::U8 => "u8".fmt(f), + RustFlagsRepr::U16 => "u16".fmt(f), + RustFlagsRepr::U32 => "u32".fmt(f), + RustFlagsRepr::U64 => "u64".fmt(f), + RustFlagsRepr::U128 => "u128".fmt(f), + } + } +} diff --git a/crates/gen-spidermonkey/src/lib.rs b/crates/gen-spidermonkey/src/lib.rs index 0cce7d878..edf35554f 100644 --- a/crates/gen-spidermonkey/src/lib.rs +++ b/crates/gen-spidermonkey/src/lib.rs @@ -16,7 +16,7 @@ use wasm_encoder::Instruction; use wit_bindgen_gen_core::{ wit_parser::{ abi::{self, AbiVariant, WasmSignature, WasmType}, - Docs, Function, Interface, Record, ResourceId, SizeAlign, Type, TypeId, Variant, + Docs, Flags, Function, Interface, Record, ResourceId, SizeAlign, Type, TypeId, Variant, }, Direction, Files, Generator, }; @@ -951,6 +951,18 @@ impl Generator for SpiderMonkeyWasm<'_> { todo!() } + fn type_flags( + &mut self, + iface: &Interface, + id: TypeId, + name: &str, + flags: &Flags, + docs: &Docs, + ) { + let _ = (iface, id, name, flags, docs); + todo!() + } + fn type_variant( &mut self, iface: &Interface, @@ -1878,22 +1890,12 @@ impl abi::Bindgen for Bindgen<'_, '_> { ty: _, } => todo!(), abi::Instruction::FlagsLower { - record: _, - name: _, - ty: _, - } => todo!(), - abi::Instruction::FlagsLower64 { - record: _, + flags: _, name: _, ty: _, } => todo!(), abi::Instruction::FlagsLift { - record: _, - name: _, - ty: _, - } => todo!(), - abi::Instruction::FlagsLift64 { - record: _, + flags: _, name: _, ty: _, } => todo!(), diff --git a/crates/gen-wasmtime-py/src/lib.rs b/crates/gen-wasmtime-py/src/lib.rs index be7cc5bc5..76fc70619 100644 --- a/crates/gen-wasmtime-py/src/lib.rs +++ b/crates/gen-wasmtime-py/src/lib.rs @@ -413,7 +413,7 @@ impl WasmtimePy { TypeDefKind::Record(r) if r.is_tuple() => { self.print_tuple(iface, r.fields.iter().map(|f| &f.ty)) } - TypeDefKind::Record(_) => unreachable!(), + TypeDefKind::Record(_) | TypeDefKind::Flags(_) => unreachable!(), TypeDefKind::Variant(v) => { if let Some(t) = v.as_option() { self.pyimport("typing", "Optional"); @@ -584,21 +584,6 @@ impl Generator for WasmtimePy { if record.is_tuple() { self.src.push_str(&format!("{} = ", name.to_camel_case())); self.print_tuple(iface, record.fields.iter().map(|f| &f.ty)); - } else if record.is_flags() { - self.pyimport("enum", "Flag"); - self.pyimport("enum", "auto"); - self.src - .push_str(&format!("class {}(Flag):\n", name.to_camel_case())); - self.indent(); - for field in record.fields.iter() { - self.docs(&field.docs); - self.src - .push_str(&format!("{} = auto()\n", field.name.to_shouty_snake_case())); - } - if record.fields.is_empty() { - self.src.push_str("pass\n"); - } - self.deindent(); } else { self.pyimport("dataclasses", "dataclass"); self.src @@ -619,6 +604,32 @@ impl Generator for WasmtimePy { self.src.push_str("\n"); } + fn type_flags( + &mut self, + _iface: &Interface, + _id: TypeId, + name: &str, + flags: &Flags, + docs: &Docs, + ) { + self.docs(docs); + self.pyimport("enum", "Flag"); + self.pyimport("enum", "auto"); + self.src + .push_str(&format!("class {}(Flag):\n", name.to_camel_case())); + self.indent(); + for flag in flags.flags.iter() { + self.docs(&flag.docs); + self.src + .push_str(&format!("{} = auto()\n", flag.name.to_shouty_snake_case())); + } + if flags.flags.is_empty() { + self.src.push_str("pass\n"); + } + self.deindent(); + self.src.push_str("\n"); + } + fn type_variant( &mut self, iface: &Interface, @@ -1522,12 +1533,33 @@ impl Bindgen for FunctionBindgen<'_> { results.push(format!("{}({})", name.to_camel_case(), operands.join(", "))); } } - Instruction::FlagsLift { name, .. } | Instruction::FlagsLift64 { name, .. } => { - results.push(format!("{}({})", name.to_camel_case(), operands[0])); - } - Instruction::FlagsLower { .. } | Instruction::FlagsLower64 { .. } => { - results.push(format!("({}).value", operands[0])); + Instruction::FlagsLift { name, .. } => { + let operand = match operands.len() { + 1 => operands[0].clone(), + _ => { + let tmp = self.locals.tmp("flags"); + self.src.push_str(&format!("{tmp} = 0\n")); + for (i, op) in operands.iter().enumerate() { + let i = 32 * i; + self.src.push_str(&format!("{tmp} |= {op} << {i}\n")); + } + tmp + } + }; + results.push(format!("{}({})", name.to_camel_case(), operand)); } + Instruction::FlagsLower { flags, .. } => match flags.repr().count() { + 1 => results.push(format!("({}).value", operands[0])), + n => { + let tmp = self.locals.tmp("flags"); + self.src + .push_str(&format!("{tmp} = ({}).value\n", operands[0])); + for i in 0..n { + let i = 32 * i; + results.push(format!("({tmp} >> {i}) & 0xffffffff")); + } + } + }, Instruction::VariantPayloadName => { let name = self.locals.tmp("payload"); diff --git a/crates/gen-wasmtime/src/lib.rs b/crates/gen-wasmtime/src/lib.rs index 5f0d26e48..0edc28932 100644 --- a/crates/gen-wasmtime/src/lib.rs +++ b/crates/gen-wasmtime/src/lib.rs @@ -9,7 +9,7 @@ use wit_bindgen_gen_core::wit_parser::abi::{ }; use wit_bindgen_gen_core::{wit_parser::*, Direction, Files, Generator, Source, TypeInfo, Types}; use wit_bindgen_gen_rust::{ - int_repr, to_rust_ident, wasm_type, FnSig, RustFunctionGenerator, RustGenerator, TypeMode, + to_rust_ident, wasm_type, FnSig, RustFlagsRepr, RustFunctionGenerator, RustGenerator, TypeMode, }; #[derive(Default)] @@ -328,49 +328,6 @@ impl Generator for Wasmtime { record: &Record, docs: &Docs, ) { - if record.is_flags() { - self.src - .push_str("wit_bindgen_wasmtime::bitflags::bitflags! {\n"); - self.rustdoc(docs); - self.src - .push_str(&format!("pub struct {}: ", name.to_camel_case())); - let repr = iface - .flags_repr(record) - .expect("unsupported number of flags"); - self.int_repr(repr); - self.src.push_str(" {\n"); - for (i, field) in record.fields.iter().enumerate() { - self.rustdoc(&field.docs); - self.src.push_str(&format!( - "const {} = 1 << {};\n", - field.name.to_shouty_snake_case(), - i, - )); - } - self.src.push_str("}\n"); - self.src.push_str("}\n\n"); - - self.src.push_str("impl core::fmt::Display for "); - self.src.push_str(&name.to_camel_case()); - self.src.push_str( - "{\nfn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n", - ); - - self.src.push_str("f.write_str(\""); - self.src.push_str(&name.to_camel_case()); - self.src.push_str("(\")?;\n"); - self.src.push_str("core::fmt::Debug::fmt(self, f)?;\n"); - self.src.push_str("f.write_str(\" (0x\")?;\n"); - self.src - .push_str("core::fmt::LowerHex::fmt(&self.bits, f)?;\n"); - self.src.push_str("f.write_str(\"))\")?;\n"); - self.src.push_str("Ok(())"); - - self.src.push_str("}\n"); - self.src.push_str("}\n\n"); - return; - } - self.print_typedef_record(iface, id, record, docs); // If this record might be used as a slice type in various places then @@ -418,6 +375,51 @@ impl Generator for Wasmtime { } } + fn type_flags( + &mut self, + _iface: &Interface, + _id: TypeId, + name: &str, + flags: &Flags, + docs: &Docs, + ) { + self.src + .push_str("wit_bindgen_wasmtime::bitflags::bitflags! {\n"); + self.rustdoc(docs); + let repr = RustFlagsRepr::new(flags); + self.src + .push_str(&format!("pub struct {}: {repr} {{\n", name.to_camel_case())); + for (i, flag) in flags.flags.iter().enumerate() { + self.rustdoc(&flag.docs); + self.src.push_str(&format!( + "const {} = 1 << {};\n", + flag.name.to_shouty_snake_case(), + i, + )); + } + self.src.push_str("}\n"); + self.src.push_str("}\n\n"); + + self.src.push_str("impl core::fmt::Display for "); + self.src.push_str(&name.to_camel_case()); + self.src.push_str( + "{\nfn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n", + ); + + self.src.push_str("f.write_str(\""); + self.src.push_str(&name.to_camel_case()); + self.src.push_str("(\")?;\n"); + self.src.push_str("core::fmt::Debug::fmt(self, f)?;\n"); + self.src.push_str("f.write_str(\" (0x\")?;\n"); + self.src + .push_str("core::fmt::LowerHex::fmt(&self.bits, f)?;\n"); + self.src.push_str("f.write_str(\"))\")?;\n"); + self.src.push_str("Ok(())"); + + self.src.push_str("}\n"); + self.src.push_str("}\n\n"); + } + fn type_variant( &mut self, iface: &Interface, @@ -1538,36 +1540,29 @@ impl Bindgen for FunctionBindgen<'_> { self.record_lift(iface, *ty, record, operands, results); } - Instruction::FlagsLower { record, .. } => { + Instruction::FlagsLower { flags, .. } => { let tmp = self.tmp(); self.push_str(&format!("let flags{} = {};\n", tmp, operands[0])); - for i in 0..record.num_i32s() { + for i in 0..flags.repr().count() { results.push(format!("(flags{}.bits >> {}) as i32", tmp, i * 32)); } } - Instruction::FlagsLower64 { .. } => { - results.push(format!("({}).bits as i64", operands[0])); - } - Instruction::FlagsLift { record, name, .. } - | Instruction::FlagsLift64 { record, name, .. } => { + Instruction::FlagsLift { flags, name, .. } => { self.gen.needs_validate_flags = true; - let repr = iface - .flags_repr(record) - .expect("unsupported number of flags"); + let repr = RustFlagsRepr::new(flags); let mut flags = String::from("0"); for (i, op) in operands.iter().enumerate() { - flags.push_str(&format!("| (i64::from({}) << {})", op, i * 32)); + flags.push_str(&format!("| (({} as {repr}) << {})", op, i * 32)); } results.push(format!( "validate_flags( {}, - {name}::all().bits() as i64, + {name}::all().bits(), \"{name}\", - |b| {name} {{ bits: b as {ty} }} + |bits| {name} {{ bits }} )?", flags, name = name.to_camel_case(), - ty = int_repr(repr), )); } diff --git a/crates/parser/src/abi.rs b/crates/parser/src/abi.rs index b79ef55e6..998af4b7f 100644 --- a/crates/parser/src/abi.rs +++ b/crates/parser/src/abi.rs @@ -1,6 +1,7 @@ use crate::sizealign::align_to; use crate::{ - Function, Int, Interface, Record, RecordKind, ResourceId, Type, TypeDefKind, TypeId, Variant, + Flags, FlagsRepr, Function, Int, Interface, Record, ResourceId, Type, TypeDefKind, TypeId, + Variant, }; use std::mem; @@ -512,27 +513,17 @@ def_instruction! { /// Converts a language-specific record-of-bools to a list of `i32`. FlagsLower { - record: &'a Record, - name: &'a str, - ty: TypeId, - } : [1] => [record.num_i32s()], - FlagsLower64 { - record: &'a Record, + flags: &'a Flags, name: &'a str, ty: TypeId, - } : [1] => [1], + } : [1] => [flags.repr().count()], /// Converts a list of native wasm `i32` to a language-specific /// record-of-bools. FlagsLift { - record: &'a Record, + flags: &'a Flags, name: &'a str, ty: TypeId, - } : [record.num_i32s()] => [1], - FlagsLift64 { - record: &'a Record, - name: &'a str, - ty: TypeId, - } : [1] => [1], + } : [flags.repr().count()] => [1], // variants @@ -913,21 +904,18 @@ impl Interface { Type::Id(id) => match &self.types[*id].kind { TypeDefKind::Type(t) => self.push_wasm(variant, t, result), - TypeDefKind::Record(r) if r.is_flags() => match self.flags_repr(r) { - Some(int) => result.push(int.into()), - None => { - for _ in 0..r.num_i32s() { - result.push(WasmType::I32); - } - } - }, - TypeDefKind::Record(r) => { for field in r.fields.iter() { self.push_wasm(variant, &field.ty, result); } } + TypeDefKind::Flags(r) => { + for _ in 0..r.repr().count() { + result.push(WasmType::I32); + } + } + TypeDefKind::List(_) => { result.push(WasmType::I32); result.push(WasmType::I32); @@ -963,18 +951,6 @@ impl Interface { } } - pub fn flags_repr(&self, record: &Record) -> Option { - match record.kind { - RecordKind::Flags(Some(hint)) => Some(hint), - RecordKind::Flags(None) if record.fields.len() <= 8 => Some(Int::U8), - RecordKind::Flags(None) if record.fields.len() <= 16 => Some(Int::U16), - RecordKind::Flags(None) if record.fields.len() <= 32 => Some(Int::U32), - RecordKind::Flags(None) if record.fields.len() <= 64 => Some(Int::U64), - RecordKind::Flags(None) => None, - _ => panic!("not a flags record"), - } - } - /// Generates an abstract sequence of instructions which represents this /// function being adapted as an imported function. /// @@ -1425,20 +1401,6 @@ impl<'a, B: Bindgen> Generator<'a, B> { self.emit(&ListLower { element, realloc }); } } - TypeDefKind::Record(record) if record.is_flags() => { - match self.iface.flags_repr(record) { - Some(Int::U64) => self.emit(&FlagsLower64 { - record, - ty: id, - name: self.iface.types[id].name.as_ref().unwrap(), - }), - _ => self.emit(&FlagsLower { - record, - ty: id, - name: self.iface.types[id].name.as_ref().unwrap(), - }), - } - } TypeDefKind::Record(record) => { self.emit(&RecordLower { record, @@ -1455,6 +1417,14 @@ impl<'a, B: Bindgen> Generator<'a, B> { } } + TypeDefKind::Flags(flags) => { + self.emit(&FlagsLower { + flags, + ty: id, + name: self.iface.types[id].name.as_ref().unwrap(), + }); + } + TypeDefKind::Variant(v) => { let mut results = Vec::new(); let mut temp = Vec::new(); @@ -1583,20 +1553,6 @@ impl<'a, B: Bindgen> Generator<'a, B> { }); } } - TypeDefKind::Record(record) if record.is_flags() => { - match self.iface.flags_repr(record) { - Some(Int::U64) => self.emit(&FlagsLift64 { - record, - ty: id, - name: self.iface.types[id].name.as_ref().unwrap(), - }), - _ => self.emit(&FlagsLift { - record, - ty: id, - name: self.iface.types[id].name.as_ref().unwrap(), - }), - } - } TypeDefKind::Record(record) => { let mut temp = Vec::new(); self.iface.push_wasm(self.variant, ty, &mut temp); @@ -1616,6 +1572,13 @@ impl<'a, B: Bindgen> Generator<'a, B> { name: self.iface.types[id].name.as_deref(), }); } + TypeDefKind::Flags(flags) => { + self.emit(&FlagsLift { + flags, + ty: id, + name: self.iface.types[id].name.as_ref().unwrap(), + }); + } TypeDefKind::Variant(v) => { let mut params = Vec::new(); @@ -1694,24 +1657,6 @@ impl<'a, B: Bindgen> Generator<'a, B> { TypeDefKind::Type(t) => self.write_to_memory(t, addr, offset), TypeDefKind::List(_) => self.write_list_to_memory(ty, addr, offset), - TypeDefKind::Record(r) if r.is_flags() => { - self.lower(ty); - match self.iface.flags_repr(r) { - Some(repr) => { - self.stack.push(addr); - self.store_intrepr(offset, repr); - } - None => { - for i in 0..r.num_i32s() { - self.stack.push(addr.clone()); - self.emit(&I32Store { - offset: offset + (i as i32) * 4, - }); - } - } - } - } - // Decompose the record into its components and then write all // the components into memory one-by-one. TypeDefKind::Record(record) => { @@ -1741,6 +1686,28 @@ impl<'a, B: Bindgen> Generator<'a, B> { } } + TypeDefKind::Flags(f) => { + self.lower(ty); + match f.repr() { + FlagsRepr::U8 => { + self.stack.push(addr); + self.store_intrepr(offset, Int::U8); + } + FlagsRepr::U16 => { + self.stack.push(addr); + self.store_intrepr(offset, Int::U16); + } + FlagsRepr::U32(n) => { + for i in (0..n).rev() { + self.stack.push(addr.clone()); + self.emit(&I32Store { + offset: offset + (i as i32) * 4, + }); + } + } + } + } + // Each case will get its own block, and the first item in each // case is writing the discriminant. After that if we have a // payload we write the payload after the discriminant, aligned up @@ -1811,24 +1778,6 @@ impl<'a, B: Bindgen> Generator<'a, B> { TypeDefKind::List(_) => self.read_list_from_memory(ty, addr, offset), - TypeDefKind::Record(r) if r.is_flags() => { - match self.iface.flags_repr(r) { - Some(repr) => { - self.stack.push(addr); - self.load_intrepr(offset, repr); - } - None => { - for i in 0..r.num_i32s() { - self.stack.push(addr.clone()); - self.emit(&I32Load { - offset: offset + (i as i32) * 4, - }); - } - } - } - self.lift(ty); - } - // Read and lift each field individually, adjusting the offset // as we go along, then aggregate all the fields into the // record. @@ -1853,6 +1802,28 @@ impl<'a, B: Bindgen> Generator<'a, B> { }); } + TypeDefKind::Flags(f) => { + match f.repr() { + FlagsRepr::U8 => { + self.stack.push(addr); + self.load_intrepr(offset, Int::U8); + } + FlagsRepr::U16 => { + self.stack.push(addr); + self.load_intrepr(offset, Int::U16); + } + FlagsRepr::U32(n) => { + for i in 0..n { + self.stack.push(addr.clone()); + self.emit(&I32Load { + offset: offset + (i as i32) * 4, + }); + } + } + } + self.lift(ty); + } + // Each case will get its own block, and we'll dispatch to the // right block based on the `i32.load` we initially perform. Each // individual block is pretty simple and just reads the payload type diff --git a/crates/parser/src/ast.rs b/crates/parser/src/ast.rs index 73bdf0196..23d24cfba 100644 --- a/crates/parser/src/ast.rs +++ b/crates/parser/src/ast.rs @@ -91,12 +91,12 @@ enum Type<'a> { Name(Id<'a>), List(Box>), Record(Record<'a>), + Flags(Flags<'a>), Variant(Variant<'a>), } struct Record<'a> { tuple_hint: bool, - flags_repr: Option>>, fields: Vec>, } @@ -106,6 +106,15 @@ struct Field<'a> { ty: Type<'a>, } +struct Flags<'a> { + flags: Vec>, +} + +struct Flag<'a> { + docs: Docs<'a>, + name: Id<'a>, +} + struct Variant<'a> { tag: Option>>, span: Span, @@ -234,20 +243,14 @@ impl<'a> TypeDef<'a> { fn parse_flags(tokens: &mut Tokenizer<'a>, docs: Docs<'a>) -> Result { tokens.expect(Token::Flags)?; let name = parse_id(tokens)?; - let ty = Type::Record(Record { - flags_repr: None, - tuple_hint: false, - fields: parse_list( + let ty = Type::Flags(Flags { + flags: parse_list( tokens, Token::LeftBrace, Token::RightBrace, |docs, tokens| { let name = parse_id(tokens)?; - Ok(Field { - docs, - name, - ty: Type::Bool, - }) + Ok(Flag { docs, name }) }, )?, }); @@ -258,7 +261,6 @@ impl<'a> TypeDef<'a> { tokens.expect(Token::Record)?; let name = parse_id(tokens)?; let ty = Type::Record(Record { - flags_repr: None, tuple_hint: false, fields: parse_list( tokens, @@ -475,7 +477,6 @@ impl<'a> Type<'a> { )?; Ok(Type::Record(Record { fields, - flags_repr: None, tuple_hint: true, })) } diff --git a/crates/parser/src/ast/resolve.rs b/crates/parser/src/ast/resolve.rs index af2b42866..964ee6fc1 100644 --- a/crates/parser/src/ast/resolve.rs +++ b/crates/parser/src/ast/resolve.rs @@ -21,6 +21,7 @@ pub struct Resolver { enum Key { Variant(Vec<(String, Option)>), Record(Vec<(String, Type)>), + Flags(Vec), List(Type), } @@ -204,6 +205,7 @@ impl Resolver { .collect(), kind: r.kind, }), + TypeDefKind::Flags(f) => TypeDefKind::Flags(f.clone()), TypeDefKind::Variant(v) => TypeDefKind::Variant(Variant { cases: v .cases @@ -367,20 +369,23 @@ impl Resolver { TypeDefKind::Record(Record { kind: if record.tuple_hint { RecordKind::Tuple - } else if let Some(hint) = &record.flags_repr { - RecordKind::Flags(Some(match &**hint { - super::Type::U8 => Int::U8, - super::Type::U16 => Int::U16, - super::Type::U32 => Int::U32, - super::Type::U64 => Int::U64, - _ => panic!("unknown explicit flags repr"), - })) } else { - RecordKind::infer(&self.types, &fields) + RecordKind::infer(&fields) }, fields, }) } + super::Type::Flags(flags) => { + let flags = flags + .flags + .iter() + .map(|flag| Flag { + docs: self.docs(&flag.docs), + name: flag.name.name.to_string(), + }) + .collect::>(); + TypeDefKind::Flags(Flags { flags }) + } super::Type::Variant(variant) => { if variant.cases.is_empty() { return Err(Error { @@ -465,6 +470,9 @@ impl Resolver { .map(|case| (case.name.clone(), case.ty)) .collect::>(), ), + TypeDefKind::Flags(r) => { + Key::Flags(r.flags.iter().map(|f| f.name.clone()).collect::>()) + } TypeDefKind::List(ty) => Key::List(*ty), }; let types = &mut self.types; @@ -624,7 +632,7 @@ impl Resolver { } } - TypeDefKind::List(_) | TypeDefKind::Type(_) => {} + TypeDefKind::Flags(_) | TypeDefKind::List(_) | TypeDefKind::Type(_) => {} } valid.insert(ty); diff --git a/crates/parser/src/lib.rs b/crates/parser/src/lib.rs index 9096a1c66..c37e7ab04 100644 --- a/crates/parser/src/lib.rs +++ b/crates/parser/src/lib.rs @@ -46,6 +46,7 @@ pub struct TypeDef { #[derive(Debug)] pub enum TypeDefKind { Record(Record), + Flags(Flags), Variant(Variant), List(Type), Type(Type), @@ -88,7 +89,6 @@ pub struct Record { #[derive(Copy, Clone, Debug)] pub enum RecordKind { Other, - Flags(Option), Tuple, } @@ -103,27 +103,14 @@ impl Record { pub fn is_tuple(&self) -> bool { matches!(self.kind, RecordKind::Tuple) } - - pub fn is_flags(&self) -> bool { - matches!(self.kind, RecordKind::Flags(_)) - } - - pub fn num_i32s(&self) -> usize { - (self.fields.len() + 31) / 32 - } } impl RecordKind { - fn infer(types: &Arena, fields: &[Field]) -> RecordKind { + fn infer(fields: &[Field]) -> RecordKind { if fields.is_empty() { return RecordKind::Other; } - // Structs-of-bools are classified to get represented as bitflags. - if fields.iter().all(|t| is_bool(&t.ty, types)) { - return RecordKind::Flags(None); - } - // fields with consecutive integer names get represented as tuples. if fields .iter() @@ -134,16 +121,43 @@ impl RecordKind { } return RecordKind::Other; + } +} - fn is_bool(t: &Type, types: &Arena) -> bool { - match t { - Type::Bool => true, - Type::Id(v) => match &types[*v].kind { - TypeDefKind::Type(t) => is_bool(t, types), - _ => false, - }, - _ => false, - } +#[derive(Debug, Clone)] +pub struct Flags { + pub flags: Vec, +} + +#[derive(Debug, Clone)] +pub struct Flag { + pub docs: Docs, + pub name: String, +} + +#[derive(Debug)] +pub enum FlagsRepr { + U8, + U16, + U32(usize), +} + +impl Flags { + pub fn repr(&self) -> FlagsRepr { + match self.flags.len() { + n if n <= 8 => FlagsRepr::U8, + n if n <= 16 => FlagsRepr::U16, + n => FlagsRepr::U32(sizealign::align_to(n, 32) / 32), + } + } +} + +impl FlagsRepr { + pub fn count(&self) -> usize { + match self { + FlagsRepr::U8 => 1, + FlagsRepr::U16 => 1, + FlagsRepr::U32(n) => *n, } } } @@ -392,6 +406,7 @@ impl Interface { return; } match &self.types[id].kind { + TypeDefKind::Flags(_) => {} TypeDefKind::Type(t) | TypeDefKind::List(t) => self.topo_visit_ty(t, list, visited), TypeDefKind::Record(r) => { for f in r.fields.iter() { @@ -435,6 +450,11 @@ impl Interface { TypeDefKind::List(_) | TypeDefKind::Variant(_) => false, TypeDefKind::Type(t) => self.all_bits_valid(t), TypeDefKind::Record(r) => r.fields.iter().all(|f| self.all_bits_valid(&f.ty)), + + // FIXME: this could perhaps be `true` for multiples-of-32 but + // seems better to probably leave this as unconditionally + // `false` for now, may want to reconsider later? + TypeDefKind::Flags(_) => false, }, } } diff --git a/crates/parser/src/sizealign.rs b/crates/parser/src/sizealign.rs index 780610f22..5037f8bf4 100644 --- a/crates/parser/src/sizealign.rs +++ b/crates/parser/src/sizealign.rs @@ -1,4 +1,4 @@ -use crate::{Int, Interface, Record, RecordKind, Type, TypeDef, TypeDefKind, Variant}; +use crate::{FlagsRepr, Int, Interface, Record, Type, TypeDef, TypeDefKind, Variant}; #[derive(Default)] pub struct SizeAlign { @@ -18,19 +18,12 @@ impl SizeAlign { match &ty.kind { TypeDefKind::Type(t) => (self.size(t), self.align(t)), TypeDefKind::List(_) => (8, 4), - TypeDefKind::Record(r) => { - if let RecordKind::Flags(repr) = r.kind { - return match repr { - Some(i) => int_size_align(i), - None if r.fields.len() <= 8 => (1, 1), - None if r.fields.len() <= 16 => (2, 2), - None if r.fields.len() <= 32 => (4, 4), - None if r.fields.len() <= 64 => (8, 8), - None => (r.num_i32s() * 4, 4), - }; - } - self.record(r.fields.iter().map(|f| &f.ty)) - } + TypeDefKind::Record(r) => self.record(r.fields.iter().map(|f| &f.ty)), + TypeDefKind::Flags(f) => match f.repr() { + FlagsRepr::U8 => (1, 1), + FlagsRepr::U16 => (2, 2), + FlagsRepr::U32(n) => (n * 4, 4), + }, TypeDefKind::Variant(v) => { let (discrim_size, discrim_align) = int_size_align(v.tag); let mut size = discrim_size; diff --git a/crates/parser/tests/all.rs b/crates/parser/tests/all.rs index e7c27c5f6..1ac96449d 100644 --- a/crates/parser/tests/all.rs +++ b/crates/parser/tests/all.rs @@ -191,6 +191,9 @@ fn to_json(i: &Interface) -> String { Record { fields: Vec<(String, String)>, }, + Flags { + flags: Vec, + }, Variant { cases: Vec<(String, Option)>, }, @@ -268,6 +271,9 @@ fn to_json(i: &Interface) -> String { .map(|f| (f.name.clone(), translate_type(&f.ty))) .collect(), }, + TypeDefKind::Flags(r) => Type::Flags { + flags: r.flags.iter().map(|f| f.name.clone()).collect(), + }, TypeDefKind::Variant(v) => Type::Variant { cases: v .cases diff --git a/crates/test-helpers/src/lib.rs b/crates/test-helpers/src/lib.rs index 8adbe5712..1cf23b5ad 100644 --- a/crates/test-helpers/src/lib.rs +++ b/crates/test-helpers/src/lib.rs @@ -179,6 +179,7 @@ pub fn codegen_rust_wasm_export(input: TokenStream) -> TokenStream { let t = quote_ty(param, iface, t); quote::quote! { Vec<#t> } } + TypeDefKind::Flags(_) => panic!("unknown flags"), TypeDefKind::Record(r) => { let fields = r.fields.iter().map(|f| quote_ty(param, iface, &f.ty)); quote::quote! { (#(#fields,)*) } diff --git a/crates/wasmlink/src/adapter/call.rs b/crates/wasmlink/src/adapter/call.rs index 6bd7dc69c..f5364266f 100644 --- a/crates/wasmlink/src/adapter/call.rs +++ b/crates/wasmlink/src/adapter/call.rs @@ -451,17 +451,12 @@ impl<'a> CallAdapter<'a> { operands: element_operands, }); } + TypeDefKind::Flags(r) => { + for _ in 0..r.repr().count() { + params.next().unwrap(); + } + } TypeDefKind::Record(r) => match r.kind { - RecordKind::Flags(_) => match interface.flags_repr(r) { - Some(_) => { - params.next().unwrap(); - } - None => { - for _ in 0..r.num_i32s() { - params.next().unwrap(); - } - } - }, RecordKind::Tuple | RecordKind::Other => { for f in &r.fields { Self::push_operands( @@ -612,8 +607,8 @@ impl<'a> CallAdapter<'a> { operands: element_operands, }); } + TypeDefKind::Flags(_) => {} TypeDefKind::Record(r) => match r.kind { - RecordKind::Flags(_) => {} RecordKind::Tuple | RecordKind::Other => { let offsets = sizes.field_offsets(r); diff --git a/crates/wasmlink/src/module.rs b/crates/wasmlink/src/module.rs index bf60744ef..255439a6d 100644 --- a/crates/wasmlink/src/module.rs +++ b/crates/wasmlink/src/module.rs @@ -48,6 +48,7 @@ fn has_list(interface: &WitInterface, ty: &WitType) -> bool { Type::Id(id) => match &interface.types[*id].kind { TypeDefKind::List(_) => true, TypeDefKind::Type(t) => has_list(interface, t), + TypeDefKind::Flags(_) => false, TypeDefKind::Record(r) => r.fields.iter().any(|f| has_list(interface, &f.ty)), TypeDefKind::Variant(v) => v.cases.iter().any(|c| { c.ty.as_ref() diff --git a/crates/wasmlink/tests/flags.wat b/crates/wasmlink/tests/flags.wat index cd395d28b..7e9270f47 100644 --- a/crates/wasmlink/tests/flags.wat +++ b/crates/wasmlink/tests/flags.wat @@ -17,7 +17,7 @@ (func (export "roundtrip-flag32") (param i32) (result i32) unreachable ) - (func (export "roundtrip-flag64") (param i64) (result i64) + (func (export "roundtrip-flag64") (param i32 i32) (result i32) unreachable ) ) diff --git a/crates/wasmtime/src/lib.rs b/crates/wasmtime/src/lib.rs index 2c2f49129..e515b46ad 100644 --- a/crates/wasmtime/src/lib.rs +++ b/crates/wasmtime/src/lib.rs @@ -73,13 +73,16 @@ pub mod rt { Trap::new(msg) } - pub fn validate_flags( - bits: i64, - all: i64, + pub fn validate_flags( + bits: T, + all: T, name: &str, - mk: impl FnOnce(i64) -> U, - ) -> Result { - if bits & !all != 0 { + mk: impl FnOnce(T) -> U, + ) -> Result + where + T: std::ops::Not + std::ops::BitAnd + From + PartialEq + Copy, + { + if bits & !all != 0u8.into() { let msg = format!("invalid flags specified for `{}`", name); Err(Trap::new(msg)) } else { diff --git a/crates/wit-component/src/decoding.rs b/crates/wit-component/src/decoding.rs index 4bad21968..08eabbf36 100644 --- a/crates/wit-component/src/decoding.rs +++ b/crates/wit-component/src/decoding.rs @@ -5,8 +5,8 @@ use wasmparser::{ Validator, WasmFeatures, }; use wit_parser::{ - validate_id, Case, Docs, Field, Function, FunctionKind, Interface, Record, RecordKind, Type, - TypeDef, TypeDefKind, TypeId, Variant, + validate_id, Case, Docs, Field, Flag, Flags, Function, FunctionKind, Interface, Record, + RecordKind, Type, TypeDef, TypeDefKind, TypeId, Variant, }; /// Represents information about a decoded WebAssembly component. @@ -446,8 +446,8 @@ impl<'a> InterfaceDecoder<'a> { let flags_name = flags_name.ok_or_else(|| anyhow!("interface has an unnamed flags type"))?; - let record = Record { - fields: names + let flags = Flags { + flags: names .map(|name| { validate_id(name).with_context(|| { format!( @@ -456,18 +456,16 @@ impl<'a> InterfaceDecoder<'a> { ) })?; - Ok(Field { + Ok(Flag { docs: Docs::default(), name: name.clone(), - ty: self.decode_primitive(PrimitiveInterfaceType::Bool)?, }) }) .collect::>()?, - kind: RecordKind::Flags(None), }; Ok(Type::Id( - self.alloc_type(Some(flags_name), TypeDefKind::Record(record)), + self.alloc_type(Some(flags_name), TypeDefKind::Flags(flags)), )) } diff --git a/crates/wit-component/src/encoding.rs b/crates/wit-component/src/encoding.rs index 919329a42..f5aeb5377 100644 --- a/crates/wit-component/src/encoding.rs +++ b/crates/wit-component/src/encoding.rs @@ -13,7 +13,8 @@ use wasm_encoder::*; use wasmparser::{Validator, WasmFeatures}; use wit_parser::{ abi::{AbiVariant, WasmSignature, WasmType}, - Function, FunctionKind, Interface, Record, RecordKind, Type, TypeDef, TypeDefKind, Variant, + Flags, Function, FunctionKind, Interface, Record, RecordKind, Type, TypeDef, TypeDefKind, + Variant, }; const INDIRECT_TABLE_NAME: &str = "$imports"; @@ -141,8 +142,14 @@ impl Hash for TypeDefKey<'_> { .hash(state); } } - TypeDefKind::Variant(v) => { + TypeDefKind::Flags(r) => { state.write_u8(1); + for f in &r.flags { + f.name.hash(state); + } + } + TypeDefKind::Variant(v) => { + state.write_u8(2); for c in &v.cases { c.name.hash(state); c.ty.map(|ty| TypeKey { @@ -153,7 +160,7 @@ impl Hash for TypeDefKey<'_> { } } TypeDefKind::List(ty) => { - state.write_u8(2); + state.write_u8(3); TypeKey { interface: self.interface, ty: *ty, @@ -161,7 +168,7 @@ impl Hash for TypeDefKey<'_> { .hash(state); } TypeDefKind::Type(ty) => { - state.write_u8(3); + state.write_u8(4); TypeKey { interface: self.interface, ty: *ty, @@ -345,6 +352,7 @@ impl<'a> TypeEncoder<'a> { } else { let mut encoded = match &ty.kind { TypeDefKind::Record(r) => self.encode_record(interface, instance, r)?, + TypeDefKind::Flags(r) => self.encode_flags(r)?, TypeDefKind::Variant(v) => self.encode_variant(interface, instance, v)?, TypeDefKind::List(ty) => { let ty = self.encode_type(interface, instance, ty)?; @@ -410,12 +418,6 @@ impl<'a> TypeEncoder<'a> { encoder.record(fields); InterfaceTypeRef::Type(index) } - RecordKind::Flags(_) => { - let index = self.types.len(); - let encoder = self.types.interface_type(); - encoder.flags(record.fields.iter().map(|f| f.name.as_str())); - InterfaceTypeRef::Type(index) - } RecordKind::Tuple => { let tys = record .fields @@ -430,6 +432,13 @@ impl<'a> TypeEncoder<'a> { }) } + fn encode_flags(&mut self, flags: &Flags) -> Result { + let index = self.types.len(); + let encoder = self.types.interface_type(); + encoder.flags(flags.flags.iter().map(|f| f.name.as_str())); + Ok(InterfaceTypeRef::Type(index)) + } + fn encode_variant( &mut self, interface: &'a Interface, @@ -565,6 +574,7 @@ impl RequiredOptions { TypeDefKind::Record(r) => { Self::for_types(interface, r.fields.iter().map(|f| &f.ty)) } + TypeDefKind::Flags(_) => Self::None, TypeDefKind::Variant(v) => { Self::for_types(interface, v.cases.iter().filter_map(|c| c.ty.as_ref())) } diff --git a/crates/wit-component/src/printing.rs b/crates/wit-component/src/printing.rs index cba6cf559..39cacf1d9 100644 --- a/crates/wit-component/src/printing.rs +++ b/crates/wit-component/src/printing.rs @@ -1,7 +1,7 @@ use anyhow::{bail, Result}; use std::collections::HashSet; use std::fmt::Write; -use wit_parser::{Interface, Record, Type, TypeDefKind, TypeId, Variant}; +use wit_parser::{Flags, Interface, Record, Type, TypeDefKind, TypeId, Variant}; /// A utility for printing WebAssembly interface definitions to a string. #[derive(Default)] @@ -72,6 +72,9 @@ impl InterfacePrinter { TypeDefKind::Record(r) => { self.print_record_type(interface, r)?; } + TypeDefKind::Flags(_) => { + bail!("interface has unnamed flags type") + } TypeDefKind::Variant(v) => { self.print_variant_type(interface, v)?; } @@ -160,6 +163,7 @@ impl InterfacePrinter { TypeDefKind::Record(r) => { self.declare_record(interface, ty.name.as_deref(), r)? } + TypeDefKind::Flags(f) => self.declare_flags(ty.name.as_deref(), f)?, TypeDefKind::Variant(v) => { self.declare_variant(interface, ty.name.as_deref(), v)? } @@ -201,20 +205,6 @@ impl InterfacePrinter { return Ok(()); } - if record.is_flags() { - match name { - Some(name) => { - writeln!(&mut self.output, "flags {} {{", name)?; - for field in &record.fields { - writeln!(&mut self.output, " {},", field.name)?; - } - self.output.push_str("}\n\n"); - } - None => bail!("interface has unnamed flags type"), - } - return Ok(()); - } - match name { Some(name) => { writeln!(&mut self.output, "record {} {{", name)?; @@ -231,6 +221,20 @@ impl InterfacePrinter { } } + fn declare_flags(&mut self, name: Option<&str>, flags: &Flags) -> Result<()> { + match name { + Some(name) => { + writeln!(&mut self.output, "flags {} {{", name)?; + for flag in &flags.flags { + writeln!(&mut self.output, " {},", flag.name)?; + } + self.output.push_str("}\n\n"); + } + None => bail!("interface has unnamed flags type"), + } + Ok(()) + } + fn declare_variant( &mut self, interface: &Interface, diff --git a/tests/runtime/many_arguments/host.rs b/tests/runtime/many_arguments/host.rs index 61f962580..1123b85ab 100644 --- a/tests/runtime/many_arguments/host.rs +++ b/tests/runtime/many_arguments/host.rs @@ -3,9 +3,7 @@ use anyhow::Result; wit_bindgen_wasmtime::export!("../../tests/runtime/many_arguments/imports.wit"); #[derive(Default)] -pub struct MyImports { - scalar: u32, -} +pub struct MyImports {} impl imports::Imports for MyImports { fn many_arguments( From 52f1d286ec599027b4f0d16e77cf9e98264985af Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Thu, 28 Apr 2022 15:19:53 -0700 Subject: [PATCH 2/4] Fix a test --- .../modules/crates/records/records.wit | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/crates/test-modules/modules/crates/records/records.wit b/crates/test-modules/modules/crates/records/records.wit index 2dfd56668..92afac641 100644 --- a/crates/test-modules/modules/crates/records/records.wit +++ b/crates/test-modules/modules/crates/records/records.wit @@ -14,16 +14,16 @@ record scalars { scalar-arg: function(x: scalars) scalar-result: function() -> scalars -record really-flags { - a: bool, - b: bool, - c: bool, - d: bool, - e: bool, - f: bool, - g: bool, - h: bool, - i: bool, +flags really-flags { + a, + b, + c, + d, + e, + f, + g, + h, + i, } flags-arg: function(x: really-flags) From fbc0912ae70a0e42ac3a16b770de2385fab4bbde Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Thu, 28 Apr 2022 15:31:31 -0700 Subject: [PATCH 3/4] Fix some more flags --- crates/parser/tests/ui/types.wit.result | 42 +++++++------------------ crates/wasmlink/tests/flags.baseline | 38 ++++++++++++++-------- crates/wasmlink/tests/flags.wat | 1 + crates/wasmlink/tests/records.wit | 20 ++++++------ 4 files changed, 48 insertions(+), 53 deletions(-) diff --git a/crates/parser/tests/ui/types.wit.result b/crates/parser/tests/ui/types.wit.result index a6ea850da..a36f72571 100644 --- a/crates/parser/tests/ui/types.wit.result +++ b/crates/parser/tests/ui/types.wit.result @@ -291,47 +291,29 @@ { "idx": 31, "name": "t30", - "record": { - "fields": [] + "flags": { + "flags": [] } }, { "idx": 32, "name": "t31", - "record": { - "fields": [ - [ - "a", - "bool" - ], - [ - "b", - "bool" - ], - [ - "c", - "bool" - ] + "flags": { + "flags": [ + "a", + "b", + "c" ] } }, { "idx": 33, "name": "t32", - "record": { - "fields": [ - [ - "a", - "bool" - ], - [ - "b", - "bool" - ], - [ - "c", - "bool" - ] + "flags": { + "flags": [ + "a", + "b", + "c" ] } }, diff --git a/crates/wasmlink/tests/flags.baseline b/crates/wasmlink/tests/flags.baseline index 52b8c442b..9611fc2a4 100644 --- a/crates/wasmlink/tests/flags.baseline +++ b/crates/wasmlink/tests/flags.baseline @@ -1,9 +1,10 @@ (module (type (;0;) (func (param i32) (result i32))) - (type (;1;) (func (param i64) (result i64))) + (type (;1;) (func (param i32 i32 i32))) + (import "$parent" "memory" (memory (;0;) 0)) (module (;0;) (type (;0;) (func (param i32) (result i32))) - (type (;1;) (func (param i64) (result i64))) + (type (;1;) (func (param i32 i32) (result i32))) (func (;0;) (type 0) (param i32) (result i32) unreachable) (func (;1;) (type 0) (param i32) (result i32) @@ -16,8 +17,10 @@ unreachable) (func (;5;) (type 0) (param i32) (result i32) unreachable) - (func (;6;) (type 1) (param i64) (result i64) + (func (;6;) (type 1) (param i32 i32) (result i32) unreachable) + (memory (;0;) 1) + (export "memory" (memory 0)) (export "roundtrip-flag1" (func 0)) (export "roundtrip-flag2" (func 1)) (export "roundtrip-flag4" (func 2)) @@ -25,15 +28,16 @@ (export "roundtrip-flag16" (func 4)) (export "roundtrip-flag32" (func 5)) (export "roundtrip-flag64" (func 6))) - (instance (;0;) + (instance (;1;) (instantiate 0)) - (alias 0 "roundtrip-flag1" (func (;0;))) - (alias 0 "roundtrip-flag2" (func (;1;))) - (alias 0 "roundtrip-flag4" (func (;2;))) - (alias 0 "roundtrip-flag8" (func (;3;))) - (alias 0 "roundtrip-flag16" (func (;4;))) - (alias 0 "roundtrip-flag32" (func (;5;))) - (alias 0 "roundtrip-flag64" (func (;6;))) + (alias 1 "memory" (memory (;1;))) + (alias 1 "roundtrip-flag1" (func (;0;))) + (alias 1 "roundtrip-flag2" (func (;1;))) + (alias 1 "roundtrip-flag4" (func (;2;))) + (alias 1 "roundtrip-flag8" (func (;3;))) + (alias 1 "roundtrip-flag16" (func (;4;))) + (alias 1 "roundtrip-flag32" (func (;5;))) + (alias 1 "roundtrip-flag64" (func (;6;))) (func (;7;) (type 0) (param i32) (result i32) local.get 0 call 0) @@ -52,9 +56,17 @@ (func (;12;) (type 0) (param i32) (result i32) local.get 0 call 5) - (func (;13;) (type 1) (param i64) (result i64) + (func (;13;) (type 1) (param i32 i32 i32) + (local i32) local.get 0 - call 6) + local.get 1 + call 6 + local.set 3 + local.get 2 + local.get 3 + i32.const 8 + memory.copy 0 1) + (export "memory" (memory 1)) (export "roundtrip-flag1" (func 7)) (export "roundtrip-flag2" (func 8)) (export "roundtrip-flag4" (func 9)) diff --git a/crates/wasmlink/tests/flags.wat b/crates/wasmlink/tests/flags.wat index 7e9270f47..ae7f34ac1 100644 --- a/crates/wasmlink/tests/flags.wat +++ b/crates/wasmlink/tests/flags.wat @@ -1,4 +1,5 @@ (module + (memory (export "memory") 1) (func (export "roundtrip-flag1") (param i32) (result i32) unreachable ) diff --git a/crates/wasmlink/tests/records.wit b/crates/wasmlink/tests/records.wit index 2dfd56668..92afac641 100644 --- a/crates/wasmlink/tests/records.wit +++ b/crates/wasmlink/tests/records.wit @@ -14,16 +14,16 @@ record scalars { scalar-arg: function(x: scalars) scalar-result: function() -> scalars -record really-flags { - a: bool, - b: bool, - c: bool, - d: bool, - e: bool, - f: bool, - g: bool, - h: bool, - i: bool, +flags really-flags { + a, + b, + c, + d, + e, + f, + g, + h, + i, } flags-arg: function(x: really-flags) From d7238784e4d4f7fd5879f1674aab7672b675a8f6 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Thu, 28 Apr 2022 15:40:16 -0700 Subject: [PATCH 4/4] Fix more tests --- crates/wit-component/src/encoding.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/crates/wit-component/src/encoding.rs b/crates/wit-component/src/encoding.rs index f5aeb5377..01f301d10 100644 --- a/crates/wit-component/src/encoding.rs +++ b/crates/wit-component/src/encoding.rs @@ -91,6 +91,16 @@ impl PartialEq for TypeDefKey<'_> { }) }) } + (TypeDefKind::Flags(f1), TypeDefKind::Flags(f2)) => { + if f1.flags.len() != f2.flags.len() { + return false; + } + + f1.flags + .iter() + .zip(f2.flags.iter()) + .all(|(f1, f2)| f1.name == f2.name) + } (TypeDefKind::Variant(v1), TypeDefKind::Variant(v2)) => { if v1.cases.len() != v2.cases.len() { return false;