Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add break and continue in unconstrained code #4569

Merged
merged 14 commits into from
Mar 19, 2024
11 changes: 7 additions & 4 deletions compiler/noirc_evaluator/src/ssa/function_builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,10 @@ impl FunctionBuilder {
) -> Self {
let mut new_function = Function::new(function_name, function_id);
new_function.set_runtime(runtime);
let current_block = new_function.entry_block();

Self {
current_block: new_function.entry_block(),
current_function: new_function,
current_block,
finished_functions: Vec::new(),
call_stack: CallStack::new(),
}
Expand Down Expand Up @@ -153,9 +152,10 @@ impl FunctionBuilder {
instruction: Instruction,
ctrl_typevars: Option<Vec<Type>>,
) -> InsertInstructionResult {
let block = self.current_block();
self.current_function.dfg.insert_instruction_and_results(
instruction,
self.current_block,
block,
ctrl_typevars,
self.call_stack.clone(),
)
Expand Down Expand Up @@ -310,8 +310,11 @@ impl FunctionBuilder {
}

/// Terminates the current block with the given terminator instruction
/// if the current block does not already have a terminator instruction.
fn terminate_block_with(&mut self, terminator: TerminatorInstruction) {
self.current_function.dfg.set_block_terminator(self.current_block, terminator);
if self.current_function.dfg[self.current_block].terminator().is_none() {
self.current_function.dfg.set_block_terminator(self.current_block, terminator);
}
}

/// Terminate the current block with a jmp instruction to jmp to the given
Expand Down
4 changes: 0 additions & 4 deletions compiler/noirc_evaluator/src/ssa/ir/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,6 @@ impl ControlFlowGraph {
);
predecessor_node.successors.insert(to);
let successor_node = self.data.entry(to).or_default();
assert!(
successor_node.predecessors.len() < 2,
"ICE: A cfg node cannot have more than two predecessors"
);
successor_node.predecessors.insert(from);
}

Expand Down
13 changes: 8 additions & 5 deletions compiler/noirc_evaluator/src/ssa/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,13 @@ impl<'function> PerFunctionContext<'function> {
}
TerminatorInstruction::Return { return_values, call_stack } => {
let return_values = vecmap(return_values, |value| self.translate_value(*value));

// Note that `translate_block` would take us back to the point at which the
// inlining of this source block began. Since additional blocks may have been
// inlined since, we are interested in the block representing the current program
// point, obtained via `current_block`.
let block_id = self.context.builder.current_block();

if self.inlining_entry {
let mut new_call_stack = self.context.call_stack.clone();
new_call_stack.append(call_stack.clone());
Expand All @@ -495,11 +502,7 @@ impl<'function> PerFunctionContext<'function> {
.set_call_stack(new_call_stack)
.terminate_with_return(return_values.clone());
}
// Note that `translate_block` would take us back to the point at which the
// inlining of this source block began. Since additional blocks may have been
// inlined since, we are interested in the block representing the current program
// point, obtained via `current_block`.
let block_id = self.context.builder.current_block();

Some((block_id, return_values))
}
}
Expand Down
33 changes: 32 additions & 1 deletion compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ pub(super) struct FunctionContext<'a> {

pub(super) builder: FunctionBuilder,
shared_context: &'a SharedContext,

/// Contains any loops we're currently in the middle of translating.
/// These are ordered such that an inner loop is at the end of the vector and
/// outer loops are at the beginning. When a loop is finished, it is popped.
loops: Vec<Loop>,
}

/// Shared context for all functions during ssa codegen. This is the only
Expand Down Expand Up @@ -72,6 +77,13 @@ pub(super) struct SharedContext {
pub(super) program: Program,
}

#[derive(Copy, Clone)]
pub(super) struct Loop {
pub(super) loop_entry: BasicBlockId,
pub(super) loop_index: ValueId,
pub(super) loop_end: BasicBlockId,
}

/// The queue of functions remaining to compile
type FunctionQueue = Vec<(ast::FuncId, IrFunctionId)>;

