Skip to content

Commit

Permalink
Auto merge of rust-lang#3205 - RalfJung:simd-bitmask, r=RalfJung
Browse files Browse the repository at this point in the history
also test simd_select_bitmask on arrays for less than 8 elements
  • Loading branch information
bors committed Dec 3, 2023
2 parents 28f9fe3 + 6e74d2a commit 6da0959
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 48 deletions.
88 changes: 45 additions & 43 deletions src/tools/miri/src/shims/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,37 +405,35 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
this.write_immediate(*val, &dest)?;
}
}
// Variant of `select` that takes a bitmask rather than a "vector of bool".
"select_bitmask" => {
let [mask, yes, no] = check_arg_count(args)?;
let (yes, yes_len) = this.operand_to_simd(yes)?;
let (no, no_len) = this.operand_to_simd(no)?;
let (dest, dest_len) = this.place_to_simd(dest)?;
let bitmask_len = dest_len.max(8);

// The mask must be an integer or an array.
assert!(
mask.layout.ty.is_integral()
|| matches!(mask.layout.ty.kind(), ty::Array(elemty, _) if elemty == &this.tcx.types.u8)
);
assert!(bitmask_len <= 64);
assert_eq!(bitmask_len, mask.layout.size.bits());
assert_eq!(dest_len, yes_len);
assert_eq!(dest_len, no_len);
let dest_len = u32::try_from(dest_len).unwrap();
let bitmask_len = u32::try_from(bitmask_len).unwrap();

// The mask can be a single integer or an array.
let mask: u64 = match mask.layout.ty.kind() {
ty::Int(..) | ty::Uint(..) =>
this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap(),
ty::Array(elem, _) if matches!(elem.kind(), ty::Uint(ty::UintTy::U8)) => {
let mask_ty = this.machine.layouts.uint(mask.layout.size).unwrap();
let mask = mask.transmute(mask_ty, this)?;
this.read_scalar(&mask)?.to_bits(mask_ty.size)?.try_into().unwrap()
}
_ => bug!("simd_select_bitmask: invalid mask type {}", mask.layout.ty),
};
// To read the mask, we transmute it to an integer.
// That does the right thing wrt endianess.
let mask_ty = this.machine.layouts.uint(mask.layout.size).unwrap();
let mask = mask.transmute(mask_ty, this)?;
let mask: u64 = this.read_scalar(&mask)?.to_bits(mask_ty.size)?.try_into().unwrap();

for i in 0..dest_len {
let mask = mask
& 1u64
.checked_shl(simd_bitmask_index(i, dest_len, this.data_layout().endian))
.unwrap();
let bit_i = simd_bitmask_index(i, dest_len, this.data_layout().endian);
let mask = mask & 1u64.checked_shl(bit_i).unwrap();
let yes = this.read_immediate(&this.project_index(&yes, i.into())?)?;
let no = this.read_immediate(&this.project_index(&no, i.into())?)?;
let dest = this.project_index(&dest, i.into())?;
Expand All @@ -445,6 +443,8 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
}
for i in dest_len..bitmask_len {
// If the mask is "padded", ensure that padding is all-zero.
// This deliberately does not use `simd_bitmask_index`; these bits are outside
// the bitmask. It does not matter in which order we check them.
let mask = mask & 1u64.checked_shl(i).unwrap();
if mask != 0 {
throw_ub_format!(
Expand All @@ -453,6 +453,36 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
}
}
}
// Converts a "vector of bool" into a bitmask.
"bitmask" => {
let [op] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let bitmask_len = op_len.max(8);

// Returns either an unsigned integer or array of `u8`.
assert!(
dest.layout.ty.is_integral()
|| matches!(dest.layout.ty.kind(), ty::Array(elemty, _) if elemty == &this.tcx.types.u8)
);
assert!(bitmask_len <= 64);
assert_eq!(bitmask_len, dest.layout.size.bits());
let op_len = u32::try_from(op_len).unwrap();

let mut res = 0u64;
for i in 0..op_len {
let op = this.read_immediate(&this.project_index(&op, i.into())?)?;
if simd_element_to_bool(op)? {
res |= 1u64
.checked_shl(simd_bitmask_index(i, op_len, this.data_layout().endian))
.unwrap();
}
}
// We have to change the type of the place to be able to write `res` into it. This
// transmutes the integer to an array, which does the right thing wrt endianess.
let dest =
dest.transmute(this.machine.layouts.uint(dest.layout.size).unwrap(), this)?;
this.write_int(res, &dest)?;
}
"cast" | "as" | "cast_ptr" | "expose_addr" | "from_exposed_addr" => {
let [op] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
Expand Down Expand Up @@ -635,34 +665,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
}
}
}
"bitmask" => {
let [op] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let bitmask_len = op_len.max(8);

// Returns either an unsigned integer or array of `u8`.
assert!(
dest.layout.ty.is_integral()
|| matches!(dest.layout.ty.kind(), ty::Array(elemty, _) if elemty == &this.tcx.types.u8)
);
assert!(bitmask_len <= 64);
assert_eq!(bitmask_len, dest.layout.size.bits());
let op_len = u32::try_from(op_len).unwrap();

let mut res = 0u64;
for i in 0..op_len {
let op = this.read_immediate(&this.project_index(&op, i.into())?)?;
if simd_element_to_bool(op)? {
res |= 1u64
.checked_shl(simd_bitmask_index(i, op_len, this.data_layout().endian))
.unwrap();
}
}
// We have to force the place type to be an int so that we can write `res` into it.
let mut dest = this.force_allocation(dest)?;
dest.layout = this.machine.layouts.uint(dest.layout.size).unwrap();
this.write_int(res, &dest)?;
}

