From 48230c8c30394e345e5431e8486266748704f86f Mon Sep 17 00:00:00 2001 From: Tom French Date: Tue, 27 Feb 2024 23:26:05 +0000 Subject: [PATCH] chore: track `unsafe` status on `TypeChecker` --- .../noirc_frontend/src/hir/type_check/expr.rs | 94 ++++++++----------- .../noirc_frontend/src/hir/type_check/mod.rs | 18 +++- .../noirc_frontend/src/hir/type_check/stmt.rs | 75 ++++++--------- 3 files changed, 86 insertions(+), 101 deletions(-) diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index ce0246f1e94..c073f90a14f 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -56,14 +56,13 @@ impl<'interner> TypeChecker<'interner> { /// an equivalent HirExpression::Call in the form `foo(a, b, c)`. This cannot /// be done earlier since we need to know the type of the object `a` to resolve which /// function `foo` to refer to. - pub(crate) fn check_expression(&mut self, expr_id: &ExprId, allow_unsafe_call: bool) -> Type { + pub(crate) fn check_expression(&mut self, expr_id: &ExprId) -> Type { let typ = match self.interner.expression(expr_id) { HirExpression::Ident(ident) => self.check_ident(ident, expr_id), HirExpression::Literal(literal) => { match literal { HirLiteral::Array(HirArrayLiteral::Standard(arr)) => { - let elem_types = - vecmap(&arr, |arg| self.check_expression(arg, allow_unsafe_call)); + let elem_types = vecmap(&arr, |arg| self.check_expression(arg)); let first_elem_type = elem_types .first() @@ -95,7 +94,7 @@ impl<'interner> TypeChecker<'interner> { arr_type } HirLiteral::Array(HirArrayLiteral::Repeated { repeated_element, length }) => { - let elem_type = self.check_expression(&repeated_element, allow_unsafe_call); + let elem_type = self.check_expression(&repeated_element); let length = match length { Type::Constant(length) => { Type::constant_variable(length, self.interner) @@ -112,8 +111,7 @@ impl<'interner> TypeChecker<'interner> { } HirLiteral::FmtStr(string, idents) => { let len = Type::Constant(string.len() as u64); - let types = - vecmap(&idents, |elem| self.check_expression(elem, allow_unsafe_call)); + let types = vecmap(&idents, |elem| self.check_expression(elem)); Type::FmtString(Box::new(len), Box::new(Type::Tuple(types))) } HirLiteral::Unit => Type::Unit, @@ -121,8 +119,8 @@ impl<'interner> TypeChecker<'interner> { } HirExpression::Infix(infix_expr) => { // The type of the infix expression must be looked up from a type table - let lhs_type = self.check_expression(&infix_expr.lhs, allow_unsafe_call); - let rhs_type = self.check_expression(&infix_expr.rhs, allow_unsafe_call); + let lhs_type = self.check_expression(&infix_expr.lhs); + let rhs_type = self.check_expression(&infix_expr.rhs); let lhs_span = self.interner.expr_span(&infix_expr.lhs); let rhs_span = self.interner.expr_span(&infix_expr.rhs); @@ -151,9 +149,7 @@ impl<'interner> TypeChecker<'interner> { } } } - HirExpression::Index(index_expr) => { - self.check_index_expression(expr_id, index_expr, allow_unsafe_call) - } + HirExpression::Index(index_expr) => self.check_index_expression(expr_id, index_expr), HirExpression::Call(call_expr) => { // Need to setup these flags here as `self` is borrowed mutably to type check the rest of the call expression // These flags are later used to type check calls to unconstrained functions from constrained functions @@ -165,15 +161,15 @@ impl<'interner> TypeChecker<'interner> { self.check_if_deprecated(&call_expr.func); - let function = self.check_expression(&call_expr.func, allow_unsafe_call); + let function = self.check_expression(&call_expr.func); let args = vecmap(&call_expr.arguments, |arg| { - let typ = self.check_expression(arg, allow_unsafe_call); + let typ = self.check_expression(arg); (typ, *arg, self.interner.expr_span(arg)) }); if is_current_func_constrained && is_unconstrained_call { - if !allow_unsafe_call { + if !self.allow_unsafe { self.errors.push(TypeCheckError::Unsafe { span: self.interner.expr_span(expr_id), }); @@ -208,8 +204,7 @@ impl<'interner> TypeChecker<'interner> { return_type } HirExpression::MethodCall(mut method_call) => { - let mut object_type = - self.check_expression(&method_call.object, allow_unsafe_call).follow_bindings(); + let mut object_type = self.check_expression(&method_call.object).follow_bindings(); let method_name = method_call.method.0.contents.as_str(); match self.lookup_method(&object_type, method_name, expr_id) { Some(method_ref) => { @@ -244,24 +239,30 @@ impl<'interner> TypeChecker<'interner> { // Type check the new call now that it has been changed from a method call // to a function call. This way we avoid duplicating code. - self.check_expression(expr_id, allow_unsafe_call) + self.check_expression(expr_id) } None => Type::Error, } } HirExpression::Cast(cast_expr) => { // Evaluate the LHS - let lhs_type = self.check_expression(&cast_expr.lhs, allow_unsafe_call); + let lhs_type = self.check_expression(&cast_expr.lhs); let span = self.interner.expr_span(expr_id); self.check_cast(lhs_type, cast_expr.r#type, span) } HirExpression::Block(block_expr) => { let mut block_type = Type::Unit; - let allow_unsafe = allow_unsafe_call || block_expr.is_unsafe; + // Before entering the block we cache the old value of `allow_unsafe` so it can be restored. + let old_allow_unsafe = self.allow_unsafe; + + // If we're already in an unsafe block then entering a new block should preserve this even if + // the inner block isn't marked as unsafe. + self.allow_unsafe |= block_expr.is_unsafe; + let statements = block_expr.statements(); for (i, stmt) in statements.iter().enumerate() { - let expr_type = self.check_statement(stmt, allow_unsafe); + let expr_type = self.check_statement(stmt); if let crate::hir_def::stmt::HirStatement::Semi(expr) = self.interner.statement(stmt) @@ -282,24 +283,23 @@ impl<'interner> TypeChecker<'interner> { } } + // Finally, we restore the original value of `self.allow_unsafe`. + self.allow_unsafe = old_allow_unsafe; + block_type } HirExpression::Prefix(prefix_expr) => { - let rhs_type = self.check_expression(&prefix_expr.rhs, allow_unsafe_call); + let rhs_type = self.check_expression(&prefix_expr.rhs); let span = self.interner.expr_span(&prefix_expr.rhs); self.type_check_prefix_operand(&prefix_expr.operator, &rhs_type, span) } - HirExpression::If(if_expr) => self.check_if_expr(&if_expr, expr_id, allow_unsafe_call), - HirExpression::Constructor(constructor) => { - self.check_constructor(constructor, expr_id, allow_unsafe_call) - } - HirExpression::MemberAccess(access) => { - self.check_member_access(access, *expr_id, allow_unsafe_call) - } + HirExpression::If(if_expr) => self.check_if_expr(&if_expr, expr_id), + HirExpression::Constructor(constructor) => self.check_constructor(constructor, expr_id), + HirExpression::MemberAccess(access) => self.check_member_access(access, *expr_id), HirExpression::Error => Type::Error, - HirExpression::Tuple(elements) => Type::Tuple(vecmap(&elements, |elem| { - self.check_expression(elem, allow_unsafe_call) - })), + HirExpression::Tuple(elements) => { + Type::Tuple(vecmap(&elements, |elem| self.check_expression(elem))) + } HirExpression::Lambda(lambda) => { let captured_vars = vecmap(lambda.captures, |capture| { self.interner.definition_type(capture.ident.id) @@ -313,7 +313,7 @@ impl<'interner> TypeChecker<'interner> { typ }); - let actual_return = self.check_expression(&lambda.body, allow_unsafe_call); + let actual_return = self.check_expression(&lambda.body); let span = self.interner.expr_span(&lambda.body); self.unify(&actual_return, &lambda.return_type, || TypeCheckError::TypeMismatch { @@ -542,9 +542,8 @@ impl<'interner> TypeChecker<'interner> { &mut self, id: &ExprId, mut index_expr: expr::HirIndexExpression, - allow_unsafe_call: bool, ) -> Type { - let index_type = self.check_expression(&index_expr.index, allow_unsafe_call); + let index_type = self.check_expression(&index_expr.index); let span = self.interner.expr_span(&index_expr.index); index_type.unify( @@ -559,7 +558,7 @@ impl<'interner> TypeChecker<'interner> { // When writing `a[i]`, if `a : &mut ...` then automatically dereference `a` as many // times as needed to get the underlying array. - let lhs_type = self.check_expression(&index_expr.collection, allow_unsafe_call); + let lhs_type = self.check_expression(&index_expr.collection); let (new_lhs, lhs_type) = self.insert_auto_dereferences(index_expr.collection, lhs_type); index_expr.collection = new_lhs; self.interner.replace_expr(id, HirExpression::Index(index_expr)); @@ -612,14 +611,9 @@ impl<'interner> TypeChecker<'interner> { } } - fn check_if_expr( - &mut self, - if_expr: &expr::HirIfExpression, - expr_id: &ExprId, - allow_unsafe_call: bool, - ) -> Type { - let cond_type = self.check_expression(&if_expr.condition, allow_unsafe_call); - let then_type = self.check_expression(&if_expr.consequence, allow_unsafe_call); + fn check_if_expr(&mut self, if_expr: &expr::HirIfExpression, expr_id: &ExprId) -> Type { + let cond_type = self.check_expression(&if_expr.condition); + let then_type = self.check_expression(&if_expr.consequence); let expr_span = self.interner.expr_span(&if_expr.condition); @@ -632,7 +626,7 @@ impl<'interner> TypeChecker<'interner> { match if_expr.alternative { None => Type::Unit, Some(alternative) => { - let else_type = self.check_expression(&alternative, allow_unsafe_call); + let else_type = self.check_expression(&alternative); let expr_span = self.interner.expr_span(expr_id); self.unify(&then_type, &else_type, || { @@ -662,7 +656,6 @@ impl<'interner> TypeChecker<'interner> { &mut self, constructor: expr::HirConstructorExpression, expr_id: &ExprId, - allow_unsafe_call: bool, ) -> Type { let typ = constructor.r#type; let generics = constructor.struct_generics; @@ -682,7 +675,7 @@ impl<'interner> TypeChecker<'interner> { // mismatch here as long as we continue typechecking the rest of the program to the best // of our ability. if param_name == arg_ident.0.contents { - let arg_type = self.check_expression(&arg, allow_unsafe_call); + let arg_type = self.check_expression(&arg); let span = self.interner.expr_span(expr_id); self.unify_with_coercions(&arg_type, ¶m_type, arg, || { @@ -698,13 +691,8 @@ impl<'interner> TypeChecker<'interner> { Type::Struct(typ, generics) } - fn check_member_access( - &mut self, - mut access: expr::HirMemberAccess, - expr_id: ExprId, - allow_unsafe_call: bool, - ) -> Type { - let lhs_type = self.check_expression(&access.lhs, allow_unsafe_call).follow_bindings(); + fn check_member_access(&mut self, mut access: expr::HirMemberAccess, expr_id: ExprId) -> Type { + let lhs_type = self.check_expression(&access.lhs).follow_bindings(); let span = self.interner.expr_span(&expr_id); let access_lhs = &mut access.lhs; diff --git a/compiler/noirc_frontend/src/hir/type_check/mod.rs b/compiler/noirc_frontend/src/hir/type_check/mod.rs index 57e2e47ed8c..b24f5df69bf 100644 --- a/compiler/noirc_frontend/src/hir/type_check/mod.rs +++ b/compiler/noirc_frontend/src/hir/type_check/mod.rs @@ -26,6 +26,11 @@ pub struct TypeChecker<'interner> { errors: Vec, current_function: Option, + /// Whether the `TypeChecker` should allow unsafe calls. + /// + /// This is generally only set to true within an `unsafe` block. + allow_unsafe: bool, + /// Trait constraints are collected during type checking until they are /// verified at the end of a function. This is because constraints arise /// on each variable, but it is only until function calls when the types @@ -164,11 +169,17 @@ fn function_info(interner: &NodeInterner, function_body_id: &ExprId) -> (noirc_e impl<'interner> TypeChecker<'interner> { fn new(interner: &'interner mut NodeInterner) -> Self { - Self { interner, errors: Vec::new(), trait_constraints: Vec::new(), current_function: None } + Self { + interner, + errors: Vec::new(), + allow_unsafe: false, + trait_constraints: Vec::new(), + current_function: None, + } } fn check_function_body(&mut self, body: &ExprId) -> Type { - self.check_expression(body, false) + self.check_expression(body) } pub fn check_global( @@ -178,11 +189,12 @@ impl<'interner> TypeChecker<'interner> { let mut this = Self { interner, errors: Vec::new(), + allow_unsafe: false, trait_constraints: Vec::new(), current_function: None, }; let statement = this.interner.get_global(id).let_statement; - this.check_statement(&statement, false); + this.check_statement(&statement); this.errors } diff --git a/compiler/noirc_frontend/src/hir/type_check/stmt.rs b/compiler/noirc_frontend/src/hir/type_check/stmt.rs index 13595bc5bc5..d644891a5b0 100644 --- a/compiler/noirc_frontend/src/hir/type_check/stmt.rs +++ b/compiler/noirc_frontend/src/hir/type_check/stmt.rs @@ -19,7 +19,7 @@ impl<'interner> TypeChecker<'interner> { /// All statements have a unit type `()` as their type so the type of the statement /// is not interesting. Type checking must still be done on statements to ensure any /// expressions used within them are typed correctly. - pub(crate) fn check_statement(&mut self, stmt_id: &StmtId, allow_unsafe_call: bool) -> Type { + pub(crate) fn check_statement(&mut self, stmt_id: &StmtId) -> Type { match self.interner.statement(stmt_id) { // Lets lay out a convincing argument that the handling of // SemiExpressions and Expressions below is correct. @@ -42,27 +42,27 @@ impl<'interner> TypeChecker<'interner> { // // The reason why we still modify the database, is to make sure it is future-proof HirStatement::Expression(expr_id) => { - return self.check_expression(&expr_id, allow_unsafe_call); + return self.check_expression(&expr_id); } HirStatement::Semi(expr_id) => { - self.check_expression(&expr_id, allow_unsafe_call); + self.check_expression(&expr_id); } - HirStatement::Let(let_stmt) => self.check_let_stmt(let_stmt, allow_unsafe_call), + HirStatement::Let(let_stmt) => self.check_let_stmt(let_stmt), HirStatement::Constrain(constrain_stmt) => { - self.check_constrain_stmt(constrain_stmt, allow_unsafe_call); + self.check_constrain_stmt(constrain_stmt); } HirStatement::Assign(assign_stmt) => { - self.check_assign_stmt(assign_stmt, stmt_id, allow_unsafe_call); + self.check_assign_stmt(assign_stmt, stmt_id); } - HirStatement::For(for_loop) => self.check_for_loop(for_loop, allow_unsafe_call), + HirStatement::For(for_loop) => self.check_for_loop(for_loop), HirStatement::Error => (), } Type::Unit } - fn check_for_loop(&mut self, for_loop: HirForStatement, allow_unsafe_call: bool) { - let start_range_type = self.check_expression(&for_loop.start_range, allow_unsafe_call); - let end_range_type = self.check_expression(&for_loop.end_range, allow_unsafe_call); + fn check_for_loop(&mut self, for_loop: HirForStatement) { + let start_range_type = self.check_expression(&for_loop.start_range); + let end_range_type = self.check_expression(&for_loop.end_range); let start_span = self.interner.expr_span(&for_loop.start_range); let end_span = self.interner.expr_span(&for_loop.end_range); @@ -85,7 +85,7 @@ impl<'interner> TypeChecker<'interner> { self.interner.push_definition_type(for_loop.identifier.id, start_range_type); - self.check_expression(&for_loop.block, allow_unsafe_call); + self.check_expression(&for_loop.block); } /// Associate a given HirPattern with the given Type, and remember @@ -136,16 +136,10 @@ impl<'interner> TypeChecker<'interner> { } } - fn check_assign_stmt( - &mut self, - assign_stmt: HirAssignStatement, - stmt_id: &StmtId, - allow_unsafe_call: bool, - ) { - let expr_type = self.check_expression(&assign_stmt.expression, allow_unsafe_call); + fn check_assign_stmt(&mut self, assign_stmt: HirAssignStatement, stmt_id: &StmtId) { + let expr_type = self.check_expression(&assign_stmt.expression); let span = self.interner.expr_span(&assign_stmt.expression); - let (lvalue_type, new_lvalue, mutable) = - self.check_lvalue(&assign_stmt.lvalue, span, allow_unsafe_call); + let (lvalue_type, new_lvalue, mutable) = self.check_lvalue(&assign_stmt.lvalue, span); if !mutable { let (name, span) = self.get_lvalue_name_and_span(&assign_stmt.lvalue); @@ -187,12 +181,7 @@ impl<'interner> TypeChecker<'interner> { } /// Type check an lvalue - the left hand side of an assignment statement. - fn check_lvalue( - &mut self, - lvalue: &HirLValue, - assign_span: Span, - allow_unsafe_call: bool, - ) -> (Type, HirLValue, bool) { + fn check_lvalue(&mut self, lvalue: &HirLValue, assign_span: Span) -> (Type, HirLValue, bool) { match lvalue { HirLValue::Ident(ident, _) => { let mut mutable = true; @@ -211,8 +200,7 @@ impl<'interner> TypeChecker<'interner> { (typ.clone(), HirLValue::Ident(ident.clone(), typ), mutable) } HirLValue::MemberAccess { object, field_name, .. } => { - let (lhs_type, object, mut mutable) = - self.check_lvalue(object, assign_span, allow_unsafe_call); + let (lhs_type, object, mut mutable) = self.check_lvalue(object, assign_span); let mut object = Box::new(object); let span = field_name.span(); let field_name = field_name.clone(); @@ -244,7 +232,7 @@ impl<'interner> TypeChecker<'interner> { (object_type, lvalue, mutable) } HirLValue::Index { array, index, .. } => { - let index_type = self.check_expression(index, allow_unsafe_call); + let index_type = self.check_expression(index); let expr_span = self.interner.expr_span(index); index_type.unify( @@ -258,7 +246,7 @@ impl<'interner> TypeChecker<'interner> { ); let (mut lvalue_type, mut lvalue, mut mutable) = - self.check_lvalue(array, assign_span, allow_unsafe_call); + self.check_lvalue(array, assign_span); // Before we check that the lvalue is an array, try to dereference it as many times // as needed to unwrap any &mut wrappers. @@ -288,8 +276,7 @@ impl<'interner> TypeChecker<'interner> { (typ.clone(), HirLValue::Index { array, index: *index, typ }, mutable) } HirLValue::Dereference { lvalue, element_type: _ } => { - let (reference_type, lvalue, _) = - self.check_lvalue(lvalue, assign_span, allow_unsafe_call); + let (reference_type, lvalue, _) = self.check_lvalue(lvalue, assign_span); let lvalue = Box::new(lvalue); let element_type = Type::type_variable(self.interner.next_type_variable_id()); @@ -307,20 +294,23 @@ impl<'interner> TypeChecker<'interner> { } } - fn check_let_stmt(&mut self, let_stmt: HirLetStatement, allow_unsafe_call: bool) { - let resolved_type = - self.check_declaration(let_stmt.expression, let_stmt.r#type, allow_unsafe_call); + fn check_let_stmt(&mut self, let_stmt: HirLetStatement) { + let resolved_type = self.check_declaration(let_stmt.expression, let_stmt.r#type); // Set the type of the pattern to be equal to the annotated type self.bind_pattern(&let_stmt.pattern, resolved_type); } - fn check_constrain_stmt(&mut self, stmt: HirConstrainStatement, allow_unsafe_call: bool) { - let expr_type = self.check_expression(&stmt.0, allow_unsafe_call); + fn check_constrain_stmt(&mut self, stmt: HirConstrainStatement) { + let expr_type = self.check_expression(&stmt.0); let expr_span = self.interner.expr_span(&stmt.0); // Must type check the assertion message expression so that we instantiate bindings - stmt.2.map(|assert_msg_expr| self.check_expression(&assert_msg_expr, true)); + // We always allow unsafe calls for assert messages as these may be a dynamic assert message call. + let old_allow_unsafe = self.allow_unsafe; + self.allow_unsafe = true; + stmt.2.map(|assert_msg_expr| self.check_expression(&assert_msg_expr)); + self.allow_unsafe = old_allow_unsafe; self.unify(&expr_type, &Type::Bool, || TypeCheckError::TypeMismatch { expr_typ: expr_type.to_string(), @@ -332,14 +322,9 @@ impl<'interner> TypeChecker<'interner> { /// All declaration statements check that the user specified type(UST) is equal to the /// expression on the RHS, unless the UST is unspecified in which case /// the type of the declaration is inferred to match the RHS. - fn check_declaration( - &mut self, - rhs_expr: ExprId, - annotated_type: Type, - allow_unsafe_call: bool, - ) -> Type { + fn check_declaration(&mut self, rhs_expr: ExprId, annotated_type: Type) -> Type { // Type check the expression on the RHS - let expr_type = self.check_expression(&rhs_expr, allow_unsafe_call); + let expr_type = self.check_expression(&rhs_expr); // First check if the LHS is unspecified // If so, then we give it the same type as the expression