diff --git a/crates/csharp/src/function.rs b/crates/csharp/src/function.rs index 5f1d1ec20..0700cde17 100644 --- a/crates/csharp/src/function.rs +++ b/crates/csharp/src/function.rs @@ -225,6 +225,148 @@ impl<'a, 'b> FunctionBindgen<'a, 'b> { results.push(lifted); } + + fn handle_result_import(&mut self, operands: &mut Vec) { + if self.interface_gen.csharp_gen.opts.with_wit_results { + uwriteln!(self.src, "return {};", operands[0]); + return; + } + + let mut payload_is_void = false; + let mut previous = operands[0].clone(); + let mut vars: Vec<(String, Option)> = Vec::with_capacity(self.results.len()); + if let Direction::Import = self.interface_gen.direction { + for ty in &self.results { + let tmp = self.locals.tmp("tmp"); + uwrite!( + self.src, + "\ + if ({previous}.IsOk) + {{ + var {tmp} = {previous}.AsOk; + " + ); + let TypeDefKind::Result(result) = &self.interface_gen.resolve.types[*ty].kind + else { + unreachable!(); + }; + let exception_name = result + .err + .map(|ty| self.interface_gen.type_name_with_qualifier(&ty, true)); + vars.push((previous.clone(), exception_name)); + payload_is_void = result.ok.is_none(); + previous = tmp; + } + } + uwriteln!( + self.src, + "return {};", + if payload_is_void { "" } else { &previous } + ); + for (level, var) in vars.iter().enumerate().rev() { + self.interface_gen.csharp_gen.needs_wit_exception = true; + let (var_name, exception_name) = var; + let exception_name = match exception_name { + Some(type_name) => &format!("WitException<{}>", type_name), + None => "WitException", + }; + uwrite!( + self.src, + "\ + }} + else + {{ + throw new {exception_name}({var_name}.AsErr!, {level}); + }} + " + ); + } + } + + fn handle_result_call( + &mut self, + func: &&wit_parser::Function, + target: String, + func_name: String, + oper: String, + ) -> String { + let ret = self.locals.tmp("ret"); + if self.interface_gen.csharp_gen.opts.with_wit_results { + uwriteln!(self.src, "var {ret} = {target}.{func_name}({oper});"); + return ret; + } + + // otherwise generate exception code + let ty = self + .interface_gen + .type_name_with_qualifier(func.results.iter_types().next().unwrap(), true); + uwriteln!(self.src, "{ty} {ret};"); + let mut cases = Vec::with_capacity(self.results.len()); + let mut oks = Vec::with_capacity(self.results.len()); + let mut payload_is_void = false; + for (index, ty) in self.results.iter().enumerate() { + let TypeDefKind::Result(result) = &self.interface_gen.resolve.types[*ty].kind else { + unreachable!(); + }; + let err_ty = if let Some(ty) = result.err { + self.interface_gen.type_name_with_qualifier(&ty, true) + } else { + "None".to_owned() + }; + let ty = self + .interface_gen + .type_name_with_qualifier(&Type::Id(*ty), true); + let head = oks.concat(); + let tail = oks.iter().map(|_| ")").collect::>().concat(); + cases.push(format!( + "\ + case {index}: + {{ + ret = {head}{ty}.Err(({err_ty}) e.Value){tail}; + break; + }} + " + )); + oks.push(format!("{ty}.Ok(")); + payload_is_void = result.ok.is_none(); + } + if !self.results.is_empty() { + self.src.push_str( + " + try + {\n + ", + ); + } + let head = oks.concat(); + let tail = oks.iter().map(|_| ")").collect::>().concat(); + let val = if payload_is_void { + uwriteln!(self.src, "{target}.{func_name}({oper});"); + "new None()".to_owned() + } else { + format!("{target}.{func_name}({oper})") + }; + uwriteln!(self.src, "{ret} = {head}{val}{tail};"); + if !self.results.is_empty() { + self.interface_gen.csharp_gen.needs_wit_exception = true; + let cases = cases.join("\n"); + uwriteln!( + self.src, + r#"}} + catch (WitException e) + {{ + switch (e.NestingLevel) + {{ + {cases} + + default: throw new ArgumentException($"invalid nesting level: {{e.NestingLevel}}"); + }} + }} + "# + ); + } + ret + } } impl Bindgen for FunctionBindgen<'_, '_> { @@ -814,70 +956,7 @@ impl Bindgen for FunctionBindgen<'_, '_> { match func.results.len() { 0 => uwriteln!(self.src, "{target}.{func_name}({oper});"), 1 => { - let ret = self.locals.tmp("ret"); - let ty = self.interface_gen.type_name_with_qualifier( - func.results.iter_types().next().unwrap(), - true - ); - uwriteln!(self.src, "{ty} {ret};"); - let mut cases = Vec::with_capacity(self.results.len()); - let mut oks = Vec::with_capacity(self.results.len()); - let mut payload_is_void = false; - for (index, ty) in self.results.iter().enumerate() { - let TypeDefKind::Result(result) = &self.interface_gen.resolve.types[*ty].kind else { - unreachable!(); - }; - let err_ty = if let Some(ty) = result.err { - self.interface_gen.type_name_with_qualifier(&ty, true) - } else { - "None".to_owned() - }; - let ty = self.interface_gen.type_name_with_qualifier(&Type::Id(*ty), true); - let head = oks.concat(); - let tail = oks.iter().map(|_| ")").collect::>().concat(); - cases.push( - format!( - "\ - case {index}: {{ - ret = {head}{ty}.Err(({err_ty}) e.Value){tail}; - break; - }} - " - ) - ); - oks.push(format!("{ty}.Ok(")); - payload_is_void = result.ok.is_none(); - } - if !self.results.is_empty() { - self.src.push_str("try {\n"); - } - let head = oks.concat(); - let tail = oks.iter().map(|_| ")").collect::>().concat(); - let val = if payload_is_void { - uwriteln!(self.src, "{target}.{func_name}({oper});"); - "new None()".to_owned() - } else { - format!("{target}.{func_name}({oper})") - }; - uwriteln!( - self.src, - "{ret} = {head}{val}{tail};" - ); - if !self.results.is_empty() { - self.interface_gen.csharp_gen.needs_wit_exception = true; - let cases = cases.join("\n"); - uwriteln!( - self.src, - r#"}} catch (WitException e) {{ - switch (e.NestingLevel) {{ - {cases} - - default: throw new ArgumentException($"invalid nesting level: {{e.NestingLevel}}"); - }} - }} - "# - ); - } + let ret = self.handle_result_call(func, target, func_name, oper); results.push(ret); } _ => { @@ -927,46 +1006,7 @@ impl Bindgen for FunctionBindgen<'_, '_> { match func.results.len() { 0 => (), 1 => { - let mut payload_is_void = false; - let mut previous = operands[0].clone(); - let mut vars: Vec<(String, Option)> = Vec::with_capacity(self.results.len()); - if let Direction::Import = self.interface_gen.direction { - for ty in &self.results { - let tmp = self.locals.tmp("tmp"); - uwrite!( - self.src, - "\ - if ({previous}.IsOk) {{ - var {tmp} = {previous}.AsOk; - " - ); - let TypeDefKind::Result(result) = &self.interface_gen.resolve.types[*ty].kind else { - unreachable!(); - }; - let exception_name = result.err - .map(|ty| self.interface_gen.type_name_with_qualifier(&ty, true)); - vars.push((previous.clone(), exception_name)); - payload_is_void = result.ok.is_none(); - previous = tmp; - } - } - uwriteln!(self.src, "return {};", if payload_is_void { "" } else { &previous }); - for (level, var) in vars.iter().enumerate().rev() { - self.interface_gen.csharp_gen.needs_wit_exception = true; - let (var_name, exception_name) = var; - let exception_name = match exception_name { - Some(type_name) => &format!("WitException<{}>",type_name), - None => "WitException", - }; - uwrite!( - self.src, - "\ - }} else {{ - throw new {exception_name}({var_name}.AsErr!, {level}); - }} - " - ); - } + self.handle_result_import(operands); } _ => { let results = operands.join(", "); diff --git a/crates/csharp/src/interface.rs b/crates/csharp/src/interface.rs index 04226c6c0..57a177797 100644 --- a/crates/csharp/src/interface.rs +++ b/crates/csharp/src/interface.rs @@ -212,6 +212,7 @@ impl InterfaceGenerator<'_> { let (payload, results) = payload_and_results( self.resolve, *func.results.iter_types().next().unwrap(), + self.csharp_gen.opts.with_wit_results, ); ( if let Some(ty) = payload { @@ -358,6 +359,7 @@ impl InterfaceGenerator<'_> { let (payload, results) = payload_and_results( self.resolve, *func.results.iter_types().next().unwrap(), + self.csharp_gen.opts.with_wit_results, ); ( if let Some(ty) = payload { @@ -842,6 +844,7 @@ impl InterfaceGenerator<'_> { let (payload, _) = payload_and_results( self.resolve, *func.results.iter_types().next().unwrap(), + self.csharp_gen.opts.with_wit_results, ); if let Some(ty) = payload { self.csharp_gen.needs_result = true; @@ -1160,7 +1163,15 @@ impl<'a> CoreInterfaceGenerator<'a> for InterfaceGenerator<'a> { } } -fn payload_and_results(resolve: &Resolve, ty: Type) -> (Option, Vec) { +fn payload_and_results( + resolve: &Resolve, + ty: Type, + with_wit_results: bool, +) -> (Option, Vec) { + if with_wit_results { + return (Some(ty), Vec::new()); + } + fn recurse(resolve: &Resolve, ty: Type, results: &mut Vec) -> Option { if let Type::Id(id) = ty { if let TypeDefKind::Result(result) = &resolve.types[id].kind { diff --git a/crates/csharp/src/lib.rs b/crates/csharp/src/lib.rs index 06b3dea10..84e80397f 100644 --- a/crates/csharp/src/lib.rs +++ b/crates/csharp/src/lib.rs @@ -30,6 +30,10 @@ pub struct Opts { /// Skip generating `cabi_realloc`, `WasmImportLinkageAttribute`, and component type files #[cfg_attr(feature = "clap", arg(long))] pub skip_support_files: bool, + + /// Generate code for WIT `Result` types instead of exceptions + #[cfg_attr(feature = "clap", arg(long))] + pub with_wit_results: bool, } impl Opts { diff --git a/crates/csharp/tests/codegen.rs b/crates/csharp/tests/codegen.rs index 6d10104cd..94941cc32 100644 --- a/crates/csharp/tests/codegen.rs +++ b/crates/csharp/tests/codegen.rs @@ -25,6 +25,7 @@ macro_rules! codegen_test { runtime: wit_bindgen_csharp::CSharpRuntime::Mono, internal: false, skip_support_files: false, + with_wit_results: false, } .build() .generate(resolve, world, files) diff --git a/tests/runtime/main.rs b/tests/runtime/main.rs index 0b3ae64f3..02fb4592c 100644 --- a/tests/runtime/main.rs +++ b/tests/runtime/main.rs @@ -663,17 +663,26 @@ fn tests(name: &str, dir_name: &str) -> Result> { let (resolve, world) = resolve_wit_dir(&dir); for path in c_sharp.iter() { let world_name = &resolve.worlds[world].name; - let out_dir = out_dir.join(format!("csharp-{}", world_name)); - drop(fs::remove_dir_all(&out_dir)); - fs::create_dir_all(&out_dir).unwrap(); - for csharp_impl in &c_sharp { - fs::copy( - &csharp_impl, - &out_dir.join(csharp_impl.file_name().unwrap()), - ) - .unwrap(); - } + let gen_option: &str = &path + .file_stem() + .and_then(|s| s.to_str()) + .unwrap() + .split('_') + .skip(1) + .collect::>() + .join("-"); + + let test_dir = if gen_option.is_empty() { + out_dir.join(format!("csharp-{}", world_name)) + } else { + out_dir.join(format!("csharp-{}-{}", world_name, gen_option)) + }; + + drop(fs::remove_dir_all(&test_dir)); + fs::create_dir_all(&test_dir).unwrap(); + + fs::copy(&path, &test_dir.join(path.file_name().unwrap())).unwrap(); let snake = world_name.replace("-", "_"); let camel = format!("{}World", snake.to_upper_camel_case()); @@ -683,10 +692,14 @@ fn tests(name: &str, dir_name: &str) -> Result> { path.file_stem().and_then(|s| s.to_str()).unwrap() ); - let out_wasm = out_dir.join(&assembly_name); + let out_wasm = test_dir.join(&assembly_name); let mut files = Default::default(); let mut opts = wit_bindgen_csharp::Opts::default(); + match gen_option { + "with-wit-results" => opts.with_wit_results = true, + _ => {} + }; if let Some(path) = path.file_name().and_then(|s| s.to_str()) { if path.contains("utf16") { opts.string_encoding = wit_component::StringEncoding::UTF16; @@ -695,17 +708,17 @@ fn tests(name: &str, dir_name: &str) -> Result> { opts.build().generate(&resolve, world, &mut files).unwrap(); for (file, contents) in files.iter() { - let dst = out_dir.join(file); + let dst = test_dir.join(file); fs::write(dst, contents).unwrap(); } let mut csproj = - wit_bindgen_csharp::CSProject::new(out_dir.clone(), &assembly_name, world_name); + wit_bindgen_csharp::CSProject::new(test_dir.clone(), &assembly_name, world_name); csproj.aot(); // Copy test file to target location to be included in compilation let file_name = path.file_name().unwrap(); - fs::copy(path, out_dir.join(file_name.to_str().unwrap()))?; + fs::copy(path, test_dir.join(file_name.to_str().unwrap()))?; csproj.generate()?; @@ -720,11 +733,11 @@ fn tests(name: &str, dir_name: &str) -> Result> { let mut wasm_filename = out_wasm.join(assembly_name); wasm_filename.set_extension("wasm"); - cmd.current_dir(&out_dir); + cmd.current_dir(&test_dir); // add .arg("/bl") to diagnose dotnet build problems cmd.arg("publish") - .arg(out_dir.join(format!("{camel}.csproj"))) + .arg(test_dir.join(format!("{camel}.csproj"))) .arg("-r") .arg("wasi-wasm") .arg("-c") diff --git a/tests/runtime/results/wasm_with_wit_results.cs b/tests/runtime/results/wasm_with_wit_results.cs new file mode 100644 index 000000000..e32ab8bb9 --- /dev/null +++ b/tests/runtime/results/wasm_with_wit_results.cs @@ -0,0 +1,76 @@ +namespace ResultsWorld.wit.exports.test.results +{ + public class TestImpl : ITest + { + public static Result StringError(float a) + { + return imports.test.results.TestInterop.StringError(a); + } + + public static Result EnumError(float a) + { + var result = imports.test.results.TestInterop.EnumError(a); + if (result.IsOk) { + return Result.Ok(result.AsOk); + } else { + switch (result.AsErr){ + case imports.test.results.ITest.E.A: + return Result.Err(ITest.E.A); + case imports.test.results.ITest.E.B: + return Result.Err(ITest.E.B); + case imports.test.results.ITest.E.C: + return Result.Err(ITest.E.C); + default: + throw new Exception("unreachable"); + } + } + } + + public static Result RecordError(float a) + { + var result = imports.test.results.TestInterop.RecordError(a); + if (result.IsOk) { + return Result.Ok(result.AsOk); + } else { + switch (result.AsErr) { + case imports.test.results.ITest.E2: + return Result.Err(new ITest.E2(result.AsErr.line, result.AsErr.column)); + default: + throw new Exception("unreachable"); + } + } + } + + public static Result VariantError(float a) + { + var result = imports.test.results.TestInterop.VariantError(a); + if (result.IsOk) { + return Result.Ok(result.AsOk); + } else { + switch (result.AsErr) { + case imports.test.results.ITest.E3: + switch (result.AsErr.Tag){ + case imports.test.results.ITest.E3.Tags.E1: + return Result.Err(ITest.E3.E1((ITest.E)Enum.Parse(typeof(ITest.E), result.AsErr.AsE1.ToString()))); + case imports.test.results.ITest.E3.Tags.E2: + return Result.Err(ITest.E3.E2(new ITest.E2(result.AsErr.AsE2.line, result.AsErr.AsE2.column))); + default: + throw new Exception("unreachable"); + } + default: + throw new Exception("unreachable"); + } + } + } + + public static Result EmptyError(uint a) + { + return imports.test.results.TestInterop.EmptyError(a); + } + + public static Result, string> DoubleError(uint a) + { + return imports.test.results.TestInterop.DoubleError(a); + } + } +}