Skip to content

Commit

Permalink
Check for occupied niches
Browse files Browse the repository at this point in the history
  • Loading branch information
saethlin committed Mar 5, 2024
1 parent d18480b commit 71e0b42
Show file tree
Hide file tree
Showing 38 changed files with 1,959 additions and 15 deletions.
9 changes: 7 additions & 2 deletions compiler/rustc_codegen_ssa/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use rustc_hir::LangItem;
use rustc_middle::mir;
use rustc_middle::ty::{self, layout::TyAndLayout, Ty, TyCtxt};
use rustc_middle::ty::{self, layout::TyAndLayout, GenericArg, Ty, TyCtxt};
use rustc_span::Span;

use crate::base;
Expand Down Expand Up @@ -120,10 +120,15 @@ pub fn build_langcall<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
bx: &Bx,
span: Option<Span>,
li: LangItem,
generic: Option<GenericArg<'tcx>>,
) -> (Bx::FnAbiOfResult, Bx::Value) {
let tcx = bx.tcx();
let def_id = tcx.require_lang_item(li, span);
let instance = ty::Instance::mono(tcx, def_id);
let instance = if let Some(arg) = generic {
ty::Instance::new(def_id, tcx.mk_args(&[arg]))
} else {
ty::Instance::mono(tcx, def_id)
};
(bx.fn_abi_of_instance(instance, ty::List::empty()), bx.get_fn_addr(instance))
}

Expand Down
27 changes: 21 additions & 6 deletions compiler/rustc_codegen_ssa/src/mir/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
}
};

let (fn_abi, llfn) = common::build_langcall(bx, Some(span), lang_item);
let (fn_abi, llfn) = common::build_langcall(bx, Some(span), lang_item, None);

// Codegen the actual panic invoke/call.
let merging_succ = helper.do_call(self, bx, fn_abi, llfn, &args, None, unwind, &[], false);
Expand All @@ -658,7 +658,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
self.set_debug_loc(bx, terminator.source_info);

// Obtain the panic entry point.
let (fn_abi, llfn) = common::build_langcall(bx, Some(span), reason.lang_item());
let (fn_abi, llfn) = common::build_langcall(bx, Some(span), reason.lang_item(), None);

