Skip to content

Commit

Permalink
feat(avm): CAST opcode implementation (#5477)
Browse files Browse the repository at this point in the history
Resolves #5466
  • Loading branch information
jeanmon authored Apr 18, 2024
1 parent 2422891 commit a821bcc
Show file tree
Hide file tree
Showing 24 changed files with 995 additions and 344 deletions.
59 changes: 51 additions & 8 deletions barretenberg/cpp/pil/avm/avm_alu.pil
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ namespace avm_alu(256);
pol commit op_div;
pol commit op_not;
pol commit op_eq;
pol commit op_cast;
pol commit op_cast_prev; // Predicate on whether op_cast is enabled at previous row
pol commit alu_sel; // Predicate to activate the copy of intermediate registers to ALU table.
pol commit op_lt;
pol commit op_lte;
pol commit cmp_sel; // Predicate if LT or LTE is set
pol commit rng_chk_sel; // Predicate representing a range check row.
pol commit rng_chk_sel; // Predicate representing a range check row used in LT/LTE.

// Instruction tag (1: u8, 2: u16, 3: u32, 4: u64, 5: u128, 6: field) copied from Main table
pol commit in_tag;
Expand Down Expand Up @@ -59,7 +61,7 @@ namespace avm_alu(256);
pol commit cf;

// Compute predicate telling whether there is a row entry in the ALU table.
alu_sel = op_add + op_sub + op_mul + op_not + op_eq + op_lt + op_lte;
alu_sel = op_add + op_sub + op_mul + op_not + op_eq + op_cast + op_lt + op_lte;
cmp_sel = op_lt + op_lte;

// ========= Type Constraints =============================================
Expand Down Expand Up @@ -282,15 +284,15 @@ namespace avm_alu(256);
// (x - y - 1) * q + (y - x) (1 - q) = result

// If LT, then swap ia and ib else keep the same
pol INPUT_IA = op_lt * ib + op_lte * ia;
pol INPUT_IA = op_lt * ib + (op_lte + op_cast) * ia;
pol INPUT_IB = op_lt * ia + op_lte * ib;

pol commit borrow;
pol commit a_lo;
pol commit a_hi;
// Check INPUT_IA is well formed from its lo and hi limbs
#[INPUT_DECOMP_1]
INPUT_IA = (a_lo + 2 ** 128 * a_hi) * cmp_sel;
INPUT_IA = (a_lo + 2 ** 128 * a_hi) * (cmp_sel + op_cast);

pol commit b_lo;
pol commit b_hi;
Expand All @@ -311,9 +313,9 @@ namespace avm_alu(256);
// First condition is if borrow = 0, second condition is if borrow = 1
// This underflow check is done by the 128-bit check that is performed on each of these lo and hi limbs.
#[SUB_LO_1]
(p_sub_a_lo - (53438638232309528389504892708671455232 - a_lo + p_a_borrow * 2 ** 128)) * cmp_sel = 0;
(p_sub_a_lo - (53438638232309528389504892708671455232 - a_lo + p_a_borrow * 2 ** 128)) * (cmp_sel + op_cast) = 0;
#[SUB_HI_1]
(p_sub_a_hi - (64323764613183177041862057485226039389 - a_hi - p_a_borrow)) * cmp_sel = 0;
(p_sub_a_hi - (64323764613183177041862057485226039389 - a_hi - p_a_borrow)) * (cmp_sel + op_cast) = 0;

pol commit p_sub_b_lo;
pol commit p_sub_b_hi;
Expand Down Expand Up @@ -432,11 +434,11 @@ namespace avm_alu(256);
cmp_rng_ctr * ((1 - rng_chk_sel) * (1 - op_eq_diff_inv) + op_eq_diff_inv) - rng_chk_sel = 0;

// We perform a range check if we have some range checks remaining or we are performing a comparison op
pol RNG_CHK_OP = rng_chk_sel + cmp_sel;
pol RNG_CHK_OP = rng_chk_sel + cmp_sel + op_cast + op_cast_prev;

pol commit rng_chk_lookup_selector;
#[RNG_CHK_LOOKUP_SELECTOR]
rng_chk_lookup_selector' = cmp_sel' + rng_chk_sel' + op_add' + op_sub' + op_mul' + op_mul * u128_tag;
rng_chk_lookup_selector' = cmp_sel' + rng_chk_sel' + op_add' + op_sub' + op_mul' + op_mul * u128_tag + op_cast' + op_cast_prev';

// Perform 128-bit range check on lo part
#[LOWER_CMP_RNG_CHK]
Expand Down Expand Up @@ -467,3 +469,44 @@ namespace avm_alu(256);
(p_sub_b_lo' - res_lo) * rng_chk_sel'= 0;
(p_sub_b_hi' - res_hi) * rng_chk_sel'= 0;


// ========= CAST Operation Constraints ===============================
// We handle the input ia independently of its tag, i.e., we suppose it can take
// any value between 0 and p-1.
// We decompose the input ia in 8-bit/16-bit limbs and prove that the decomposition
// sums up to ia over the integers (i.e., no modulo p wrapping). To prove this, we
// re-use techniques above from LT/LTE opcode. The following relations are toggled for CAST:
// - #[INPUT_DECOMP_1] shows a = a_lo + 2 ** 128 * a_hi
// - #[SUB_LO_1] and #[SUB_LO_1] shows that the above does not overflow modulo p.
// - We toggle RNG_CHK_OP with op_cast to show that a_lo, a_hi are correctly decomposed
// over the 8/16-bit ALU registers in #[LOWER_CMP_RNG_CHK] and #[UPPER_CMP_RNG_CHK].
// - The 128-bit range checks for a_lo, a_hi are activated in #[RNG_CHK_LOOKUP_SELECTOR].
// - We copy p_sub_a_lo resp. p_sub_a_hi into next row in columns a_lo resp. a_hi so
// that decomposition into the 8/16-bit ALU registers and range checks are performed.
// Copy is done in #[OP_CAST_RNG_CHECK_P_SUB_A_LOW/HIGH] below.
// Activation of decomposition and range check is achieved by adding op_cast_prev in
// #[LOWER_CMP_RNG_CHK], #[UPPER_CMP_RNG_CHK] and #[RNG_CHK_LOOKUP_SELECTOR].
// - Finally, the truncated result SUM_TAG is copied in ic as part of #[ALU_OP_CAST] below.
// - Note that the tag of return value must be constrained to be in_tag and is enforced in
// the main and memory traces.
//
// TODO: Potential optimization is to un-toggle all CAST relevant operations when ff_tag is
// enabled. In this case, ic = ia trivially.
// Another one is to activate range checks in a more granular way depending on the tag.

#[OP_CAST_PREV_LINE]
op_cast_prev' = op_cast;

#[ALU_OP_CAST]
op_cast * (SUM_TAG + ff_tag * ia - ic) = 0;

#[OP_CAST_RNG_CHECK_P_SUB_A_LOW]
op_cast * (a_lo' - p_sub_a_lo) = 0;

#[OP_CAST_RNG_CHECK_P_SUB_A_HIGH]
op_cast * (a_hi' - p_sub_a_hi) = 0;

// 128-bit multiplication and CAST need two rows in ALU trace. We need to ensure
// that another ALU operation does not start in the second row.
#[TWO_LINE_OP_NO_OVERLAP]
(op_mul * ff_tag + op_cast) * alu_sel' = 0;
35 changes: 26 additions & 9 deletions barretenberg/cpp/pil/avm/avm_main.pil
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ namespace avm_main(256);
pol commit sel_op_or;
// XOR
pol commit sel_op_xor;
// CAST
pol commit sel_op_cast;
// LT
pol commit sel_op_lt;
// LTE
Expand All @@ -68,6 +70,7 @@ namespace avm_main(256);
// Instruction memory tags read/write (1: u8, 2: u16, 3: u32, 4: u64, 5: u128, 6: field)
pol commit r_in_tag;
pol commit w_in_tag;
pol commit alu_in_tag; // Copy of r_in_tag or w_in_tag depending of the operation. It is sent to ALU trace.

// Errors
pol commit op_err; // Boolean flag pertaining to an operation error
Expand Down Expand Up @@ -121,7 +124,8 @@ namespace avm_main(256);
pol commit last;

// Relations on type constraints

// TODO: Very likely, we can remove these constraints as the selectors should be derived during
// opcode decomposition.
sel_op_add * (1 - sel_op_add) = 0;
sel_op_sub * (1 - sel_op_sub) = 0;
sel_op_mul * (1 - sel_op_mul) = 0;
Expand All @@ -131,6 +135,7 @@ namespace avm_main(256);
sel_op_and * (1 - sel_op_and) = 0;
sel_op_or * (1 - sel_op_or) = 0;
sel_op_xor * (1 - sel_op_xor) = 0;
sel_op_cast * (1 - sel_op_cast) = 0;
sel_op_lt * (1 - sel_op_lt) = 0;
sel_op_lte * (1 - sel_op_lte) = 0;

Expand Down Expand Up @@ -243,7 +248,8 @@ namespace avm_main(256);

//===== CONTROL_FLOW_CONSISTENCY ============================================
pol INTERNAL_CALL_STACK_SELECTORS = (first + sel_internal_call + sel_internal_return + sel_halt);
pol OPCODE_SELECTORS = (sel_op_add + sel_op_sub + sel_op_div + sel_op_mul + sel_op_not + sel_op_eq + sel_op_and + sel_op_or + sel_op_xor);
pol OPCODE_SELECTORS = (sel_op_add + sel_op_sub + sel_op_div + sel_op_mul + sel_op_not
+ sel_op_eq + sel_op_and + sel_op_or + sel_op_xor + sel_op_cast);

// Program counter must increment if not jumping or returning
#[PC_INCREMENT]
Expand Down Expand Up @@ -287,24 +293,35 @@ namespace avm_main(256);
#[MOV_MAIN_SAME_TAG]
(sel_mov + sel_cmov) * (r_in_tag - w_in_tag) = 0;

//===== ALU CONSTRAINTS =====================================================
// TODO: when division is moved to the alu, we will need to add the selector in the list below.
pol ALU_R_TAG_SEL = sel_op_add + sel_op_sub + sel_op_mul + sel_op_not + sel_op_eq + sel_op_lt + sel_op_lte;
pol ALU_W_TAG_SEL = sel_op_cast;
pol ALU_ALL_SEL = ALU_R_TAG_SEL + ALU_W_TAG_SEL;

// Predicate to activate the copy of intermediate registers to ALU table. If tag_err == 1,
// the operation is not copied to the ALU table.
alu_sel = ALU_ALL_SEL * (1 - tag_err);

// Dispatch the correct in_tag for alu
ALU_R_TAG_SEL * (alu_in_tag - r_in_tag) = 0;
ALU_W_TAG_SEL * (alu_in_tag - w_in_tag) = 0;

//====== Inter-table Constraints ============================================
#[INCL_MAIN_TAG_ERR]
avm_mem.tag_err {avm_mem.clk} in tag_err {clk};

#[INCL_MEM_TAG_ERR]
tag_err {clk} in avm_mem.tag_err {avm_mem.clk};

// Predicate to activate the copy of intermediate registers to ALU table. If tag_err == 1,
// the operation is not copied to the ALU table.
// TODO: when division is moved to the alu, we will need to add the selector in the list below:
alu_sel = (sel_op_add + sel_op_sub + sel_op_mul + sel_op_not + sel_op_eq + sel_op_lt + sel_op_lte) * (1 - tag_err);

#[PERM_MAIN_ALU]
alu_sel {clk, ia, ib, ic, sel_op_add, sel_op_sub,
sel_op_mul, sel_op_eq, sel_op_not, sel_op_lt, sel_op_lte, r_in_tag}
sel_op_mul, sel_op_eq, sel_op_not, sel_op_cast,
sel_op_lt, sel_op_lte, alu_in_tag}
is
avm_alu.alu_sel {avm_alu.clk, avm_alu.ia, avm_alu.ib, avm_alu.ic, avm_alu.op_add, avm_alu.op_sub,
avm_alu.op_mul, avm_alu.op_eq, avm_alu.op_not, avm_alu.op_lt, avm_alu.op_lte, avm_alu.in_tag};
avm_alu.op_mul, avm_alu.op_eq, avm_alu.op_not, avm_alu.op_cast,
avm_alu.op_lt, avm_alu.op_lte, avm_alu.in_tag};

// Based on the boolean selectors, we derive the binary op id to lookup in the table;
// TODO: Check if having 4 columns (op_id + 3 boolean selectors) is more optimal that just using the op_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,18 @@ template <class Params_> struct alignas(32) field {
return static_cast<bool>(out.data[0]);
}

constexpr explicit operator uint8_t() const
{
field out = from_montgomery_form();
return static_cast<uint8_t>(out.data[0]);
}

constexpr explicit operator uint16_t() const
{
field out = from_montgomery_form();
return static_cast<uint16_t>(out.data[0]);
}

constexpr explicit operator uint32_t() const
{
field out = from_montgomery_form();
Expand Down
Loading

0 comments on commit a821bcc

Please sign in to comment.