Skip to content

Commit

Permalink
asm: infer the result type of OpAccessChain (when indexing into arrays).
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyb committed Apr 16, 2021
1 parent fedac0f commit bb7adc9
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 6 deletions.
40 changes: 34 additions & 6 deletions crates/rustc_codegen_spirv/src/builder/spirv_asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,30 +591,52 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
struct Ambiguous;

/// Construct a type from `pat`, replacing `TyPat::Var(i)` with `ty_vars[i]`.
/// `leftover_operands` is used for `IndexComposite` patterns, if any exist.
/// If the pattern isn't constraining enough to determine an unique type,
/// `Err(Ambiguous)` is returned instead.
fn subst_ty_pat(
cx: &CodegenCx<'_>,
pat: &TyPat<'_>,
ty_vars: &[Option<Word>],
leftover_operands: &[dr::Operand],
) -> Result<Word, Ambiguous> {
Ok(match pat {
&TyPat::Var(i) => match ty_vars.get(i) {
Some(&Some(ty)) => ty,
_ => return Err(Ambiguous),
},

TyPat::Pointer(_, pat) => SpirvType::Pointer {
pointee: subst_ty_pat(cx, pat, ty_vars, leftover_operands)?,
}
.def(DUMMY_SP, cx),

TyPat::Vector4(pat) => SpirvType::Vector {
element: subst_ty_pat(cx, pat, ty_vars)?,
element: subst_ty_pat(cx, pat, ty_vars, leftover_operands)?,
count: 4,
}
.def(DUMMY_SP, cx),

TyPat::SampledImage(pat) => SpirvType::SampledImage {
image_type: subst_ty_pat(cx, pat, ty_vars)?,
image_type: subst_ty_pat(cx, pat, ty_vars, leftover_operands)?,
}
.def(DUMMY_SP, cx),

TyPat::IndexComposite(pat) => {
let mut ty = subst_ty_pat(cx, pat, ty_vars, leftover_operands)?;
for _index in leftover_operands {
// FIXME(eddyb) support more than just arrays, by looking
// up the indices (of struct fields) as constant integers.
ty = match cx.lookup_type(ty) {
SpirvType::Array { element, .. }
| SpirvType::RuntimeArray { element } => element,

_ => return Err(Ambiguous),
};
}
ty
}

_ => return Err(Ambiguous),
})
}
Expand All @@ -631,11 +653,12 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {

let mut combined_ty_vars = [None];

let mut ids = instruction.operands.iter().filter_map(|o| o.id_ref_any());
let mut operands = instruction.operands.iter();
let mut next_id_operand = || operands.find_map(|o| o.id_ref_any());
while let TyListPat::Cons { first: pat, suffix } = *sig.input_types {
sig.input_types = suffix;

let match_result = match id_to_type_map.get(&ids.next()?) {
let match_result = match id_to_type_map.get(&next_id_operand()?) {
Some(&ty) => match_ty_pat(self, pat, ty),

// Non-value ID operand (or value operand of unknown type),
Expand Down Expand Up @@ -673,14 +696,19 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {

TyListPat::Any => {}
TyListPat::Nil => {
if ids.next().is_some() {
if next_id_operand().is_some() {
return None;
}
}
_ => return None,
}

match subst_ty_pat(self, sig.output_type.unwrap(), &combined_ty_vars) {
match subst_ty_pat(
self,
sig.output_type.unwrap(),
&combined_ty_vars,
operands.as_slice(),
) {
Ok(ty) => Some(ty),
Err(Ambiguous) => None,
}
Expand Down
22 changes: 22 additions & 0 deletions tests/ui/lang/asm/infer-access-chain-array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Tests that `asm!` can infer the result type of `OpAccessChain`,
// when used to index arrays.

// build-pass

use spirv_std as _;

use glam::Vec4;

#[spirv(fragment)]
pub fn main(#[spirv(push_constant)] array_in: &[Vec4; 16], i: u32, out: &mut Vec4) {
unsafe {
asm!(
"%val_ptr = OpAccessChain _ {array_ptr} {index}",
"%val = OpLoad _ %val_ptr",
"OpStore {out_ptr} %val",
array_ptr = in(reg) array_in,
index = in(reg) i,
out_ptr = in(reg) out,
);
}
}

0 comments on commit bb7adc9

Please sign in to comment.