Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(avm): CAST opcode implementation #5477

Merged
merged 8 commits into from
Apr 18, 2024
59 changes: 51 additions & 8 deletions barretenberg/cpp/pil/avm/avm_alu.pil
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ namespace avm_alu(256);
pol commit op_div;
pol commit op_not;
pol commit op_eq;
pol commit op_cast;
pol commit op_cast_prev; // Predicate on whether op_cast is enabled at previous row
pol commit alu_sel; // Predicate to activate the copy of intermediate registers to ALU table.
pol commit op_lt;
pol commit op_lte;
pol commit cmp_sel; // Predicate if LT or LTE is set
pol commit rng_chk_sel; // Predicate representing a range check row.
pol commit rng_chk_sel; // Predicate representing a range check row used in LT/LTE.

// Instruction tag (1: u8, 2: u16, 3: u32, 4: u64, 5: u128, 6: field) copied from Main table
pol commit in_tag;
Expand Down Expand Up @@ -59,7 +61,7 @@ namespace avm_alu(256);
pol commit cf;

// Compute predicate telling whether there is a row entry in the ALU table.
alu_sel = op_add + op_sub + op_mul + op_not + op_eq + op_lt + op_lte;
alu_sel = op_add + op_sub + op_mul + op_not + op_eq + op_cast + op_lt + op_lte;
cmp_sel = op_lt + op_lte;

// ========= Type Constraints =============================================
Expand Down Expand Up @@ -282,15 +284,15 @@ namespace avm_alu(256);
// (x - y - 1) * q + (y - x) (1 - q) = result

// If LT, then swap ia and ib else keep the same
pol INPUT_IA = op_lt * ib + op_lte * ia;
pol INPUT_IA = op_lt * ib + (op_lte + op_cast) * ia;
pol INPUT_IB = op_lt * ia + op_lte * ib;

pol commit borrow;
pol commit a_lo;
pol commit a_hi;
// Check INPUT_IA is well formed from its lo and hi limbs
#[INPUT_DECOMP_1]
INPUT_IA = (a_lo + 2 ** 128 * a_hi) * cmp_sel;
INPUT_IA = (a_lo + 2 ** 128 * a_hi) * (cmp_sel + op_cast);

pol commit b_lo;
pol commit b_hi;
Expand All @@ -311,9 +313,9 @@ namespace avm_alu(256);
// First condition is if borrow = 0, second condition is if borrow = 1
// This underflow check is done by the 128-bit check that is performed on each of these lo and hi limbs.
#[SUB_LO_1]
(p_sub_a_lo - (53438638232309528389504892708671455232 - a_lo + p_a_borrow * 2 ** 128)) * cmp_sel = 0;
(p_sub_a_lo - (53438638232309528389504892708671455232 - a_lo + p_a_borrow * 2 ** 128)) * (cmp_sel + op_cast) = 0;
#[SUB_HI_1]
(p_sub_a_hi - (64323764613183177041862057485226039389 - a_hi - p_a_borrow)) * cmp_sel = 0;
(p_sub_a_hi - (64323764613183177041862057485226039389 - a_hi - p_a_borrow)) * (cmp_sel + op_cast) = 0;

pol commit p_sub_b_lo;
pol commit p_sub_b_hi;
Expand Down Expand Up @@ -432,11 +434,11 @@ namespace avm_alu(256);
cmp_rng_ctr * ((1 - rng_chk_sel) * (1 - op_eq_diff_inv) + op_eq_diff_inv) - rng_chk_sel = 0;

// We perform a range check if we have some range checks remaining or we are performing a comparison op
pol RNG_CHK_OP = rng_chk_sel + cmp_sel;
pol RNG_CHK_OP = rng_chk_sel + cmp_sel + op_cast + op_cast_prev;

pol commit rng_chk_lookup_selector;
#[RNG_CHK_LOOKUP_SELECTOR]
rng_chk_lookup_selector' = cmp_sel' + rng_chk_sel' + op_add' + op_sub' + op_mul' + op_mul * u128_tag;
rng_chk_lookup_selector' = cmp_sel' + rng_chk_sel' + op_add' + op_sub' + op_mul' + op_mul * u128_tag + op_cast' + op_cast_prev';