Expand All @@ -97,7 +109,8 @@ impl<'a> FunctionContext<'a> {
.1;

let builder = FunctionBuilder::new(function_name, function_id, runtime);
let mut this = Self { definitions: HashMap::default(), builder, shared_context };
let definitions = HashMap::default();
let mut this = Self { definitions, builder, shared_context, loops: Vec::new() };
this.add_parameters_to_scope(parameters);
this
}
Expand Down Expand Up @@ -1053,6 +1066,24 @@ impl<'a> FunctionContext<'a> {
self.builder.decrement_array_reference_count(parameter);
}
}

pub(crate) fn enter_loop(
&mut self,
loop_entry: BasicBlockId,
loop_index: ValueId,
loop_end: BasicBlockId,
) {
self.loops.push(Loop { loop_entry, loop_index, loop_end });
}

pub(crate) fn exit_loop(&mut self) {
self.loops.pop();
}

pub(crate) fn current_loop(&self) -> Loop {
// The frontend should ensure break/continue are never used outside a loop
*self.loops.last().expect("current_loop: not in a loop!")
}
}

/// True if the given operator cannot be encoded directly and needs
Expand Down
22 changes: 22 additions & 0 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@
}
Expression::Assign(assign) => self.codegen_assign(assign),
Expression::Semi(semi) => self.codegen_semi(semi),
Expression::Break => Ok(self.codegen_break()),
Expression::Continue => Ok(self.codegen_continue()),
}
}

Expand Down Expand Up @@ -461,7 +463,7 @@
/// br loop_entry(v0)
/// loop_entry(i: Field):
/// v2 = lt i v1
/// brif v2, then: loop_body, else: loop_end

Check warning on line 466 in compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (brif)
/// loop_body():
/// v3 = ... codegen body ...
/// v4 = add 1, i
Expand All @@ -477,6 +479,10 @@
let index_type = Self::convert_non_tuple_type(&for_expr.index_type);
let loop_index = self.builder.add_block_parameter(loop_entry, index_type);

// Remember the blocks and variable used in case there are break/continue instructions
// within the loop which need to jump to them.
self.enter_loop(loop_entry, loop_index, loop_end);

self.builder.set_location(for_expr.start_range_location);
let start_index = self.codegen_non_tuple_expression(&for_expr.start_range)?;

Expand Down Expand Up @@ -507,6 +513,7 @@

// Finish by switching back to the end of the loop
self.builder.switch_to_block(loop_end);
self.exit_loop();
Ok(Self::unit_value())
}

Expand All @@ -515,7 +522,7 @@
/// For example, the expression `if cond { a } else { b }` is codegen'd as:
///
/// v0 = ... codegen cond ...
/// brif v0, then: then_block, else: else_block

Check warning on line 525 in compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (brif)
/// then_block():
/// v1 = ... codegen a ...
/// br end_if(v1)
Expand All @@ -528,7 +535,7 @@
/// As another example, the expression `if cond { a }` is codegen'd as:
///
/// v0 = ... codegen cond ...
/// brif v0, then: then_block, else: end_block

Check warning on line 538 in compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (brif)
/// then_block:
/// v1 = ... codegen a ...
/// br end_if()
Expand Down Expand Up @@ -736,4 +743,19 @@
self.codegen_expression(expr)?;
Ok(Self::unit_value())
}

fn codegen_break(&mut self) -> Values {
let loop_end = self.current_loop().loop_end;
self.builder.terminate_with_jmp(loop_end, Vec::new());
Self::unit_value()
}

