Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Mutability in the comptime interpreter #5517

Merged
merged 5 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions compiler/noirc_frontend/src/hir/comptime/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
}
} else {
let name = self.interner.function_name(&function);
unreachable!("Non-builtin, lowlevel or oracle builtin fn '{name}'")

Check warning on line 164 in compiler/noirc_frontend/src/hir/comptime/interpreter.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (lowlevel)
}
}

Expand Down Expand Up @@ -241,6 +241,8 @@
Ok(())
}
HirPattern::Mutable(pattern, _) => {
// Create a mutable reference to store to
let argument = Value::Pointer(Shared::new(argument), true);
self.define_pattern(pattern, typ, argument, location)
}
HirPattern::Tuple(pattern_fields, _) => match (argument, typ) {
Expand Down Expand Up @@ -334,8 +336,19 @@
}
}

/// Evaluate an expression and return the result
/// Evaluate an expression and return the result.
/// This will automatically dereference a mutable variable if used.
pub fn evaluate(&mut self, id: ExprId) -> IResult<Value> {
match self.evaluate_no_dereference(id)? {
Value::Pointer(elem, true) => Ok(elem.borrow().clone()),
other => Ok(other),
}
}

/// Evaluating a mutable variable will dereference it automatically.
/// This function should be used when that is not desired - e.g. when
/// compiling a `&mut var` expression to grab the original reference.
fn evaluate_no_dereference(&mut self, id: ExprId) -> IResult<Value> {
match self.interner.expression(&id) {
HirExpression::Ident(ident, _) => self.evaluate_ident(ident, id),
HirExpression::Literal(literal) => self.evaluate_literal(literal, id),
Expand Down Expand Up @@ -592,7 +605,10 @@
}

fn evaluate_prefix(&mut self, prefix: HirPrefixExpression, id: ExprId) -> IResult<Value> {
let rhs = self.evaluate(prefix.rhs)?;
let rhs = match prefix.operator {
UnaryOp::MutableReference => self.evaluate_no_dereference(prefix.rhs)?,
_ => self.evaluate(prefix.rhs)?,
};
self.evaluate_prefix_with_value(rhs, prefix.operator, id)
}

Expand Down Expand Up @@ -634,9 +650,17 @@
Err(InterpreterError::InvalidValueForUnary { value, location, operator: "not" })
}
},
UnaryOp::MutableReference => Ok(Value::Pointer(Shared::new(rhs))),
UnaryOp::MutableReference => {
// If this is a mutable variable (auto_deref = true), turn this into an explicit
// mutable reference just by switching the value of `auto_deref`. Otherwise, wrap
// the value in a fresh reference.
match rhs {
Value::Pointer(elem, true) => Ok(Value::Pointer(elem, false)),
other => Ok(Value::Pointer(Shared::new(other), false)),
}
}
UnaryOp::Dereference { implicitly_added: _ } => match rhs {
Value::Pointer(element) => Ok(element.borrow().clone()),
Value::Pointer(element, _) => Ok(element.borrow().clone()),
value => {
let location = self.interner.expr_location(&id);
Err(InterpreterError::NonPointerDereferenced { value, location })
Expand Down Expand Up @@ -1303,7 +1327,7 @@
HirLValue::Ident(ident, typ) => self.mutate(ident.id, rhs, ident.location),
HirLValue::Dereference { lvalue, element_type: _, location } => {
match self.evaluate_lvalue(&lvalue)? {
Value::Pointer(value) => {
Value::Pointer(value, _) => {
*value.borrow_mut() = rhs;
Ok(())
}
Expand Down Expand Up @@ -1353,10 +1377,13 @@

fn evaluate_lvalue(&mut self, lvalue: &HirLValue) -> IResult<Value> {
match lvalue {
HirLValue::Ident(ident, _) => self.lookup(ident),
HirLValue::Ident(ident, _) => match self.lookup(ident)? {
Value::Pointer(elem, true) => Ok(elem.borrow().clone()),
other => Ok(other),
},
HirLValue::Dereference { lvalue, element_type: _, location } => {
match self.evaluate_lvalue(lvalue)? {
Value::Pointer(value) => Ok(value.borrow().clone()),
Value::Pointer(value, _) => Ok(value.borrow().clone()),
value => {
Err(InterpreterError::NonPointerDereferenced { value, location: *location })
}
Expand Down
12 changes: 12 additions & 0 deletions compiler/noirc_frontend/src/hir/comptime/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ fn mutating_mutable_references() {
assert_eq!(result, Value::I64(4));
}

#[test]
fn mutation_leaks() {
let program = "comptime fn main() -> pub i8 {
let mut x = 3;
let y = &mut x;
*y = 5;
x
}";
let result = interpret(program, vec!["main".into()]);
assert_eq!(result, Value::I8(5));
}

#[test]
fn mutating_arrays() {
let program = "comptime fn main() -> pub u8 {
Expand Down
10 changes: 5 additions & 5 deletions compiler/noirc_frontend/src/hir/comptime/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub enum Value {
Closure(HirLambda, Vec<Value>, Type),
Tuple(Vec<Value>),
Struct(HashMap<Rc<String>, Value>, Type),
Pointer(Shared<Value>),
Pointer(Shared<Value>, /* auto_deref */ bool),
Array(Vector<Value>, Type),
Slice(Vector<Value>, Type),
Code(Rc<Tokens>),
Expand Down Expand Up @@ -79,7 +79,7 @@ impl Value {
Value::Slice(_, typ) => return Cow::Borrowed(typ),
Value::Code(_) => Type::Quoted(QuotedType::Quoted),
Value::StructDefinition(_) => Type::Quoted(QuotedType::StructDefinition),
Value::Pointer(element) => {
Value::Pointer(element, _) => {
let element = element.borrow().get_type().into_owned();
Type::MutableReference(Box::new(element))
}
Expand Down Expand Up @@ -199,7 +199,7 @@ impl Value {
}
};
}
Value::Pointer(_)
Value::Pointer(..)
| Value::StructDefinition(_)
| Value::TraitDefinition(_)
| Value::FunctionDefinition(_)
Expand Down Expand Up @@ -309,7 +309,7 @@ impl Value {
HirExpression::Literal(HirLiteral::Slice(HirArrayLiteral::Standard(elements)))
}
Value::Code(block) => HirExpression::Unquote(unwrap_rc(block)),
Value::Pointer(_)
Value::Pointer(..)
| Value::StructDefinition(_)
| Value::TraitDefinition(_)
| Value::FunctionDefinition(_)
Expand Down Expand Up @@ -400,7 +400,7 @@ impl Display for Value {
let fields = vecmap(fields, |(name, value)| format!("{}: {}", name, value));
write!(f, "{typename} {{ {} }}", fields.join(", "))
}
Value::Pointer(value) => write!(f, "&mut {}", value.borrow()),
Value::Pointer(value, _) => write!(f, "&mut {}", value.borrow()),
Value::Array(values, _) => {
let values = vecmap(values, ToString::to_string);
write!(f, "[{}]", values.join(", "))
Expand Down
Loading