Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for occupied niches #121174

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion compiler/rustc_codegen_ssa/src/mir/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,17 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
) -> MergingSucc {
debug!("codegen_terminator: {:?}", terminator);

if bx.tcx().may_insert_niche_checks() {
if let mir::TerminatorKind::Return = terminator.kind {
let op = mir::Operand::Copy(mir::Place::return_place());
let ty = op.ty(self.mir, bx.tcx());
let ty = self.monomorphize(ty);
if let Some(niche) = bx.layout_of(ty).largest_niche {
self.codegen_niche_check(bx, op, niche, terminator.source_info);
}
}
}

let helper = TerminatorCodegenHelper { bb, terminator };

let mergeable_succ = || {
Expand Down Expand Up @@ -1598,7 +1609,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
tuple.layout.fields.count()
}

fn get_caller_location(
pub fn get_caller_location(
&mut self,
bx: &mut Bx,
source_info: mir::SourceInfo,
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_ssa/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub mod coverageinfo;
pub mod debuginfo;
mod intrinsic;
mod locals;
mod niche_check;
pub mod operand;
pub mod place;
mod rvalue;
Expand Down
283 changes: 283 additions & 0 deletions compiler/rustc_codegen_ssa/src/mir/niche_check.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
use rustc_hir::LangItem;
use rustc_middle::mir;
use rustc_middle::mir::visit::{NonMutatingUseContext, PlaceContext, Visitor};
use rustc_middle::ty::{Mutability, Ty, TyCtxt};
use rustc_span::def_id::LOCAL_CRATE;
use rustc_span::Span;
use rustc_target::abi::{Float, Integer, Niche, Primitive, Size};

use super::FunctionCx;
use crate::mir::place::PlaceValue;
use crate::mir::OperandValue;
use crate::traits::*;
use crate::{base, common};

pub struct NicheFinder<'s, 'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> {
pub fx: &'s mut FunctionCx<'a, 'tcx, Bx>,
pub bx: &'s mut Bx,
pub places: Vec<(mir::Operand<'tcx>, Niche)>,
}

impl<'s, 'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> Visitor<'tcx> for NicheFinder<'s, 'a, 'tcx, Bx> {
fn visit_rvalue(&mut self, rvalue: &mir::Rvalue<'tcx>, location: mir::Location) {
match rvalue {
mir::Rvalue::Cast(mir::CastKind::Transmute, op, ty) => {
let ty = self.fx.monomorphize(*ty);
if let Some(niche) = self.bx.layout_of(ty).largest_niche {
self.places.push((op.clone(), niche));
}
}
_ => self.super_rvalue(rvalue, location),
}
}

fn visit_terminator(&mut self, terminator: &mir::Terminator<'tcx>, _location: mir::Location) {
if let mir::TerminatorKind::Return = terminator.kind {
let op = mir::Operand::Copy(mir::Place::return_place());
let ty = op.ty(self.fx.mir, self.bx.tcx());
let ty = self.fx.monomorphize(ty);
if let Some(niche) = self.bx.layout_of(ty).largest_niche {
self.places.push((op, niche));
}
}
}

fn visit_place(
&mut self,
place: &mir::Place<'tcx>,
context: PlaceContext,
_location: mir::Location,
) {
match context {
PlaceContext::NonMutatingUse(
NonMutatingUseContext::Copy | NonMutatingUseContext::Move,
) => {}
_ => {
return;
}
}

let ty = place.ty(self.fx.mir, self.bx.tcx()).ty;
let ty = self.fx.monomorphize(ty);
if let Some(niche) = self.bx.layout_of(ty).largest_niche {
self.places.push((mir::Operand::Copy(*place), niche));
};
}
}

use rustc_target::abi::{Abi, Scalar, WrappingRange};
impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
fn value_in_niche(
&mut self,
bx: &mut Bx,
op: crate::mir::OperandRef<'tcx, Bx::Value>,
niche: Niche,
) -> Option<Bx::Value> {
let niche_ty = niche.ty(bx.tcx());
let niche_layout = bx.layout_of(niche_ty);

let (imm, from_scalar, from_backend_ty) = match op.val {
OperandValue::Immediate(imm) => {
let Abi::Scalar(from_scalar) = op.layout.abi else { unreachable!() };
let from_backend_ty = bx.backend_type(op.layout);
(imm, from_scalar, from_backend_ty)
}
OperandValue::Pair(first, second) => {
let Abi::ScalarPair(first_scalar, second_scalar) = op.layout.abi else {
unreachable!()
};
if niche.offset == Size::ZERO {
(first, first_scalar, bx.scalar_pair_element_backend_type(op.layout, 0, true))
} else {
// yolo
(second, second_scalar, bx.scalar_pair_element_backend_type(op.layout, 1, true))
}
}
OperandValue::ZeroSized => unreachable!(),
OperandValue::Ref(PlaceValue { llval: ptr, .. }) => {
// General case: Load the niche primitive via pointer arithmetic.
let niche_ptr_ty = Ty::new_ptr(bx.tcx(), niche_ty, Mutability::Not);
let ptr = bx.pointercast(ptr, bx.backend_type(bx.layout_of(niche_ptr_ty)));

let offset = niche.offset.bytes() / niche_layout.size.bytes();
let niche_backend_ty = bx.backend_type(bx.layout_of(niche_ty));
let ptr = bx.inbounds_gep(niche_backend_ty, ptr, &[bx.const_usize(offset)]);
let value = bx.load(niche_backend_ty, ptr, rustc_target::abi::Align::ONE);
return Some(value);
}
};

// Any type whose ABI is a Scalar bool is turned into an i1, so it cannot contain a value
// outside of its niche.
if from_scalar.is_bool() {
return None;
}

let to_scalar = Scalar::Initialized {
value: niche.value,
valid_range: WrappingRange::full(niche.size(bx.tcx())),
};
let to_backend_ty = bx.backend_type(niche_layout);
if from_backend_ty == to_backend_ty {
return Some(imm);
}
let value = self.transmute_immediate(
bx,
imm,
from_scalar,
from_backend_ty,
to_scalar,
to_backend_ty,
);
Some(value)
}

#[instrument(level = "debug", skip(self, bx))]
pub fn codegen_niche_check(
&mut self,
bx: &mut Bx,
mir_op: mir::Operand<'tcx>,
niche: Niche,
source_info: mir::SourceInfo,
) {
let tcx = bx.tcx();
let op_ty = self.monomorphize(mir_op.ty(self.mir, tcx));
if op_ty == tcx.types.bool {
return;
}

let op = self.codegen_operand(bx, &mir_op);

let Some(value_in_niche) = self.value_in_niche(bx, op, niche) else {
return;
};
let size = niche.size(tcx);

let start = niche.scalar(niche.valid_range.start, bx);
let end = niche.scalar(niche.valid_range.end, bx);

let binop_le = base::bin_op_to_icmp_predicate(mir::BinOp::Le.to_hir_binop(), false);
let binop_ge = base::bin_op_to_icmp_predicate(mir::BinOp::Ge.to_hir_binop(), false);
let is_valid = if niche.valid_range.start == 0 {
bx.icmp(binop_le, value_in_niche, end)
} else if niche.valid_range.end == (u128::MAX >> 128 - size.bits()) {
bx.icmp(binop_ge, value_in_niche, start)
} else {
// We need to check if the value is within a *wrapping* range. We could do this:
// (niche >= start) && (niche <= end)
// But what we're going to actually do is this:
// max = end - start
// (niche - start) <= max
// The latter is much more complicated conceptually, but is actually less operations
// because we can compute max in codegen.
let mut max = niche.valid_range.end.wrapping_sub(niche.valid_range.start);
let size = niche.size(tcx);
if size.bits() < 128 {
let mask = (1 << size.bits()) - 1;
max &= mask;
}
let max_adjusted_allowed_value = niche.scalar(max, bx);

let biased = bx.sub(value_in_niche, start);
bx.icmp(binop_le, biased, max_adjusted_allowed_value)
};

// Create destination blocks, branching on is_valid
let panic = bx.append_sibling_block("panic");
let success = bx.append_sibling_block("success");
bx.cond_br(is_valid, success, panic);

// Switch to the failure block and codegen a call to the panic intrinsic
bx.switch_to_block(panic);
self.set_debug_loc(bx, source_info);
let location = self.get_caller_location(bx, source_info).immediate();
self.codegen_panic(
bx,
niche.lang_item(),
&[value_in_niche, start, end, location],
source_info.span,
);

// Continue codegen in the success block.
bx.switch_to_block(success);
self.set_debug_loc(bx, source_info);
}

#[instrument(level = "debug", skip(self, bx))]
fn codegen_panic(&mut self, bx: &mut Bx, lang_item: LangItem, args: &[Bx::Value], span: Span) {
if bx.tcx().is_compiler_builtins(LOCAL_CRATE) {
bx.abort()
} else {
let (fn_abi, fn_ptr, instance) = common::build_langcall(bx, Some(span), lang_item);
let fn_ty = bx.fn_decl_backend_type(&fn_abi);
let fn_attrs = if bx.tcx().def_kind(self.instance.def_id()).has_codegen_attrs() {
Some(bx.tcx().codegen_fn_attrs(self.instance.def_id()))
} else {
None
};
bx.call(fn_ty, fn_attrs, Some(&fn_abi), fn_ptr, args, None, Some(instance));
}
bx.unreachable();
}
}

pub trait NicheExt {
fn ty<'tcx>(&self, tcx: TyCtxt<'tcx>) -> Ty<'tcx>;
fn size<'tcx>(&self, tcx: TyCtxt<'tcx>) -> Size;
fn scalar<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(&self, val: u128, bx: &mut Bx) -> Bx::Value;
fn lang_item(&self) -> LangItem;
}

impl NicheExt for Niche {
fn lang_item(&self) -> LangItem {
match self.value {
Primitive::Int(Integer::I8, _) => LangItem::PanicOccupiedNicheU8,
Primitive::Int(Integer::I16, _) => LangItem::PanicOccupiedNicheU16,
Primitive::Int(Integer::I32, _) => LangItem::PanicOccupiedNicheU32,
Primitive::Int(Integer::I64, _) => LangItem::PanicOccupiedNicheU64,
Primitive::Int(Integer::I128, _) => LangItem::PanicOccupiedNicheU128,
Primitive::Pointer(_) => LangItem::PanicOccupiedNichePtr,
Primitive::Float(Float::F16) => LangItem::PanicOccupiedNicheU16,
Primitive::Float(Float::F32) => LangItem::PanicOccupiedNicheU32,
Primitive::Float(Float::F64) => LangItem::PanicOccupiedNicheU64,
Primitive::Float(Float::F128) => LangItem::PanicOccupiedNicheU128,
}
}

fn ty<'tcx>(&self, tcx: TyCtxt<'tcx>) -> Ty<'tcx> {
let types = &tcx.types;
match self.value {
Primitive::Int(Integer::I8, _) => types.u8,
Primitive::Int(Integer::I16, _) => types.u16,
Primitive::Int(Integer::I32, _) => types.u32,
Primitive::Int(Integer::I64, _) => types.u64,
Primitive::Int(Integer::I128, _) => types.u128,
Primitive::Pointer(_) => Ty::new_ptr(tcx, types.unit, Mutability::Not),
Primitive::Float(Float::F16) => types.u16,
Primitive::Float(Float::F32) => types.u32,
Primitive::Float(Float::F64) => types.u64,
Primitive::Float(Float::F128) => types.u128,
}
}

