From 5d19b061ce6922670177c9024c8391e03e03d992 Mon Sep 17 00:00:00 2001 From: khyperia <953151+khyperia@users.noreply.github.com> Date: Wed, 27 Oct 2021 14:12:54 +0200 Subject: [PATCH] Implement bool fusion pass Fixes #677 --- crates/rustc_codegen_spirv/src/linker/mod.rs | 1 + .../src/linker/peephole_opts.rs | 168 +++++++++++++++++- tests/ui/lang/core/unwrap_or.rs | 12 ++ tests/ui/lang/core/unwrap_or.stderr | 39 ++++ 4 files changed, 219 insertions(+), 1 deletion(-) create mode 100644 tests/ui/lang/core/unwrap_or.rs create mode 100644 tests/ui/lang/core/unwrap_or.stderr diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index 351419b5e4..2a24c55baa 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -266,6 +266,7 @@ pub fn link(sess: &Session, mut inputs: Vec, opts: &Options) -> Result FxHashMap { @@ -447,3 +448,168 @@ pub fn vector_ops( } } } + +fn can_fuse_bool( + types: &FxHashMap, + defs: &FxHashMap, + inst: &Instruction, +) -> bool { + fn constant_value(types: &FxHashMap, val: Word) -> Option { + let inst = match types.get(&val) { + None => return None, + Some(inst) => inst, + }; + if inst.class.opcode != Op::Constant { + return None; + } + match inst.operands[0] { + Operand::LiteralInt32(v) => Some(v), + _ => None, + } + } + + fn visit( + types: &FxHashMap, + defs: &FxHashMap, + visited: &mut FxHashSet, + value: Word, + ) -> bool { + if visited.insert(value) { + let inst = match defs.get(&value) { + Some((_, inst)) => inst, + None => return false, + }; + match inst.class.opcode { + Op::Select => { + constant_value(types, inst.operands[1].unwrap_id_ref()) == Some(1) + && constant_value(types, inst.operands[2].unwrap_id_ref()) == Some(0) + } + Op::Phi => inst + .operands + .iter() + .step_by(2) + .all(|op| visit(types, defs, visited, op.unwrap_id_ref())), + _ => false, + } + } else { + true + } + } + + if inst.class.opcode != Op::INotEqual + || constant_value(types, inst.operands[1].unwrap_id_ref()) != Some(0) + { + return false; + } + let int_value = inst.operands[0].unwrap_id_ref(); + + visit(types, defs, &mut FxHashSet::default(), int_value) +} + +fn fuse_bool( + header: &mut ModuleHeader, + defs: &FxHashMap, + phis_to_insert: &mut Vec<(usize, Instruction)>, + already_mapped: &mut FxHashMap, + bool_ty: Word, + int_value: Word, +) -> Word { + if let Some(&result) = already_mapped.get(&int_value) { + return result; + } + let (block_of_inst, inst) = defs.get(&int_value).unwrap(); + match inst.class.opcode { + Op::Select => inst.operands[0].unwrap_id_ref(), + Op::Phi => { + let result_id = id(header); + already_mapped.insert(int_value, result_id); + let new_phi_args = inst + .operands + .chunks(2) + .flat_map(|arr| { + let phi_value = &arr[0]; + let block = &arr[1]; + [ + Operand::IdRef(fuse_bool( + header, + defs, + phis_to_insert, + already_mapped, + bool_ty, + phi_value.unwrap_id_ref(), + )), + block.clone(), + ] + }) + .collect::>(); + let inst = Instruction::new(Op::Phi, Some(bool_ty), Some(result_id), new_phi_args); + phis_to_insert.push((*block_of_inst, inst)); + result_id + } + _ => bug!("can_fuse_bool should have prevented this case"), + } +} + +// The compiler generates a lot of code that looks like this: +// %v_int = OpSelect %int %v %const_1 %const_0 +// %v2 = OpINotEqual %bool %v_int %const_0 +// (This is due to rustc/spirv not supporting bools in memory, and needing to convert to u8, but +// then things get inlined/mem2reg'd) +// +// This pass fuses together those two instructions to strip out the intermediate integer variable. +// The purpose is to make simple code that doesn't actually do memory-stuff with bools not require +// the Int8 capability (and so we can't rely on spirv-opt to do this same pass). +// +// Unfortunately, things get complicated because of phis: the majority of actually useful cases to +// do this pass need to track pseudo-bool ints through phi instructions. +// +// The logic goes like: +// 1) Figure out what we *can* fuse. This means finding OpINotEqual instructions (converting back +// from int->bool) and tracing the value back recursively through any phis, and making sure each +// one terminates in either a loop back around to something we've already seen, or an OpSelect +// (converting from bool->int). +// 2) Do the fusion. Trace back through phis, generating a second bool-typed phi alongside the +// original int-typed phi, and when hitting an OpSelect, taking the bool value directly. +// 3) DCE the dead OpSelects/int-typed OpPhis (done in a later pass). We don't nuke them here, +// since they might be used elsewhere, and don't want to accidentally leave a dangling +// reference. +pub fn bool_fusion( + header: &mut ModuleHeader, + types: &FxHashMap, + function: &mut Function, +) { + let defs: FxHashMap = function + .blocks + .iter() + .enumerate() + .flat_map(|(block_id, block)| { + block + .instructions + .iter() + .filter_map(move |inst| Some((inst.result_id?, (block_id, inst.clone())))) + }) + .collect(); + let mut rewrite_rules = FxHashMap::default(); + let mut phis_to_insert = Default::default(); + let mut already_mapped = Default::default(); + for block in &mut function.blocks { + for inst in &mut block.instructions { + if can_fuse_bool(types, &defs, inst) { + let rewrite_to = fuse_bool( + header, + &defs, + &mut phis_to_insert, + &mut already_mapped, + inst.result_type.unwrap(), + inst.operands[0].unwrap_id_ref(), + ); + rewrite_rules.insert(inst.result_id.unwrap(), rewrite_to); + *inst = Instruction::new(Op::Nop, None, None, Vec::new()); + } + } + } + for (block, phi) in phis_to_insert { + function.blocks[block].instructions.insert(0, phi); + } + super::apply_rewrite_rules(&rewrite_rules, &mut function.blocks); +} diff --git a/tests/ui/lang/core/unwrap_or.rs b/tests/ui/lang/core/unwrap_or.rs new file mode 100644 index 0000000000..84adbbf4bb --- /dev/null +++ b/tests/ui/lang/core/unwrap_or.rs @@ -0,0 +1,12 @@ +// unwrap_or generates some memory-bools (as u8). Test to make sure they're fused away. +// OpINotEqual, as well as %bool, should not appear in the output. + +// build-pass +// compile-flags: -C llvm-args=--disassemble-entry=main + +use spirv_std as _; + +#[spirv(fragment)] +pub fn main(out: &mut u32) { + *out = None.unwrap_or(15); +} diff --git a/tests/ui/lang/core/unwrap_or.stderr b/tests/ui/lang/core/unwrap_or.stderr new file mode 100644 index 0000000000..87b084357a --- /dev/null +++ b/tests/ui/lang/core/unwrap_or.stderr @@ -0,0 +1,39 @@ +%1 = OpFunction %2 None %3 +%4 = OpLabel +OpLine %5 11 11 +%6 = OpCompositeInsert %7 %8 %9 0 +OpLine %5 11 11 +%10 = OpCompositeExtract %11 %6 1 +OpLine %12 767 14 +%13 = OpBitcast %14 %8 +OpLine %12 767 8 +OpSelectionMerge %15 None +OpSwitch %13 %16 0 %17 1 %18 +%16 = OpLabel +OpLine %12 767 14 +OpUnreachable +%17 = OpLabel +OpLine %12 769 20 +OpBranch %15 +%18 = OpLabel +OpLine %12 771 4 +OpBranch %15 +%15 = OpLabel +%19 = OpPhi %20 %21 %17 %22 %18 +%23 = OpPhi %11 %24 %17 %10 %18 +OpBranch %25 +%25 = OpLabel +OpLine %12 771 4 +OpSelectionMerge %26 None +OpBranchConditional %19 %27 %28 +%27 = OpLabel +OpLine %12 771 4 +OpBranch %26 +%28 = OpLabel +OpBranch %26 +%26 = OpLabel +OpLine %5 11 4 +OpStore %29 %23 +OpLine %5 12 1 +OpReturn +OpFunctionEnd