fn codegen_continue(&mut self) -> Values {
let loop_ = self.current_loop();

// Must remember to increment i before jumping
let new_loop_index = self.make_offset(loop_.loop_index, 1);
self.builder.terminate_with_jmp(loop_.loop_entry, vec![new_loop_index]);
Self::unit_value()
}
}
6 changes: 6 additions & 0 deletions compiler/noirc_frontend/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ pub enum StatementKind {
Expression(Expression),
Assign(AssignStatement),
For(ForLoopStatement),
Break,
Continue,
// This is an expression with a trailing semi-colon
Semi(Expression),
// This statement is the result of a recovered parse error.
Expand All @@ -59,6 +61,8 @@ impl Statement {
| StatementKind::Constrain(_)
| StatementKind::Assign(_)
| StatementKind::Semi(_)
| StatementKind::Break
| StatementKind::Continue
| StatementKind::Error => {
// To match rust, statements always require a semicolon, even at the end of a block
if semi.is_none() {
Expand Down Expand Up @@ -637,6 +641,8 @@ impl Display for StatementKind {
StatementKind::Expression(expression) => expression.fmt(f),
StatementKind::Assign(assign) => assign.fmt(f),
StatementKind::For(for_loop) => for_loop.fmt(f),
StatementKind::Break => write!(f, "break"),
StatementKind::Continue => write!(f, "continue"),
StatementKind::Semi(semi) => write!(f, "{semi};"),
StatementKind::Error => write!(f, "Error"),
}
Expand Down
20 changes: 20 additions & 0 deletions compiler/noirc_frontend/src/hir/resolution/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ pub enum ResolverError {
LowLevelFunctionOutsideOfStdlib { ident: Ident },
#[error("Dependency cycle found, '{item}' recursively depends on itself: {cycle} ")]
DependencyCycle { span: Span, item: String, cycle: String },
#[error("break/continue are only allowed in unconstrained functions")]
JumpInConstrainedFn { is_break: bool, span: Span },
#[error("break/continue are only allowed within loops")]
JumpOutsideLoop { is_break: bool, span: Span },
}

impl ResolverError {
Expand Down Expand Up @@ -322,6 +326,22 @@ impl From<ResolverError> for Diagnostic {
span,
)
},
ResolverError::JumpInConstrainedFn { is_break, span } => {
let item = if is_break { "break" } else { "continue" };
Diagnostic::simple_error(
format!("{item} is only allowed in unconstrained functions"),
"Constrained code must always have a known number of loop iterations".into(),
span,
)
},
ResolverError::JumpOutsideLoop { is_break, span } => {
let item = if is_break { "break" } else { "continue" };
Diagnostic::simple_error(
format!("{item} is only allowed within loops"),
"".into(),
span,
)
},
}
}
}
47 changes: 41 additions & 6 deletions compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use crate::{
use crate::{
ArrayLiteral, BinaryOpKind, Distinctness, ForRange, FunctionDefinition, FunctionReturnType,
Generics, ItemVisibility, LValue, NoirStruct, NoirTypeAlias, Param, Path, PathKind, Pattern,
Shared, StructType, Type, TypeAlias, TypeVariable, TypeVariableKind, UnaryOp,
Shared, Statement, StructType, Type, TypeAlias, TypeVariable, TypeVariableKind, UnaryOp,
UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData,
UnresolvedTypeExpression, Visibility, ERROR_IDENT,
};
Expand Down Expand Up @@ -115,6 +115,13 @@ pub struct Resolver<'a> {
/// that are captured. We do this in order to create the hidden environment
/// parameter for the lambda function.
lambda_stack: Vec<LambdaContext>,

/// True if we're currently resolving an unconstrained function
in_unconstrained_fn: bool,

/// How many loops we're currently within.
/// This increases by 1 at the start of a loop, and decreases by 1 when it ends.
nested_loops: u32,
michaeljklein marked this conversation as resolved.
Show resolved Hide resolved
}

/// ResolverMetas are tagged onto each definition to track how many times they are used
Expand Down Expand Up @@ -155,6 +162,8 @@ impl<'a> Resolver<'a> {
current_item: None,
file,
in_contract,
in_unconstrained_fn: false,
nested_loops: 0,
}
}

