diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index d179fba3b6..d96ed8d991 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -128,120 +128,33 @@ impl<'tcx> CodegenCx<'tcx> { id.with_type(fn_void_void) }; - let mut decoration_locations = HashMap::new(); - let interface_globals = arg_abis - .iter() - .zip(hir_params) - .map(|(entry_fn_arg, hir_param)| { - self.declare_interface_global_for_param( - entry_fn_arg.layout, - hir_param, - &mut decoration_locations, - ) - }) - .collect::>(); + let mut op_entry_point_interface_operands = vec![]; + let mut bx = Builder::new_block(self, stub_fn, ""); - // Adjust any global `OpVariable`s as needed (e.g. loading from `Input`s, - // or accessing the sole field of an "interface block" `OpTypeStruct`), - // to match the argument type we have to pass to the Rust entry `fn`. - let arguments: Vec<_> = interface_globals - .iter() - .zip(arg_abis) - .zip(hir_params) - .flat_map( - |((&(global_var, storage_class), entry_fn_arg), hir_param)| { - bx.set_span(hir_param.span); - - let var_value_spirv_type = match self.lookup_type(global_var.ty) { - SpirvType::Pointer { pointee } => pointee, - _ => unreachable!(), - }; - - let (first, second) = match entry_fn_arg.layout.ty.kind() { - TyKind::Ref(_, pointee_ty, _) => { - let arg_pointee_spirv_type = self - .layout_of(pointee_ty) - .spirv_type(hir_param.ty_span, self); - - if let SpirvType::InterfaceBlock { inner_type } = - self.lookup_type(var_value_spirv_type) - { - assert_ty_eq!(self, arg_pointee_spirv_type, inner_type); - - let inner = bx.struct_gep(global_var, 0); - - match entry_fn_arg.mode { - PassMode::Direct(_) => (inner, None), - - // Unsized pointee with length (i.e. `&[T]`). - PassMode::Pair(..) => { - // FIXME(eddyb) shouldn't this be `usize`? - let len_spirv_type = self.type_isize(); - - let len = bx - .emit() - .array_length( - len_spirv_type, - None, - global_var.def(&bx), - 0, - ) - .unwrap() - .with_type(len_spirv_type); - - (inner, Some(len)) - } - - _ => unreachable!(), - } - } else { - assert_ty_eq!(self, arg_pointee_spirv_type, var_value_spirv_type); - assert_matches!(entry_fn_arg.mode, PassMode::Direct(_)); - (global_var, None) - } - } - _ => { - assert_eq!(storage_class, StorageClass::Input); - - let arg_spirv_type = - entry_fn_arg.layout.spirv_type(hir_param.ty_span, self); - - assert_ty_eq!(self, arg_spirv_type, var_value_spirv_type); - - match entry_fn_arg.mode { - PassMode::Indirect { .. } => (global_var, None), - PassMode::Direct(_) => { - (bx.load(global_var, entry_fn_arg.layout.align.abi), None) - } - _ => unreachable!(), - } - } - }; - std::iter::once(first).chain(second) - }, + let mut call_args = vec![]; + let mut decoration_locations = HashMap::new(); + for (entry_arg_abi, hir_param) in arg_abis.iter().zip(hir_params) { + bx.set_span(hir_param.span); + self.declare_shader_interface_for_param( + entry_arg_abi, + hir_param, + &mut op_entry_point_interface_operands, + &mut bx, + &mut call_args, + &mut decoration_locations, ) - .collect(); + } bx.set_span(span); - bx.call(entry_func, &arguments, None); + bx.call(entry_func, &call_args, None); bx.ret_void(); - let interface: Vec<_> = if self.emit_global().version().unwrap() > (1, 3) { - // SPIR-V >= v1.4 includes all OpVariables in the interface. - interface_globals - .into_iter() - .map(|(var, _)| var.def_cx(self)) - .collect() - } else { - // SPIR-V <= v1.3 only includes Input and Output in the interface. - interface_globals - .into_iter() - .filter(|&(_, s)| s == StorageClass::Input || s == StorageClass::Output) - .map(|(var, _)| var.def_cx(self)) - .collect() - }; let stub_fn_id = stub_fn.def_cx(self); - self.emit_global() - .entry_point(execution_model, stub_fn_id, name, interface); + self.emit_global().entry_point( + execution_model, + stub_fn_id, + name, + op_entry_point_interface_operands, + ); stub_fn_id } @@ -342,28 +255,116 @@ impl<'tcx> CodegenCx<'tcx> { (spirv_ty, storage_class) } - fn declare_interface_global_for_param( + fn declare_shader_interface_for_param( &self, - layout: TyAndLayout<'tcx>, + entry_arg_abi: &ArgAbi<'tcx, Ty<'tcx>>, hir_param: &hir::Param<'tcx>, + op_entry_point_interface_operands: &mut Vec, + bx: &mut Builder<'_, 'tcx>, + call_args: &mut Vec, decoration_locations: &mut HashMap, - ) -> (SpirvValue, StorageClass) { + ) { let attrs = AggregatedSpirvAttributes::parse(self, self.tcx.hir().attrs(hir_param.hir_id)); - let (mut value_spirv_type, storage_class) = - self.infer_param_ty_and_storage_class(layout, hir_param, &attrs); - // Pre-allocate the module-scoped `OpVariable`'s *Result* ID. - let variable = self.emit_global().id(); + let var = self.emit_global().id(); + + let (value_spirv_type, storage_class) = + self.infer_param_ty_and_storage_class(entry_arg_abi.layout, hir_param, &attrs); + + // Certain storage classes require an `OpTypeStruct` decorated with `Block`, + // which we represent with `SpirvType::InterfaceBlock` (see its doc comment). + // This "interface block" construct is also required for "runtime arrays". + let is_unsized = self.lookup_type(value_spirv_type).sizeof(self).is_none(); + let var_ptr_spirv_type; + let (value_ptr, value_len) = match storage_class { + StorageClass::PushConstant | StorageClass::Uniform | StorageClass::StorageBuffer => { + var_ptr_spirv_type = self.type_ptr_to( + SpirvType::InterfaceBlock { + inner_type: value_spirv_type, + } + .def(hir_param.span, self), + ); + + let value_ptr = bx.struct_gep(var.with_type(var_ptr_spirv_type), 0); + + let value_len = if is_unsized { + match self.lookup_type(value_spirv_type) { + SpirvType::RuntimeArray { .. } => {} + _ => self.tcx.sess.span_err( + hir_param.ty_span, + "only plain slices are supported as unsized types", + ), + } + + // FIXME(eddyb) shouldn't this be `usize`? + let len_spirv_type = self.type_isize(); + let len = bx + .emit() + .array_length(len_spirv_type, None, var, 0) + .unwrap(); + + Some(len.with_type(len_spirv_type)) + } else { + None + }; + + (value_ptr, value_len) + } + _ => { + var_ptr_spirv_type = self.type_ptr_to(value_spirv_type); + + if is_unsized { + self.tcx.sess.span_fatal( + hir_param.ty_span, + &format!( + "unsized types are not supported for storage class {:?}", + storage_class + ), + ); + } + + (var.with_type(var_ptr_spirv_type), None) + } + }; + + // Compute call argument(s) to match what the Rust entry `fn` expects, + // starting from the `value_ptr` pointing to a `value_spirv_type` + // (e.g. `Input` doesn't use indirection, so we have to load from it). + if let TyKind::Ref(..) = entry_arg_abi.layout.ty.kind() { + call_args.push(value_ptr); + match entry_arg_abi.mode { + PassMode::Direct(_) => assert_eq!(value_len, None), + PassMode::Pair(..) => call_args.push(value_len.unwrap()), + _ => unreachable!(), + } + } else { + assert_eq!(storage_class, StorageClass::Input); + + call_args.push(match entry_arg_abi.mode { + PassMode::Indirect { .. } => value_ptr, + PassMode::Direct(_) => bx.load(value_ptr, entry_arg_abi.layout.align.abi), + _ => unreachable!(), + }); + assert_eq!(value_len, None); + } + // FIXME(eddyb) check whether the storage class is compatible with the + // specific shader stage of this entry-point, and any decorations + // (e.g. Vulkan has specific rules for builtin storage classes). + + // Emit `OpName` in the simple case of a pattern that's just a variable + // name (e.g. "foo" for `foo: Vec3`). While `OpName` is *not* suppposed + // to be semantic, OpenGL and some tooling rely on it for reflection. if let hir::PatKind::Binding(_, _, ident, _) = &hir_param.pat.kind { - self.emit_global().name(variable, ident.to_string()); + self.emit_global().name(var, ident.to_string()); } + // Emit `OpDecorate`s based on attributes. let mut decoration_supersedes_location = false; if let Some(builtin) = attrs.builtin.map(|attr| attr.value) { self.emit_global().decorate( - variable, + var, Decoration::BuiltIn, std::iter::once(Operand::BuiltIn(builtin)), ); @@ -371,7 +372,7 @@ impl<'tcx> CodegenCx<'tcx> { } if let Some(index) = attrs.descriptor_set.map(|attr| attr.value) { self.emit_global().decorate( - variable, + var, Decoration::DescriptorSet, std::iter::once(Operand::LiteralInt32(index)), ); @@ -379,7 +380,7 @@ impl<'tcx> CodegenCx<'tcx> { } if let Some(index) = attrs.binding.map(|attr| attr.value) { self.emit_global().decorate( - variable, + var, Decoration::Binding, std::iter::once(Operand::LiteralInt32(index)), ); @@ -387,11 +388,11 @@ impl<'tcx> CodegenCx<'tcx> { } if attrs.flat.is_some() { self.emit_global() - .decorate(variable, Decoration::Flat, std::iter::empty()); + .decorate(var, Decoration::Flat, std::iter::empty()); } if let Some(invariant) = attrs.invariant { self.emit_global() - .decorate(variable, Decoration::Invariant, std::iter::empty()); + .decorate(var, Decoration::Invariant, std::iter::empty()); if storage_class != StorageClass::Output { self.tcx.sess.span_err( invariant.span, @@ -400,44 +401,6 @@ impl<'tcx> CodegenCx<'tcx> { } } - // Certain storage classes require an `OpTypeStruct` decorated with `Block`, - // which we represent with `SpirvType::InterfaceBlock` (see its doc comment). - // This "interface block" construct is also required for "runtime arrays". - let is_unsized = self.lookup_type(value_spirv_type).sizeof(self).is_none(); - match storage_class { - StorageClass::PushConstant | StorageClass::Uniform | StorageClass::StorageBuffer => { - if is_unsized { - match self.lookup_type(value_spirv_type) { - SpirvType::RuntimeArray { .. } => {} - _ => self.tcx.sess.span_err( - hir_param.ty_span, - "only plain slices are supported as unsized types", - ), - } - } - - value_spirv_type = SpirvType::InterfaceBlock { - inner_type: value_spirv_type, - } - .def(hir_param.span, self); - } - _ => { - if is_unsized { - self.tcx.sess.span_fatal( - hir_param.ty_span, - &format!( - "unsized types are not supported for storage class {:?}", - storage_class - ), - ); - } - } - } - - // FIXME(eddyb) check whether the storage class is compatible with the - // specific shader stage of this entry-point, and any decorations - // (e.g. Vulkan has specific rules for builtin storage classes). - // Assign locations from left to right, incrementing each storage class // individually. // TODO: Is this right for UniformConstant? Do they share locations with @@ -452,22 +415,28 @@ impl<'tcx> CodegenCx<'tcx> { .entry(storage_class) .or_insert_with(|| 0); self.emit_global().decorate( - variable, + var, Decoration::Location, std::iter::once(Operand::LiteralInt32(*location)), ); *location += 1; } - // Emit the `OpVariable` with its *Result* ID set to `variable`. - let var_spirv_type = SpirvType::Pointer { - pointee: value_spirv_type, - } - .def(hir_param.span, self); + // Emit the `OpVariable` with its *Result* ID set to `var`. self.emit_global() - .variable(var_spirv_type, Some(variable), storage_class, None); + .variable(var_ptr_spirv_type, Some(var), storage_class, None); - (variable.with_type(var_spirv_type), storage_class) + // Record this `OpVariable` as needing to be added (if applicable), + // to the *Interface* operands of the `OpEntryPoint` instruction. + if self.emit_global().version().unwrap() > (1, 3) { + // SPIR-V >= v1.4 includes all OpVariables in the interface. + op_entry_point_interface_operands.push(var); + } else { + // SPIR-V <= v1.3 only includes Input and Output in the interface. + if storage_class == StorageClass::Input || storage_class == StorageClass::Output { + op_entry_point_interface_operands.push(var); + } + } } // Kernel mode takes its interface as function parameters(??) diff --git a/crates/rustc_codegen_spirv/src/spirv_type.rs b/crates/rustc_codegen_spirv/src/spirv_type.rs index 5c9aae1151..f95e511133 100644 --- a/crates/rustc_codegen_spirv/src/spirv_type.rs +++ b/crates/rustc_codegen_spirv/src/spirv_type.rs @@ -518,7 +518,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> { .finish(), SpirvType::InterfaceBlock { inner_type } => f - .debug_struct("SampledImage") + .debug_struct("InterfaceBlock") .field("id", &self.id) .field("inner_type", &self.cx.debug_type(inner_type)) .finish(),