name => throw_unsup_format!("unimplemented intrinsic: `simd_{name}`"),
}
Expand Down
54 changes: 49 additions & 5 deletions src/tools/miri/tests/pass/portable-simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
#![allow(incomplete_features, internal_features)]
use std::simd::{prelude::*, StdFloat};

extern "platform-intrinsic" {
pub(crate) fn simd_bitmask<T, U>(x: T) -> U;
}

fn simd_ops_f32() {
let a = f32x4::splat(10.0);
let b = f32x4::from_array([1.0, 2.0, 3.0, -4.0]);
Expand Down Expand Up @@ -218,6 +214,11 @@ fn simd_ops_i32() {
}

fn simd_mask() {
extern "platform-intrinsic" {
pub(crate) fn simd_bitmask<T, U>(x: T) -> U;
pub(crate) fn simd_select_bitmask<M, T>(m: M, yes: T, no: T) -> T;
}

let intmask = Mask::from_int(i32x4::from_array([0, -1, 0, 0]));
assert_eq!(intmask, Mask::from_array([false, true, false, false]));
assert_eq!(intmask.to_array(), [false, true, false, false]);
Expand Down Expand Up @@ -266,7 +267,16 @@ fn simd_mask() {
}
}

// This used to cause an ICE.
// This used to cause an ICE. It exercises simd_select_bitmask with an array as input.
if cfg!(target_endian = "little") {
// FIXME this test currently fails on big-endian:
// <https://github.com/rust-lang/portable-simd/issues/379>
let bitmask = u8x4::from_array([0b00001101, 0, 0, 0]);
assert_eq!(
mask32x4::from_bitmask_vector(bitmask),
mask32x4::from_array([true, false, true, true]),
);
}
let bitmask = u8x8::from_array([0b01000101, 0, 0, 0, 0, 0, 0, 0]);
assert_eq!(
mask32x8::from_bitmask_vector(bitmask),
Expand All @@ -281,6 +291,40 @@ fn simd_mask() {
true, true, true,
]),
);

// Also directly call simd_select_bitmask, to test both kinds of argument types.
unsafe {
// These masks are exactly the results we got out above in the `simd_bitmask` tests.
let selected1 = simd_select_bitmask::<u16, _>(
if cfg!(target_endian = "little") { 0b1010001101001001 } else { 0b1001001011000101 },
i32x16::splat(1), // yes
i32x16::splat(0), // no
);
let selected2 = simd_select_bitmask::<[u8; 2], _>(
if cfg!(target_endian = "little") {
[0b01001001, 0b10100011]
} else {
[0b10010010, 0b11000101]
},
i32x16::splat(1), // yes
i32x16::splat(0), // no
);
assert_eq!(selected1, i32x16::from_array([1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1]));
assert_eq!(selected2, selected1);
// Also try masks less than a byte long.
let selected1 = simd_select_bitmask::<u8, _>(
if cfg!(target_endian = "little") { 0b1000 } else { 0b0001 },
i32x4::splat(1), // yes
i32x4::splat(0), // no
);
let selected2 = simd_select_bitmask::<[u8; 1], _>(
if cfg!(target_endian = "little") { [0b1000] } else { [0b0001] },
i32x4::splat(1), // yes
i32x4::splat(0), // no
);
assert_eq!(selected1, i32x4::from_array([0, 0, 0, 1]));
assert_eq!(selected2, selected1);
}
}

fn simd_cast() {
Expand Down

0 comments on commit 6da0959

Please sign in to comment.