// Codegen the actual panic invoke/call.
let merging_succ = helper.do_call(
Expand Down Expand Up @@ -719,8 +719,12 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let msg = bx.const_str(&msg_str);

// Obtain the panic entry point.
let (fn_abi, llfn) =
common::build_langcall(bx, Some(source_info.span), LangItem::PanicNounwind);
let (fn_abi, llfn) = common::build_langcall(
bx,
Some(source_info.span),
LangItem::PanicNounwind,
None,
);

// Codegen the actual panic invoke/call.
helper.do_call(
Expand Down Expand Up @@ -1194,6 +1198,17 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
) -> MergingSucc {
debug!("codegen_terminator: {:?}", terminator);

if bx.tcx().may_insert_niche_checks() {
if let mir::TerminatorKind::Return = terminator.kind {
let op = mir::Operand::Copy(mir::Place::return_place());
let ty = op.ty(self.mir, bx.tcx());
let ty = self.monomorphize(ty);
if let Some(niche) = bx.layout_of(ty).largest_niche {
self.codegen_niche_check(bx, op, niche, terminator.source_info);
}
}
}

let helper = TerminatorCodegenHelper { bb, terminator };

let mergeable_succ = || {
Expand Down Expand Up @@ -1462,7 +1477,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
tuple.layout.fields.count()
}

fn get_caller_location(
pub fn get_caller_location(
&mut self,
bx: &mut Bx,
source_info: mir::SourceInfo,
Expand Down Expand Up @@ -1603,7 +1618,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {

self.set_debug_loc(&mut bx, mir::SourceInfo::outermost(self.mir.span));

let (fn_abi, fn_ptr) = common::build_langcall(&bx, None, reason.lang_item());
let (fn_abi, fn_ptr) = common::build_langcall(&bx, None, reason.lang_item(), None);
let fn_ty = bx.fn_decl_backend_type(fn_abi);

let llret = bx.call(fn_ty, None, Some(fn_abi), fn_ptr, &[], funclet.as_ref());
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_ssa/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub mod coverageinfo;
pub mod debuginfo;
mod intrinsic;
mod locals;
mod niche_check;
pub mod operand;
pub mod place;
mod rvalue;
Expand Down
295 changes: 295 additions & 0 deletions compiler/rustc_codegen_ssa/src/mir/niche_check.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
use rustc_hir::LangItem;
use rustc_middle::mir;
use rustc_middle::mir::visit::Visitor;
use rustc_middle::mir::visit::{NonMutatingUseContext, PlaceContext};
use rustc_middle::ty::Mutability;
use rustc_middle::ty::Ty;
use rustc_middle::ty::TyCtxt;
use rustc_middle::ty::TypeAndMut;
use rustc_span::Span;
use rustc_target::abi::Integer;
use rustc_target::abi::Niche;
use rustc_target::abi::Primitive;
use rustc_target::abi::Size;

use super::FunctionCx;
use crate::base;
use crate::common;
use crate::mir::OperandValue;
use crate::traits::*;

pub struct NicheFinder<'s, 'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> {
pub fx: &'s mut FunctionCx<'a, 'tcx, Bx>,
pub bx: &'s mut Bx,
pub places: Vec<(mir::Operand<'tcx>, Niche)>,
}

impl<'s, 'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> Visitor<'tcx> for NicheFinder<'s, 'a, 'tcx, Bx> {
fn visit_rvalue(&mut self, rvalue: &mir::Rvalue<'tcx>, location: mir::Location) {
match rvalue {
mir::Rvalue::Cast(mir::CastKind::Transmute, op, ty) => {
let ty = self.fx.monomorphize(*ty);
if let Some(niche) = self.bx.layout_of(ty).largest_niche {
self.places.push((op.clone(), niche));
}
}
_ => self.super_rvalue(rvalue, location),
}
}

fn visit_terminator(&mut self, terminator: &mir::Terminator<'tcx>, _location: mir::Location) {
if let mir::TerminatorKind::Return = terminator.kind {
let op = mir::Operand::Copy(mir::Place::return_place());
let ty = op.ty(self.fx.mir, self.bx.tcx());
let ty = self.fx.monomorphize(ty);
if let Some(niche) = self.bx.layout_of(ty).largest_niche {
self.places.push((op, niche));
}
}
}

fn visit_place(
&mut self,
place: &mir::Place<'tcx>,
context: PlaceContext,
_location: mir::Location,
) {
match context {
PlaceContext::NonMutatingUse(
NonMutatingUseContext::Copy | NonMutatingUseContext::Move,
) => {}
_ => {
return;
}
}

let ty = place.ty(self.fx.mir, self.bx.tcx()).ty;
let ty = self.fx.monomorphize(ty);
if let Some(niche) = self.bx.layout_of(ty).largest_niche {
self.places.push((mir::Operand::Copy(*place), niche));
};
}
}

use rustc_target::abi::Abi;
use rustc_target::abi::Scalar;
use rustc_target::abi::WrappingRange;
impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
fn value_in_niche(
&mut self,
bx: &mut Bx,
op: crate::mir::OperandRef<'tcx, Bx::Value>,
niche: Niche,
) -> Option<Bx::Value> {
let niche_ty = niche.ty(bx.tcx());
let niche_layout = bx.layout_of(niche_ty);

let (imm, from_scalar, from_backend_ty) = match op.val {
OperandValue::Immediate(imm) => {
let Abi::Scalar(from_scalar) = op.layout.abi else { unreachable!() };
let from_backend_ty = bx.backend_type(op.layout);
(imm, from_scalar, from_backend_ty)
}
OperandValue::Pair(first, second) => {
let Abi::ScalarPair(first_scalar, second_scalar) = op.layout.abi else {
unreachable!()
};
if niche.offset == Size::ZERO {
(first, first_scalar, bx.scalar_pair_element_backend_type(op.layout, 0, true))
} else {
// yolo
(second, second_scalar, bx.scalar_pair_element_backend_type(op.layout, 1, true))
}
}
OperandValue::ZeroSized => unreachable!(),
OperandValue::Ref(ptr, _metadata, _align) => {
// General case: Load the niche primitive via pointer arithmetic.
let niche_ptr_ty =
Ty::new_ptr(bx.tcx(), TypeAndMut { ty: niche_ty, mutbl: Mutability::Not });
let ptr = bx.pointercast(ptr, bx.backend_type(bx.layout_of(niche_ptr_ty)));

let offset = niche.offset.bytes() / niche_layout.size.bytes();
let niche_backend_ty = bx.backend_type(bx.layout_of(niche_ty));
let ptr = bx.inbounds_gep(niche_backend_ty, ptr, &[bx.const_usize(offset)]);
let value = bx.load(niche_backend_ty, ptr, rustc_target::abi::Align::ONE);
return Some(value);
}
};

// Any type whose ABI is a Scalar bool is turned into an i1, so it cannot contain a value
// outside of its niche.
if from_scalar.is_bool() {
return None;
}

let to_scalar = Scalar::Initialized {
value: niche.value,
valid_range: WrappingRange::full(niche.size(bx.tcx())),
};
let to_backend_ty = bx.backend_type(niche_layout);
if from_backend_ty == to_backend_ty {
return Some(imm);
}
let value = self.transmute_immediate(
bx,
imm,
from_scalar,
from_backend_ty,
to_scalar,
to_backend_ty,
);
Some(value)
}

#[instrument(level = "debug", skip(self, bx))]
pub fn codegen_niche_check(
&mut self,
bx: &mut Bx,
mir_op: mir::Operand<'tcx>,
niche: Niche,
source_info: mir::SourceInfo,
) {
let op_ty = self.monomorphize(mir_op.ty(self.mir, bx.tcx()));
if op_ty == bx.tcx().types.bool {
return;
}

let op = self.codegen_operand(bx, &mir_op);

let Some(value_in_niche) = self.value_in_niche(bx, op, niche) else {
return;
};
let size = niche.size(bx.tcx());

let start = niche.scalar(niche.valid_range.start, bx);
let end = niche.scalar(niche.valid_range.end, bx);

let binop_le = base::bin_op_to_icmp_predicate(mir::BinOp::Le.to_hir_binop(), false);
let binop_ge = base::bin_op_to_icmp_predicate(mir::BinOp::Ge.to_hir_binop(), false);
let is_valid = if niche.valid_range.start == 0 {
bx.icmp(binop_le, value_in_niche, end)
} else if niche.valid_range.end == (u128::MAX >> 128 - size.bits()) {
bx.icmp(binop_ge, value_in_niche, start)
} else {
// We need to check if the value is within a *wrapping* range. We could do this:
// (niche >= start) && (niche <= end)
// But what we're going to actually do is this:
// max = end - start
// (niche - start) <= max
// The latter is much more complicated conceptually, but is actually less operations
// because we can compute max in codegen.
let mut max = niche.valid_range.end.wrapping_sub(niche.valid_range.start);
let size = niche.size(bx.tcx());
if size.bits() < 128 {
let mask = (1 << size.bits()) - 1;
max &= mask;
}
let max_adjusted_allowed_value = niche.scalar(max, bx);

let biased = bx.sub(value_in_niche, start);
bx.icmp(binop_le, biased, max_adjusted_allowed_value)
};

// Create destination blocks, branching on is_valid
let panic = bx.append_sibling_block("panic");
let success = bx.append_sibling_block("success");
bx.cond_br(is_valid, success, panic);

// Switch to the failure block and codegen a call to the panic intrinsic
bx.switch_to_block(panic);
self.set_debug_loc(bx, source_info);
let location = self.get_caller_location(bx, source_info).immediate();
self.codegen_panic(
bx,
LangItem::PanicOccupiedNiche,
&[value_in_niche, start, end, location],
source_info.span,
niche.ty(bx.tcx()),
);

// Continue codegen in the success block.
bx.switch_to_block(success);
self.set_debug_loc(bx, source_info);
}

#[instrument(level = "debug", skip(self, bx))]
fn codegen_panic(
&mut self,
bx: &mut Bx,
lang_item: LangItem,
args: &[Bx::Value],
span: Span,
ty: Ty<'tcx>,
) {
let (fn_abi, fn_ptr) = common::build_langcall(bx, Some(span), lang_item, Some(ty.into()));
let fn_ty = bx.fn_decl_backend_type(&fn_abi);
let fn_attrs = if bx.tcx().def_kind(self.instance.def_id()).has_codegen_attrs() {
Some(bx.tcx().codegen_fn_attrs(self.instance.def_id()))
} else {
None
};

bx.call(fn_ty, fn_attrs, Some(&fn_abi), fn_ptr, args, None /* funclet */);
bx.unreachable();
}
}

pub trait NicheExt {
fn ty<'tcx>(&self, tcx: TyCtxt<'tcx>) -> Ty<'tcx>;
fn size<'tcx>(&self, tcx: TyCtxt<'tcx>) -> Size;
fn scalar<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(&self, val: u128, bx: &mut Bx) -> Bx::Value;
}

impl NicheExt for Niche {
fn ty<'tcx>(&self, tcx: TyCtxt<'tcx>) -> Ty<'tcx> {
let types = &tcx.types;
match self.value {
Primitive::Int(Integer::I8, _) => types.u8,
Primitive::Int(Integer::I16, _) => types.u16,
Primitive::Int(Integer::I32, _) => types.u32,
Primitive::Int(Integer::I64, _) => types.u64,
Primitive::Int(Integer::I128, _) => types.u128,
Primitive::Pointer(_) => {
Ty::new_ptr(tcx, TypeAndMut { ty: types.unit, mutbl: Mutability::Not })
}
Primitive::F16 => types.u16,
Primitive::F32 => types.u32,
Primitive::F64 => types.u64,
Primitive::F128 => types.u128,
}
}

fn size<'tcx>(&self, tcx: TyCtxt<'tcx>) -> Size {
let bits = match self.value {
Primitive::Int(Integer::I8, _) => 8,
Primitive::Int(Integer::I16, _) => 16,
Primitive::Int(Integer::I32, _) => 32,
Primitive::Int(Integer::I64, _) => 64,
Primitive::Int(Integer::I128, _) => 128,
Primitive::Pointer(_) => tcx.sess.target.pointer_width as usize,
Primitive::F16 => 16,
Primitive::F32 => 32,
Primitive::F64 => 64,
Primitive::F128 => 128,
};
Size::from_bits(bits)
}

fn scalar<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(&self, val: u128, bx: &mut Bx) -> Bx::Value {
use rustc_middle::mir::interpret::Pointer;
use rustc_middle::mir::interpret::Scalar;

let tcx = bx.tcx();
let niche_ty = self.ty(tcx);
let value = if niche_ty.is_any_ptr() {
Scalar::from_maybe_pointer(Pointer::from_addr_invalid(val as u64), &tcx)
} else {
Scalar::from_uint(val, self.size(tcx))
};
let layout = rustc_target::abi::Scalar::Initialized {
value: self.value,
valid_range: WrappingRange::full(self.size(tcx)),
};
bx.scalar_to_backend(value, layout, bx.backend_type(bx.layout_of(self.ty(tcx))))
}
}
Loading

0 comments on commit 71e0b42

Please sign in to comment.