diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 4121620b61..234d261453 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -17,7 +17,7 @@ use rustc_middle::ty::Ty; use rustc_span::Span; use rustc_target::abi::{Abi, Align, Scalar, Size}; use std::convert::TryInto; -use std::iter::empty; +use std::iter::{self, empty}; use std::ops::Range; macro_rules! simple_op { @@ -38,8 +38,8 @@ macro_rules! simple_op { let size = Size::from_bits(bits); let as_u128 = |const_val| { let x = match const_val { - SpirvConst::U32(_, x) => x as u128, - SpirvConst::U64(_, x) => x as u128, + SpirvConst::U32(x) => x as u128, + SpirvConst::U64(x) => x as u128, _ => return None, }; Some(if signed { @@ -125,7 +125,7 @@ fn memset_dynamic_scalar( .composite_construct( composite_type, None, - std::iter::repeat(fill_var).take(byte_width), + iter::repeat(fill_var).take(byte_width), ) .unwrap(); let result_type = if is_float { @@ -214,15 +214,18 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte); self.constant_composite( ty.clone().def(self.span(), self), - vec![elem_pat; count as usize], + iter::repeat(elem_pat).take(count as usize), ) .def(self) } SpirvType::Array { element, count } => { let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte); let count = self.builder.lookup_const_u64(count).unwrap() as usize; - self.constant_composite(ty.clone().def(self.span(), self), vec![elem_pat; count]) - .def(self) + self.constant_composite( + ty.clone().def(self.span(), self), + iter::repeat(elem_pat).take(count), + ) + .def(self) } SpirvType::RuntimeArray { .. } => { self.fatal("memset on runtime arrays not implemented yet") @@ -267,7 +270,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { .composite_construct( ty.clone().def(self.span(), self), None, - std::iter::repeat(elem_pat).take(count), + iter::repeat(elem_pat).take(count), ) .unwrap() } @@ -277,7 +280,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { .composite_construct( ty.clone().def(self.span(), self), None, - std::iter::repeat(elem_pat).take(count as usize), + iter::repeat(elem_pat).take(count as usize), ) .unwrap() } @@ -835,8 +838,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } fn load(&mut self, ptr: Self::Value, _align: Align) -> Self::Value { - // See comment on `SpirvValueKind::ConstantPointer` - if let Some(value) = ptr.const_ptr_val(self) { + if let Some(value) = ptr.const_fold_load(self) { return value; } let ty = match self.lookup_type(ptr.ty) { @@ -1662,9 +1664,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { }; let src_element_size = src_pointee.and_then(|p| self.lookup_type(p).sizeof(self)); if src_element_size.is_some() && src_element_size == const_size.map(Size::from_bytes) { - // See comment on `SpirvValueKind::ConstantPointer` - - if let Some(const_value) = src.const_ptr_val(self) { + if let Some(const_value) = src.const_fold_load(self) { self.store(const_value, dst, Align::from_bytes(0).unwrap()); } else { self.emit() @@ -1791,13 +1791,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } .def(self.span(), self); if self.builder.lookup_const(elt).is_some() { - self.constant_composite(result_type, vec![elt.def(self); num_elts]) + self.constant_composite(result_type, iter::repeat(elt.def(self)).take(num_elts)) } else { self.emit() .composite_construct( result_type, None, - std::iter::repeat(elt.def(self)).take(num_elts), + iter::repeat(elt.def(self)).take(num_elts), ) .unwrap() .with_type(result_type) diff --git a/crates/rustc_codegen_spirv/src/builder_spirv.rs b/crates/rustc_codegen_spirv/src/builder_spirv.rs index fc287e3fc0..bdc2a10e5d 100644 --- a/crates/rustc_codegen_spirv/src/builder_spirv.rs +++ b/crates/rustc_codegen_spirv/src/builder_spirv.rs @@ -1,43 +1,32 @@ use crate::builder; use crate::codegen_cx::CodegenCx; use crate::spirv_type::SpirvType; -use bimap::BiHashMap; use rspirv::dr::{Block, Builder, Module, Operand}; -use rspirv::spirv::{AddressingModel, Capability, MemoryModel, Op, Word}; +use rspirv::spirv::{AddressingModel, Capability, MemoryModel, Op, StorageClass, Word}; use rspirv::{binary::Assemble, binary::Disassemble}; +use rustc_data_structures::fx::FxHashMap; use rustc_middle::bug; use rustc_span::{Span, DUMMY_SP}; use std::cell::{RefCell, RefMut}; +use std::rc::Rc; use std::{fs::File, io::Write, path::Path}; #[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)] pub enum SpirvValueKind { Def(Word), + /// The ID of a global instruction matching a `SpirvConst`, but which cannot + /// pass validation. Used to error (or attach zombie spans), at the usesites + /// of such constants, instead of where they're generated (and cached). + IllegalConst(Word), + // FIXME(eddyb) this shouldn't be needed, but `rustc_codegen_ssa` still relies // on converting `Function`s to `Value`s even for direct calls, the `Builder` // should just have direct and indirect `call` variants (or a `Callee` enum). - // FIXME(eddyb) document? not sure what to do with the `ConstantPointer` comment. FnAddr { function: Word, }, - /// There are a fair number of places where `rustc_codegen_ssa` creates a pointer to something - /// that cannot be pointed to in SPIR-V. For example, constant values are frequently emitted as - /// a pointer to constant memory, and then dereferenced where they're used. Functions are the - /// same way, when compiling a call, the function's pointer is loaded, then dereferenced, then - /// called. Directly translating these constructs is impossible, because SPIR-V doesn't allow - /// pointers to constants, or function pointers. So, instead, we create this ConstantPointer - /// "meta-value": directly using it is an error, however, if it is attempted to be - /// dereferenced, the "load" is instead a no-op that returns the underlying value directly. - ConstantPointer { - initializer: Word, - - /// The global (module-scoped) `OpVariable` (with `initializer` set as - /// its initializer) to attach zombies to. - global_var: Word, - }, - /// Deferred pointer cast, for the `Logical` addressing model (which doesn't /// really support raw pointers in the way Rust expects to be able to use). /// @@ -64,22 +53,29 @@ pub struct SpirvValue { } impl SpirvValue { - pub fn const_ptr_val(self, cx: &CodegenCx<'_>) -> Option { + pub fn const_fold_load(self, cx: &CodegenCx<'_>) -> Option { match self.kind { - SpirvValueKind::ConstantPointer { - initializer, - global_var: _, - } => { - let ty = match cx.lookup_type(self.ty) { - SpirvType::Pointer { pointee } => pointee, - ty => bug!("load called on variable that wasn't a pointer: {:?}", ty), - }; - Some(initializer.with_type(ty)) + SpirvValueKind::Def(id) | SpirvValueKind::IllegalConst(id) => { + let entry = cx.builder.id_to_const.borrow().get(&id)?.clone(); + match entry.val { + SpirvConst::PtrTo { pointee } => { + let ty = match cx.lookup_type(self.ty) { + SpirvType::Pointer { pointee } => pointee, + ty => bug!("load called on value that wasn't a pointer: {:?}", ty), + }; + // FIXME(eddyb) deduplicate this `if`-`else` and its other copies. + let kind = if entry.legal.is_ok() { + SpirvValueKind::Def(pointee) + } else { + SpirvValueKind::IllegalConst(pointee) + }; + Some(SpirvValue { kind, ty }) + } + _ => None, + } } - SpirvValueKind::FnAddr { .. } - | SpirvValueKind::Def(_) - | SpirvValueKind::LogicalPtrCast { .. } => None, + _ => None, } } @@ -98,13 +94,50 @@ impl SpirvValue { pub fn def_with_span(self, cx: &CodegenCx<'_>, span: Span) -> Word { match self.kind { - SpirvValueKind::Def(word) => word, + SpirvValueKind::Def(id) => id, + + SpirvValueKind::IllegalConst(id) => { + let entry = &cx.builder.id_to_const.borrow()[&id]; + let msg = match entry.legal.unwrap_err() { + IllegalConst::Shallow(cause) => { + if let ( + LeafIllegalConst::CompositeContainsPtrTo, + SpirvConst::Composite(_fields), + ) = (cause, &entry.val) + { + // FIXME(eddyb) materialize this at runtime, using + // `OpCompositeConstruct` (transitively, i.e. after + // putting every field through `SpirvValue::def`), + // if we have a `Builder` to do that in. + // FIXME(eddyb) this isn't possible right now, as + // the builder would be dynamically "locked" anyway + // (i.e. attempting to do `bx.emit()` would panic). + } + + cause.message() + } + + IllegalConst::Indirect(cause) => cause.message(), + }; + + // HACK(eddyb) we don't know whether this constant originated + // in a system crate, so it's better to always zombie. + cx.zombie_even_in_user_code(id, span, msg); + + id + } + SpirvValueKind::FnAddr { .. } => { if cx.is_system_crate() { - *cx.zombie_undefs_for_system_fn_addrs + cx.builder + .const_to_id .borrow() - .get(&self.ty) + .get(&WithType { + ty: self.ty, + val: SpirvConst::ZombieUndefForFnAddr, + }) .expect("FnAddr didn't go through proper undef registration") + .val } else { cx.tcx .sess @@ -115,21 +148,6 @@ impl SpirvValue { } } - SpirvValueKind::ConstantPointer { - initializer: _, - global_var, - } => { - // HACK(eddyb) we don't know whether this constant originated - // in a system crate, so it's better to always zombie. - cx.zombie_even_in_user_code( - global_var, - span, - "Cannot use this pointer directly, it must be dereferenced first", - ); - - global_var - } - SpirvValueKind::LogicalPtrCast { original_ptr: _, original_pointee_ty, @@ -171,16 +189,76 @@ impl SpirvValueExt for Word { #[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)] pub enum SpirvConst { - U32(Word, u32), - U64(Word, u64), + U32(u32), + U64(u64), /// f32 isn't hash, so store bits - F32(Word, u32), + F32(u32), /// f64 isn't hash, so store bits - F64(Word, u64), - Bool(Word, bool), - Composite(Word, Vec), - Null(Word), - Undef(Word), + F64(u64), + Bool(bool), + + Null, + Undef, + + /// Like `Undef`, but cached separately to avoid `FnAddr` zombies accidentally + /// applying to non-zombie `Undef`s of the same types. + // FIXME(eddyb) include the function ID so that multiple `fn` pointers to + // different functions, but of the same type, don't overlap their zombies. + ZombieUndefForFnAddr, + + Composite(Rc<[Word]>), + + /// Pointer to constant data, i.e. `&pointee`, represented as an `OpVariable` + /// in the `Private` storage class, and with `pointee` as its initializer. + PtrTo { + pointee: Word, + }, +} + +#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)] +struct WithType { + ty: Word, + val: V, +} + +/// Primary causes for a `SpirvConst` to be deemed illegal. +#[derive(Copy, Clone, Debug)] +enum LeafIllegalConst { + /// `SpirvConst::Composite` containing a `SpirvConst::PtrTo` as a field. + /// This is illegal because `OpConstantComposite` must have other constants + /// as its operands, and `OpVariable`s are never considered constant. + // FIXME(eddyb) figure out if this is an accidental omission in SPIR-V. + CompositeContainsPtrTo, +} + +impl LeafIllegalConst { + fn message(&self) -> &'static str { + match *self { + Self::CompositeContainsPtrTo => { + "constant arrays/structs cannot contain pointers to other constants" + } + } + } +} + +#[derive(Copy, Clone, Debug)] +enum IllegalConst { + /// This `SpirvConst` is (or contains) a "leaf" illegal constant. As there + /// is no indirection, some of these could still be materialized at runtime, + /// using e.g. `OpCompositeConstruct` instead of `OpConstantComposite`. + Shallow(LeafIllegalConst), + + /// This `SpirvConst` is (or contains/points to) a `PtrTo` which points to + /// a "leaf" illegal constant. As the data would have to live for `'static`, + /// there is no way to materialize it as a pointer in SPIR-V. However, it + /// could still be legalized during codegen by e.g. folding loads from it. + Indirect(LeafIllegalConst), +} + +#[derive(Copy, Clone, Debug)] +struct WithConstLegality { + val: V, + legal: Result<(), IllegalConst>, } /// Cursor system: @@ -214,7 +292,13 @@ pub struct BuilderCursor { pub struct BuilderSpirv { builder: RefCell, - constants: RefCell>, + + // Bidirectional maps between `SpirvConst` and the ID of the defined global + // (e.g. `OpConstant...`) instruction. + // NOTE(eddyb) both maps have `WithConstLegality` around their keys, which + // allows getting that legality information without additional lookups. + const_to_id: RefCell, WithConstLegality>>, + id_to_const: RefCell>>, } impl BuilderSpirv { @@ -257,7 +341,8 @@ impl BuilderSpirv { } Self { builder: RefCell::new(builder), - constants: Default::default(), + const_to_id: Default::default(), + id_to_const: Default::default(), } } @@ -329,44 +414,143 @@ impl BuilderSpirv { bug!("Function not found: {}", id); } - pub fn def_constant(&self, val: SpirvConst) -> SpirvValue { + pub fn def_constant(&self, ty: Word, val: SpirvConst) -> SpirvValue { + let val_with_type = WithType { ty, val }; let mut builder = self.builder(BuilderCursor::default()); - if let Some(value) = self.constants.borrow_mut().get_by_left(&val) { - return *value; + if let Some(entry) = self.const_to_id.borrow().get(&val_with_type) { + // FIXME(eddyb) deduplicate this `if`-`else` and its other copies. + let kind = if entry.legal.is_ok() { + SpirvValueKind::Def(entry.val) + } else { + SpirvValueKind::IllegalConst(entry.val) + }; + return SpirvValue { kind, ty }; } + let val = val_with_type.val; let id = match val { - SpirvConst::U32(ty, v) => builder.constant_u32(ty, v).with_type(ty), - SpirvConst::U64(ty, v) => builder.constant_u64(ty, v).with_type(ty), - SpirvConst::F32(ty, v) => builder.constant_f32(ty, f32::from_bits(v)).with_type(ty), - SpirvConst::F64(ty, v) => builder.constant_f64(ty, f64::from_bits(v)).with_type(ty), - SpirvConst::Bool(ty, v) => { + SpirvConst::U32(v) => builder.constant_u32(ty, v), + SpirvConst::U64(v) => builder.constant_u64(ty, v), + SpirvConst::F32(v) => builder.constant_f32(ty, f32::from_bits(v)), + SpirvConst::F64(v) => builder.constant_f64(ty, f64::from_bits(v)), + SpirvConst::Bool(v) => { if v { - builder.constant_true(ty).with_type(ty) + builder.constant_true(ty) } else { - builder.constant_false(ty).with_type(ty) + builder.constant_false(ty) } } - SpirvConst::Composite(ty, ref v) => builder - .constant_composite(ty, v.iter().copied()) - .with_type(ty), - SpirvConst::Null(ty) => builder.constant_null(ty).with_type(ty), - SpirvConst::Undef(ty) => builder.undef(ty, None).with_type(ty), + + SpirvConst::Null => builder.constant_null(ty), + SpirvConst::Undef | SpirvConst::ZombieUndefForFnAddr => builder.undef(ty, None), + + SpirvConst::Composite(ref v) => builder.constant_composite(ty, v.iter().copied()), + + SpirvConst::PtrTo { pointee } => { + builder.variable(ty, None, StorageClass::Private, Some(pointee)) + } }; - self.constants - .borrow_mut() - .insert_no_overwrite(val, id) - .unwrap(); - id + #[allow(clippy::match_same_arms)] + let legal = match val { + SpirvConst::U32(_) + | SpirvConst::U64(_) + | SpirvConst::F32(_) + | SpirvConst::F64(_) + | SpirvConst::Bool(_) => Ok(()), + + SpirvConst::Null => { + // FIXME(eddyb) check that the type supports `OpConstantNull`. + Ok(()) + } + SpirvConst::Undef => { + // FIXME(eddyb) check that the type supports `OpUndef`. + Ok(()) + } + + SpirvConst::ZombieUndefForFnAddr => { + // This can be considered legal as it's already marked as zombie. + // FIXME(eddyb) is it possible for the original zombie to lack a + // span, and should we go through `IllegalConst` in order to be + // able to attach a proper usesite span? + Ok(()) + } + + SpirvConst::Composite(ref v) => v.iter().fold(Ok(()), |composite_legal, field| { + let field_entry = &self.id_to_const.borrow()[field]; + let field_legal_in_composite = field_entry.legal.and_then(|()| { + // `field` is itself some legal `SpirvConst`, but can we have + // it as part of an `OpConstantComposite`? + match field_entry.val { + SpirvConst::PtrTo { .. } => Err(IllegalConst::Shallow( + LeafIllegalConst::CompositeContainsPtrTo, + )), + _ => Ok(()), + } + }); + + match (composite_legal, field_legal_in_composite) { + (Ok(()), Ok(())) => Ok(()), + (Err(illegal), Ok(())) | (Ok(()), Err(illegal)) => Err(illegal), + + // Combining two causes of an illegal `SpirvConst` has to + // take into account which is "worse", i.e. which imposes + // more restrictions on how the resulting value can be used. + // `Indirect` is worse than `Shallow` because it cannot be + // materialized at runtime in the same way `Shallow` can be. + (Err(illegal @ IllegalConst::Indirect(_)), Err(_)) + | (Err(_), Err(illegal @ IllegalConst::Indirect(_))) + | (Err(illegal @ IllegalConst::Shallow(_)), Err(IllegalConst::Shallow(_))) => { + Err(illegal) + } + } + }), + + SpirvConst::PtrTo { pointee } => match self.id_to_const.borrow()[&pointee].legal { + Ok(()) => Ok(()), + + // `Shallow` becomes `Indirect` when placed behind a pointer. + Err(IllegalConst::Shallow(cause)) | Err(IllegalConst::Indirect(cause)) => { + Err(IllegalConst::Indirect(cause)) + } + }, + }; + assert_matches!( + self.const_to_id.borrow_mut().insert( + WithType { + ty, + val: val.clone() + }, + WithConstLegality { val: id, legal } + ), + None + ); + assert_matches!( + self.id_to_const + .borrow_mut() + .insert(id, WithConstLegality { val, legal }), + None + ); + // FIXME(eddyb) deduplicate this `if`-`else` and its other copies. + let kind = if legal.is_ok() { + SpirvValueKind::Def(id) + } else { + SpirvValueKind::IllegalConst(id) + }; + SpirvValue { kind, ty } } pub fn lookup_const(&self, def: SpirvValue) -> Option { - self.constants.borrow().get_by_right(&def).cloned() + match def.kind { + SpirvValueKind::Def(id) | SpirvValueKind::IllegalConst(id) => { + Some(self.id_to_const.borrow().get(&id)?.val.clone()) + } + _ => None, + } } pub fn lookup_const_u64(&self, def: SpirvValue) -> Option { match self.lookup_const(def)? { - SpirvConst::U32(_, v) => Some(v as u64), - SpirvConst::U64(_, v) => Some(v), + SpirvConst::U32(v) => Some(v as u64), + SpirvConst::U64(v) => Some(v), _ => None, } } diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs index 1e17af1338..64e5cdb7b6 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs @@ -4,7 +4,7 @@ use crate::builder_spirv::{SpirvConst, SpirvValue, SpirvValueExt}; use crate::spirv_type::SpirvType; use rspirv::spirv::Word; use rustc_codegen_ssa::mir::place::PlaceRef; -use rustc_codegen_ssa::traits::{ConstMethods, MiscMethods, StaticMethods}; +use rustc_codegen_ssa::traits::{BaseTypeMethods, ConstMethods, MiscMethods, StaticMethods}; use rustc_middle::bug; use rustc_middle::mir::interpret::{AllocId, Allocation, GlobalAlloc, Pointer, ScalarMaybeUninit}; use rustc_middle::ty::layout::TyAndLayout; @@ -16,27 +16,27 @@ use rustc_target::abi::{self, AddressSpace, HasDataLayout, Integer, LayoutOf, Pr impl<'tcx> CodegenCx<'tcx> { pub fn constant_u8(&self, span: Span, val: u8) -> SpirvValue { let ty = SpirvType::Integer(8, false).def(span, self); - self.builder.def_constant(SpirvConst::U32(ty, val as u32)) + self.builder.def_constant(ty, SpirvConst::U32(val as u32)) } pub fn constant_u16(&self, span: Span, val: u16) -> SpirvValue { let ty = SpirvType::Integer(16, false).def(span, self); - self.builder.def_constant(SpirvConst::U32(ty, val as u32)) + self.builder.def_constant(ty, SpirvConst::U32(val as u32)) } pub fn constant_i32(&self, span: Span, val: i32) -> SpirvValue { let ty = SpirvType::Integer(32, !self.kernel_mode).def(span, self); - self.builder.def_constant(SpirvConst::U32(ty, val as u32)) + self.builder.def_constant(ty, SpirvConst::U32(val as u32)) } pub fn constant_u32(&self, span: Span, val: u32) -> SpirvValue { let ty = SpirvType::Integer(32, false).def(span, self); - self.builder.def_constant(SpirvConst::U32(ty, val)) + self.builder.def_constant(ty, SpirvConst::U32(val)) } pub fn constant_u64(&self, span: Span, val: u64) -> SpirvValue { let ty = SpirvType::Integer(64, false).def(span, self); - self.builder.def_constant(SpirvConst::U64(ty, val)) + self.builder.def_constant(ty, SpirvConst::U64(val)) } pub fn constant_int(&self, ty: Word, val: u64) -> SpirvValue { @@ -44,18 +44,18 @@ impl<'tcx> CodegenCx<'tcx> { SpirvType::Integer(bits @ 8..=32, signed) => { let size = Size::from_bits(bits); let val = val as u128; - self.builder.def_constant(SpirvConst::U32( + self.builder.def_constant( ty, - if signed { + SpirvConst::U32(if signed { size.sign_extend(val) } else { size.truncate(val) - } as u32, - )) + } as u32), + ) } - SpirvType::Integer(64, _) => self.builder.def_constant(SpirvConst::U64(ty, val)), + SpirvType::Integer(64, _) => self.builder.def_constant(ty, SpirvConst::U64(val)), SpirvType::Bool => match val { - 0 | 1 => self.builder.def_constant(SpirvConst::Bool(ty, val != 0)), + 0 | 1 => self.builder.def_constant(ty, SpirvConst::Bool(val != 0)), _ => self .tcx .sess @@ -76,23 +76,23 @@ impl<'tcx> CodegenCx<'tcx> { pub fn constant_f32(&self, span: Span, val: f32) -> SpirvValue { let ty = SpirvType::Float(32).def(span, self); self.builder - .def_constant(SpirvConst::F32(ty, val.to_bits())) + .def_constant(ty, SpirvConst::F32(val.to_bits())) } pub fn constant_f64(&self, span: Span, val: f64) -> SpirvValue { let ty = SpirvType::Float(64).def(span, self); self.builder - .def_constant(SpirvConst::F64(ty, val.to_bits())) + .def_constant(ty, SpirvConst::F64(val.to_bits())) } pub fn constant_float(&self, ty: Word, val: f64) -> SpirvValue { match self.lookup_type(ty) { SpirvType::Float(32) => self .builder - .def_constant(SpirvConst::F32(ty, (val as f32).to_bits())), + .def_constant(ty, SpirvConst::F32((val as f32).to_bits())), SpirvType::Float(64) => self .builder - .def_constant(SpirvConst::F64(ty, val.to_bits())), + .def_constant(ty, SpirvConst::F64(val.to_bits())), other => self.tcx.sess.fatal(&format!( "constant_float invalid on type {}", other.debug(ty, self) @@ -102,19 +102,20 @@ impl<'tcx> CodegenCx<'tcx> { pub fn constant_bool(&self, span: Span, val: bool) -> SpirvValue { let ty = SpirvType::Bool.def(span, self); - self.builder.def_constant(SpirvConst::Bool(ty, val)) + self.builder.def_constant(ty, SpirvConst::Bool(val)) } - pub fn constant_composite(&self, ty: Word, val: Vec) -> SpirvValue { - self.builder.def_constant(SpirvConst::Composite(ty, val)) + pub fn constant_composite(&self, ty: Word, fields: impl Iterator) -> SpirvValue { + self.builder + .def_constant(ty, SpirvConst::Composite(fields.collect())) } pub fn constant_null(&self, ty: Word) -> SpirvValue { - self.builder.def_constant(SpirvConst::Null(ty)) + self.builder.def_constant(ty, SpirvConst::Null) } pub fn undef(&self, ty: Word) -> SpirvValue { - self.builder.def_constant(SpirvConst::Undef(ty)) + self.builder.def_constant(ty, SpirvConst::Undef) } } @@ -165,7 +166,12 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> { .spirv_type(DUMMY_SP, self); // FIXME(eddyb) include the actual byte data. ( - self.make_constant_pointer(DUMMY_SP, self.undef(str_ty)), + self.builder.def_constant( + self.type_ptr_to(str_ty), + SpirvConst::PtrTo { + pointee: self.undef(str_ty).def_cx(self), + }, + ), self.const_usize(len as u64), ) } @@ -182,7 +188,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> { field_names: None, } .def(DUMMY_SP, self); - self.constant_composite(struct_ty, elts.iter().map(|f| f.def_cx(self)).collect()) + self.constant_composite(struct_ty, elts.iter().map(|f| f.def_cx(self))) } fn const_to_opt_uint(&self, v: Self::Value) -> Option { @@ -448,7 +454,7 @@ impl<'tcx> CodegenCx<'tcx> { "create_const_alloc must consume all bytes of an Allocation after an unsized struct" ); } - self.constant_composite(ty, values) + self.constant_composite(ty, values.into_iter()) } SpirvType::Opaque { name } => self.tcx.sess.fatal(&format!( "Cannot create const alloc of type opaque: {}", @@ -456,12 +462,10 @@ impl<'tcx> CodegenCx<'tcx> { )), SpirvType::Array { element, count } => { let count = self.builder.lookup_const_u64(count).unwrap() as usize; - let values = (0..count) - .map(|_| { - self.create_const_alloc2(alloc, offset, element) - .def_cx(self) - }) - .collect::>(); + let values = (0..count).map(|_| { + self.create_const_alloc2(alloc, offset, element) + .def_cx(self) + }); self.constant_composite(ty, values) } SpirvType::Vector { element, count } => { @@ -469,16 +473,15 @@ impl<'tcx> CodegenCx<'tcx> { .sizeof(self) .expect("create_const_alloc: Vectors must be sized"); let final_offset = *offset + total_size; - let values = (0..count) - .map(|_| { - self.create_const_alloc2(alloc, offset, element) - .def_cx(self) - }) - .collect::>(); + let values = (0..count).map(|_| { + self.create_const_alloc2(alloc, offset, element) + .def_cx(self) + }); + let result = self.constant_composite(ty, values); assert!(*offset <= final_offset); // Vectors sometimes have padding at the end (e.g. vec3), skip over it. *offset = final_offset; - self.constant_composite(ty, values) + result } SpirvType::RuntimeArray { element } => { let mut values = Vec::new(); @@ -488,7 +491,7 @@ impl<'tcx> CodegenCx<'tcx> { .def_cx(self), ); } - let result = self.constant_composite(ty, values); + let result = self.constant_composite(ty, values.into_iter()); // TODO: Figure out how to do this. Compiling the below crashes both clspv *and* llvm-spirv: /* __constant struct A { diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs index c2f8c440bd..0644336c8a 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs @@ -6,14 +6,14 @@ use crate::decorations::UnrollLoopsDecoration; use crate::spirv_type::SpirvType; use rspirv::spirv::{FunctionControl, LinkageType, StorageClass, Word}; use rustc_attr::InlineAttr; -use rustc_codegen_ssa::traits::{PreDefineMethods, StaticMethods}; +use rustc_codegen_ssa::traits::{BaseTypeMethods, PreDefineMethods, StaticMethods}; use rustc_middle::bug; use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs}; use rustc_middle::mir::mono::{Linkage, MonoItem, Visibility}; use rustc_middle::ty::layout::FnAbiExt; use rustc_middle::ty::{self, Instance, ParamEnv, TypeFoldable}; use rustc_span::def_id::DefId; -use rustc_span::{Span, DUMMY_SP}; +use rustc_span::Span; use rustc_target::abi::call::FnAbi; use rustc_target::abi::{Align, LayoutOf}; @@ -242,7 +242,12 @@ impl<'tcx> PreDefineMethods<'tcx> for CodegenCx<'tcx> { impl<'tcx> StaticMethods for CodegenCx<'tcx> { fn static_addr_of(&self, cv: Self::Value, _align: Align, _kind: Option<&str>) -> Self::Value { - self.make_constant_pointer(DUMMY_SP, cv) + self.builder.def_constant( + self.type_ptr_to(cv.ty), + SpirvConst::PtrTo { + pointee: cv.def_cx(self), + }, + ) } fn codegen_static(&self, def_id: DefId, _is_mutable: bool) { @@ -264,9 +269,8 @@ impl<'tcx> StaticMethods for CodegenCx<'tcx> { let mut v = self.create_const_alloc(alloc, value_ty); if self.lookup_type(v.ty) == SpirvType::Bool { - let val = self.builder.lookup_const(v).unwrap(); - let val_int = match val { - SpirvConst::Bool(_, val) => val as u8, + let val_int = match self.builder.lookup_const(v).unwrap() { + SpirvConst::Bool(val) => val as u8, _ => bug!(), }; v = self.constant_u8(span, val_int); diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs index 5223677601..5492366d50 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs @@ -4,14 +4,14 @@ mod entry; mod type_; use crate::builder::{ExtInst, InstructionTable}; -use crate::builder_spirv::{BuilderCursor, BuilderSpirv, SpirvValue, SpirvValueKind}; +use crate::builder_spirv::{BuilderCursor, BuilderSpirv, SpirvConst, SpirvValue, SpirvValueKind}; use crate::decorations::{ CustomDecoration, SerializedSpan, UnrollLoopsDecoration, ZombieDecoration, }; use crate::spirv_type::{SpirvType, SpirvTypePrinter, TypeCache}; use crate::symbols::Symbols; use rspirv::dr::{Module, Operand}; -use rspirv::spirv::{AddressingModel, Decoration, LinkageType, MemoryModel, StorageClass, Word}; +use rspirv::spirv::{AddressingModel, Decoration, LinkageType, MemoryModel, Word}; use rustc_codegen_ssa::mir::debuginfo::{FunctionDebugContext, VariableKind}; use rustc_codegen_ssa::traits::{ AsmMethods, BackendTypes, CoverageInfoMethods, DebugInfoMethods, MiscMethods, @@ -59,7 +59,6 @@ pub struct CodegenCx<'tcx> { /// Cache of all the builtin symbols we need pub sym: Rc, pub instruction_table: InstructionTable, - pub zombie_undefs_for_system_fn_addrs: RefCell>, pub libm_intrinsics: RefCell>, /// Simple `panic!("...")` and builtin panics (from MIR `Assert`s) call `#[lang = "panic"]`. @@ -120,7 +119,6 @@ impl<'tcx> CodegenCx<'tcx> { kernel_mode, sym, instruction_table: InstructionTable::new(), - zombie_undefs_for_system_fn_addrs: Default::default(), libm_intrinsics: Default::default(), panic_fn_id: Default::default(), panic_bounds_check_fn_id: Default::default(), @@ -230,36 +228,6 @@ impl<'tcx> CodegenCx<'tcx> { once(Operand::LiteralString(name)).chain(once(Operand::LinkageType(linkage))), ) } - - /// See note on `SpirvValueKind::ConstantPointer` - pub fn make_constant_pointer(&self, span: Span, value: SpirvValue) -> SpirvValue { - let ty = SpirvType::Pointer { pointee: value.ty }.def(span, self); - let initializer = value.def_cx(self); - - // Create these up front instead of on demand in SpirvValue::def because - // SpirvValue::def can't use cx.emit() - // FIXME(eddyb) figure out what the correct storage class is. - let global_var = - self.emit_global() - .variable(ty, None, StorageClass::Private, Some(initializer)); - - // In all likelihood, this zombie message will get overwritten in SpirvValue::def_with_span - // to the use site of this constant. However, if this constant happens to never get used, we - // still want to zobmie it, so zombie here. - self.zombie_even_in_user_code( - global_var, - span, - "Cannot use this pointer directly, it must be dereferenced first", - ); - - SpirvValue { - kind: SpirvValueKind::ConstantPointer { - initializer, - global_var, - }, - ty, - } - } } pub struct CodegenArgs { @@ -377,13 +345,8 @@ impl<'tcx> MiscMethods<'tcx> for CodegenCx<'tcx> { if self.is_system_crate() { // Create these undefs up front instead of on demand in SpirvValue::def because // SpirvValue::def can't use cx.emit() - self.zombie_undefs_for_system_fn_addrs - .borrow_mut() - .entry(ty) - .or_insert_with(|| { - // We want a unique ID for these undefs, so don't use the caching system. - self.emit_global().undef(ty, None) - }); + self.builder + .def_constant(ty, SpirvConst::ZombieUndefForFnAddr); } SpirvValue { diff --git a/tests/ui/lang/consts/nested-ref-in-composite.rs b/tests/ui/lang/consts/nested-ref-in-composite.rs new file mode 100644 index 0000000000..5c99d9243a --- /dev/null +++ b/tests/ui/lang/consts/nested-ref-in-composite.rs @@ -0,0 +1,28 @@ +// Test `&'static T` constants where the `T` values themselves contain references, +// nested in `OpConstantComposite` (structs/arrays) - currently these are disallowed. + +// build-fail + +use spirv_std as _; + +use glam::{const_mat2, Mat2, Vec2}; + +#[inline(never)] +fn pair_deep_load(r: &'static (&'static u32, &'static f32)) -> (u32, f32) { + (*r.0, *r.1) +} + +#[inline(never)] +fn array3_deep_load(r: &'static [&'static u32; 3]) -> [u32; 3] { + [*r[0], *r[1], *r[2]] +} + +#[spirv(fragment)] +pub fn main_pair(pair_out: &mut (u32, f32)) { + *pair_out = pair_deep_load(&(&123, &3.14)); +} + +#[spirv(fragment)] +pub fn main_array3(array3_out: &mut [u32; 3]) { + *array3_out = array3_deep_load(&[&0, &1, &2]); +} diff --git a/tests/ui/lang/consts/nested-ref-in-composite.stderr b/tests/ui/lang/consts/nested-ref-in-composite.stderr new file mode 100644 index 0000000000..97cc199a2f --- /dev/null +++ b/tests/ui/lang/consts/nested-ref-in-composite.stderr @@ -0,0 +1,27 @@ +error: constant arrays/structs cannot contain pointers to other constants + --> $DIR/nested-ref-in-composite.rs:22:17 + | +22 | *pair_out = pair_deep_load(&(&123, &3.14)); + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = note: Stack: + nested_ref_in_composite::main_pair + Unnamed function ID %26 + +error: constant arrays/structs cannot contain pointers to other constants + --> $DIR/nested-ref-in-composite.rs:27:19 + | +27 | *array3_out = array3_deep_load(&[&0, &1, &2]); + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = note: Stack: + nested_ref_in_composite::main_array3 + Unnamed function ID %33 + +error: invalid binary:0:0 - No OpEntryPoint instruction was found. This is only allowed if the Linkage capability is being used. + | + = note: spirv-val failed + = note: module `$TEST_BUILD_DIR/lang/consts/nested-ref-in-composite.stage-id.spv` + +error: aborting due to 3 previous errors + diff --git a/tests/ui/lang/consts/nested-ref.rs b/tests/ui/lang/consts/nested-ref.rs new file mode 100644 index 0000000000..58a8799f4c --- /dev/null +++ b/tests/ui/lang/consts/nested-ref.rs @@ -0,0 +1,32 @@ +// Test `&'static &'static T` constants where the `T` values don't themselves +// contain references, and where the `T` values aren't immediatelly loaded from. + +// build-pass + +use spirv_std as _; + +use glam::{const_mat2, Mat2, Vec2}; + +#[inline(never)] +fn deep_load(r: &'static &'static u32) -> u32 { + **r +} + +const ROT90: &Mat2 = &const_mat2![[0.0, 1.0], [-1.0, 0.0]]; + +#[inline(never)] +fn deep_transpose(r: &'static &'static Mat2) -> Mat2 { + r.transpose() +} + +#[spirv(fragment)] +pub fn main( + scalar_out: &mut u32, + #[spirv(push_constant)] vec_in: &Vec2, + bool_out: &mut bool, + vec_out: &mut Vec2, +) { + *scalar_out = deep_load(&&123); + *bool_out = vec_in == &Vec2::ZERO; + *vec_out = deep_transpose(&ROT90) * *vec_in; +} diff --git a/tests/ui/lang/consts/shallow-ref.rs b/tests/ui/lang/consts/shallow-ref.rs new file mode 100644 index 0000000000..76253bad32 --- /dev/null +++ b/tests/ui/lang/consts/shallow-ref.rs @@ -0,0 +1,22 @@ +// Test `&'static T` constants where the `T` values don't themselves contain +// references, and where the `T` values aren't immediatelly loaded from. + +// build-pass + +use spirv_std as _; + +use glam::{const_mat2, Mat2, Vec2}; + +#[inline(never)] +fn scalar_load(r: &'static u32) -> u32 { + *r +} + +const ROT90: Mat2 = const_mat2![[0.0, 1.0], [-1.0, 0.0]]; + +#[spirv(fragment)] +pub fn main(scalar_out: &mut u32, vec_in: Vec2, bool_out: &mut bool, vec_out: &mut Vec2) { + *scalar_out = scalar_load(&123); + *bool_out = vec_in == Vec2::ZERO; + *vec_out = ROT90.transpose() * vec_in; +}