Skip to content

Commit

Permalink
chore: track unsafe status on TypeChecker
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAFrench committed Feb 27, 2024
1 parent a1a4286 commit 48230c8
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 101 deletions.
94 changes: 41 additions & 53 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,13 @@ impl<'interner> TypeChecker<'interner> {
/// an equivalent HirExpression::Call in the form `foo(a, b, c)`. This cannot
/// be done earlier since we need to know the type of the object `a` to resolve which
/// function `foo` to refer to.
pub(crate) fn check_expression(&mut self, expr_id: &ExprId, allow_unsafe_call: bool) -> Type {
pub(crate) fn check_expression(&mut self, expr_id: &ExprId) -> Type {
let typ = match self.interner.expression(expr_id) {
HirExpression::Ident(ident) => self.check_ident(ident, expr_id),
HirExpression::Literal(literal) => {
match literal {
HirLiteral::Array(HirArrayLiteral::Standard(arr)) => {
let elem_types =
vecmap(&arr, |arg| self.check_expression(arg, allow_unsafe_call));
let elem_types = vecmap(&arr, |arg| self.check_expression(arg));

let first_elem_type = elem_types
.first()
Expand Down Expand Up @@ -95,7 +94,7 @@ impl<'interner> TypeChecker<'interner> {
arr_type
}
HirLiteral::Array(HirArrayLiteral::Repeated { repeated_element, length }) => {
let elem_type = self.check_expression(&repeated_element, allow_unsafe_call);
let elem_type = self.check_expression(&repeated_element);
let length = match length {
Type::Constant(length) => {
Type::constant_variable(length, self.interner)
Expand All @@ -112,17 +111,16 @@ impl<'interner> TypeChecker<'interner> {
}
HirLiteral::FmtStr(string, idents) => {
let len = Type::Constant(string.len() as u64);
let types =
vecmap(&idents, |elem| self.check_expression(elem, allow_unsafe_call));
let types = vecmap(&idents, |elem| self.check_expression(elem));
Type::FmtString(Box::new(len), Box::new(Type::Tuple(types)))
}
HirLiteral::Unit => Type::Unit,
}
}
HirExpression::Infix(infix_expr) => {
// The type of the infix expression must be looked up from a type table
let lhs_type = self.check_expression(&infix_expr.lhs, allow_unsafe_call);
let rhs_type = self.check_expression(&infix_expr.rhs, allow_unsafe_call);
let lhs_type = self.check_expression(&infix_expr.lhs);
let rhs_type = self.check_expression(&infix_expr.rhs);

let lhs_span = self.interner.expr_span(&infix_expr.lhs);
let rhs_span = self.interner.expr_span(&infix_expr.rhs);
Expand Down Expand Up @@ -151,9 +149,7 @@ impl<'interner> TypeChecker<'interner> {
}
}
}
HirExpression::Index(index_expr) => {
self.check_index_expression(expr_id, index_expr, allow_unsafe_call)
}
HirExpression::Index(index_expr) => self.check_index_expression(expr_id, index_expr),
HirExpression::Call(call_expr) => {
// Need to setup these flags here as `self` is borrowed mutably to type check the rest of the call expression
// These flags are later used to type check calls to unconstrained functions from constrained functions
Expand All @@ -165,15 +161,15 @@ impl<'interner> TypeChecker<'interner> {

self.check_if_deprecated(&call_expr.func);

let function = self.check_expression(&call_expr.func, allow_unsafe_call);
let function = self.check_expression(&call_expr.func);

let args = vecmap(&call_expr.arguments, |arg| {
let typ = self.check_expression(arg, allow_unsafe_call);
let typ = self.check_expression(arg);
(typ, *arg, self.interner.expr_span(arg))
});

if is_current_func_constrained && is_unconstrained_call {
if !allow_unsafe_call {
if !self.allow_unsafe {
self.errors.push(TypeCheckError::Unsafe {
span: self.interner.expr_span(expr_id),
});
Expand Down Expand Up @@ -208,8 +204,7 @@ impl<'interner> TypeChecker<'interner> {
return_type
}
HirExpression::MethodCall(mut method_call) => {
let mut object_type =
self.check_expression(&method_call.object, allow_unsafe_call).follow_bindings();
let mut object_type = self.check_expression(&method_call.object).follow_bindings();
let method_name = method_call.method.0.contents.as_str();
match self.lookup_method(&object_type, method_name, expr_id) {
Some(method_ref) => {
Expand Down Expand Up @@ -244,24 +239,30 @@ impl<'interner> TypeChecker<'interner> {

// Type check the new call now that it has been changed from a method call
// to a function call. This way we avoid duplicating code.
self.check_expression(expr_id, allow_unsafe_call)
self.check_expression(expr_id)
}
None => Type::Error,
}
}
HirExpression::Cast(cast_expr) => {
// Evaluate the LHS
let lhs_type = self.check_expression(&cast_expr.lhs, allow_unsafe_call);
let lhs_type = self.check_expression(&cast_expr.lhs);
let span = self.interner.expr_span(expr_id);
self.check_cast(lhs_type, cast_expr.r#type, span)
}
HirExpression::Block(block_expr) => {
let mut block_type = Type::Unit;

let allow_unsafe = allow_unsafe_call || block_expr.is_unsafe;
// Before entering the block we cache the old value of `allow_unsafe` so it can be restored.
let old_allow_unsafe = self.allow_unsafe;

// If we're already in an unsafe block then entering a new block should preserve this even if
// the inner block isn't marked as unsafe.
self.allow_unsafe |= block_expr.is_unsafe;

let statements = block_expr.statements();
for (i, stmt) in statements.iter().enumerate() {
let expr_type = self.check_statement(stmt, allow_unsafe);
let expr_type = self.check_statement(stmt);

if let crate::hir_def::stmt::HirStatement::Semi(expr) =
self.interner.statement(stmt)
Expand All @@ -282,24 +283,23 @@ impl<'interner> TypeChecker<'interner> {
}
}

// Finally, we restore the original value of `self.allow_unsafe`.
self.allow_unsafe = old_allow_unsafe;

block_type
}
HirExpression::Prefix(prefix_expr) => {
let rhs_type = self.check_expression(&prefix_expr.rhs, allow_unsafe_call);
let rhs_type = self.check_expression(&prefix_expr.rhs);
let span = self.interner.expr_span(&prefix_expr.rhs);
self.type_check_prefix_operand(&prefix_expr.operator, &rhs_type, span)
}
HirExpression::If(if_expr) => self.check_if_expr(&if_expr, expr_id, allow_unsafe_call),
HirExpression::Constructor(constructor) => {
self.check_constructor(constructor, expr_id, allow_unsafe_call)
}
HirExpression::MemberAccess(access) => {
self.check_member_access(access, *expr_id, allow_unsafe_call)
}
HirExpression::If(if_expr) => self.check_if_expr(&if_expr, expr_id),
HirExpression::Constructor(constructor) => self.check_constructor(constructor, expr_id),
HirExpression::MemberAccess(access) => self.check_member_access(access, *expr_id),
HirExpression::Error => Type::Error,
HirExpression::Tuple(elements) => Type::Tuple(vecmap(&elements, |elem| {
self.check_expression(elem, allow_unsafe_call)
})),
HirExpression::Tuple(elements) => {
Type::Tuple(vecmap(&elements, |elem| self.check_expression(elem)))
}
HirExpression::Lambda(lambda) => {
let captured_vars = vecmap(lambda.captures, |capture| {
self.interner.definition_type(capture.ident.id)
Expand All @@ -313,7 +313,7 @@ impl<'interner> TypeChecker<'interner> {
typ
});

let actual_return = self.check_expression(&lambda.body, allow_unsafe_call);
let actual_return = self.check_expression(&lambda.body);

let span = self.interner.expr_span(&lambda.body);
self.unify(&actual_return, &lambda.return_type, || TypeCheckError::TypeMismatch {
Expand Down Expand Up @@ -542,9 +542,8 @@ impl<'interner> TypeChecker<'interner> {
&mut self,
id: &ExprId,
mut index_expr: expr::HirIndexExpression,
allow_unsafe_call: bool,
) -> Type {
let index_type = self.check_expression(&index_expr.index, allow_unsafe_call);
let index_type = self.check_expression(&index_expr.index);
let span = self.interner.expr_span(&index_expr.index);

index_type.unify(
Expand All @@ -559,7 +558,7 @@ impl<'interner> TypeChecker<'interner> {

// When writing `a[i]`, if `a : &mut ...` then automatically dereference `a` as many
// times as needed to get the underlying array.
let lhs_type = self.check_expression(&index_expr.collection, allow_unsafe_call);
let lhs_type = self.check_expression(&index_expr.collection);
let (new_lhs, lhs_type) = self.insert_auto_dereferences(index_expr.collection, lhs_type);
index_expr.collection = new_lhs;
self.interner.replace_expr(id, HirExpression::Index(index_expr));
Expand Down Expand Up @@ -612,14 +611,9 @@ impl<'interner> TypeChecker<'interner> {
}
}

fn check_if_expr(
&mut self,
if_expr: &expr::HirIfExpression,
expr_id: &ExprId,
allow_unsafe_call: bool,
) -> Type {
let cond_type = self.check_expression(&if_expr.condition, allow_unsafe_call);
let then_type = self.check_expression(&if_expr.consequence, allow_unsafe_call);
fn check_if_expr(&mut self, if_expr: &expr::HirIfExpression, expr_id: &ExprId) -> Type {
let cond_type = self.check_expression(&if_expr.condition);
let then_type = self.check_expression(&if_expr.consequence);

let expr_span = self.interner.expr_span(&if_expr.condition);

Expand All @@ -632,7 +626,7 @@ impl<'interner> TypeChecker<'interner> {
match if_expr.alternative {
None => Type::Unit,
Some(alternative) => {
let else_type = self.check_expression(&alternative, allow_unsafe_call);
let else_type = self.check_expression(&alternative);

let expr_span = self.interner.expr_span(expr_id);
self.unify(&then_type, &else_type, || {
Expand Down Expand Up @@ -662,7 +656,6 @@ impl<'interner> TypeChecker<'interner> {
&mut self,
constructor: expr::HirConstructorExpression,
expr_id: &ExprId,
allow_unsafe_call: bool,
) -> Type {
let typ = constructor.r#type;
let generics = constructor.struct_generics;
Expand All @@ -682,7 +675,7 @@ impl<'interner> TypeChecker<'interner> {
// mismatch here as long as we continue typechecking the rest of the program to the best
// of our ability.
if param_name == arg_ident.0.contents {
let arg_type = self.check_expression(&arg, allow_unsafe_call);
let arg_type = self.check_expression(&arg);

let span = self.interner.expr_span(expr_id);
self.unify_with_coercions(&arg_type, &param_type, arg, || {
Expand All @@ -698,13 +691,8 @@ impl<'interner> TypeChecker<'interner> {
Type::Struct(typ, generics)
}

fn check_member_access(
&mut self,
mut access: expr::HirMemberAccess,
expr_id: ExprId,
allow_unsafe_call: bool,
) -> Type {
let lhs_type = self.check_expression(&access.lhs, allow_unsafe_call).follow_bindings();
fn check_member_access(&mut self, mut access: expr::HirMemberAccess, expr_id: ExprId) -> Type {
let lhs_type = self.check_expression(&access.lhs).follow_bindings();
let span = self.interner.expr_span(&expr_id);
let access_lhs = &mut access.lhs;

Expand Down
18 changes: 15 additions & 3 deletions compiler/noirc_frontend/src/hir/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ pub struct TypeChecker<'interner> {
errors: Vec<TypeCheckError>,
current_function: Option<FuncId>,

/// Whether the `TypeChecker` should allow unsafe calls.
///
/// This is generally only set to true within an `unsafe` block.
allow_unsafe: bool,

/// Trait constraints are collected during type checking until they are
/// verified at the end of a function. This is because constraints arise
/// on each variable, but it is only until function calls when the types
Expand Down Expand Up @@ -164,11 +169,17 @@ fn function_info(interner: &NodeInterner, function_body_id: &ExprId) -> (noirc_e

impl<'interner> TypeChecker<'interner> {
fn new(interner: &'interner mut NodeInterner) -> Self {
Self { interner, errors: Vec::new(), trait_constraints: Vec::new(), current_function: None }
Self {
interner,
errors: Vec::new(),
allow_unsafe: false,
trait_constraints: Vec::new(),
current_function: None,
}
}

fn check_function_body(&mut self, body: &ExprId) -> Type {
self.check_expression(body, false)
self.check_expression(body)
}

pub fn check_global(
Expand All @@ -178,11 +189,12 @@ impl<'interner> TypeChecker<'interner> {
let mut this = Self {
interner,
errors: Vec::new(),
allow_unsafe: false,
trait_constraints: Vec::new(),
current_function: None,
};
let statement = this.interner.get_global(id).let_statement;
this.check_statement(&statement, false);
this.check_statement(&statement);
this.errors
}

Expand Down
Loading

0 comments on commit 48230c8

Please sign in to comment.