Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype: Add unstable -Z reference-niches option #113166

Merged
merged 13 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 98 additions & 10 deletions compiler/rustc_abi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ bitflags! {
}
}

/// Which niches (beyond the `null` niche) are available on references.
#[derive(Default, Copy, Clone, Hash, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "nightly", derive(Encodable, Decodable, HashStable_Generic))]
pub struct ReferenceNichePolicy {
pub size: bool,
pub align: bool,
}

#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "nightly", derive(Encodable, Decodable, HashStable_Generic))]
pub enum IntegerType {
Expand Down Expand Up @@ -346,6 +354,33 @@ impl TargetDataLayout {
}
}

#[inline]
pub fn target_usize_max(&self) -> u64 {
self.pointer_size.unsigned_int_max().try_into().unwrap()
}

#[inline]
pub fn target_isize_min(&self) -> i64 {
self.pointer_size.signed_int_min().try_into().unwrap()
}

#[inline]
pub fn target_isize_max(&self) -> i64 {
self.pointer_size.signed_int_max().try_into().unwrap()
}

/// Returns the (inclusive) range of possible addresses for an allocation with
/// the given size and alignment.
///
/// Note that this doesn't take into account target-specific limitations.
#[inline]
pub fn address_range_for(&self, size: Size, align: Align) -> (u64, u64) {
let end = Size::from_bytes(self.target_usize_max());
let min = align.bytes();
let max = (end - size).align_down_to(align).bytes();
(min, max)
}

