Skip to content
This repository was archived by the owner on Oct 31, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/rustc_codegen_spirv/src/linker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
for func in &mut output.functions {
peephole_opts::composite_construct(&types, func);
peephole_opts::vector_ops(output.header.as_mut().unwrap(), &types, func);
peephole_opts::bool_fusion(output.header.as_mut().unwrap(), &types, func);
}
}

Expand Down
168 changes: 167 additions & 1 deletion crates/rustc_codegen_spirv/src/linker/peephole_opts.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::id;
use rspirv::dr::{Function, Instruction, Module, ModuleHeader, Operand};
use rspirv::spirv::{Op, Word};
use rustc_data_structures::fx::FxHashMap;
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_middle::bug;

pub fn collect_types(module: &Module) -> FxHashMap<Word, Instruction> {
Expand Down Expand Up @@ -447,3 +448,168 @@ pub fn vector_ops(
}
}
}

fn can_fuse_bool(
types: &FxHashMap<Word, Instruction>,
defs: &FxHashMap<Word, (usize, Instruction)>,
inst: &Instruction,
) -> bool {
fn constant_value(types: &FxHashMap<Word, Instruction>, val: Word) -> Option<u32> {
let inst = match types.get(&val) {
None => return None,
Some(inst) => inst,
};
if inst.class.opcode != Op::Constant {
return None;
}
match inst.operands[0] {
Operand::LiteralInt32(v) => Some(v),
_ => None,
}
}

fn visit(
types: &FxHashMap<Word, Instruction>,
defs: &FxHashMap<Word, (usize, Instruction)>,
visited: &mut FxHashSet<Word>,
value: Word,
) -> bool {
if visited.insert(value) {
let inst = match defs.get(&value) {
Some((_, inst)) => inst,
None => return false,
};
match inst.class.opcode {
Op::Select => {
constant_value(types, inst.operands[1].unwrap_id_ref()) == Some(1)
&& constant_value(types, inst.operands[2].unwrap_id_ref()) == Some(0)
}
Op::Phi => inst
.operands
.iter()
.step_by(2)
.all(|op| visit(types, defs, visited, op.unwrap_id_ref())),
_ => false,
}
} else {
true
}
}

if inst.class.opcode != Op::INotEqual
|| constant_value(types, inst.operands[1].unwrap_id_ref()) != Some(0)
{
return false;
}
let int_value = inst.operands[0].unwrap_id_ref();

visit(types, defs, &mut FxHashSet::default(), int_value)
}

fn fuse_bool(
header: &mut ModuleHeader,
defs: &FxHashMap<Word, (usize, Instruction)>,
phis_to_insert: &mut Vec<(usize, Instruction)>,
already_mapped: &mut FxHashMap<Word, Word>,
bool_ty: Word,
int_value: Word,
) -> Word {
if let Some(&result) = already_mapped.get(&int_value) {
return result;
}
let (block_of_inst, inst) = defs.get(&int_value).unwrap();
match inst.class.opcode {
Op::Select => inst.operands[0].unwrap_id_ref(),
Op::Phi => {
let result_id = id(header);
already_mapped.insert(int_value, result_id);
let new_phi_args = inst
.operands
.chunks(2)
.flat_map(|arr| {
let phi_value = &arr[0];
let block = &arr[1];
[
Operand::IdRef(fuse_bool(
header,
defs,
phis_to_insert,
already_mapped,
bool_ty,
phi_value.unwrap_id_ref(),
)),
block.clone(),
]
})
.collect::<Vec<_>>();
let inst = Instruction::new(Op::Phi, Some(bool_ty), Some(result_id), new_phi_args);
phis_to_insert.push((*block_of_inst, inst));
result_id
}
_ => bug!("can_fuse_bool should have prevented this case"),
}
}