Expand Down Expand Up @@ -416,6 +425,11 @@ impl<'a> Resolver<'a> {

fn intern_function(&mut self, func: NoirFunction, id: FuncId) -> (HirFunction, FuncMeta) {
let func_meta = self.extract_meta(&func, id);

if func.def.is_unconstrained {
self.in_unconstrained_fn = true;
}

let hir_func = match func.kind {
FunctionKind::Builtin | FunctionKind::LowLevel | FunctionKind::Oracle => {
HirFunction::empty()
Expand Down Expand Up @@ -1148,7 +1162,7 @@ impl<'a> Resolver<'a> {
})
}

pub fn resolve_stmt(&mut self, stmt: StatementKind) -> HirStatement {
pub fn resolve_stmt(&mut self, stmt: StatementKind, span: Span) -> HirStatement {
match stmt {
StatementKind::Let(let_stmt) => {
let expression = self.resolve_expression(let_stmt.expression);
Expand Down Expand Up @@ -1188,6 +1202,8 @@ impl<'a> Resolver<'a> {
let end_range = self.resolve_expression(end_range);
let (identifier, block) = (for_loop.identifier, for_loop.block);

self.nested_loops += 1;

// TODO: For loop variables are currently mutable by default since we haven't
// yet implemented syntax for them to be optionally mutable.
let (identifier, block) = self.in_new_scope(|this| {
Expand All @@ -1200,6 +1216,8 @@ impl<'a> Resolver<'a> {
(decl, this.resolve_expression(block))
});

self.nested_loops -= 1;

HirStatement::For(HirForStatement {
start_range,
end_range,
Expand All @@ -1210,10 +1228,18 @@ impl<'a> Resolver<'a> {
range @ ForRange::Array(_) => {
let for_stmt =
range.into_for(for_loop.identifier, for_loop.block, for_loop.span);
self.resolve_stmt(for_stmt)
self.resolve_stmt(for_stmt, for_loop.span)
}
}
}
StatementKind::Break => {
self.check_break_continue(true, span);
HirStatement::Break
}
StatementKind::Continue => {
self.check_break_continue(false, span);
HirStatement::Continue
}
StatementKind::Error => HirStatement::Error,
}
}
Expand Down Expand Up @@ -1260,8 +1286,8 @@ impl<'a> Resolver<'a> {
Some(self.resolve_expression(assert_msg_call_expr))
}

pub fn intern_stmt(&mut self, stmt: StatementKind) -> StmtId {
let hir_stmt = self.resolve_stmt(stmt);
pub fn intern_stmt(&mut self, stmt: Statement) -> StmtId {
let hir_stmt = self.resolve_stmt(stmt.kind, stmt.span);
self.interner.push_stmt(hir_stmt)
}

Expand Down Expand Up @@ -1909,7 +1935,7 @@ impl<'a> Resolver<'a> {

fn resolve_block(&mut self, block_expr: BlockExpression) -> HirExpression {
let statements =
self.in_new_scope(|this| vecmap(block_expr.0, |stmt| this.intern_stmt(stmt.kind)));
self.in_new_scope(|this| vecmap(block_expr.0, |stmt| this.intern_stmt(stmt)));
HirExpression::Block(HirBlockExpression(statements))
}

Expand Down Expand Up @@ -2036,6 +2062,15 @@ impl<'a> Resolver<'a> {
}
HirLiteral::FmtStr(str, fmt_str_idents)
}

fn check_break_continue(&mut self, is_break: bool, span: Span) {
if !self.in_unconstrained_fn {
self.push_err(ResolverError::JumpInConstrainedFn { is_break, span });
}
if self.nested_loops == 0 {
self.push_err(ResolverError::JumpOutsideLoop { is_break, span });
}
}
}

/// Gives an error if a user tries to create a mutable reference
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/hir/type_check/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl<'interner> TypeChecker<'interner> {
HirStatement::Constrain(constrain_stmt) => self.check_constrain_stmt(constrain_stmt),
HirStatement::Assign(assign_stmt) => self.check_assign_stmt(assign_stmt, stmt_id),
HirStatement::For(for_loop) => self.check_for_loop(for_loop),
HirStatement::Error => (),
HirStatement::Break | HirStatement::Continue | HirStatement::Error => (),
}
Type::Unit
}
Expand Down
2 changes: 2 additions & 0 deletions compiler/noirc_frontend/src/hir_def/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ pub enum HirStatement {
Constrain(HirConstrainStatement),
Assign(HirAssignStatement),
For(HirForStatement),
Break,
Continue,
Expression(ExprId),
Semi(ExprId),
Error,
Expand Down
Loading
Loading