Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: decompose Instruction::Cast to have an explicit truncation instruction #3946

Merged
merged 16 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 3 additions & 38 deletions compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,9 @@ impl Context {
self.acir_context.assert_eq_var(lhs, rhs, assert_message.clone())?;
}
}
Instruction::Cast(value_id, typ) => {
let result_acir_var = self.convert_ssa_cast(value_id, typ, dfg)?;
self.define_result_var(dfg, instruction_id, result_acir_var);
Instruction::Cast(value_id, _) => {
let acir_var = self.convert_numeric_value(*value_id, dfg)?;
self.define_result_var(dfg, instruction_id, acir_var);
}
Instruction::Call { func, arguments } => {
let result_ids = dfg.instruction_results(instruction_id);
Expand Down Expand Up @@ -1636,41 +1636,6 @@ impl Context {
}
}

/// Returns an `AcirVar` that is constrained to fit in the target type by truncating the input.
/// If the target cast is to a `NativeField`, no truncation is required so the cast becomes a
/// no-op.
fn convert_ssa_cast(
&mut self,
value_id: &ValueId,
typ: &Type,
dfg: &DataFlowGraph,
) -> Result<AcirVar, RuntimeError> {
let (variable, incoming_type) = match self.convert_value(*value_id, dfg) {
AcirValue::Var(variable, typ) => (variable, typ),
AcirValue::DynamicArray(_) | AcirValue::Array(_) => {
unreachable!("Cast is only applied to numerics")
}
};
let target_numeric = match typ {
Type::Numeric(numeric) => numeric,
_ => unreachable!("Can only cast to a numeric"),
};
match target_numeric {
NumericType::NativeField => {
// Casting into a Field as a no-op
Ok(variable)
}
NumericType::Unsigned { bit_size } | NumericType::Signed { bit_size } => {
let max_bit_size = incoming_type.bit_size();
if max_bit_size <= *bit_size {
// Incoming variable already fits into target bit size - this is a no-op
return Ok(variable);
}
self.acir_context.truncate_var(variable, *bit_size, max_bit_size)
}
}
}

/// Returns an `AcirVar`that is constrained to be result of the truncation.
fn convert_ssa_truncate(
&mut self,
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/ir/dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
constants: HashMap<(FieldElement, Type), ValueId>,

/// Contains each function that has been imported into the current function.
/// Each function's Value::Function is uniqued here so any given FunctionId

Check warning on line 50 in compiler/noirc_evaluator/src/ssa/ir/dfg.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (uniqued)
/// will always have the same ValueId within this function.
functions: HashMap<FunctionId, ValueId>,

Expand All @@ -57,7 +57,7 @@
intrinsics: HashMap<Intrinsic, ValueId>,

/// Contains each foreign function that has been imported into the current function.
/// This map is used to ensure that the ValueId for any given foreign functôn is always

Check warning on line 60 in compiler/noirc_evaluator/src/ssa/ir/dfg.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (functôn)
/// represented by only 1 ValueId within this function.
foreign_functions: HashMap<String, ValueId>,

Expand All @@ -71,11 +71,11 @@

/// Source location of each instruction for debugging and issuing errors.
///
/// The `CallStack` here corresponds to the entire callstack of locations. Initially this

Check warning on line 74 in compiler/noirc_evaluator/src/ssa/ir/dfg.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (callstack)
/// only contains the actual location of the instruction. During inlining, a new location
/// will be pushed to each instruction for the location of the function call of the function
/// the instruction was originally located in. Once inlining is complete, the locations Vec
/// here should contain the entire callstack for each instruction.

Check warning on line 78 in compiler/noirc_evaluator/src/ssa/ir/dfg.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (callstack)
///
/// Instructions inserted by internal SSA passes that don't correspond to user code
/// may not have a corresponding location.
Expand Down Expand Up @@ -160,7 +160,7 @@
call_stack: CallStack,
) -> InsertInstructionResult {
use InsertInstructionResult::*;
match instruction.simplify(self, block, ctrl_typevars.clone()) {
match instruction.simplify(self, block, ctrl_typevars.clone(), &call_stack) {
jfecher marked this conversation as resolved.
Show resolved Hide resolved
SimplifyResult::SimplifiedTo(simplification) => SimplifiedTo(simplification),
SimplifyResult::SimplifiedToMultiple(simplification) => {
SimplifiedToMultiple(simplification)
Expand Down
8 changes: 3 additions & 5 deletions compiler/noirc_evaluator/src/ssa/ir/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,7 @@ impl Instruction {
// In ACIR, a division with a false predicate outputs (0,0), so it cannot replace another instruction unless they have the same predicate
bin.operator != BinaryOp::Div
}
Cast(_, _) | Not(_) | ArrayGet { .. } | ArraySet { .. } => true,

// Unclear why this instruction causes problems.
Truncate { .. } => false,
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved
Cast(_, _) | Truncate { .. } | Not(_) | ArrayGet { .. } | ArraySet { .. } => true,

// These either have side-effects or interact with memory
Constrain(..)
Expand Down Expand Up @@ -408,6 +405,7 @@ impl Instruction {
dfg: &mut DataFlowGraph,
block: BasicBlockId,
ctrl_typevars: Option<Vec<Type>>,
call_stack: &CallStack,
) -> SimplifyResult {
use SimplifyResult::*;
match self {
Expand Down Expand Up @@ -551,7 +549,7 @@ impl Instruction {
}
}
Instruction::Call { func, arguments } => {
simplify_call(*func, arguments, dfg, block, ctrl_typevars)
simplify_call(*func, arguments, dfg, block, ctrl_typevars, call_stack)
}
Instruction::EnableSideEffects { condition } => {
if let Some(last) = dfg[block].instructions().last().copied() {
Expand Down
20 changes: 19 additions & 1 deletion compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub(super) fn simplify_call(
dfg: &mut DataFlowGraph,
block: BasicBlockId,
ctrl_typevars: Option<Vec<Type>>,
call_stack: &CallStack,
) -> SimplifyResult {
let intrinsic = match &dfg[func] {
Value::Intrinsic(intrinsic) => *intrinsic,
Expand Down Expand Up @@ -242,7 +243,24 @@ pub(super) fn simplify_call(
SimplifyResult::SimplifiedToInstruction(instruction)
}
Intrinsic::FromField => {
let instruction = Instruction::Cast(arguments[0], ctrl_typevars.unwrap().remove(0));
let incoming_type = Type::field();
let target_type = ctrl_typevars.unwrap().remove(0);

let truncate = Instruction::Truncate {
value: arguments[0],
bit_size: target_type.bit_size(),
max_bit_size: incoming_type.bit_size(),
};
let truncated_value = dfg
.insert_instruction_and_results(
truncate,
block,
Some(vec![incoming_type]),
call_stack.clone(),
)
.first();

let instruction = Instruction::Cast(truncated_value, target_type);
SimplifyResult::SimplifiedToInstruction(instruction)
}
}
Expand Down
48 changes: 34 additions & 14 deletions compiler/noirc_evaluator/src/ssa/ir/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,28 @@ pub enum NumericType {
NativeField,
}

impl NumericType {
/// Returns the bit size of the provided numeric type.
pub(crate) fn bit_size(self: &NumericType) -> u32 {
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved
match self {
NumericType::NativeField => FieldElement::max_num_bits(),
NumericType::Unsigned { bit_size } | NumericType::Signed { bit_size } => *bit_size,
}
}

/// Returns true if the given Field value is within the numeric limits
/// for the current NumericType.
pub(crate) fn value_is_within_limits(self, field: FieldElement) -> bool {
match self {
NumericType::Signed { bit_size } | NumericType::Unsigned { bit_size } => {
let max = 2u128.pow(bit_size) - 1;
field <= max.into()
}
NumericType::NativeField => true,
}
}
}

/// All types representable in the IR.
#[derive(Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub(crate) enum Type {
Expand Down Expand Up @@ -68,6 +90,18 @@ impl Type {
Type::Numeric(NumericType::NativeField)
}

/// Returns the bit size of the provided numeric type.
///
/// # Panics
///
/// Panics if `self` is not a [`Type::Numeric`]
pub(crate) fn bit_size(&self) -> u32 {
match self {
Type::Numeric(numeric_type) => numeric_type.bit_size(),
other => panic!("bit_size: Expected numeric type, found {other}"),
}
}

/// Returns the size of the element type for this array/slice.
/// The size of a type is defined as representing how many Fields are needed
/// to represent the type. This is 1 for every primitive type, and is the number of fields
Expand Down Expand Up @@ -122,20 +156,6 @@ impl Type {
}
}

impl NumericType {
/// Returns true if the given Field value is within the numeric limits
/// for the current NumericType.
pub(crate) fn value_is_within_limits(self, field: FieldElement) -> bool {
match self {
NumericType::Signed { bit_size } | NumericType::Unsigned { bit_size } => {
let max = 2u128.pow(bit_size) - 1;
field <= max.into()
}
NumericType::NativeField => true,
}
}
}

/// Composite Types are essentially flattened struct or tuple types.
/// Array types may have these as elements where each flattened field is
/// included in the array sequentially.
Expand Down
10 changes: 5 additions & 5 deletions compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ mod test {
instruction::{BinaryOp, Instruction, TerminatorInstruction},
map::Id,
types::Type,
value::{Value, ValueId},
value::Value,
},
};

Expand Down Expand Up @@ -293,7 +293,7 @@ mod test {
#[test]
fn instruction_deduplication() {
// fn main f0 {
// b0(v0: Field):
// b0(v0: u16):
// v1 = cast v0 as u32
// v2 = cast v0 as u32
// constrain v1 v2
Expand All @@ -308,7 +308,7 @@ mod test {

// Compiling main
let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir);
let v0 = builder.add_parameter(Type::field());
let v0 = builder.add_parameter(Type::unsigned(16));

let v1 = builder.insert_cast(v0, Type::unsigned(32));
let v2 = builder.insert_cast(v0, Type::unsigned(32));
Expand All @@ -322,7 +322,7 @@ mod test {
// Expected output:
//
// fn main f0 {
// b0(v0: Field):
// b0(v0: u16):
// v1 = cast v0 as u32
// }
let ssa = ssa.fold_constants();
Expand All @@ -332,6 +332,6 @@ mod test {
assert_eq!(instructions.len(), 1);
let instruction = &main.dfg[instructions[0]];

assert_eq!(instruction, &Instruction::Cast(ValueId::test_new(0), Type::unsigned(32)));
assert_eq!(instruction, &Instruction::Cast(v0, Type::unsigned(32)));
}
}
48 changes: 38 additions & 10 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ impl<'a> FunctionContext<'a> {
let bit_width =
self.builder.numeric_constant(FieldElement::from(2_i128.pow(bit_size)), Type::field());
let sign_not = self.builder.insert_binary(one, BinaryOp::Sub, sign);

// We use unsafe casts here, this is fine as we're casting to a `field` type.
let as_field = self.builder.insert_cast(input, Type::field());
let sign_field = self.builder.insert_cast(sign, Type::field());
let positive_predicate = self.builder.insert_binary(sign_field, BinaryOp::Mul, as_field);
Expand Down Expand Up @@ -310,12 +312,12 @@ impl<'a> FunctionContext<'a> {
match operator {
BinaryOpKind::Add | BinaryOpKind::Subtract => {
// Result is computed modulo the bit size
let mut result =
self.builder.insert_truncate(result, bit_size, bit_size + 1);
result = self.builder.insert_cast(result, Type::unsigned(bit_size));
let result = self.builder.insert_truncate(result, bit_size, bit_size + 1);
let result =
self.insert_safe_cast(result, Type::unsigned(bit_size), location);

self.check_signed_overflow(result, lhs, rhs, operator, bit_size, location);
self.builder.insert_cast(result, result_type)
self.insert_safe_cast(result, result_type, location)
}
BinaryOpKind::Multiply => {
// Result is computed modulo the bit size
Expand All @@ -324,7 +326,7 @@ impl<'a> FunctionContext<'a> {
result = self.builder.insert_truncate(result, bit_size, 2 * bit_size);

self.check_signed_overflow(result, lhs, rhs, operator, bit_size, location);
self.builder.insert_cast(result, result_type)
self.insert_safe_cast(result, result_type, location)
}
BinaryOpKind::ShiftLeft | BinaryOpKind::ShiftRight => {
self.check_shift_overflow(result, rhs, bit_size, location, true)
Expand Down Expand Up @@ -374,8 +376,11 @@ impl<'a> FunctionContext<'a> {
is_signed: bool,
) -> ValueId {
let one = self.builder.numeric_constant(FieldElement::one(), Type::bool());
let rhs =
if is_signed { self.builder.insert_cast(rhs, Type::unsigned(bit_size)) } else { rhs };
let rhs = if is_signed {
self.insert_safe_cast(rhs, Type::unsigned(bit_size), location)
} else {
rhs
};
// Bit-shift with a negative number is an overflow
if is_signed {
// We compute the sign of rhs.
Expand Down Expand Up @@ -431,8 +436,8 @@ impl<'a> FunctionContext<'a> {
Type::unsigned(bit_size),
);
// We compute the sign of the operands. The overflow checks for signed integers depends on these signs
let lhs_as_unsigned = self.builder.insert_cast(lhs, Type::unsigned(bit_size));
let rhs_as_unsigned = self.builder.insert_cast(rhs, Type::unsigned(bit_size));
let lhs_as_unsigned = self.insert_safe_cast(lhs, Type::unsigned(bit_size), location);
let rhs_as_unsigned = self.insert_safe_cast(rhs, Type::unsigned(bit_size), location);
let lhs_sign = self.builder.insert_binary(lhs_as_unsigned, BinaryOp::Lt, half_width);
let mut rhs_sign = self.builder.insert_binary(rhs_as_unsigned, BinaryOp::Lt, half_width);
let message = if is_sub {
Expand Down Expand Up @@ -473,7 +478,7 @@ impl<'a> FunctionContext<'a> {
// Then we check the signed product fits in a signed integer of bit_size-bits
let not_same = self.builder.insert_binary(one, BinaryOp::Sub, same_sign);
let not_same_sign_field =
self.builder.insert_cast(not_same, Type::unsigned(bit_size));
self.insert_safe_cast(not_same, Type::unsigned(bit_size), location);
let positive_maximum_with_offset =
self.builder.insert_binary(half_width, BinaryOp::Add, not_same_sign_field);
let product_overflow_check =
Expand Down Expand Up @@ -663,6 +668,29 @@ impl<'a> FunctionContext<'a> {
reshaped_return_values
}

/// Inserts a cast instruction at the end of the current block and returns the results
/// of the cast.
///
/// Compared to `self.builder.insert_cast`, this version will automatically truncate `value` to be a valid `typ`.
pub(super) fn insert_safe_cast(
&mut self,
mut value: ValueId,
typ: Type,
location: Location,
) -> ValueId {
self.builder.set_location(location);

// To ensure that `value` is a valid `typ`, we insert an `Instruction::Truncate` instruction beforehand if
// we're narrowing the type size.
let incoming_type_size = self.builder.type_of_value(value).bit_size();
let target_type_size = typ.bit_size();
if target_type_size < incoming_type_size {
value = self.builder.insert_truncate(value, target_type_size, incoming_type_size);
}

self.builder.insert_cast(value, typ)
}

/// Create a const offset of an address for an array load or store
pub(super) fn make_offset(&mut self, mut address: ValueId, offset: u128) -> ValueId {
if offset != 0 {
Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,8 @@
fn codegen_cast(&mut self, cast: &ast::Cast) -> Result<Values, RuntimeError> {
let lhs = self.codegen_non_tuple_expression(&cast.lhs)?;
let typ = Self::convert_non_tuple_type(&cast.r#type);
self.builder.set_location(cast.location);
Ok(self.builder.insert_cast(lhs, typ).into())

Ok(self.insert_safe_cast(lhs, typ, cast.location).into())
}

/// Codegens a for loop, creating three new blocks in the process.
Expand All @@ -448,7 +448,7 @@
/// br loop_entry(v0)
/// loop_entry(i: Field):
/// v2 = lt i v1
/// brif v2, then: loop_body, else: loop_end

Check warning on line 451 in compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (brif)
/// loop_body():
/// v3 = ... codegen body ...
/// v4 = add 1, i
Expand Down Expand Up @@ -502,7 +502,7 @@
/// For example, the expression `if cond { a } else { b }` is codegen'd as:
///
/// v0 = ... codegen cond ...
/// brif v0, then: then_block, else: else_block

Check warning on line 505 in compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (brif)
/// then_block():
/// v1 = ... codegen a ...
/// br end_if(v1)
Expand All @@ -515,7 +515,7 @@
/// As another example, the expression `if cond { a }` is codegen'd as:
///
/// v0 = ... codegen cond ...
/// brif v0, then: then_block, else: end_block

Check warning on line 518 in compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (brif)
/// then_block:
/// v1 = ... codegen a ...
/// br end_if()
Expand Down
Loading