// The compiler generates a lot of code that looks like this:
// %v_int = OpSelect %int %v %const_1 %const_0
// %v2 = OpINotEqual %bool %v_int %const_0
// (This is due to rustc/spirv not supporting bools in memory, and needing to convert to u8, but
// then things get inlined/mem2reg'd)
//
// This pass fuses together those two instructions to strip out the intermediate integer variable.
// The purpose is to make simple code that doesn't actually do memory-stuff with bools not require
// the Int8 capability (and so we can't rely on spirv-opt to do this same pass).
//
// Unfortunately, things get complicated because of phis: the majority of actually useful cases to
// do this pass need to track pseudo-bool ints through phi instructions.
//
// The logic goes like:
// 1) Figure out what we *can* fuse. This means finding OpINotEqual instructions (converting back
// from int->bool) and tracing the value back recursively through any phis, and making sure each
// one terminates in either a loop back around to something we've already seen, or an OpSelect
// (converting from bool->int).
// 2) Do the fusion. Trace back through phis, generating a second bool-typed phi alongside the
// original int-typed phi, and when hitting an OpSelect, taking the bool value directly.
// 3) DCE the dead OpSelects/int-typed OpPhis (done in a later pass). We don't nuke them here,
// since they might be used elsewhere, and don't want to accidentally leave a dangling
// reference.
pub fn bool_fusion(
header: &mut ModuleHeader,
types: &FxHashMap<Word, Instruction>,
function: &mut Function,
) {
let defs: FxHashMap<Word, (usize, Instruction)> = function
.blocks
.iter()
.enumerate()
.flat_map(|(block_id, block)| {
block
.instructions
.iter()
.filter_map(move |inst| Some((inst.result_id?, (block_id, inst.clone()))))
})
.collect();
let mut rewrite_rules = FxHashMap::default();
let mut phis_to_insert = Default::default();
let mut already_mapped = Default::default();
for block in &mut function.blocks {
for inst in &mut block.instructions {
if can_fuse_bool(types, &defs, inst) {
let rewrite_to = fuse_bool(
header,
&defs,
&mut phis_to_insert,
&mut already_mapped,
inst.result_type.unwrap(),
inst.operands[0].unwrap_id_ref(),
);
rewrite_rules.insert(inst.result_id.unwrap(), rewrite_to);
*inst = Instruction::new(Op::Nop, None, None, Vec::new());
}
}
}
for (block, phi) in phis_to_insert {
function.blocks[block].instructions.insert(0, phi);
}
super::apply_rewrite_rules(&rewrite_rules, &mut function.blocks);
}
12 changes: 12 additions & 0 deletions tests/ui/lang/core/unwrap_or.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// unwrap_or generates some memory-bools (as u8). Test to make sure they're fused away.
// OpINotEqual, as well as %bool, should not appear in the output.

// build-pass
// compile-flags: -C llvm-args=--disassemble-entry=main

use spirv_std as _;

#[spirv(fragment)]
pub fn main(out: &mut u32) {
*out = None.unwrap_or(15);
}
39 changes: 39 additions & 0 deletions tests/ui/lang/core/unwrap_or.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
%1 = OpFunction %2 None %3
%4 = OpLabel
OpLine %5 11 11
%6 = OpCompositeInsert %7 %8 %9 0
OpLine %5 11 11
%10 = OpCompositeExtract %11 %6 1
OpLine %12 767 14
%13 = OpBitcast %14 %8
OpLine %12 767 8
OpSelectionMerge %15 None
OpSwitch %13 %16 0 %17 1 %18
%16 = OpLabel
OpLine %12 767 14
OpUnreachable
%17 = OpLabel
OpLine %12 769 20
OpBranch %15
%18 = OpLabel
OpLine %12 771 4
OpBranch %15
%15 = OpLabel
%19 = OpPhi %20 %21 %17 %22 %18
%23 = OpPhi %11 %24 %17 %10 %18
OpBranch %25
%25 = OpLabel
OpLine %12 771 4
OpSelectionMerge %26 None
OpBranchConditional %19 %27 %28
%27 = OpLabel
OpLine %12 771 4
OpBranch %26
%28 = OpLabel
OpBranch %26
%26 = OpLabel
OpLine %5 11 4
OpStore %29 %23
OpLine %5 12 1
OpReturn
OpFunctionEnd