#[inline]
pub fn vector_align(&self, vec_size: Size) -> AbiAndPrefAlign {
for &(size, align) in &self.vector_align {
Expand Down Expand Up @@ -473,6 +508,12 @@ impl Size {
Size::from_bytes((self.bytes() + mask) & !mask)
}

#[inline]
pub fn align_down_to(self, align: Align) -> Size {
let mask = align.bytes() - 1;
Size::from_bytes(self.bytes() & !mask)
}

#[inline]
pub fn is_aligned(self, align: Align) -> bool {
let mask = align.bytes() - 1;
Expand Down Expand Up @@ -967,6 +1008,43 @@ impl WrappingRange {
}
}

/// Returns `true` if `range` is contained in `self`.
#[inline(always)]
pub fn contains_range<I: Into<u128> + Ord>(&self, range: RangeInclusive<I>) -> bool {
if range.is_empty() {
return true;
}

let (vmin, vmax) = range.into_inner();
let (vmin, vmax) = (vmin.into(), vmax.into());

if self.start <= self.end {
self.start <= vmin && vmax <= self.end
} else {
// The last check is needed to cover the following case:
// `vmin ... start, end ... vmax`. In this special case there is no gap
// between `start` and `end` so we must return true.
self.start <= vmin || vmax <= self.end || self.start == self.end + 1
moulins marked this conversation as resolved.
Show resolved Hide resolved
}
}

/// Returns `true` if `range` has an overlap with `self`.
#[inline(always)]
pub fn overlaps_range<I: Into<u128> + Ord>(&self, range: RangeInclusive<I>) -> bool {
if range.is_empty() {
return false;
}

let (vmin, vmax) = range.into_inner();
let (vmin, vmax) = (vmin.into(), vmax.into());

if self.start <= self.end {
self.start <= vmax && vmin <= self.end
} else {
self.start <= vmax || vmin <= self.end
}
}

/// Returns `self` with replaced `start`
#[inline(always)]
pub fn with_start(mut self, start: u128) -> Self {
Expand All @@ -984,9 +1062,15 @@ impl WrappingRange {
/// Returns `true` if `size` completely fills the range.
#[inline]
pub fn is_full_for(&self, size: Size) -> bool {
debug_assert!(self.is_in_range_for(size));
self.start == (self.end.wrapping_add(1) & size.unsigned_int_max())
}

/// Returns `true` if the range is valid for `size`.
#[inline(always)]
pub fn is_in_range_for(&self, size: Size) -> bool {
let max_value = size.unsigned_int_max();
debug_assert!(self.start <= max_value && self.end <= max_value);
self.start == (self.end.wrapping_add(1) & max_value)
self.start <= max_value && self.end <= max_value
}
}

Expand Down Expand Up @@ -1427,16 +1511,21 @@ impl Niche {

pub fn reserve<C: HasDataLayout>(&self, cx: &C, count: u128) -> Option<(u128, Scalar)> {
assert!(count > 0);
if count > self.available(cx) {
return None;
}

let Self { value, valid_range: v, .. } = *self;
let size = value.size(cx);
assert!(size.bits() <= 128);
let max_value = size.unsigned_int_max();
let max_value = value.size(cx).unsigned_int_max();
let distance_end_zero = max_value - v.end;

let niche = v.end.wrapping_add(1)..v.start;
let available = niche.end.wrapping_sub(niche.start) & max_value;
if count > available {
return None;
// Null-pointer optimization. This is guaranteed by Rust (at least for `Option<_>`),
// and offers better codegen opportunities.
if count == 1 && matches!(value, Pointer(_)) && !v.contains(0) {
// Select which bound to move to minimize the number of lost niches.
let valid_range =
if v.start - 1 > distance_end_zero { v.with_end(0) } else { v.with_start(0) };
return Some((0, Scalar::Initialized { value, valid_range }));
}

// Extend the range of valid values being reserved by moving either `v.start` or `v.end` bound.
Expand All @@ -1459,7 +1548,6 @@ impl Niche {
let end = v.end.wrapping_add(count) & max_value;
Some((start, Scalar::Initialized { value, valid_range: v.with_end(end) }))
};
let distance_end_zero = max_value - v.end;
if v.start > v.end {
// zero is unavailable because wrapping occurs
move_end(v)
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_codegen_gcc/src/type_of.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,8 @@ impl<'tcx> LayoutGccExt<'tcx> for TyAndLayout<'tcx> {
return pointee;
}

let result = Ty::ty_and_layout_pointee_info_at(*self, cx, offset);
let assume_valid_ptr = true;
let result = Ty::ty_and_layout_pointee_info_at(*self, cx, offset, assume_valid_ptr);

cx.pointee_infos.borrow_mut().insert((self.ty, offset), result);
result
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_codegen_llvm/src/type_of.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,8 @@ impl<'tcx> LayoutLlvmExt<'tcx> for TyAndLayout<'tcx> {
if let Some(&pointee) = cx.pointee_infos.borrow().get(&(self.ty, offset)) {
return pointee;
}

let result = Ty::ty_and_layout_pointee_info_at(*self, cx, offset);
let assume_valid_ptr = true;
let result = Ty::ty_and_layout_pointee_info_at(*self, cx, offset, assume_valid_ptr);

cx.pointee_infos.borrow_mut().insert((self.ty, offset), result);
result
Expand Down
1 change: 0 additions & 1 deletion compiler/rustc_const_eval/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ const_eval_not_enough_caller_args =
const_eval_null_box = {$front_matter}: encountered a null box
const_eval_null_fn_ptr = {$front_matter}: encountered a null function pointer
const_eval_null_ref = {$front_matter}: encountered a null reference
const_eval_nullable_ptr_out_of_range = {$front_matter}: encountered a potentially null pointer, but expected something that cannot possibly fail to be {$in_range}
const_eval_nullary_intrinsic_fail =
could not evaluate nullary intrinsic

Expand Down
9 changes: 4 additions & 5 deletions compiler/rustc_const_eval/src/const_eval/machine.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use rustc_hir::def::DefKind;
use rustc_hir::{LangItem, CRATE_HIR_ID};
use rustc_middle::mir;
use rustc_middle::mir::interpret::PointerArithmetic;
use rustc_middle::ty::layout::{FnAbiOf, TyAndLayout};
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_session::lint::builtin::INVALID_ALIGNMENT;
Expand All @@ -17,7 +16,7 @@ use rustc_ast::Mutability;
use rustc_hir::def_id::DefId;
use rustc_middle::mir::AssertMessage;
use rustc_span::symbol::{sym, Symbol};
use rustc_target::abi::{Align, Size};
use rustc_target::abi::{Align, HasDataLayout as _, Size};
use rustc_target::spec::abi::Abi as CallAbi;

use crate::errors::{LongRunning, LongRunningWarn};
Expand Down Expand Up @@ -304,8 +303,8 @@ impl<'mir, 'tcx: 'mir> CompileTimeEvalContext<'mir, 'tcx> {
Ok(ControlFlow::Break(()))
} else {
// Not alignable in const, return `usize::MAX`.
let usize_max = Scalar::from_target_usize(self.target_usize_max(), self);
self.write_scalar(usize_max, dest)?;
let usize_max = self.data_layout().target_usize_max();
self.write_scalar(Scalar::from_target_usize(usize_max, self), dest)?;
self.return_to_block(ret)?;
Ok(ControlFlow::Break(()))
}
Expand Down Expand Up @@ -333,7 +332,7 @@ impl<'mir, 'tcx: 'mir> CompileTimeEvalContext<'mir, 'tcx> {
// Inequality with integers other than null can never be known for sure.
(Scalar::Int(int), ptr @ Scalar::Ptr(..))
| (ptr @ Scalar::Ptr(..), Scalar::Int(int))
if int.is_null() && !self.scalar_may_be_null(ptr)? =>
if int.is_null() && !self.ptr_scalar_range(ptr)?.contains(&0) =>
{
0
}
Expand Down
5 changes: 1 addition & 4 deletions compiler/rustc_const_eval/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,6 @@ impl<'tcx> ReportErrorExt for ValidationErrorInfo<'tcx> {
MutableRefInConst => const_eval_mutable_ref_in_const,
NullFnPtr => const_eval_null_fn_ptr,
NeverVal => const_eval_never_val,
NullablePtrOutOfRange { .. } => const_eval_nullable_ptr_out_of_range,
PtrOutOfRange { .. } => const_eval_ptr_out_of_range,
OutOfRange { .. } => const_eval_out_of_range,
UnsafeCell => const_eval_unsafe_cell,
Expand Down Expand Up @@ -732,9 +731,7 @@ impl<'tcx> ReportErrorExt for ValidationErrorInfo<'tcx> {
| InvalidFnPtr { value } => {
err.set_arg("value", value);
}
NullablePtrOutOfRange { range, max_value } | PtrOutOfRange { range, max_value } => {
add_range_arg(range, max_value, handler, err)
}
PtrOutOfRange { range, max_value } => add_range_arg(range, max_value, handler, err),
OutOfRange { range, max_value, value } => {
err.set_arg("value", value);
add_range_arg(range, max_value, handler, err);
Expand Down
24 changes: 14 additions & 10 deletions compiler/rustc_const_eval/src/interpret/discriminant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

use rustc_middle::ty::layout::{LayoutOf, PrimitiveExt};
use rustc_middle::{mir, ty};
use rustc_target::abi::{self, TagEncoding};
use rustc_target::abi::{VariantIdx, Variants};
use rustc_target::abi::{self, TagEncoding, VariantIdx, Variants, WrappingRange};

use super::{ImmTy, InterpCx, InterpResult, Machine, OpTy, PlaceTy, Scalar};

Expand Down Expand Up @@ -180,19 +179,24 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
// discriminant (encoded in niche/tag) and variant index are the same.
let variants_start = niche_variants.start().as_u32();
let variants_end = niche_variants.end().as_u32();
let variants_len = u128::from(variants_end - variants_start);
let variant = match tag_val.try_to_int() {
Err(dbg_val) => {
// So this is a pointer then, and casting to an int failed.
// Can only happen during CTFE.
// The niche must be just 0, and the ptr not null, then we know this is
// okay. Everything else, we conservatively reject.
let ptr_valid = niche_start == 0
&& variants_start == variants_end
&& !self.scalar_may_be_null(tag_val)?;
if !ptr_valid {
// The pointer and niches ranges must be disjoint, then we know
// this is the untagged variant (as the value is not in the niche).
// Everything else, we conservatively reject.
let range = self.ptr_scalar_range(tag_val)?;
let niches = WrappingRange {
start: niche_start,
end: niche_start.wrapping_add(variants_len),
};
if niches.overlaps_range(range) {
throw_ub!(InvalidTag(dbg_val))
} else {
untagged_variant
}
untagged_variant
}
Ok(tag_bits) => {
let tag_bits = tag_bits.assert_bits(tag_layout.size);
Expand All @@ -205,7 +209,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
let variant_index_relative =
variant_index_relative_val.to_scalar().assert_bits(tag_val.layout.size);
// Check if this is in the range that indicates an actual discriminant.
if variant_index_relative <= u128::from(variants_end - variants_start) {
if variant_index_relative <= variants_len {
let variant_index_relative = u32::try_from(variant_index_relative)
.expect("we checked that this fits into a u32");
// Then computing the absolute variant idx should not overflow any more.
Expand Down
11 changes: 5 additions & 6 deletions compiler/rustc_const_eval/src/interpret/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@
use rustc_hir::def_id::DefId;
use rustc_middle::mir::{
self,
interpret::{
Allocation, ConstAllocation, ConstValue, GlobalId, InterpResult, PointerArithmetic, Scalar,
},
interpret::{Allocation, ConstAllocation, ConstValue, GlobalId, InterpResult, Scalar},
BinOp, NonDivergingIntrinsic,
};
use rustc_middle::ty;
use rustc_middle::ty::layout::{LayoutOf as _, ValidityRequirement};
use rustc_middle::ty::GenericArgsRef;
use rustc_middle::ty::{Ty, TyCtxt};
use rustc_span::symbol::{sym, Symbol};
use rustc_target::abi::{Abi, Align, Primitive, Size};
use rustc_target::abi::{Abi, Align, HasDataLayout as _, Primitive, Size};

use super::{
util::ensure_monomorphic_enough, CheckInAllocMsg, ImmTy, InterpCx, Machine, OpTy, PlaceTy,
Expand Down Expand Up @@ -361,11 +359,12 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
)?;

// Perform division by size to compute return value.
let dl = self.data_layout();
let ret_layout = if intrinsic_name == sym::ptr_offset_from_unsigned {
assert!(0 <= dist && dist <= self.target_isize_max());
assert!(0 <= dist && dist <= dl.target_isize_max());
usize_layout
} else {
assert!(self.target_isize_min() <= dist && dist <= self.target_isize_max());
assert!(dl.target_isize_min() <= dist && dist <= dl.target_isize_max());
isize_layout
};
let pointee_layout = self.layout_of(instance_args.type_at(0))?;
Expand Down
41 changes: 26 additions & 15 deletions compiler/rustc_const_eval/src/interpret/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::assert_matches::assert_matches;
use std::borrow::Cow;
use std::collections::VecDeque;
use std::fmt;
use std::ops::RangeInclusive;
use std::ptr;

use rustc_ast::Mutability;
Expand Down Expand Up @@ -1222,24 +1223,34 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {

/// Machine pointer introspection.
impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
/// Test if this value might be null.
/// Turn a pointer-sized scalar into a (non-empty) range of possible values.
/// If the machine does not support ptr-to-int casts, this is conservative.
pub fn scalar_may_be_null(&self, scalar: Scalar<M::Provenance>) -> InterpResult<'tcx, bool> {
Ok(match scalar.try_to_int() {
Ok(int) => int.is_null(),
Err(_) => {
// Can only happen during CTFE.
let ptr = scalar.to_pointer(self)?;
match self.ptr_try_get_alloc_id(ptr) {
Ok((alloc_id, offset, _)) => {
let (size, _align, _kind) = self.get_alloc_info(alloc_id);
// If the pointer is out-of-bounds, it may be null.
// Note that one-past-the-end (offset == size) is still inbounds, and never null.
offset > size
}
Err(_offset) => bug!("a non-int scalar is always a pointer"),
pub fn ptr_scalar_range(
&self,
scalar: Scalar<M::Provenance>,
) -> InterpResult<'tcx, RangeInclusive<u64>> {
if let Ok(int) = scalar.to_target_usize(self) {
return Ok(int..=int);
}

let ptr = scalar.to_pointer(self)?;

// Can only happen during CTFE.
Ok(match self.ptr_try_get_alloc_id(ptr) {
Ok((alloc_id, offset, _)) => {
let offset = offset.bytes();
let (size, align, _) = self.get_alloc_info(alloc_id);
let dl = self.data_layout();
if offset > size.bytes() {
// If the pointer is out-of-bounds, we do not have a
// meaningful range to return.
0..=dl.target_usize_max()
} else {
let (min, max) = dl.address_range_for(size, align);
(min + offset)..=(max + offset)
}
}
Err(_offset) => bug!("a non-int scalar is always a pointer"),
})
}

Expand Down
Loading
Loading