diff --git a/crates/gen-c/src/lib.rs b/crates/gen-c/src/lib.rs
index 067badd3b..02898c0e2 100644
--- a/crates/gen-c/src/lib.rs
+++ b/crates/gen-c/src/lib.rs
@@ -170,7 +170,7 @@ impl C {
Type::Id(id) => match &iface.types[*id].kind {
TypeDefKind::Type(t) => self.is_arg_by_pointer(iface, t),
TypeDefKind::Variant(v) => !v.is_enum(),
- TypeDefKind::Record(r) if r.is_flags() => false,
+ TypeDefKind::Flags(_) => false,
TypeDefKind::Record(_) | TypeDefKind::List(_) => true,
},
Type::String => true,
@@ -246,14 +246,10 @@ impl C {
}
match &ty.kind {
TypeDefKind::Type(t) => self.print_ty(iface, t),
- TypeDefKind::Variant(_) => {
- self.public_anonymous_types.insert(*id);
- self.private_anonymous_types.remove(id);
- self.print_namespace(iface);
- self.print_ty_name(iface, &Type::Id(*id));
- self.src.h("_t");
- }
- TypeDefKind::Record(_) | TypeDefKind::List(_) => {
+ TypeDefKind::Flags(_)
+ | TypeDefKind::Record(_)
+ | TypeDefKind::List(_)
+ | TypeDefKind::Variant(_) => {
self.public_anonymous_types.insert(*id);
self.private_anonymous_types.remove(id);
self.print_namespace(iface);
@@ -289,6 +285,7 @@ impl C {
}
match &ty.kind {
TypeDefKind::Type(t) => self.print_ty_name(iface, t),
+ TypeDefKind::Flags(_) => unimplemented!(),
TypeDefKind::Record(r) => {
assert!(r.is_tuple());
self.src.h("tuple");
@@ -331,7 +328,7 @@ impl C {
self.src.h("typedef ");
let kind = &iface.types[ty].kind;
match kind {
- TypeDefKind::Type(_) => {
+ TypeDefKind::Type(_) | TypeDefKind::Flags(_) => {
unreachable!()
}
TypeDefKind::Record(r) => {
@@ -459,6 +456,8 @@ impl C {
match &iface.types[id].kind {
TypeDefKind::Type(t) => self.free(iface, t, "ptr"),
+ TypeDefKind::Flags(_) => {}
+
TypeDefKind::Record(r) => {
for field in r.fields.iter() {
if !self.owns_anything(iface, &field.ty) {
@@ -525,6 +524,7 @@ impl C {
match &iface.types[id].kind {
TypeDefKind::Type(t) => self.owns_anything(iface, t),
TypeDefKind::Record(r) => r.fields.iter().any(|t| self.owns_anything(iface, &t.ty)),
+ TypeDefKind::Flags(_) => false,
TypeDefKind::List(_) => true,
TypeDefKind::Variant(v) => v
.cases
@@ -580,7 +580,7 @@ impl Return {
TypeDefKind::Type(t) => self.return_single(iface, t, orig_ty),
// Flags are returned as their bare values
- TypeDefKind::Record(r) if r.is_flags() => {
+ TypeDefKind::Flags(_) => {
self.scalar = Some(Scalar::Type(*orig_ty));
}
@@ -669,41 +669,52 @@ impl Generator for C {
let prev = mem::take(&mut self.src.header);
self.docs(docs);
self.names.insert(&name.to_snake_case()).unwrap();
- if record.is_flags() {
- self.src.h("typedef ");
- let repr = iface
- .flags_repr(record)
- .expect("unsupported number of flags");
- self.src.h(int_repr(repr));
+ self.src.h("typedef struct {\n");
+ for field in record.fields.iter() {
+ self.print_ty(iface, &field.ty);
self.src.h(" ");
- self.print_namespace(iface);
- self.src.h(&name.to_snake_case());
- self.src.h("_t;\n");
-
- for (i, field) in record.fields.iter().enumerate() {
- self.src.h(&format!(
- "#define {}_{}_{} (1 << {})\n",
- iface.name.to_shouty_snake_case(),
- name.to_shouty_snake_case(),
- field.name.to_shouty_snake_case(),
- i,
- ));
+ if record.is_tuple() {
+ self.src.h("f");
}
- } else {
- self.src.h("typedef struct {\n");
- for field in record.fields.iter() {
- self.print_ty(iface, &field.ty);
- self.src.h(" ");
- if record.is_tuple() {
- self.src.h("f");
- }
- self.src.h(&field.name.to_snake_case());
- self.src.h(";\n");
- }
- self.src.h("} ");
- self.print_namespace(iface);
- self.src.h(&name.to_snake_case());
- self.src.h("_t;\n");
+ self.src.h(&field.name.to_snake_case());
+ self.src.h(";\n");
+ }
+ self.src.h("} ");
+ self.print_namespace(iface);
+ self.src.h(&name.to_snake_case());
+ self.src.h("_t;\n");
+
+ self.types
+ .insert(id, mem::replace(&mut self.src.header, prev));
+ }
+
+ fn type_flags(
+ &mut self,
+ iface: &Interface,
+ id: TypeId,
+ name: &str,
+ flags: &Flags,
+ docs: &Docs,
+ ) {
+ let prev = mem::take(&mut self.src.header);
+ self.docs(docs);
+ self.names.insert(&name.to_snake_case()).unwrap();
+ self.src.h("typedef ");
+ let repr = flags_repr(flags);
+ self.src.h(int_repr(repr));
+ self.src.h(" ");
+ self.print_namespace(iface);
+ self.src.h(&name.to_snake_case());
+ self.src.h("_t;\n");
+
+ for (i, flag) in flags.flags.iter().enumerate() {
+ self.src.h(&format!(
+ "#define {}_{}_{} (1 << {})\n",
+ iface.name.to_shouty_snake_case(),
+ name.to_shouty_snake_case(),
+ flag.name.to_shouty_snake_case(),
+ i,
+ ));
}
self.types
@@ -1398,15 +1409,31 @@ impl Bindgen for FunctionBindgen<'_> {
}
// TODO: checked
- Instruction::FlagsLower { record, .. } | Instruction::FlagsLift { record, .. } => {
- match record.num_i32s() {
- 0 | 1 => results.push(operands.pop().unwrap()),
- _ => panic!("unsupported bitflags"),
+ Instruction::FlagsLower { flags, ty, .. } => match flags_repr(flags) {
+ Int::U8 | Int::U16 | Int::U32 => {
+ results.push(operands.pop().unwrap());
}
- }
- Instruction::FlagsLower64 { .. } | Instruction::FlagsLift64 { .. } => {
- results.push(operands.pop().unwrap());
- }
+ Int::U64 => {
+ let name = self.gen.type_string(iface, &Type::Id(*ty));
+ let tmp = self.locals.tmp("flags");
+ self.src
+ .push_str(&format!("{name} {tmp} = {};\n", operands[0]));
+ results.push(format!("{tmp} & 0xffffffff"));
+ results.push(format!("({tmp} >> 32) & 0xffffffff"));
+ }
+ },
+
+ Instruction::FlagsLift { flags, ty, .. } => match flags_repr(flags) {
+ Int::U8 | Int::U16 | Int::U32 => {
+ results.push(operands.pop().unwrap());
+ }
+ Int::U64 => {
+ let name = self.gen.type_string(iface, &Type::Id(*ty));
+ let op0 = &operands[0];
+ let op1 = &operands[1];
+ results.push(format!("(({name}) ({op0})) | ((({name}) ({op1})) << 32)"));
+ }
+ },
Instruction::VariantPayloadName => {
let name = self.locals.tmp("payload");
@@ -1824,3 +1851,13 @@ fn case_field_name(case: &Case) -> String {
case.name.to_snake_case()
}
}
+
+fn flags_repr(f: &Flags) -> Int {
+ match f.repr() {
+ FlagsRepr::U8 => Int::U8,
+ FlagsRepr::U16 => Int::U16,
+ FlagsRepr::U32(1) => Int::U32,
+ FlagsRepr::U32(2) => Int::U64,
+ repr => panic!("unimplemented flags {:?}", repr),
+ }
+}
diff --git a/crates/gen-core/src/lib.rs b/crates/gen-core/src/lib.rs
index df4940f96..337cfd168 100644
--- a/crates/gen-core/src/lib.rs
+++ b/crates/gen-core/src/lib.rs
@@ -55,6 +55,7 @@ pub trait Generator {
record: &Record,
docs: &Docs,
);
+ fn type_flags(&mut self, iface: &Interface, id: TypeId, name: &str, flags: &Flags, docs: &Docs);
fn type_variant(
&mut self,
iface: &Interface,
@@ -88,6 +89,7 @@ pub trait Generator {
};
match &ty.kind {
TypeDefKind::Record(record) => self.type_record(iface, id, name, record, &ty.docs),
+ TypeDefKind::Flags(flags) => self.type_flags(iface, id, name, flags, &ty.docs),
TypeDefKind::Variant(variant) => {
self.type_variant(iface, id, name, variant, &ty.docs)
}
@@ -188,6 +190,7 @@ impl Types {
info |= self.type_info(iface, &field.ty);
}
}
+ TypeDefKind::Flags(_) => {}
TypeDefKind::Variant(v) => {
for case in v.cases.iter() {
if let Some(ty) = &case.ty {
@@ -225,6 +228,7 @@ impl Types {
self.set_param_result_ty(iface, &field.ty, param, result)
}
}
+ TypeDefKind::Flags(_) => {}
TypeDefKind::Variant(v) => {
for case in v.cases.iter() {
if let Some(ty) = &case.ty {
diff --git a/crates/gen-js/src/lib.rs b/crates/gen-js/src/lib.rs
index fe813fd3b..d5e8d92bf 100644
--- a/crates/gen-js/src/lib.rs
+++ b/crates/gen-js/src/lib.rs
@@ -171,6 +171,7 @@ impl Js {
TypeDefKind::Type(t) => self.print_ty(iface, t),
TypeDefKind::Record(r) if r.is_tuple() => self.print_tuple(iface, r),
TypeDefKind::Record(_) => panic!("anonymous record"),
+ TypeDefKind::Flags(_) => panic!("anonymous flags"),
TypeDefKind::Variant(v) => {
if self.is_nullable_option(iface, v) {
self.print_ty(iface, v.cases[1].ty.as_ref().unwrap());
@@ -336,37 +337,6 @@ impl Generator for Js {
.ts(&format!("export type {} = ", name.to_camel_case()));
self.print_tuple(iface, record);
self.src.ts(";\n");
- } else if record.is_flags() {
- let repr = iface
- .flags_repr(record)
- .expect("unsupported number of flags");
- let suffix = if repr == Int::U64 {
- self.src
- .ts(&format!("export type {} = bigint;\n", name.to_camel_case()));
- "n"
- } else {
- self.src
- .ts(&format!("export type {} = number;\n", name.to_camel_case()));
- ""
- };
- let name = name.to_shouty_snake_case();
- for (i, field) in record.fields.iter().enumerate() {
- let field = field.name.to_shouty_snake_case();
- self.src.js(&format!(
- "export const {}_{} = {}{};\n",
- name,
- field,
- 1u64 << i,
- suffix,
- ));
- self.src.ts(&format!(
- "export const {}_{} = {}{};\n",
- name,
- field,
- 1u64 << i,
- suffix,
- ));
- }
} else {
self.src
.ts(&format!("export interface {} {{\n", name.to_camel_case()));
@@ -384,6 +354,34 @@ impl Generator for Js {
}
}
+ fn type_flags(
+ &mut self,
+ _iface: &Interface,
+ _id: TypeId,
+ name: &str,
+ flags: &Flags,
+ docs: &Docs,
+ ) {
+ self.docs(docs);
+ let repr = js_flags_repr(flags);
+ let ty = repr.ty();
+ let suffix = repr.suffix();
+ self.src
+ .ts(&format!("export type {} = {ty};\n", name.to_camel_case()));
+ let name = name.to_shouty_snake_case();
+ for (i, flag) in flags.flags.iter().enumerate() {
+ let flag = flag.name.to_shouty_snake_case();
+ self.src.js(&format!(
+ "export const {name}_{flag} = {}{suffix};\n",
+ 1u128 << i,
+ ));
+ self.src.ts(&format!(
+ "export const {name}_{flag} = {}{suffix};\n",
+ 1u128 << i,
+ ));
+ }
+ }
+
fn type_variant(
&mut self,
iface: &Interface,
@@ -1412,20 +1410,56 @@ impl Bindgen for FunctionBindgen<'_> {
}
}
- Instruction::FlagsLower { record, .. } | Instruction::FlagsLift { record, .. } => {
- match record.num_i32s() {
- 0 | 1 => {
- let validate = self.gen.intrinsic(Intrinsic::ValidateFlags);
- let mask = (1u64 << record.fields.len()) - 1;
- results.push(format!("{}({}, {})", validate, operands[0], mask));
+ Instruction::FlagsLower { flags, .. } => {
+ let repr = js_flags_repr(flags);
+ let validate = match repr {
+ JsFlagsRepr::Number => self.gen.intrinsic(Intrinsic::ValidateFlags),
+ JsFlagsRepr::Bigint => self.gen.intrinsic(Intrinsic::ValidateFlags64),
+ };
+ let op0 = &operands[0];
+ let len = flags.flags.len();
+ let n = repr.suffix();
+ let tmp = self.tmp();
+ let mask = (1u128 << len) - 1;
+ self.src.js(&format!(
+ "const flags{tmp} = {validate}({op0}, {mask}{n});\n"
+ ));
+ match repr {
+ JsFlagsRepr::Number => {
+ results.push(format!("flags{}", tmp));
+ }
+ JsFlagsRepr::Bigint => {
+ for i in 0..flags.repr().count() {
+ let i = 32 * i;
+ results.push(format!("Number((flags{tmp} >> {i}n) & 0xffffffffn)",));
+ }
}
- _ => panic!("unsupported bitflags"),
}
}
- Instruction::FlagsLower64 { record, .. } | Instruction::FlagsLift64 { record, .. } => {
- let validate = self.gen.intrinsic(Intrinsic::ValidateFlags64);
- let mask = (1u128 << record.fields.len()) - 1;
- results.push(format!("{}({}, {}n)", validate, operands[0], mask));
+
+ Instruction::FlagsLift { flags, .. } => {
+ let repr = js_flags_repr(flags);
+ let n = repr.suffix();
+ let tmp = self.tmp();
+ let operand = match repr {
+ JsFlagsRepr::Number => operands[0].clone(),
+ JsFlagsRepr::Bigint => {
+ self.src.js(&format!("let flags{tmp} = 0n;\n"));
+ for (i, op) in operands.iter().enumerate() {
+ let i = 32 * i;
+ self.src
+ .js(&format!("flags{tmp} |= BigInt({op}) << {i}n;\n",));
+ }
+ format!("flags{tmp}")
+ }
+ };
+ let validate = match repr {
+ JsFlagsRepr::Number => self.gen.intrinsic(Intrinsic::ValidateFlags),
+ JsFlagsRepr::Bigint => self.gen.intrinsic(Intrinsic::ValidateFlags64),
+ };
+ let len = flags.flags.len();
+ let mask = (1u128 << len) - 1;
+ results.push(format!("{validate}({operand}, {mask}{n})"));
}
Instruction::VariantPayloadName => results.push("e".to_string()),
@@ -2238,3 +2272,30 @@ impl Source {
self.ts.push_str(s);
}
}
+
+enum JsFlagsRepr {
+ Number,
+ Bigint,
+}
+
+impl JsFlagsRepr {
+ fn ty(&self) -> &'static str {
+ match self {
+ JsFlagsRepr::Number => "number",
+ JsFlagsRepr::Bigint => "bigint",
+ }
+ }
+ fn suffix(&self) -> &'static str {
+ match self {
+ JsFlagsRepr::Number => "",
+ JsFlagsRepr::Bigint => "n",
+ }
+ }
+}
+
+fn js_flags_repr(f: &Flags) -> JsFlagsRepr {
+ match f.repr() {
+ FlagsRepr::U8 | FlagsRepr::U16 | FlagsRepr::U32(1) => JsFlagsRepr::Number,
+ FlagsRepr::U32(_) => JsFlagsRepr::Bigint,
+ }
+}
diff --git a/crates/gen-markdown/src/lib.rs b/crates/gen-markdown/src/lib.rs
index 87c7a646c..b71da58ed 100644
--- a/crates/gen-markdown/src/lib.rs
+++ b/crates/gen-markdown/src/lib.rs
@@ -79,6 +79,7 @@ impl Markdown {
}
self.src.push_str(")");
}
+ TypeDefKind::Flags(_) => unreachable!(),
TypeDefKind::Variant(v) => {
if let Some(t) = v.as_option() {
self.src.push_str("option<");
@@ -162,7 +163,7 @@ impl Generator for Markdown {
self.src.push_str("record\n\n");
self.print_type_info(id, docs);
self.src.push_str("\n### Record Fields\n\n");
- for (i, field) in record.fields.iter().enumerate() {
+ for field in record.fields.iter() {
self.src.push_str(&format!(
"- [`{name}`](#{r}.{f}): ",
r = name.to_snake_case(),
@@ -178,9 +179,38 @@ impl Generator for Markdown {
self.src.push_str("\n\n");
self.docs(&field.docs);
self.src.deindent(1);
- if record.is_flags() {
- self.src.push_str(&format!("Bit: {}\n", i));
- }
+ self.src.push_str("\n");
+ }
+ }
+
+ fn type_flags(
+ &mut self,
+ _iface: &Interface,
+ id: TypeId,
+ name: &str,
+ flags: &Flags,
+ docs: &Docs,
+ ) {
+ self.print_type_header(name);
+ self.src.push_str("record\n\n");
+ self.print_type_info(id, docs);
+ self.src.push_str("\n### Record Fields\n\n");
+ for (i, flag) in flags.flags.iter().enumerate() {
+ self.src.push_str(&format!(
+ "- [`{name}`](#{r}.{f}): ",
+ r = name.to_snake_case(),
+ f = flag.name.to_snake_case(),
+ name = flag.name,
+ ));
+ self.hrefs.insert(
+ format!("{}::{}", name, flag.name),
+ format!("#{}.{}", name.to_snake_case(), flag.name.to_snake_case()),
+ );
+ self.src.indent(1);
+ self.src.push_str("\n\n");
+ self.docs(&flag.docs);
+ self.src.deindent(1);
+ self.src.push_str(&format!("Bit: {}\n", i));
self.src.push_str("\n");
}
}
diff --git a/crates/gen-rust-wasm/src/lib.rs b/crates/gen-rust-wasm/src/lib.rs
index 4b5caf4f6..407cda398 100644
--- a/crates/gen-rust-wasm/src/lib.rs
+++ b/crates/gen-rust-wasm/src/lib.rs
@@ -8,7 +8,7 @@ use wit_bindgen_gen_core::wit_parser::abi::{
};
use wit_bindgen_gen_core::{wit_parser::*, Direction, Files, Generator, Source, TypeInfo, Types};
use wit_bindgen_gen_rust::{
- int_repr, wasm_type, FnSig, RustFunctionGenerator, RustGenerator, TypeMode,
+ int_repr, wasm_type, FnSig, RustFlagsRepr, RustFunctionGenerator, RustGenerator, TypeMode,
};
#[derive(Default)]
@@ -165,72 +165,53 @@ impl Generator for RustWasm {
&mut self,
iface: &Interface,
id: TypeId,
- name: &str,
+ _name: &str,
record: &Record,
docs: &Docs,
) {
- if record.is_flags() {
- self.src
- .push_str("wit_bindgen_rust::bitflags::bitflags! {\n");
- self.rustdoc(docs);
- let repr = iface
- .flags_repr(record)
- .expect("unsupported number of flags");
- self.src.push_str(&format!(
- "pub struct {}: {} {{\n",
- name.to_camel_case(),
- int_repr(repr)
- ));
- for (i, field) in record.fields.iter().enumerate() {
- self.rustdoc(&field.docs);
- self.src.push_str(&format!(
- "const {} = 1 << {};\n",
- field.name.to_shouty_snake_case(),
- i,
- ));
- }
- self.src.push_str("}\n");
- self.src.push_str("}\n");
+ self.print_typedef_record(iface, id, record, docs);
+ }
- // Add a `from_bits_preserve` method.
- self.src
- .push_str(&format!("impl {} {{\n", name.to_camel_case()));
- self.src.push_str(&format!(
- " /// Convert from a raw integer, preserving any unknown bits. See\n"
- ));
- self.src.push_str(&format!(" /// \n"));
- self.src.push_str(&format!(
- " pub fn from_bits_preserve(bits: {}) -> Self {{\n",
- int_repr(repr)
- ));
- self.src.push_str(&format!(" Self {{ bits }}\n"));
- self.src.push_str(&format!(" }}\n"));
- self.src.push_str(&format!("}}\n"));
-
- // Add a `AsI64` etc. method.
- let as_trait = match repr {
- Int::U8 | Int::U16 | Int::U32 => "i32",
- Int::U64 => "i64",
- };
- self.src.push_str(&format!(
- "impl wit_bindgen_rust::rt::As{} for {} {{\n",
- as_trait.to_camel_case(),
- name.to_camel_case()
- ));
- self.src.push_str(&format!(" #[inline]"));
+ fn type_flags(
+ &mut self,
+ _iface: &Interface,
+ _id: TypeId,
+ name: &str,
+ flags: &Flags,
+ docs: &Docs,
+ ) {
+ self.src
+ .push_str("wit_bindgen_rust::bitflags::bitflags! {\n");
+ self.rustdoc(docs);
+ let repr = RustFlagsRepr::new(flags);
+ self.src
+ .push_str(&format!("pub struct {}: {repr} {{\n", name.to_camel_case(),));
+ for (i, flag) in flags.flags.iter().enumerate() {
+ self.rustdoc(&flag.docs);
self.src.push_str(&format!(
- " fn as_{}(self) -> {} {{\n",
- as_trait, as_trait
+ "const {} = 1 << {};\n",
+ flag.name.to_shouty_snake_case(),
+ i,
));
- self.src
- .push_str(&format!(" self.bits() as {}\n", as_trait));
- self.src.push_str(&format!(" }}"));
- self.src.push_str(&format!("}}\n"));
-
- return;
}
+ self.src.push_str("}\n");
+ self.src.push_str("}\n");
- self.print_typedef_record(iface, id, record, docs);
+ // Add a `from_bits_preserve` method.
+ self.src
+ .push_str(&format!("impl {} {{\n", name.to_camel_case()));
+ self.src.push_str(&format!(
+ " /// Convert from a raw integer, preserving any unknown bits. See\n"
+ ));
+ self.src.push_str(&format!(
+ " /// \n"
+ ));
+ self.src.push_str(&format!(
+ " pub fn from_bits_preserve(bits: {repr}) -> Self {{\n",
+ ));
+ self.src.push_str(&format!(" Self {{ bits }}\n"));
+ self.src.push_str(&format!(" }}\n"));
+ self.src.push_str(&format!("}}\n"));
}
fn type_variant(
@@ -934,23 +915,20 @@ impl Bindgen for FunctionBindgen<'_> {
));
}
- Instruction::FlagsLower { record, .. } => {
+ Instruction::FlagsLower { flags, .. } => {
let tmp = self.tmp();
self.push_str(&format!("let flags{} = {};\n", tmp, operands[0]));
- for i in 0..record.num_i32s() {
+ for i in 0..flags.repr().count() {
results.push(format!("(flags{}.bits() >> {}) as i32", tmp, i * 32));
}
}
- Instruction::FlagsLower64 { .. } => {
- let s = operands.pop().unwrap();
- results.push(format!("wit_bindgen_rust::rt::as_i64({})", s));
- }
- Instruction::FlagsLift { name, .. } | Instruction::FlagsLift64 { name, .. } => {
+ Instruction::FlagsLift { name, flags, .. } => {
+ let repr = RustFlagsRepr::new(flags);
let name = name.to_camel_case();
let mut result = format!("{}::empty()", name);
for (i, op) in operands.iter().enumerate() {
result.push_str(&format!(
- " | {}::from_bits_preserve((({} as u32) << {}) as _)",
+ " | {}::from_bits_preserve((({} as {repr}) << {}) as _)",
name,
op,
i * 32
diff --git a/crates/gen-rust/src/lib.rs b/crates/gen-rust/src/lib.rs
index 170dee605..ac10272e0 100644
--- a/crates/gen-rust/src/lib.rs
+++ b/crates/gen-rust/src/lib.rs
@@ -1,4 +1,5 @@
use heck::*;
+use std::fmt;
use wit_bindgen_gen_core::wit_parser::abi::{Bitcast, LiftLower, WasmType};
use wit_bindgen_gen_core::{wit_parser::*, TypeInfo, Types};
@@ -276,6 +277,9 @@ pub trait RustGenerator {
TypeDefKind::Record(_) => {
panic!("unsupported anonymous type reference: record")
}
+ TypeDefKind::Flags(_) => {
+ panic!("unsupported anonymous type reference: flags")
+ }
TypeDefKind::Type(t) => self.print_ty(iface, t, mode),
}
@@ -942,3 +946,36 @@ pub fn bitcast(casts: &[Bitcast], operands: &[String], results: &mut Vec
});
}
}
+
+pub enum RustFlagsRepr {
+ U8,
+ U16,
+ U32,
+ U64,
+ U128,
+}
+
+impl RustFlagsRepr {
+ pub fn new(f: &Flags) -> RustFlagsRepr {
+ match f.repr() {
+ FlagsRepr::U8 => RustFlagsRepr::U8,
+ FlagsRepr::U16 => RustFlagsRepr::U16,
+ FlagsRepr::U32(1) => RustFlagsRepr::U32,
+ FlagsRepr::U32(2) => RustFlagsRepr::U64,
+ FlagsRepr::U32(3 | 4) => RustFlagsRepr::U128,
+ FlagsRepr::U32(n) => panic!("unsupported number of flags: {}", n * 32),
+ }
+ }
+}
+
+impl fmt::Display for RustFlagsRepr {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ RustFlagsRepr::U8 => "u8".fmt(f),
+ RustFlagsRepr::U16 => "u16".fmt(f),
+ RustFlagsRepr::U32 => "u32".fmt(f),
+ RustFlagsRepr::U64 => "u64".fmt(f),
+ RustFlagsRepr::U128 => "u128".fmt(f),
+ }
+ }
+}
diff --git a/crates/gen-spidermonkey/src/lib.rs b/crates/gen-spidermonkey/src/lib.rs
index 0cce7d878..edf35554f 100644
--- a/crates/gen-spidermonkey/src/lib.rs
+++ b/crates/gen-spidermonkey/src/lib.rs
@@ -16,7 +16,7 @@ use wasm_encoder::Instruction;
use wit_bindgen_gen_core::{
wit_parser::{
abi::{self, AbiVariant, WasmSignature, WasmType},
- Docs, Function, Interface, Record, ResourceId, SizeAlign, Type, TypeId, Variant,
+ Docs, Flags, Function, Interface, Record, ResourceId, SizeAlign, Type, TypeId, Variant,
},
Direction, Files, Generator,
};
@@ -951,6 +951,18 @@ impl Generator for SpiderMonkeyWasm<'_> {
todo!()
}
+ fn type_flags(
+ &mut self,
+ iface: &Interface,
+ id: TypeId,
+ name: &str,
+ flags: &Flags,
+ docs: &Docs,
+ ) {
+ let _ = (iface, id, name, flags, docs);
+ todo!()
+ }
+
fn type_variant(
&mut self,
iface: &Interface,
@@ -1878,22 +1890,12 @@ impl abi::Bindgen for Bindgen<'_, '_> {
ty: _,
} => todo!(),
abi::Instruction::FlagsLower {
- record: _,
- name: _,
- ty: _,
- } => todo!(),
- abi::Instruction::FlagsLower64 {
- record: _,
+ flags: _,
name: _,
ty: _,
} => todo!(),
abi::Instruction::FlagsLift {
- record: _,
- name: _,
- ty: _,
- } => todo!(),
- abi::Instruction::FlagsLift64 {
- record: _,
+ flags: _,
name: _,
ty: _,
} => todo!(),
diff --git a/crates/gen-wasmtime-py/src/lib.rs b/crates/gen-wasmtime-py/src/lib.rs
index be7cc5bc5..76fc70619 100644
--- a/crates/gen-wasmtime-py/src/lib.rs
+++ b/crates/gen-wasmtime-py/src/lib.rs
@@ -413,7 +413,7 @@ impl WasmtimePy {
TypeDefKind::Record(r) if r.is_tuple() => {
self.print_tuple(iface, r.fields.iter().map(|f| &f.ty))
}
- TypeDefKind::Record(_) => unreachable!(),
+ TypeDefKind::Record(_) | TypeDefKind::Flags(_) => unreachable!(),
TypeDefKind::Variant(v) => {
if let Some(t) = v.as_option() {
self.pyimport("typing", "Optional");
@@ -584,21 +584,6 @@ impl Generator for WasmtimePy {
if record.is_tuple() {
self.src.push_str(&format!("{} = ", name.to_camel_case()));
self.print_tuple(iface, record.fields.iter().map(|f| &f.ty));
- } else if record.is_flags() {
- self.pyimport("enum", "Flag");
- self.pyimport("enum", "auto");
- self.src
- .push_str(&format!("class {}(Flag):\n", name.to_camel_case()));
- self.indent();
- for field in record.fields.iter() {
- self.docs(&field.docs);
- self.src
- .push_str(&format!("{} = auto()\n", field.name.to_shouty_snake_case()));
- }
- if record.fields.is_empty() {
- self.src.push_str("pass\n");
- }
- self.deindent();
} else {
self.pyimport("dataclasses", "dataclass");
self.src
@@ -619,6 +604,32 @@ impl Generator for WasmtimePy {
self.src.push_str("\n");
}
+ fn type_flags(
+ &mut self,
+ _iface: &Interface,
+ _id: TypeId,
+ name: &str,
+ flags: &Flags,
+ docs: &Docs,
+ ) {
+ self.docs(docs);
+ self.pyimport("enum", "Flag");
+ self.pyimport("enum", "auto");
+ self.src
+ .push_str(&format!("class {}(Flag):\n", name.to_camel_case()));
+ self.indent();
+ for flag in flags.flags.iter() {
+ self.docs(&flag.docs);
+ self.src
+ .push_str(&format!("{} = auto()\n", flag.name.to_shouty_snake_case()));
+ }
+ if flags.flags.is_empty() {
+ self.src.push_str("pass\n");
+ }
+ self.deindent();
+ self.src.push_str("\n");
+ }
+
fn type_variant(
&mut self,
iface: &Interface,
@@ -1522,12 +1533,33 @@ impl Bindgen for FunctionBindgen<'_> {
results.push(format!("{}({})", name.to_camel_case(), operands.join(", ")));
}
}
- Instruction::FlagsLift { name, .. } | Instruction::FlagsLift64 { name, .. } => {
- results.push(format!("{}({})", name.to_camel_case(), operands[0]));
- }
- Instruction::FlagsLower { .. } | Instruction::FlagsLower64 { .. } => {
- results.push(format!("({}).value", operands[0]));
+ Instruction::FlagsLift { name, .. } => {
+ let operand = match operands.len() {
+ 1 => operands[0].clone(),
+ _ => {
+ let tmp = self.locals.tmp("flags");
+ self.src.push_str(&format!("{tmp} = 0\n"));
+ for (i, op) in operands.iter().enumerate() {
+ let i = 32 * i;
+ self.src.push_str(&format!("{tmp} |= {op} << {i}\n"));
+ }
+ tmp
+ }
+ };
+ results.push(format!("{}({})", name.to_camel_case(), operand));
}
+ Instruction::FlagsLower { flags, .. } => match flags.repr().count() {
+ 1 => results.push(format!("({}).value", operands[0])),
+ n => {
+ let tmp = self.locals.tmp("flags");
+ self.src
+ .push_str(&format!("{tmp} = ({}).value\n", operands[0]));
+ for i in 0..n {
+ let i = 32 * i;
+ results.push(format!("({tmp} >> {i}) & 0xffffffff"));
+ }
+ }
+ },
Instruction::VariantPayloadName => {
let name = self.locals.tmp("payload");
diff --git a/crates/gen-wasmtime/src/lib.rs b/crates/gen-wasmtime/src/lib.rs
index 5f0d26e48..0edc28932 100644
--- a/crates/gen-wasmtime/src/lib.rs
+++ b/crates/gen-wasmtime/src/lib.rs
@@ -9,7 +9,7 @@ use wit_bindgen_gen_core::wit_parser::abi::{
};
use wit_bindgen_gen_core::{wit_parser::*, Direction, Files, Generator, Source, TypeInfo, Types};
use wit_bindgen_gen_rust::{
- int_repr, to_rust_ident, wasm_type, FnSig, RustFunctionGenerator, RustGenerator, TypeMode,
+ to_rust_ident, wasm_type, FnSig, RustFlagsRepr, RustFunctionGenerator, RustGenerator, TypeMode,
};
#[derive(Default)]
@@ -328,49 +328,6 @@ impl Generator for Wasmtime {
record: &Record,
docs: &Docs,
) {
- if record.is_flags() {
- self.src
- .push_str("wit_bindgen_wasmtime::bitflags::bitflags! {\n");
- self.rustdoc(docs);
- self.src
- .push_str(&format!("pub struct {}: ", name.to_camel_case()));
- let repr = iface
- .flags_repr(record)
- .expect("unsupported number of flags");
- self.int_repr(repr);
- self.src.push_str(" {\n");
- for (i, field) in record.fields.iter().enumerate() {
- self.rustdoc(&field.docs);
- self.src.push_str(&format!(
- "const {} = 1 << {};\n",
- field.name.to_shouty_snake_case(),
- i,
- ));
- }
- self.src.push_str("}\n");
- self.src.push_str("}\n\n");
-
- self.src.push_str("impl core::fmt::Display for ");
- self.src.push_str(&name.to_camel_case());
- self.src.push_str(
- "{\nfn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n",
- );
-
- self.src.push_str("f.write_str(\"");
- self.src.push_str(&name.to_camel_case());
- self.src.push_str("(\")?;\n");
- self.src.push_str("core::fmt::Debug::fmt(self, f)?;\n");
- self.src.push_str("f.write_str(\" (0x\")?;\n");
- self.src
- .push_str("core::fmt::LowerHex::fmt(&self.bits, f)?;\n");
- self.src.push_str("f.write_str(\"))\")?;\n");
- self.src.push_str("Ok(())");
-
- self.src.push_str("}\n");
- self.src.push_str("}\n\n");
- return;
- }
-
self.print_typedef_record(iface, id, record, docs);
// If this record might be used as a slice type in various places then
@@ -418,6 +375,51 @@ impl Generator for Wasmtime {
}
}
+ fn type_flags(
+ &mut self,
+ _iface: &Interface,
+ _id: TypeId,
+ name: &str,
+ flags: &Flags,
+ docs: &Docs,
+ ) {
+ self.src
+ .push_str("wit_bindgen_wasmtime::bitflags::bitflags! {\n");
+ self.rustdoc(docs);
+ let repr = RustFlagsRepr::new(flags);
+ self.src
+ .push_str(&format!("pub struct {}: {repr} {{\n", name.to_camel_case()));
+ for (i, flag) in flags.flags.iter().enumerate() {
+ self.rustdoc(&flag.docs);
+ self.src.push_str(&format!(
+ "const {} = 1 << {};\n",
+ flag.name.to_shouty_snake_case(),
+ i,
+ ));
+ }
+ self.src.push_str("}\n");
+ self.src.push_str("}\n\n");
+
+ self.src.push_str("impl core::fmt::Display for ");
+ self.src.push_str(&name.to_camel_case());
+ self.src.push_str(
+ "{\nfn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {\n",
+ );
+
+ self.src.push_str("f.write_str(\"");
+ self.src.push_str(&name.to_camel_case());
+ self.src.push_str("(\")?;\n");
+ self.src.push_str("core::fmt::Debug::fmt(self, f)?;\n");
+ self.src.push_str("f.write_str(\" (0x\")?;\n");
+ self.src
+ .push_str("core::fmt::LowerHex::fmt(&self.bits, f)?;\n");
+ self.src.push_str("f.write_str(\"))\")?;\n");
+ self.src.push_str("Ok(())");
+
+ self.src.push_str("}\n");
+ self.src.push_str("}\n\n");
+ }
+
fn type_variant(
&mut self,
iface: &Interface,
@@ -1538,36 +1540,29 @@ impl Bindgen for FunctionBindgen<'_> {
self.record_lift(iface, *ty, record, operands, results);
}
- Instruction::FlagsLower { record, .. } => {
+ Instruction::FlagsLower { flags, .. } => {
let tmp = self.tmp();
self.push_str(&format!("let flags{} = {};\n", tmp, operands[0]));
- for i in 0..record.num_i32s() {
+ for i in 0..flags.repr().count() {
results.push(format!("(flags{}.bits >> {}) as i32", tmp, i * 32));
}
}
- Instruction::FlagsLower64 { .. } => {
- results.push(format!("({}).bits as i64", operands[0]));
- }
- Instruction::FlagsLift { record, name, .. }
- | Instruction::FlagsLift64 { record, name, .. } => {
+ Instruction::FlagsLift { flags, name, .. } => {
self.gen.needs_validate_flags = true;
- let repr = iface
- .flags_repr(record)
- .expect("unsupported number of flags");
+ let repr = RustFlagsRepr::new(flags);
let mut flags = String::from("0");
for (i, op) in operands.iter().enumerate() {
- flags.push_str(&format!("| (i64::from({}) << {})", op, i * 32));
+ flags.push_str(&format!("| (({} as {repr}) << {})", op, i * 32));
}
results.push(format!(
"validate_flags(
{},
- {name}::all().bits() as i64,
+ {name}::all().bits(),
\"{name}\",
- |b| {name} {{ bits: b as {ty} }}
+ |bits| {name} {{ bits }}
)?",
flags,
name = name.to_camel_case(),
- ty = int_repr(repr),
));
}
diff --git a/crates/parser/src/abi.rs b/crates/parser/src/abi.rs
index b79ef55e6..998af4b7f 100644
--- a/crates/parser/src/abi.rs
+++ b/crates/parser/src/abi.rs
@@ -1,6 +1,7 @@
use crate::sizealign::align_to;
use crate::{
- Function, Int, Interface, Record, RecordKind, ResourceId, Type, TypeDefKind, TypeId, Variant,
+ Flags, FlagsRepr, Function, Int, Interface, Record, ResourceId, Type, TypeDefKind, TypeId,
+ Variant,
};
use std::mem;
@@ -512,27 +513,17 @@ def_instruction! {
/// Converts a language-specific record-of-bools to a list of `i32`.
FlagsLower {
- record: &'a Record,
- name: &'a str,
- ty: TypeId,
- } : [1] => [record.num_i32s()],
- FlagsLower64 {
- record: &'a Record,
+ flags: &'a Flags,
name: &'a str,
ty: TypeId,
- } : [1] => [1],
+ } : [1] => [flags.repr().count()],
/// Converts a list of native wasm `i32` to a language-specific
/// record-of-bools.
FlagsLift {
- record: &'a Record,
+ flags: &'a Flags,
name: &'a str,
ty: TypeId,
- } : [record.num_i32s()] => [1],
- FlagsLift64 {
- record: &'a Record,
- name: &'a str,
- ty: TypeId,
- } : [1] => [1],
+ } : [flags.repr().count()] => [1],
// variants
@@ -913,21 +904,18 @@ impl Interface {
Type::Id(id) => match &self.types[*id].kind {
TypeDefKind::Type(t) => self.push_wasm(variant, t, result),
- TypeDefKind::Record(r) if r.is_flags() => match self.flags_repr(r) {
- Some(int) => result.push(int.into()),
- None => {
- for _ in 0..r.num_i32s() {
- result.push(WasmType::I32);
- }
- }
- },
-
TypeDefKind::Record(r) => {
for field in r.fields.iter() {
self.push_wasm(variant, &field.ty, result);
}
}
+ TypeDefKind::Flags(r) => {
+ for _ in 0..r.repr().count() {
+ result.push(WasmType::I32);
+ }
+ }
+
TypeDefKind::List(_) => {
result.push(WasmType::I32);
result.push(WasmType::I32);
@@ -963,18 +951,6 @@ impl Interface {
}
}
- pub fn flags_repr(&self, record: &Record) -> Option {
- match record.kind {
- RecordKind::Flags(Some(hint)) => Some(hint),
- RecordKind::Flags(None) if record.fields.len() <= 8 => Some(Int::U8),
- RecordKind::Flags(None) if record.fields.len() <= 16 => Some(Int::U16),
- RecordKind::Flags(None) if record.fields.len() <= 32 => Some(Int::U32),
- RecordKind::Flags(None) if record.fields.len() <= 64 => Some(Int::U64),
- RecordKind::Flags(None) => None,
- _ => panic!("not a flags record"),
- }
- }
-
/// Generates an abstract sequence of instructions which represents this
/// function being adapted as an imported function.
///
@@ -1425,20 +1401,6 @@ impl<'a, B: Bindgen> Generator<'a, B> {
self.emit(&ListLower { element, realloc });
}
}
- TypeDefKind::Record(record) if record.is_flags() => {
- match self.iface.flags_repr(record) {
- Some(Int::U64) => self.emit(&FlagsLower64 {
- record,
- ty: id,
- name: self.iface.types[id].name.as_ref().unwrap(),
- }),
- _ => self.emit(&FlagsLower {
- record,
- ty: id,
- name: self.iface.types[id].name.as_ref().unwrap(),
- }),
- }
- }
TypeDefKind::Record(record) => {
self.emit(&RecordLower {
record,
@@ -1455,6 +1417,14 @@ impl<'a, B: Bindgen> Generator<'a, B> {
}
}
+ TypeDefKind::Flags(flags) => {
+ self.emit(&FlagsLower {
+ flags,
+ ty: id,
+ name: self.iface.types[id].name.as_ref().unwrap(),
+ });
+ }
+
TypeDefKind::Variant(v) => {
let mut results = Vec::new();
let mut temp = Vec::new();
@@ -1583,20 +1553,6 @@ impl<'a, B: Bindgen> Generator<'a, B> {
});
}
}
- TypeDefKind::Record(record) if record.is_flags() => {
- match self.iface.flags_repr(record) {
- Some(Int::U64) => self.emit(&FlagsLift64 {
- record,
- ty: id,
- name: self.iface.types[id].name.as_ref().unwrap(),
- }),
- _ => self.emit(&FlagsLift {
- record,
- ty: id,
- name: self.iface.types[id].name.as_ref().unwrap(),
- }),
- }
- }
TypeDefKind::Record(record) => {
let mut temp = Vec::new();
self.iface.push_wasm(self.variant, ty, &mut temp);
@@ -1616,6 +1572,13 @@ impl<'a, B: Bindgen> Generator<'a, B> {
name: self.iface.types[id].name.as_deref(),
});
}
+ TypeDefKind::Flags(flags) => {
+ self.emit(&FlagsLift {
+ flags,
+ ty: id,
+ name: self.iface.types[id].name.as_ref().unwrap(),
+ });
+ }
TypeDefKind::Variant(v) => {
let mut params = Vec::new();
@@ -1694,24 +1657,6 @@ impl<'a, B: Bindgen> Generator<'a, B> {
TypeDefKind::Type(t) => self.write_to_memory(t, addr, offset),
TypeDefKind::List(_) => self.write_list_to_memory(ty, addr, offset),
- TypeDefKind::Record(r) if r.is_flags() => {
- self.lower(ty);
- match self.iface.flags_repr(r) {
- Some(repr) => {
- self.stack.push(addr);
- self.store_intrepr(offset, repr);
- }
- None => {
- for i in 0..r.num_i32s() {
- self.stack.push(addr.clone());
- self.emit(&I32Store {
- offset: offset + (i as i32) * 4,
- });
- }
- }
- }
- }
-
// Decompose the record into its components and then write all
// the components into memory one-by-one.
TypeDefKind::Record(record) => {
@@ -1741,6 +1686,28 @@ impl<'a, B: Bindgen> Generator<'a, B> {
}
}
+ TypeDefKind::Flags(f) => {
+ self.lower(ty);
+ match f.repr() {
+ FlagsRepr::U8 => {
+ self.stack.push(addr);
+ self.store_intrepr(offset, Int::U8);
+ }
+ FlagsRepr::U16 => {
+ self.stack.push(addr);
+ self.store_intrepr(offset, Int::U16);
+ }
+ FlagsRepr::U32(n) => {
+ for i in (0..n).rev() {
+ self.stack.push(addr.clone());
+ self.emit(&I32Store {
+ offset: offset + (i as i32) * 4,
+ });
+ }
+ }
+ }
+ }
+
// Each case will get its own block, and the first item in each
// case is writing the discriminant. After that if we have a
// payload we write the payload after the discriminant, aligned up
@@ -1811,24 +1778,6 @@ impl<'a, B: Bindgen> Generator<'a, B> {
TypeDefKind::List(_) => self.read_list_from_memory(ty, addr, offset),
- TypeDefKind::Record(r) if r.is_flags() => {
- match self.iface.flags_repr(r) {
- Some(repr) => {
- self.stack.push(addr);
- self.load_intrepr(offset, repr);
- }
- None => {
- for i in 0..r.num_i32s() {
- self.stack.push(addr.clone());
- self.emit(&I32Load {
- offset: offset + (i as i32) * 4,
- });
- }
- }
- }
- self.lift(ty);
- }
-
// Read and lift each field individually, adjusting the offset
// as we go along, then aggregate all the fields into the
// record.
@@ -1853,6 +1802,28 @@ impl<'a, B: Bindgen> Generator<'a, B> {
});
}
+ TypeDefKind::Flags(f) => {
+ match f.repr() {
+ FlagsRepr::U8 => {
+ self.stack.push(addr);
+ self.load_intrepr(offset, Int::U8);
+ }
+ FlagsRepr::U16 => {
+ self.stack.push(addr);
+ self.load_intrepr(offset, Int::U16);
+ }
+ FlagsRepr::U32(n) => {
+ for i in 0..n {
+ self.stack.push(addr.clone());
+ self.emit(&I32Load {
+ offset: offset + (i as i32) * 4,
+ });
+ }
+ }
+ }
+ self.lift(ty);
+ }
+
// Each case will get its own block, and we'll dispatch to the
// right block based on the `i32.load` we initially perform. Each
// individual block is pretty simple and just reads the payload type
diff --git a/crates/parser/src/ast.rs b/crates/parser/src/ast.rs
index 73bdf0196..23d24cfba 100644
--- a/crates/parser/src/ast.rs
+++ b/crates/parser/src/ast.rs
@@ -91,12 +91,12 @@ enum Type<'a> {
Name(Id<'a>),
List(Box>),
Record(Record<'a>),
+ Flags(Flags<'a>),
Variant(Variant<'a>),
}
struct Record<'a> {
tuple_hint: bool,
- flags_repr: Option>>,
fields: Vec>,
}
@@ -106,6 +106,15 @@ struct Field<'a> {
ty: Type<'a>,
}
+struct Flags<'a> {
+ flags: Vec>,
+}
+
+struct Flag<'a> {
+ docs: Docs<'a>,
+ name: Id<'a>,
+}
+
struct Variant<'a> {
tag: Option>>,
span: Span,
@@ -234,20 +243,14 @@ impl<'a> TypeDef<'a> {
fn parse_flags(tokens: &mut Tokenizer<'a>, docs: Docs<'a>) -> Result {
tokens.expect(Token::Flags)?;
let name = parse_id(tokens)?;
- let ty = Type::Record(Record {
- flags_repr: None,
- tuple_hint: false,
- fields: parse_list(
+ let ty = Type::Flags(Flags {
+ flags: parse_list(
tokens,
Token::LeftBrace,
Token::RightBrace,
|docs, tokens| {
let name = parse_id(tokens)?;
- Ok(Field {
- docs,
- name,
- ty: Type::Bool,
- })
+ Ok(Flag { docs, name })
},
)?,
});
@@ -258,7 +261,6 @@ impl<'a> TypeDef<'a> {
tokens.expect(Token::Record)?;
let name = parse_id(tokens)?;
let ty = Type::Record(Record {
- flags_repr: None,
tuple_hint: false,
fields: parse_list(
tokens,
@@ -475,7 +477,6 @@ impl<'a> Type<'a> {
)?;
Ok(Type::Record(Record {
fields,
- flags_repr: None,
tuple_hint: true,
}))
}
diff --git a/crates/parser/src/ast/resolve.rs b/crates/parser/src/ast/resolve.rs
index af2b42866..964ee6fc1 100644
--- a/crates/parser/src/ast/resolve.rs
+++ b/crates/parser/src/ast/resolve.rs
@@ -21,6 +21,7 @@ pub struct Resolver {
enum Key {
Variant(Vec<(String, Option)>),
Record(Vec<(String, Type)>),
+ Flags(Vec),
List(Type),
}
@@ -204,6 +205,7 @@ impl Resolver {
.collect(),
kind: r.kind,
}),
+ TypeDefKind::Flags(f) => TypeDefKind::Flags(f.clone()),
TypeDefKind::Variant(v) => TypeDefKind::Variant(Variant {
cases: v
.cases
@@ -367,20 +369,23 @@ impl Resolver {
TypeDefKind::Record(Record {
kind: if record.tuple_hint {
RecordKind::Tuple
- } else if let Some(hint) = &record.flags_repr {
- RecordKind::Flags(Some(match &**hint {
- super::Type::U8 => Int::U8,
- super::Type::U16 => Int::U16,
- super::Type::U32 => Int::U32,
- super::Type::U64 => Int::U64,
- _ => panic!("unknown explicit flags repr"),
- }))
} else {
- RecordKind::infer(&self.types, &fields)
+ RecordKind::infer(&fields)
},
fields,
})
}
+ super::Type::Flags(flags) => {
+ let flags = flags
+ .flags
+ .iter()
+ .map(|flag| Flag {
+ docs: self.docs(&flag.docs),
+ name: flag.name.name.to_string(),
+ })
+ .collect::>();
+ TypeDefKind::Flags(Flags { flags })
+ }
super::Type::Variant(variant) => {
if variant.cases.is_empty() {
return Err(Error {
@@ -465,6 +470,9 @@ impl Resolver {
.map(|case| (case.name.clone(), case.ty))
.collect::>(),
),
+ TypeDefKind::Flags(r) => {
+ Key::Flags(r.flags.iter().map(|f| f.name.clone()).collect::>())
+ }
TypeDefKind::List(ty) => Key::List(*ty),
};
let types = &mut self.types;
@@ -624,7 +632,7 @@ impl Resolver {
}
}
- TypeDefKind::List(_) | TypeDefKind::Type(_) => {}
+ TypeDefKind::Flags(_) | TypeDefKind::List(_) | TypeDefKind::Type(_) => {}
}
valid.insert(ty);
diff --git a/crates/parser/src/lib.rs b/crates/parser/src/lib.rs
index 9096a1c66..c37e7ab04 100644
--- a/crates/parser/src/lib.rs
+++ b/crates/parser/src/lib.rs
@@ -46,6 +46,7 @@ pub struct TypeDef {
#[derive(Debug)]
pub enum TypeDefKind {
Record(Record),
+ Flags(Flags),
Variant(Variant),
List(Type),
Type(Type),
@@ -88,7 +89,6 @@ pub struct Record {
#[derive(Copy, Clone, Debug)]
pub enum RecordKind {
Other,
- Flags(Option),
Tuple,
}
@@ -103,27 +103,14 @@ impl Record {
pub fn is_tuple(&self) -> bool {
matches!(self.kind, RecordKind::Tuple)
}
-
- pub fn is_flags(&self) -> bool {
- matches!(self.kind, RecordKind::Flags(_))
- }
-
- pub fn num_i32s(&self) -> usize {
- (self.fields.len() + 31) / 32
- }
}
impl RecordKind {
- fn infer(types: &Arena, fields: &[Field]) -> RecordKind {
+ fn infer(fields: &[Field]) -> RecordKind {
if fields.is_empty() {
return RecordKind::Other;
}
- // Structs-of-bools are classified to get represented as bitflags.
- if fields.iter().all(|t| is_bool(&t.ty, types)) {
- return RecordKind::Flags(None);
- }
-
// fields with consecutive integer names get represented as tuples.
if fields
.iter()
@@ -134,16 +121,43 @@ impl RecordKind {
}
return RecordKind::Other;
+ }
+}
- fn is_bool(t: &Type, types: &Arena) -> bool {
- match t {
- Type::Bool => true,
- Type::Id(v) => match &types[*v].kind {
- TypeDefKind::Type(t) => is_bool(t, types),
- _ => false,
- },
- _ => false,
- }
+#[derive(Debug, Clone)]
+pub struct Flags {
+ pub flags: Vec,
+}
+
+#[derive(Debug, Clone)]
+pub struct Flag {
+ pub docs: Docs,
+ pub name: String,
+}
+
+#[derive(Debug)]
+pub enum FlagsRepr {
+ U8,
+ U16,
+ U32(usize),
+}
+
+impl Flags {
+ pub fn repr(&self) -> FlagsRepr {
+ match self.flags.len() {
+ n if n <= 8 => FlagsRepr::U8,
+ n if n <= 16 => FlagsRepr::U16,
+ n => FlagsRepr::U32(sizealign::align_to(n, 32) / 32),
+ }
+ }
+}
+
+impl FlagsRepr {
+ pub fn count(&self) -> usize {
+ match self {
+ FlagsRepr::U8 => 1,
+ FlagsRepr::U16 => 1,
+ FlagsRepr::U32(n) => *n,
}
}
}
@@ -392,6 +406,7 @@ impl Interface {
return;
}
match &self.types[id].kind {
+ TypeDefKind::Flags(_) => {}
TypeDefKind::Type(t) | TypeDefKind::List(t) => self.topo_visit_ty(t, list, visited),
TypeDefKind::Record(r) => {
for f in r.fields.iter() {
@@ -435,6 +450,11 @@ impl Interface {
TypeDefKind::List(_) | TypeDefKind::Variant(_) => false,
TypeDefKind::Type(t) => self.all_bits_valid(t),
TypeDefKind::Record(r) => r.fields.iter().all(|f| self.all_bits_valid(&f.ty)),
+
+ // FIXME: this could perhaps be `true` for multiples-of-32 but
+ // seems better to probably leave this as unconditionally
+ // `false` for now, may want to reconsider later?
+ TypeDefKind::Flags(_) => false,
},
}
}
diff --git a/crates/parser/src/sizealign.rs b/crates/parser/src/sizealign.rs
index 780610f22..5037f8bf4 100644
--- a/crates/parser/src/sizealign.rs
+++ b/crates/parser/src/sizealign.rs
@@ -1,4 +1,4 @@
-use crate::{Int, Interface, Record, RecordKind, Type, TypeDef, TypeDefKind, Variant};
+use crate::{FlagsRepr, Int, Interface, Record, Type, TypeDef, TypeDefKind, Variant};
#[derive(Default)]
pub struct SizeAlign {
@@ -18,19 +18,12 @@ impl SizeAlign {
match &ty.kind {
TypeDefKind::Type(t) => (self.size(t), self.align(t)),
TypeDefKind::List(_) => (8, 4),
- TypeDefKind::Record(r) => {
- if let RecordKind::Flags(repr) = r.kind {
- return match repr {
- Some(i) => int_size_align(i),
- None if r.fields.len() <= 8 => (1, 1),
- None if r.fields.len() <= 16 => (2, 2),
- None if r.fields.len() <= 32 => (4, 4),
- None if r.fields.len() <= 64 => (8, 8),
- None => (r.num_i32s() * 4, 4),
- };
- }
- self.record(r.fields.iter().map(|f| &f.ty))
- }
+ TypeDefKind::Record(r) => self.record(r.fields.iter().map(|f| &f.ty)),
+ TypeDefKind::Flags(f) => match f.repr() {
+ FlagsRepr::U8 => (1, 1),
+ FlagsRepr::U16 => (2, 2),
+ FlagsRepr::U32(n) => (n * 4, 4),
+ },
TypeDefKind::Variant(v) => {
let (discrim_size, discrim_align) = int_size_align(v.tag);
let mut size = discrim_size;
diff --git a/crates/parser/tests/all.rs b/crates/parser/tests/all.rs
index e7c27c5f6..1ac96449d 100644
--- a/crates/parser/tests/all.rs
+++ b/crates/parser/tests/all.rs
@@ -191,6 +191,9 @@ fn to_json(i: &Interface) -> String {
Record {
fields: Vec<(String, String)>,
},
+ Flags {
+ flags: Vec,
+ },
Variant {
cases: Vec<(String, Option)>,
},
@@ -268,6 +271,9 @@ fn to_json(i: &Interface) -> String {
.map(|f| (f.name.clone(), translate_type(&f.ty)))
.collect(),
},
+ TypeDefKind::Flags(r) => Type::Flags {
+ flags: r.flags.iter().map(|f| f.name.clone()).collect(),
+ },
TypeDefKind::Variant(v) => Type::Variant {
cases: v
.cases
diff --git a/crates/parser/tests/ui/types.wit.result b/crates/parser/tests/ui/types.wit.result
index a6ea850da..a36f72571 100644
--- a/crates/parser/tests/ui/types.wit.result
+++ b/crates/parser/tests/ui/types.wit.result
@@ -291,47 +291,29 @@
{
"idx": 31,
"name": "t30",
- "record": {
- "fields": []
+ "flags": {
+ "flags": []
}
},
{
"idx": 32,
"name": "t31",
- "record": {
- "fields": [
- [
- "a",
- "bool"
- ],
- [
- "b",
- "bool"
- ],
- [
- "c",
- "bool"
- ]
+ "flags": {
+ "flags": [
+ "a",
+ "b",
+ "c"
]
}
},
{
"idx": 33,
"name": "t32",
- "record": {
- "fields": [
- [
- "a",
- "bool"
- ],
- [
- "b",
- "bool"
- ],
- [
- "c",
- "bool"
- ]
+ "flags": {
+ "flags": [
+ "a",
+ "b",
+ "c"
]
}
},
diff --git a/crates/test-helpers/src/lib.rs b/crates/test-helpers/src/lib.rs
index 8adbe5712..1cf23b5ad 100644
--- a/crates/test-helpers/src/lib.rs
+++ b/crates/test-helpers/src/lib.rs
@@ -179,6 +179,7 @@ pub fn codegen_rust_wasm_export(input: TokenStream) -> TokenStream {
let t = quote_ty(param, iface, t);
quote::quote! { Vec<#t> }
}
+ TypeDefKind::Flags(_) => panic!("unknown flags"),
TypeDefKind::Record(r) => {
let fields = r.fields.iter().map(|f| quote_ty(param, iface, &f.ty));
quote::quote! { (#(#fields,)*) }
diff --git a/crates/test-modules/modules/crates/records/records.wit b/crates/test-modules/modules/crates/records/records.wit
index 2dfd56668..92afac641 100644
--- a/crates/test-modules/modules/crates/records/records.wit
+++ b/crates/test-modules/modules/crates/records/records.wit
@@ -14,16 +14,16 @@ record scalars {
scalar-arg: function(x: scalars)
scalar-result: function() -> scalars
-record really-flags {
- a: bool,
- b: bool,
- c: bool,
- d: bool,
- e: bool,
- f: bool,
- g: bool,
- h: bool,
- i: bool,
+flags really-flags {
+ a,
+ b,
+ c,
+ d,
+ e,
+ f,
+ g,
+ h,
+ i,
}
flags-arg: function(x: really-flags)
diff --git a/crates/wasmlink/src/adapter/call.rs b/crates/wasmlink/src/adapter/call.rs
index 6bd7dc69c..f5364266f 100644
--- a/crates/wasmlink/src/adapter/call.rs
+++ b/crates/wasmlink/src/adapter/call.rs
@@ -451,17 +451,12 @@ impl<'a> CallAdapter<'a> {
operands: element_operands,
});
}
+ TypeDefKind::Flags(r) => {
+ for _ in 0..r.repr().count() {
+ params.next().unwrap();
+ }
+ }
TypeDefKind::Record(r) => match r.kind {
- RecordKind::Flags(_) => match interface.flags_repr(r) {
- Some(_) => {
- params.next().unwrap();
- }
- None => {
- for _ in 0..r.num_i32s() {
- params.next().unwrap();
- }
- }
- },
RecordKind::Tuple | RecordKind::Other => {
for f in &r.fields {
Self::push_operands(
@@ -612,8 +607,8 @@ impl<'a> CallAdapter<'a> {
operands: element_operands,
});
}
+ TypeDefKind::Flags(_) => {}
TypeDefKind::Record(r) => match r.kind {
- RecordKind::Flags(_) => {}
RecordKind::Tuple | RecordKind::Other => {
let offsets = sizes.field_offsets(r);
diff --git a/crates/wasmlink/src/module.rs b/crates/wasmlink/src/module.rs
index bf60744ef..255439a6d 100644
--- a/crates/wasmlink/src/module.rs
+++ b/crates/wasmlink/src/module.rs
@@ -48,6 +48,7 @@ fn has_list(interface: &WitInterface, ty: &WitType) -> bool {
Type::Id(id) => match &interface.types[*id].kind {
TypeDefKind::List(_) => true,
TypeDefKind::Type(t) => has_list(interface, t),
+ TypeDefKind::Flags(_) => false,
TypeDefKind::Record(r) => r.fields.iter().any(|f| has_list(interface, &f.ty)),
TypeDefKind::Variant(v) => v.cases.iter().any(|c| {
c.ty.as_ref()
diff --git a/crates/wasmlink/tests/flags.baseline b/crates/wasmlink/tests/flags.baseline
index 52b8c442b..9611fc2a4 100644
--- a/crates/wasmlink/tests/flags.baseline
+++ b/crates/wasmlink/tests/flags.baseline
@@ -1,9 +1,10 @@
(module
(type (;0;) (func (param i32) (result i32)))
- (type (;1;) (func (param i64) (result i64)))
+ (type (;1;) (func (param i32 i32 i32)))
+ (import "$parent" "memory" (memory (;0;) 0))
(module (;0;)
(type (;0;) (func (param i32) (result i32)))
- (type (;1;) (func (param i64) (result i64)))
+ (type (;1;) (func (param i32 i32) (result i32)))
(func (;0;) (type 0) (param i32) (result i32)
unreachable)
(func (;1;) (type 0) (param i32) (result i32)
@@ -16,8 +17,10 @@
unreachable)
(func (;5;) (type 0) (param i32) (result i32)
unreachable)
- (func (;6;) (type 1) (param i64) (result i64)
+ (func (;6;) (type 1) (param i32 i32) (result i32)
unreachable)
+ (memory (;0;) 1)
+ (export "memory" (memory 0))
(export "roundtrip-flag1" (func 0))
(export "roundtrip-flag2" (func 1))
(export "roundtrip-flag4" (func 2))
@@ -25,15 +28,16 @@
(export "roundtrip-flag16" (func 4))
(export "roundtrip-flag32" (func 5))
(export "roundtrip-flag64" (func 6)))
- (instance (;0;)
+ (instance (;1;)
(instantiate 0))
- (alias 0 "roundtrip-flag1" (func (;0;)))
- (alias 0 "roundtrip-flag2" (func (;1;)))
- (alias 0 "roundtrip-flag4" (func (;2;)))
- (alias 0 "roundtrip-flag8" (func (;3;)))
- (alias 0 "roundtrip-flag16" (func (;4;)))
- (alias 0 "roundtrip-flag32" (func (;5;)))
- (alias 0 "roundtrip-flag64" (func (;6;)))
+ (alias 1 "memory" (memory (;1;)))
+ (alias 1 "roundtrip-flag1" (func (;0;)))
+ (alias 1 "roundtrip-flag2" (func (;1;)))
+ (alias 1 "roundtrip-flag4" (func (;2;)))
+ (alias 1 "roundtrip-flag8" (func (;3;)))
+ (alias 1 "roundtrip-flag16" (func (;4;)))
+ (alias 1 "roundtrip-flag32" (func (;5;)))
+ (alias 1 "roundtrip-flag64" (func (;6;)))
(func (;7;) (type 0) (param i32) (result i32)
local.get 0
call 0)
@@ -52,9 +56,17 @@
(func (;12;) (type 0) (param i32) (result i32)
local.get 0
call 5)
- (func (;13;) (type 1) (param i64) (result i64)
+ (func (;13;) (type 1) (param i32 i32 i32)
+ (local i32)
local.get 0
- call 6)
+ local.get 1
+ call 6
+ local.set 3
+ local.get 2
+ local.get 3
+ i32.const 8
+ memory.copy 0 1)
+ (export "memory" (memory 1))
(export "roundtrip-flag1" (func 7))
(export "roundtrip-flag2" (func 8))
(export "roundtrip-flag4" (func 9))
diff --git a/crates/wasmlink/tests/flags.wat b/crates/wasmlink/tests/flags.wat
index cd395d28b..ae7f34ac1 100644
--- a/crates/wasmlink/tests/flags.wat
+++ b/crates/wasmlink/tests/flags.wat
@@ -1,4 +1,5 @@
(module
+ (memory (export "memory") 1)
(func (export "roundtrip-flag1") (param i32) (result i32)
unreachable
)
@@ -17,7 +18,7 @@
(func (export "roundtrip-flag32") (param i32) (result i32)
unreachable
)
- (func (export "roundtrip-flag64") (param i64) (result i64)
+ (func (export "roundtrip-flag64") (param i32 i32) (result i32)
unreachable
)
)
diff --git a/crates/wasmlink/tests/records.wit b/crates/wasmlink/tests/records.wit
index 2dfd56668..92afac641 100644
--- a/crates/wasmlink/tests/records.wit
+++ b/crates/wasmlink/tests/records.wit
@@ -14,16 +14,16 @@ record scalars {
scalar-arg: function(x: scalars)
scalar-result: function() -> scalars
-record really-flags {
- a: bool,
- b: bool,
- c: bool,
- d: bool,
- e: bool,
- f: bool,
- g: bool,
- h: bool,
- i: bool,
+flags really-flags {
+ a,
+ b,
+ c,
+ d,
+ e,
+ f,
+ g,
+ h,
+ i,
}
flags-arg: function(x: really-flags)
diff --git a/crates/wasmtime/src/lib.rs b/crates/wasmtime/src/lib.rs
index 2c2f49129..e515b46ad 100644
--- a/crates/wasmtime/src/lib.rs
+++ b/crates/wasmtime/src/lib.rs
@@ -73,13 +73,16 @@ pub mod rt {
Trap::new(msg)
}
- pub fn validate_flags(
- bits: i64,
- all: i64,
+ pub fn validate_flags(
+ bits: T,
+ all: T,
name: &str,
- mk: impl FnOnce(i64) -> U,
- ) -> Result {
- if bits & !all != 0 {
+ mk: impl FnOnce(T) -> U,
+ ) -> Result
+ where
+ T: std::ops::Not