fn size<'tcx>(&self, tcx: TyCtxt<'tcx>) -> Size {
self.value.size(&tcx)
}

fn scalar<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(&self, val: u128, bx: &mut Bx) -> Bx::Value {
use rustc_middle::mir::interpret::{Pointer, Scalar};

let tcx = bx.tcx();
let niche_ty = self.ty(tcx);
let value = if niche_ty.is_any_ptr() {
Scalar::from_maybe_pointer(Pointer::from_addr_invalid(val as u64), &tcx)
} else {
Scalar::from_uint(val, self.size(tcx))
};
let layout = rustc_target::abi::Scalar::Initialized {
value: self.value,
valid_range: WrappingRange::full(self.size(tcx)),
};
bx.scalar_to_backend(value, layout, bx.backend_type(bx.layout_of(self.ty(tcx))))
}
}
6 changes: 3 additions & 3 deletions compiler/rustc_codegen_ssa/src/mir/rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
}
}

fn codegen_transmute(
pub fn codegen_transmute(
&mut self,
bx: &mut Bx,
src: OperandRef<'tcx, Bx::Value>,
Expand Down Expand Up @@ -194,7 +194,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
///
/// Returns `None` for cases that can't work in that framework, such as for
/// `Immediate`->`Ref` that needs an `alloc` to get the location.
fn codegen_transmute_operand(
pub fn codegen_transmute_operand(
&mut self,
bx: &mut Bx,
operand: OperandRef<'tcx, Bx::Value>,
Expand Down Expand Up @@ -336,7 +336,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
///
/// `to_backend_ty` must be the *non*-immediate backend type (so it will be
/// `i8`, not `i1`, for `bool`-like types.)
fn transmute_immediate(
pub fn transmute_immediate(
&self,
bx: &mut Bx,
mut imm: Bx::Value,
Expand Down
23 changes: 21 additions & 2 deletions compiler/rustc_codegen_ssa/src/mir/statement.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,34 @@
use rustc_middle::mir::{self, NonDivergingIntrinsic};
use rustc_middle::span_bug;
use rustc_middle::mir::visit::Visitor;
use rustc_middle::mir::NonDivergingIntrinsic;
use rustc_middle::{mir, span_bug};
use rustc_session::config::OptLevel;
use tracing::instrument;

use super::{FunctionCx, LocalRef};
use crate::mir::niche_check::NicheFinder;
use crate::traits::*;

impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
fn niches_to_check(
&mut self,
bx: &mut Bx,
statement: &mir::Statement<'tcx>,
) -> Vec<(mir::Operand<'tcx>, rustc_target::abi::Niche)> {
let mut finder = NicheFinder { fx: self, bx, places: Vec::new() };
finder.visit_statement(statement, rustc_middle::mir::Location::START);
finder.places
}

#[instrument(level = "debug", skip(self, bx))]
pub fn codegen_statement(&mut self, bx: &mut Bx, statement: &mir::Statement<'tcx>) {
self.set_debug_loc(bx, statement.source_info);

if bx.tcx().may_insert_niche_checks() {
for (op, niche) in self.niches_to_check(bx, statement) {
self.codegen_niche_check(bx, op, niche, statement.source_info);
}
}

match statement.kind {
mir::StatementKind::Assign(box (ref place, ref rvalue)) => {
if let Some(index) = place.as_local() {
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_hir/src/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,12 @@ language_item_table! {
ConstPanicFmt, sym::const_panic_fmt, const_panic_fmt, Target::Fn, GenericRequirement::None;
PanicBoundsCheck, sym::panic_bounds_check, panic_bounds_check_fn, Target::Fn, GenericRequirement::Exact(0);
PanicMisalignedPointerDereference, sym::panic_misaligned_pointer_dereference, panic_misaligned_pointer_dereference_fn, Target::Fn, GenericRequirement::Exact(0);
PanicOccupiedNicheU8, sym::panic_occupied_niche_u8, panic_occupied_niche_u8, Target::Fn, GenericRequirement::None;
PanicOccupiedNicheU16, sym::panic_occupied_niche_u16, panic_occupied_niche_u16, Target::Fn, GenericRequirement::None;
PanicOccupiedNicheU32, sym::panic_occupied_niche_u32, panic_occupied_niche_u32, Target::Fn, GenericRequirement::None;
PanicOccupiedNicheU64, sym::panic_occupied_niche_u64, panic_occupied_niche_u64, Target::Fn, GenericRequirement::None;
PanicOccupiedNicheU128, sym::panic_occupied_niche_u128, panic_occupied_niche_u128, Target::Fn, GenericRequirement::None;
PanicOccupiedNichePtr, sym::panic_occupied_niche_ptr, panic_occupied_niche_ptr, Target::Fn, GenericRequirement::None;
PanicInfo, sym::panic_info, panic_info, Target::Struct, GenericRequirement::None;
PanicLocation, sym::panic_location, panic_location, Target::Struct, GenericRequirement::None;
PanicImpl, sym::panic_impl, panic_impl, Target::Fn, GenericRequirement::None;
Expand Down
Loading
Loading