// Perform 128-bit range check on lo part
#[LOWER_CMP_RNG_CHK]
Expand Down Expand Up @@ -467,3 +469,44 @@ namespace avm_alu(256);
(p_sub_b_lo' - res_lo) * rng_chk_sel'= 0;
(p_sub_b_hi' - res_hi) * rng_chk_sel'= 0;


// ========= CAST Operation Constraints ===============================
// We handle the input ia independently of its tag, i.e., we suppose it can take
// any value between 0 and p-1.
// We decompose the input ia in 8-bit/16-bit limbs and prove that the decomposition
// sums up to ia over the integers (i.e., no modulo p wrapping). To prove this, we
// re-use techniques above from LT/LTE opcode. The following relations are toggled for CAST:
// - #[INPUT_DECOMP_1] shows a = a_lo + 2 ** 128 * a_hi
// - #[SUB_LO_1] and #[SUB_LO_1] shows that the above does not overflow modulo p.
// - We toggle RNG_CHK_OP with op_cast to show that a_lo, a_hi are correctly decomposed
// over the 8/16-bit ALU registers in #[LOWER_CMP_RNG_CHK] and #[UPPER_CMP_RNG_CHK].
// - The 128-bit range checks for a_lo, a_hi are activated in #[RNG_CHK_LOOKUP_SELECTOR].
// - We copy p_sub_a_lo resp. p_sub_a_hi into next row in columns a_lo resp. a_hi so
// that decomposition into the 8/16-bit ALU registers and range checks are performed.
// Copy is done in #[OP_CAST_RNG_CHECK_P_SUB_A_LOW/HIGH] below.
// Activation of decomposition and range check is achieved by adding op_cast_prev in
// #[LOWER_CMP_RNG_CHK], #[UPPER_CMP_RNG_CHK] and #[RNG_CHK_LOOKUP_SELECTOR].
// - Finally, the truncated result SUM_TAG is copied in ic as part of #[ALU_OP_CAST] below.
// - Note that the tag of return value must be constrained to be in_tag and is enforced in
// the main and memory traces.
//
// TODO: Potential optimization is to un-toggle all CAST relevant operations when ff_tag is
// enabled. In this case, ic = ia trivially.
// Another one is to activate range checks in a more granular way depending on the tag.

#[OP_CAST_PREV_LINE]
op_cast_prev' = op_cast;

#[ALU_OP_CAST]
op_cast * (SUM_TAG + ff_tag * ia - ic) = 0;

#[OP_CAST_RNG_CHECK_P_SUB_A_LOW]
op_cast * (a_lo' - p_sub_a_lo) = 0;

#[OP_CAST_RNG_CHECK_P_SUB_A_HIGH]
op_cast * (a_hi' - p_sub_a_hi) = 0;

// 128-bit multiplication and CAST need two rows in ALU trace. We need to ensure
// that another ALU operation does not start in the second row.
#[TWO_LINE_OP_NO_OVERLAP]
(op_mul * ff_tag + op_cast) * alu_sel' = 0;
35 changes: 26 additions & 9 deletions barretenberg/cpp/pil/avm/avm_main.pil
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ namespace avm_main(256);
pol commit sel_op_or;
// XOR
pol commit sel_op_xor;
// CAST
pol commit sel_op_cast;
// LT
pol commit sel_op_lt;
// LTE
Expand All @@ -68,6 +70,7 @@ namespace avm_main(256);
// Instruction memory tags read/write (1: u8, 2: u16, 3: u32, 4: u64, 5: u128, 6: field)
pol commit r_in_tag;
pol commit w_in_tag;
pol commit alu_in_tag; // Copy of r_in_tag or w_in_tag depending of the operation. It is sent to ALU trace.

// Errors
pol commit op_err; // Boolean flag pertaining to an operation error
Expand Down Expand Up @@ -121,7 +124,8 @@ namespace avm_main(256);
pol commit last;

// Relations on type constraints

// TODO: Very likely, we can remove these constraints as the selectors should be derived during
// opcode decomposition.
sel_op_add * (1 - sel_op_add) = 0;
sel_op_sub * (1 - sel_op_sub) = 0;
sel_op_mul * (1 - sel_op_mul) = 0;
Expand All @@ -131,6 +135,7 @@ namespace avm_main(256);
sel_op_and * (1 - sel_op_and) = 0;
sel_op_or * (1 - sel_op_or) = 0;
sel_op_xor * (1 - sel_op_xor) = 0;
sel_op_cast * (1 - sel_op_cast) = 0;
sel_op_lt * (1 - sel_op_lt) = 0;
sel_op_lte * (1 - sel_op_lte) = 0;

Expand Down Expand Up @@ -243,7 +248,8 @@ namespace avm_main(256);

//===== CONTROL_FLOW_CONSISTENCY ============================================
pol INTERNAL_CALL_STACK_SELECTORS = (first + sel_internal_call + sel_internal_return + sel_halt);
pol OPCODE_SELECTORS = (sel_op_add + sel_op_sub + sel_op_div + sel_op_mul + sel_op_not + sel_op_eq + sel_op_and + sel_op_or + sel_op_xor);
pol OPCODE_SELECTORS = (sel_op_add + sel_op_sub + sel_op_div + sel_op_mul + sel_op_not
+ sel_op_eq + sel_op_and + sel_op_or + sel_op_xor + sel_op_cast);

// Program counter must increment if not jumping or returning
#[PC_INCREMENT]
Expand Down Expand Up @@ -287,24 +293,35 @@ namespace avm_main(256);
#[MOV_MAIN_SAME_TAG]
(sel_mov + sel_cmov) * (r_in_tag - w_in_tag) = 0;

//===== ALU CONSTRAINTS =====================================================
// TODO: when division is moved to the alu, we will need to add the selector in the list below.
pol ALU_R_TAG_SEL = sel_op_add + sel_op_sub + sel_op_mul + sel_op_not + sel_op_eq + sel_op_lt + sel_op_lte;
pol ALU_W_TAG_SEL = sel_op_cast;
pol ALU_ALL_SEL = ALU_R_TAG_SEL + ALU_W_TAG_SEL;

// Predicate to activate the copy of intermediate registers to ALU table. If tag_err == 1,
// the operation is not copied to the ALU table.
alu_sel = ALU_ALL_SEL * (1 - tag_err);

// Dispatch the correct in_tag for alu
ALU_R_TAG_SEL * (alu_in_tag - r_in_tag) = 0;
ALU_W_TAG_SEL * (alu_in_tag - w_in_tag) = 0;

//====== Inter-table Constraints ============================================
#[INCL_MAIN_TAG_ERR]
avm_mem.tag_err {avm_mem.clk} in tag_err {clk};

#[INCL_MEM_TAG_ERR]
tag_err {clk} in avm_mem.tag_err {avm_mem.clk};

// Predicate to activate the copy of intermediate registers to ALU table. If tag_err == 1,
// the operation is not copied to the ALU table.
// TODO: when division is moved to the alu, we will need to add the selector in the list below:
alu_sel = (sel_op_add + sel_op_sub + sel_op_mul + sel_op_not + sel_op_eq + sel_op_lt + sel_op_lte) * (1 - tag_err);

#[PERM_MAIN_ALU]
alu_sel {clk, ia, ib, ic, sel_op_add, sel_op_sub,
sel_op_mul, sel_op_eq, sel_op_not, sel_op_lt, sel_op_lte, r_in_tag}
sel_op_mul, sel_op_eq, sel_op_not, sel_op_cast,
sel_op_lt, sel_op_lte, alu_in_tag}
is
avm_alu.alu_sel {avm_alu.clk, avm_alu.ia, avm_alu.ib, avm_alu.ic, avm_alu.op_add, avm_alu.op_sub,
avm_alu.op_mul, avm_alu.op_eq, avm_alu.op_not, avm_alu.op_lt, avm_alu.op_lte, avm_alu.in_tag};
avm_alu.op_mul, avm_alu.op_eq, avm_alu.op_not, avm_alu.op_cast,
avm_alu.op_lt, avm_alu.op_lte, avm_alu.in_tag};

// Based on the boolean selectors, we derive the binary op id to lookup in the table;
// TODO: Check if having 4 columns (op_id + 3 boolean selectors) is more optimal that just using the op_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,18 @@ template <class Params_> struct alignas(32) field {
return static_cast<bool>(out.data[0]);
}

constexpr explicit operator uint8_t() const
{
field out = from_montgomery_form();
return static_cast<uint8_t>(out.data[0]);
}

constexpr explicit operator uint16_t() const
{
field out = from_montgomery_form();
return static_cast<uint16_t>(out.data[0]);
}

constexpr explicit operator uint32_t() const
{
field out = from_montgomery_form();
Expand Down
Loading
Loading