Skip to content

Commit

Permalink
feat: support casting in globals (#5164)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #5160

## Summary\*

Adds a case for casts in comptime globals

## Additional Context



## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: jfecher <jake@aztecprotocol.com>
Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com>
Co-authored-by: TomAFrench <tom@tomfren.ch>
  • Loading branch information
4 people authored Jun 4, 2024
1 parent 90632da commit 6d3e732
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 14 deletions.
12 changes: 12 additions & 0 deletions compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::{
UnresolvedTypeExpression,
},
hir::{
comptime::{Interpreter, Value},
def_map::ModuleDefId,
resolution::{
errors::ResolverError,
Expand Down Expand Up @@ -498,6 +499,17 @@ impl<'context> Elaborator<'context> {
BinaryOpKind::Modulo => Ok(lhs % rhs),
}
}
HirExpression::Cast(cast) => {
let lhs = self.try_eval_array_length_id_with_fuel(cast.lhs, span, fuel - 1)?;
let lhs_value = Value::Field(lhs.into());
let evaluated_value =
Interpreter::evaluate_cast_one_step(&cast, rhs, lhs_value, self.interner)
.map_err(|error| Some(ResolverError::ArrayLengthInterpreter { error }))?;

evaluated_value
.to_u128()
.ok_or_else(|| Some(ResolverError::InvalidArrayLengthExpr { span }))
}
_other => Err(Some(ResolverError::InvalidArrayLengthExpr { span })),
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/hir/comptime/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use noirc_errors::{CustomDiagnostic, Location};
use super::value::Value;

/// The possible errors that can halt the interpreter.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum InterpreterError {
ArgumentCountMismatch { expected: usize, actual: usize, location: Location },
TypeMismatch { expected: Type, value: Value, location: Location },
Expand Down
29 changes: 20 additions & 9 deletions compiler/noirc_frontend/src/hir/comptime/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ impl<'a> Interpreter<'a> {
HirExpression::MemberAccess(access) => self.evaluate_access(access, id),
HirExpression::Call(call) => self.evaluate_call(call, id),
HirExpression::MethodCall(call) => self.evaluate_method_call(call, id),
HirExpression::Cast(cast) => self.evaluate_cast(cast, id),
HirExpression::Cast(cast) => self.evaluate_cast(&cast, id),
HirExpression::If(if_) => self.evaluate_if(if_, id),
HirExpression::Tuple(tuple) => self.evaluate_tuple(tuple),
HirExpression::Lambda(lambda) => self.evaluate_lambda(lambda, id),
Expand Down Expand Up @@ -929,7 +929,18 @@ impl<'a> Interpreter<'a> {
}
}

fn evaluate_cast(&mut self, cast: HirCastExpression, id: ExprId) -> IResult<Value> {
fn evaluate_cast(&mut self, cast: &HirCastExpression, id: ExprId) -> IResult<Value> {
let evaluated_lhs = self.evaluate(cast.lhs)?;
Self::evaluate_cast_one_step(cast, id, evaluated_lhs, self.interner)
}

/// evaluate_cast without recursion
pub fn evaluate_cast_one_step(
cast: &HirCastExpression,
id: ExprId,
evaluated_lhs: Value,
interner: &NodeInterner,
) -> IResult<Value> {
macro_rules! signed_int_to_field {
($x:expr) => {{
// Need to convert the signed integer to an i128 before
Expand All @@ -943,7 +954,7 @@ impl<'a> Interpreter<'a> {
}};
}

let (mut lhs, lhs_is_negative) = match self.evaluate(cast.lhs)? {
let (mut lhs, lhs_is_negative) = match evaluated_lhs {
Value::Field(value) => (value, false),
Value::U8(value) => ((value as u128).into(), false),
Value::U16(value) => ((value as u128).into(), false),
Expand All @@ -957,7 +968,7 @@ impl<'a> Interpreter<'a> {
(if value { FieldElement::one() } else { FieldElement::zero() }, false)
}
value => {
let location = self.interner.expr_location(&id);
let location = interner.expr_location(&id);
return Err(InterpreterError::NonNumericCasted { value, location });
}
};
Expand All @@ -982,8 +993,8 @@ impl<'a> Interpreter<'a> {
}
Type::Integer(sign, bit_size) => match (sign, bit_size) {
(Signedness::Unsigned, IntegerBitSize::One) => {
let location = self.interner.expr_location(&id);
Err(InterpreterError::TypeUnsupported { typ: cast.r#type, location })
let location = interner.expr_location(&id);
Err(InterpreterError::TypeUnsupported { typ: cast.r#type.clone(), location })
}
(Signedness::Unsigned, IntegerBitSize::Eight) => cast_to_int!(lhs, to_u128, u8, U8),
(Signedness::Unsigned, IntegerBitSize::Sixteen) => {
Expand All @@ -996,8 +1007,8 @@ impl<'a> Interpreter<'a> {
cast_to_int!(lhs, to_u128, u64, U64)
}
(Signedness::Signed, IntegerBitSize::One) => {
let location = self.interner.expr_location(&id);
Err(InterpreterError::TypeUnsupported { typ: cast.r#type, location })
let location = interner.expr_location(&id);
Err(InterpreterError::TypeUnsupported { typ: cast.r#type.clone(), location })
}
(Signedness::Signed, IntegerBitSize::Eight) => cast_to_int!(lhs, to_i128, i8, I8),
(Signedness::Signed, IntegerBitSize::Sixteen) => {
Expand All @@ -1012,7 +1023,7 @@ impl<'a> Interpreter<'a> {
},
Type::Bool => Ok(Value::Bool(!lhs.is_zero() || lhs_is_negative)),
typ => {
let location = self.interner.expr_location(&id);
let location = interner.expr_location(&id);
Err(InterpreterError::CastToNonNumericType { typ, location })
}
}
Expand Down
19 changes: 18 additions & 1 deletion compiler/noirc_frontend/src/hir/comptime/value.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{borrow::Cow, rc::Rc};

use acvm::FieldElement;
use acvm::{AcirField, FieldElement};
use im::Vector;
use iter_extended::{try_vecmap, vecmap};
use noirc_errors::Location;
Expand Down Expand Up @@ -283,6 +283,23 @@ impl Value {
interner.push_expr_type(id, typ);
Ok(id)
}

/// Converts any unsigned `Value` into a `u128`.
/// Returns `None` for negative integers.
pub(crate) fn to_u128(&self) -> Option<u128> {
match self {
Self::Field(value) => Some(value.to_u128()),
Self::I8(value) => (*value >= 0).then_some(*value as u128),
Self::I16(value) => (*value >= 0).then_some(*value as u128),
Self::I32(value) => (*value >= 0).then_some(*value as u128),
Self::I64(value) => (*value >= 0).then_some(*value as u128),
Self::U8(value) => Some(*value as u128),
Self::U16(value) => Some(*value as u128),
Self::U32(value) => Some(*value as u128),
Self::U64(value) => Some(*value as u128),
_ => None,
}
}
}

/// Unwraps an Rc value without cloning the inner value if the reference count is 1. Clones otherwise.
Expand Down
5 changes: 4 additions & 1 deletion compiler/noirc_frontend/src/hir/resolution/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pub use noirc_errors::Span;
use noirc_errors::{CustomDiagnostic as Diagnostic, FileDiagnostic};
use thiserror::Error;

use crate::{ast::Ident, parser::ParserError, Type};
use crate::{ast::Ident, hir::comptime::InterpreterError, parser::ParserError, Type};

use super::import::PathResolutionError;

Expand Down Expand Up @@ -94,6 +94,8 @@ pub enum ResolverError {
NoPredicatesAttributeOnUnconstrained { ident: Ident },
#[error("#[fold] attribute is only allowed on constrained functions")]
FoldAttributeOnUnconstrained { ident: Ident },
#[error("Invalid array length construction")]
ArrayLengthInterpreter { error: InterpreterError },
}

impl ResolverError {
Expand Down Expand Up @@ -386,6 +388,7 @@ impl<'a> From<&'a ResolverError> for Diagnostic {
diag.add_note("The `#[fold]` attribute specifies whether a constrained function should be treated as a separate circuit rather than inlined into the program entry point".to_owned());
diag
}
ResolverError::ArrayLengthInterpreter { error } => Diagnostic::from(error),
}
}
}
17 changes: 16 additions & 1 deletion compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ use crate::ast::{
};
use crate::graph::CrateId;
use crate::hir::def_map::{ModuleDefId, TryFromModuleDefId, MAIN_FUNCTION};
use crate::hir::{def_map::CrateDefMap, resolution::path_resolver::PathResolver};
use crate::hir::{
comptime::{Interpreter, Value},
def_map::CrateDefMap,
resolution::path_resolver::PathResolver,
};
use crate::hir_def::stmt::{HirAssignStatement, HirForStatement, HirLValue, HirPattern};
use crate::node_interner::{
DefinitionId, DefinitionKind, DependencyId, ExprId, FuncId, GlobalId, NodeInterner, StmtId,
Expand Down Expand Up @@ -2067,6 +2071,17 @@ impl<'a> Resolver<'a> {
BinaryOpKind::Modulo => Ok(lhs % rhs),
}
}
HirExpression::Cast(cast) => {
let lhs = self.try_eval_array_length_id_with_fuel(cast.lhs, span, fuel - 1)?;
let lhs_value = Value::Field(lhs.into());
let evaluated_value =
Interpreter::evaluate_cast_one_step(&cast, rhs, lhs_value, self.interner)
.map_err(|error| Some(ResolverError::ArrayLengthInterpreter { error }))?;

evaluated_value
.to_u128()
.ok_or_else(|| Some(ResolverError::InvalidArrayLengthExpr { span }))
}
_other => Err(Some(ResolverError::InvalidArrayLengthExpr { span })),
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/hir/type_check/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub enum Source {
Return(FunctionReturnType, Span),
}

#[derive(Error, Debug, Clone, PartialEq, Eq)]
#[derive(Error, Debug, Clone)]
pub enum TypeCheckError {
#[error("Operator {op:?} cannot be used in a {place:?}")]
OpCannotBeUsed { op: HirBinaryOp, place: &'static str, span: Span },
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "cast_and_shift_global"
type = "bin"
authors = [""]
compiler_version = ">=0.30.0"

[dependencies]
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
global THREE: u64 = 3;
global EIGHT: u64 = 1 << THREE as u8;
global SEVEN: u64 = EIGHT - 1;

fn main() {
assert([0; EIGHT] == [0; 8]);
assert([0; SEVEN] == [0; 7]);
}

0 comments on commit 6d3e732

Please sign in to comment.