diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index 6ef0f5a70e..17c0e9e41a 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -5,12 +5,15 @@ use crate::builder_spirv::SpirvValue; use crate::spirv_type::SpirvType; use rspirv::dr::Operand; use rspirv::spirv::{Decoration, ExecutionModel, FunctionControl, StorageClass, Word}; +use rustc_codegen_ssa::traits::BaseTypeMethods; use rustc_hir as hir; -use rustc_middle::ty::layout::TyAndLayout; +use rustc_middle::ty::layout::{HasParamEnv, TyAndLayout}; use rustc_middle::ty::{Instance, Ty, TyKind}; use rustc_span::Span; -use rustc_target::abi::call::{FnAbi, PassMode}; -use rustc_target::abi::LayoutOf; +use rustc_target::abi::{ + call::{ArgAbi, ArgAttribute, ArgAttributes, FnAbi, PassMode}, + LayoutOf, Size, +}; use std::collections::HashMap; impl<'tcx> CodegenCx<'tcx> { @@ -37,9 +40,27 @@ impl<'tcx> CodegenCx<'tcx> { }; let fn_hir_id = self.tcx.hir().local_def_id_to_hir_id(local_id); let body = self.tcx.hir().body(self.tcx.hir().body_owned_by(fn_hir_id)); + const EMPTY: ArgAttribute = ArgAttribute::empty(); for (abi, arg) in fn_abi.args.iter().zip(body.params) { match abi.mode { - PassMode::Direct(_) | PassMode::Indirect { .. } => {} + PassMode::Direct(_) + | PassMode::Indirect { .. } + // plain DST/RTA/VLA + | PassMode::Pair( + ArgAttributes { + pointee_size: Size::ZERO, + .. + }, + ArgAttributes { regular: EMPTY, .. }, + ) + // DST struct with fields before the DST member + | PassMode::Pair( + ArgAttributes { .. }, + ArgAttributes { + pointee_size: Size::ZERO, + .. + }, + ) => {} _ => self.tcx.sess.span_err( arg.span, &format!("PassMode {:?} invalid for entry point parameter", abi.mode), @@ -63,7 +84,7 @@ impl<'tcx> CodegenCx<'tcx> { self.shader_entry_stub( self.tcx.def_span(instance.def_id()), entry_func, - fn_abi, + &fn_abi.args, body.params, name, execution_model, @@ -82,7 +103,7 @@ impl<'tcx> CodegenCx<'tcx> { &self, span: Span, entry_func: SpirvValue, - entry_fn_abi: &FnAbi<'tcx, Ty<'tcx>>, + arg_abis: &[ArgAbi<'tcx, Ty<'tcx>>], hir_params: &[hir::Param<'tcx>], name: String, execution_model: ExecutionModel, @@ -94,10 +115,7 @@ impl<'tcx> CodegenCx<'tcx> { } .def(span, self); let entry_func_return_type = match self.lookup_type(entry_func.ty) { - SpirvType::Function { - return_type, - arguments: _, - } => return_type, + SpirvType::Function { return_type, .. } => return_type, other => self.tcx.sess.fatal(&format!( "Invalid entry_stub type: {}", other.debug(entry_func.ty, self) @@ -105,14 +123,14 @@ impl<'tcx> CodegenCx<'tcx> { }; let mut decoration_locations = HashMap::new(); // Create OpVariables before OpFunction so they're global instead of local vars. - let declared_params = entry_fn_abi - .args + let declared_params = arg_abis .iter() .zip(hir_params) .map(|(entry_fn_arg, hir_param)| { self.declare_parameter(entry_fn_arg.layout, hir_param, &mut decoration_locations) }) .collect::>(); + let len_t = self.type_isize(); let mut emit = self.emit_global(); let fn_id = emit .begin_function(void, None, FunctionControl::NONE, fn_void_void) @@ -121,12 +139,19 @@ impl<'tcx> CodegenCx<'tcx> { // Adjust any global `OpVariable`s as needed (e.g. loading from `Input`s). let arguments: Vec<_> = declared_params .iter() - .zip(&entry_fn_abi.args) + .zip(arg_abis) .zip(hir_params) - .map(|((&(var, storage_class), entry_fn_arg), hir_param)| { - match entry_fn_arg.layout.ty.kind() { - TyKind::Ref(..) => var, - + .flat_map(|((&(var, storage_class), entry_fn_arg), hir_param)| { + let mut dst_len_arg = None; + let arg = match entry_fn_arg.layout.ty.kind() { + TyKind::Ref(_, ty, _) => { + if !ty.is_sized(self.tcx.at(span), self.param_env()) { + dst_len_arg.replace( + self.dst_length_argument(&mut emit, ty, hir_param, len_t, var), + ); + } + var + } _ => match entry_fn_arg.mode { PassMode::Indirect { .. } => var, PassMode::Direct(_) => { @@ -142,7 +167,8 @@ impl<'tcx> CodegenCx<'tcx> { } _ => unreachable!(), }, - } + }; + std::iter::once(arg).chain(dst_len_arg) }) .collect(); emit.function_call( @@ -170,6 +196,38 @@ impl<'tcx> CodegenCx<'tcx> { fn_id } + fn dst_length_argument( + &self, + emit: &mut std::cell::RefMut<'_, rspirv::dr::Builder>, + ty: Ty<'tcx>, + hir_param: &hir::Param<'tcx>, + len_t: Word, + var: Word, + ) -> Word { + match ty.kind() { + TyKind::Adt(adt_def, substs) => { + let (member_idx, field_def) = adt_def.all_fields().enumerate().last().unwrap(); + let field_ty = field_def.ty(self.tcx, substs); + if !matches!(field_ty.kind(), TyKind::Slice(..)) { + self.tcx.sess.span_fatal( + hir_param.ty_span, + "DST parameters are currently restricted to a reference to a struct whose last field is a slice.", + ) + } + emit.array_length(len_t, None, var, member_idx as u32) + .unwrap() + } + TyKind::Slice(..) | TyKind::Str => self.tcx.sess.span_fatal( + hir_param.ty_span, + "Straight slices are not yet supported, wrap the slice in a newtype.", + ), + _ => self + .tcx + .sess + .span_fatal(hir_param.ty_span, "Unsupported parameter type."), + } + } + fn declare_parameter( &self, layout: TyAndLayout<'tcx>, diff --git a/crates/rustc_codegen_spirv/src/spirv_type.rs b/crates/rustc_codegen_spirv/src/spirv_type.rs index 3cac1176d5..098147ed3c 100644 --- a/crates/rustc_codegen_spirv/src/spirv_type.rs +++ b/crates/rustc_codegen_spirv/src/spirv_type.rs @@ -188,6 +188,17 @@ impl SpirvType { } Self::RuntimeArray { element } => { let result = cx.emit_global().type_runtime_array(element); + // ArrayStride decoration wants in *bytes* + let element_size = cx + .lookup_type(element) + .sizeof(cx) + .expect("Element of sized array must be sized") + .bytes(); + cx.emit_global().decorate( + result, + Decoration::ArrayStride, + iter::once(Operand::LiteralInt32(element_size as u32)), + ); if cx.kernel_mode { cx.zombie_with_span(result, def_span, "RuntimeArray in kernel mode"); } diff --git a/crates/spirv-builder/src/test/basic.rs b/crates/spirv-builder/src/test/basic.rs index ff3863e61c..5829d279de 100644 --- a/crates/spirv-builder/src/test/basic.rs +++ b/crates/spirv-builder/src/test/basic.rs @@ -1,4 +1,4 @@ -use super::{dis_fn, dis_globals, val, val_vulkan}; +use super::{dis_entry_fn, dis_fn, dis_globals, val, val_vulkan}; use std::ffi::OsStr; struct SetEnvVar<'a> { @@ -183,20 +183,21 @@ OpEntryPoint Fragment %1 "main" OpExecutionMode %1 OriginUpperLeft OpName %2 "test_project::add_decorate" OpName %3 "test_project::main" -OpDecorate %4 DescriptorSet 0 -OpDecorate %4 Binding 0 -%5 = OpTypeVoid -%6 = OpTypeFunction %5 -%7 = OpTypeInt 32 0 -%8 = OpTypePointer Function %7 -%9 = OpConstant %7 1 -%10 = OpTypeFloat 32 -%11 = OpTypeImage %10 2D 0 0 0 1 Unknown -%12 = OpTypeSampledImage %11 -%13 = OpTypeRuntimeArray %12 -%14 = OpTypePointer UniformConstant %13 -%4 = OpVariable %14 UniformConstant -%15 = OpTypePointer UniformConstant %12"#, +OpDecorate %4 ArrayStride 4 +OpDecorate %5 DescriptorSet 0 +OpDecorate %5 Binding 0 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 0 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 1 +%11 = OpTypeFloat 32 +%12 = OpTypeImage %11 2D 0 0 0 1 Unknown +%13 = OpTypeSampledImage %12 +%4 = OpTypeRuntimeArray %13 +%14 = OpTypePointer UniformConstant %4 +%5 = OpVariable %14 UniformConstant +%15 = OpTypePointer UniformConstant %13"#, ); } @@ -479,3 +480,54 @@ fn ptr_copy_from_method() { "# ); } + +#[test] +fn index_user_dst() { + dis_entry_fn( + r#" +#[spirv(fragment)] +pub fn main( + #[spirv(uniform, descriptor_set = 0, binding = 0)] slice: &mut SliceF32, +) { + let float: f32 = slice.rta[0]; + let _ = float; +} + +pub struct SliceF32 { + rta: [f32], +} + "#, + "main", + r#"%1 = OpFunction %2 None %3 +%4 = OpLabel +%5 = OpArrayLength %6 %7 0 +%8 = OpCompositeInsert %9 %7 %10 0 +%11 = OpCompositeInsert %9 %5 %8 1 +%12 = OpAccessChain %13 %7 %14 +%15 = OpULessThan %16 %14 %5 +OpSelectionMerge %17 None +OpBranchConditional %15 %18 %19 +%18 = OpLabel +%20 = OpAccessChain %13 %7 %14 +%21 = OpInBoundsAccessChain %22 %20 %14 +%23 = OpLoad %24 %21 +OpReturn +%19 = OpLabel +OpBranch %25 +%25 = OpLabel +OpBranch %26 +%26 = OpLabel +%27 = OpPhi %16 %28 %25 %28 %29 +OpLoopMerge %30 %29 None +OpBranchConditional %27 %31 %30 +%31 = OpLabel +OpBranch %29 +%29 = OpLabel +OpBranch %26 +%30 = OpLabel +OpUnreachable +%17 = OpLabel +OpUnreachable +OpFunctionEnd"#, + ) +} diff --git a/crates/spirv-builder/src/test/mod.rs b/crates/spirv-builder/src/test/mod.rs index c4ce15236e..4d819470c5 100644 --- a/crates/spirv-builder/src/test/mod.rs +++ b/crates/spirv-builder/src/test/mod.rs @@ -159,6 +159,33 @@ fn dis_fn(src: &str, func: &str, expect: &str) { assert_str_eq(expect, &func.disassemble()) } +fn dis_entry_fn(src: &str, func: &str, expect: &str) { + let _lock = global_lock(); + let module = read_module(&build(src)).unwrap(); + let id = module + .entry_points + .iter() + .find(|inst| inst.operands.last().unwrap().unwrap_literal_string() == func) + .unwrap_or_else(|| { + panic!( + "no entry point with the name `{}` found in:\n{}\n", + func, + module.disassemble() + ) + }) + .operands[1] + .unwrap_id_ref(); + let mut func = module + .functions + .into_iter() + .find(|f| f.def_id().unwrap() == id) + .unwrap(); + // Compact to make IDs more stable + compact_ids(&mut func); + use rspirv::binary::Disassemble; + assert_str_eq(expect, &func.disassemble()) +} + fn dis_globals(src: &str, expect: &str) { let _lock = global_lock(); let module = read_module(&build(src)).unwrap(); diff --git a/tests/ui/lang/core/ptr/allocate_const_scalar.stderr b/tests/ui/lang/core/ptr/allocate_const_scalar.stderr index 1bc13a388c..15428bb61a 100644 --- a/tests/ui/lang/core/ptr/allocate_const_scalar.stderr +++ b/tests/ui/lang/core/ptr/allocate_const_scalar.stderr @@ -2,7 +2,7 @@ error: pointer has non-null integer address | = note: Stack: allocate_const_scalar::main - Unnamed function ID %4 + Unnamed function ID %5 error: invalid binary:0:0 - No OpEntryPoint instruction was found. This is only allowed if the Linkage capability is being used. |