Skip to content

Commit

Permalink
feat(avm): shift relations
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasRidhuan committed Apr 19, 2024
1 parent e14cffa commit 7478991
Show file tree
Hide file tree
Showing 22 changed files with 1,890 additions and 344 deletions.
1 change: 1 addition & 0 deletions barretenberg/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED TRUE)
set(CMAKE_CXX_EXTENSIONS ON)

if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
add_compile_options(-fbracket-depth=512)
if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "14")
message(WARNING "Clang <14 is not supported")
endif()
Expand Down
118 changes: 107 additions & 11 deletions barretenberg/cpp/pil/avm/avm_alu.pil
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ namespace avm_alu(256);
pol commit op_lt;
pol commit op_lte;
pol commit cmp_sel; // Predicate if LT or LTE is set
pol commit rng_chk_sel; // Predicate representing a range check row used in LT/LTE.
pol commit rng_chk_sel; // Predicate representing a range check row.
pol commit op_shl;
pol commit op_shr;
pol commit shift_sel; // Predicate if SHR or SHR is set

// Instruction tag (1: u8, 2: u16, 3: u32, 4: u64, 5: u128, 6: field) copied from Main table
pol commit in_tag;
Expand Down Expand Up @@ -61,8 +64,9 @@ namespace avm_alu(256);
pol commit cf;

// Compute predicate telling whether there is a row entry in the ALU table.
alu_sel = op_add + op_sub + op_mul + op_not + op_eq + op_cast + op_lt + op_lte;
alu_sel = op_add + op_sub + op_mul + op_not + op_eq + op_cast + op_lt + op_lte + op_shr + op_shl;
cmp_sel = op_lt + op_lte;
shift_sel = op_shl + op_shr;

// ========= Type Constraints =============================================
// TODO: Range constraints
Expand Down Expand Up @@ -355,7 +359,7 @@ namespace avm_alu(256);
// (b) IS_GT = 1 - ic = 0
// (c) res_lo = B_SUB_A_LO and res_hi = B_SUB_A_HI
// (d) res_lo = y_lo - x_lo + borrow * 2**128 and res_hi = y_hi - x_hi - borrow.
// (e) Due to 128-bit range checks on res_lo, res_hi, y_lo, x_lo, y_hi, y_lo, we
// (e) Due to 128-bit range checks on res_lo, res_hi, y_lo, x_lo, y_hi, x_hi, we
// have the guarantee that res_lo >= 0 && res_hi >= 0. Furthermore, borrow is
// boolean and so we have two cases to consider:
// (i) borrow == 0 ==> y_lo >= x_lo && y_hi >= x_hi
Expand All @@ -368,7 +372,7 @@ namespace avm_alu(256);
// (b) IS_GT = 1 - ic = 1
// (c) res_lo = A_SUB_B_LO and res_hi = A_SUB_B_HI
// (d) res_lo = x_lo - y_lo - 1 + borrow * 2**128 and res_hi = x_hi - y_hi - borrow.
// (e) Due to 128-bit range checks on res_lo, res_hi, y_lo, x_lo, y_hi, y_lo, we
// (e) Due to 128-bit range checks on res_lo, res_hi, y_lo, x_lo, y_hi, x_hi, we
// have the guarantee that res_lo >= 0 && res_hi >= 0. Furthermore, borrow is
// boolean and so we have two cases to consider:
// (i) borrow == 0 ==> x_lo > y_lo && x_hi >= y_hi
Expand All @@ -383,7 +387,7 @@ namespace avm_alu(256);
// (b) IS_GT = ic = 1
// (c) res_lo = A_SUB_B_LO and res_hi = A_SUB_B_HI, **remember we have swapped inputs**
// (d) res_lo = y_lo - x_lo - 1 + borrow * 2**128 and res_hi = y_hi - x_hi - borrow.
// (e) Due to 128-bit range checks on res_lo, res_hi, y_lo, x_lo, y_hi, y_lo, we
// (e) Due to 128-bit range checks on res_lo, res_hi, y_lo, x_lo, y_hi, x_hi, we
// have the guarantee that res_lo >= 0 && res_hi >= 0. Furthermore, borrow is
// boolean and so we have two cases to consider:
// (i) borrow == 0 ==> y_lo > x_lo && y_hi >= x_hi
Expand All @@ -395,8 +399,8 @@ namespace avm_alu(256);
// (a) We DO swap the operands, so a = y and b = x,
// (b) IS_GT = ic = 0
// (c) res_lo = B_SUB_A_LO and res_hi = B_SUB_A_HI, **remember we have swapped inputs**
// (d) res_lo = x_lo - y_lo + borrow * 2**128 and res_hi = x_hi - y_hi - borrow.
// (e) Due to 128-bit range checks on res_lo, res_hi, y_lo, x_lo, y_hi, y_lo, we
// (d) res_lo = a_lo - y_lo + borrow * 2**128 and res_hi = a_hi - y_hi - borrow.
// (e) Due to 128-bit range checks on res_lo, res_hi, y_lo, x_lo, y_hi, x_hi, we
// have the guarantee that res_lo >= 0 && res_hi >= 0. Furthermore, borrow is
// boolean and so we have two cases to consider:
// (i) borrow == 0 ==> x_lo >= y_lo && x_hi >= y_hi
Expand Down Expand Up @@ -434,11 +438,13 @@ namespace avm_alu(256);
cmp_rng_ctr * ((1 - rng_chk_sel) * (1 - op_eq_diff_inv) + op_eq_diff_inv) - rng_chk_sel = 0;

// We perform a range check if we have some range checks remaining or we are performing a comparison op
pol RNG_CHK_OP = rng_chk_sel + cmp_sel + op_cast + op_cast_prev;
pol RNG_CHK_OP = rng_chk_sel + cmp_sel + op_cast + op_cast_prev + shift_lt_bit_len;

pol commit rng_chk_lookup_selector;
// TODO: Possible optimisation here if we swap the op_shl and op_shr with shift_lt_bit_len.
// Shift_lt_bit_len is a more restrictive form therefore we can avoid performing redundant range checks when we know the result == 0.
#[RNG_CHK_LOOKUP_SELECTOR]
rng_chk_lookup_selector' = cmp_sel' + rng_chk_sel' + op_add' + op_sub' + op_mul' + op_mul * u128_tag + op_cast' + op_cast_prev';
rng_chk_lookup_selector' = cmp_sel' + rng_chk_sel' + op_add' + op_sub' + op_mul' + op_mul * u128_tag + op_cast' + op_cast_prev' + op_shl' + op_shr';

// Perform 128-bit range check on lo part
#[LOWER_CMP_RNG_CHK]
Expand Down Expand Up @@ -469,7 +475,6 @@ namespace avm_alu(256);
(p_sub_b_lo' - res_lo) * rng_chk_sel'= 0;
(p_sub_b_hi' - res_hi) * rng_chk_sel'= 0;


// ========= CAST Operation Constraints ===============================
// We handle the input ia independently of its tag, i.e., we suppose it can take
// any value between 0 and p-1.
Expand Down Expand Up @@ -509,4 +514,95 @@ namespace avm_alu(256);
// 128-bit multiplication and CAST need two rows in ALU trace. We need to ensure
// that another ALU operation does not start in the second row.
#[TWO_LINE_OP_NO_OVERLAP]
(op_mul * ff_tag + op_cast) * alu_sel' = 0;
(op_mul * ff_tag + op_cast) * alu_sel' = 0;

// ========= SHIFT LEFT/RIGHT OPERATIONS ===============================
// Given inputs to a shift operation, a & b, and a memory tag, mem_tag.
// Split a into Big Endian hi and lo limbs, a_hi and a_lo, and the number of bits represented by the memory tag, t.
// QUESTION: SHOULD B BE CONSTRAINED TO BE U8 -> i.e. when would we shift more than 255 bits when the max number of bits of a is 128bits?
// If we are shifting by more than the bit length represented by the memory tag, the result is trivially zero
//
// === Steps when performing SHR
// (1) Prove the correct decomposition: a_hi * 2**b + a_lo = a
// (2) Range check a_hi < 2**(t-b) && a_lo < 2**b, ensure we have not overflowed the limbs during decomp
// (3) Return a_hi
//
// <--(t-b) bits --> | <-- b bits -->
// -------------------|-------------------
// | a_hi | a_lo | --> a
// ---------------------------------------
//
// === Steps when performing SHL
// (1) Prove the correct decomposition: a_hi * 2**(t-b) + a_lo = a
// (2) Range check a_hi < 2**b && a_lo < 2**(t-b)
// (3) Return a_lo * 2**b
//
// <-- b bits --> | <-- (t-b) bits -->
// ------------------|-------------------
// | a_hi | a_lo | --> a
// --------------------------------------

// TODO: Possibly optimised with variable length checks
// Check that a_lo and a_hi are range checked such that that:
// SHR: a_hi < 2**(t - b) && a_lo < 2**b
// SHL: a_hi < 2**b && a_lo < 2**(t - b)

// In lieu of a variable length check, we can utilise 2 fixed range checks instead.
// Given the dynamic range check of 0 <= a_hi <= 2**(t-b), where b < t
// (1) 0 <= a_hi <= 2**t
// (2) 0 <= 2**(t - b) - a_hi <= 2**t
// Note that (1) is guaranteed elsewhere through the tagged memory model, so we focus on (2) here.

// === General Notes:
// There are probably ways to merge various relations for the SHL/SHR, but they are separate
// now while we are still figuring out.

// Indicate if the shift amount < MAX_BITS
pol commit shift_lt_bit_len;
shift_lt_bit_len * (1 - shift_lt_bit_len) = 0;

// The number of bits represented by the memory tag, any shifts greater than this will result in zero.
pol MAX_BITS = u8_tag * 8 +
u16_tag * 16 +
u32_tag * 32 +
u64_tag * 64 +
u128_tag * 128;

// The result of MAX_BITS - ib
pol commit t_sub_b_bits;

// For our assumptions to hold, we must check that b < MAX_BITS. This can be achieved by the following relation.
// We check if b < MAX_BITS || b >= MAX_BITS using boolean shift_le_bit_len.
pol SHIFT_DIFF = shift_lt_bit_len * (MAX_BITS - ib - 1) + (1 - shift_lt_bit_len) * (ib - MAX_BITS);

// Regardless of which side is evaluated, the value of t_sub_b_bits < 2**8
// so it is automatically range checked by the lookup for 2**t_sub_b_bits
#[SHIFT_LT_BIT_LEN]
t_sub_b_bits = shift_sel * (SHIFT_DIFF + shift_lt_bit_len); // Re-add the 1

// Lookups for powers of 2.
// 2**(MAX_BITS - ib)
pol commit two_pow_t_sub_b;
// 2 ** ib
pol commit two_pow_b;

// ========= SHIFT RIGHT OPERATIONS ===============================
// a_hi * 2**b + a_lo = a
// If ib > MAX_BITS, we trivially skip this check since the result will be forced to 0.
#[CHECK_INPUT_DECOMPOSITION_0]
shift_lt_bit_len * op_shr * ((b_hi * two_pow_b + b_lo) - ia) = 0;

// Return hi limb, if ib > MAX_BITS, the output is forced to be 0
#[SHR_OUTPUT_0]
shift_lt_bit_len * op_shr * (b_hi - ic) = 0;

// ========= SHIFT LEFT OPERATIONS ===============================
// a_hi * 2**(t-b) + a_lo = a
// If ib > MAX_BITS, we trivially skip this check since the result will be forced to 0.
#[CHECK_INPUT_DECOMPOSITION_1]
shift_lt_bit_len * op_shl * ((b_hi * two_pow_t_sub_b + b_lo) - ia) = 0;

// Return lo limb a_lo * 2**b, if ib > MAX_BITS, the output is forced to be 0
#[SHL_OUTPUT_1]
shift_lt_bit_len * (op_shl * (b_lo * two_pow_b - ic)) = 0;

26 changes: 23 additions & 3 deletions barretenberg/cpp/pil/avm/avm_main.pil
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ namespace avm_main(256);
pol commit sel_rng_8; // Boolean selector for the 8-bit range check lookup
pol commit sel_rng_16; // Boolean selector for the 16-bit range check lookup

//===== Lookup table powers of 2 =============================================
pol commit table_pow_2; // Table of powers of 2 for 8-bit numbers.

//===== CONTROL FLOW ==========================================================
// Program counter
pol commit pc;
Expand Down Expand Up @@ -60,6 +63,10 @@ namespace avm_main(256);
pol commit sel_op_lt;
// LTE
pol commit sel_op_lte;
// SHL
pol commit sel_op_shl;
// SHR
pol commit sel_op_shr;

// Helper selector to characterize an ALU chiplet selector
pol commit alu_sel;
Expand Down Expand Up @@ -138,6 +145,8 @@ namespace avm_main(256);
sel_op_cast * (1 - sel_op_cast) = 0;
sel_op_lt * (1 - sel_op_lt) = 0;
sel_op_lte * (1 - sel_op_lte) = 0;
sel_op_shl * (1 - sel_op_shl) = 0;
sel_op_shr * (1 - sel_op_shr) = 0;

sel_internal_call * (1 - sel_internal_call) = 0;
sel_internal_return * (1 - sel_internal_return) = 0;
Expand Down Expand Up @@ -295,7 +304,7 @@ namespace avm_main(256);

//===== ALU CONSTRAINTS =====================================================
// TODO: when division is moved to the alu, we will need to add the selector in the list below.
pol ALU_R_TAG_SEL = sel_op_add + sel_op_sub + sel_op_mul + sel_op_not + sel_op_eq + sel_op_lt + sel_op_lte;
pol ALU_R_TAG_SEL = sel_op_add + sel_op_sub + sel_op_mul + sel_op_not + sel_op_eq + sel_op_lt + sel_op_lte + sel_op_shr + sel_op_shl;
pol ALU_W_TAG_SEL = sel_op_cast;
pol ALU_ALL_SEL = ALU_R_TAG_SEL + ALU_W_TAG_SEL;

Expand All @@ -317,11 +326,11 @@ namespace avm_main(256);
#[PERM_MAIN_ALU]
alu_sel {clk, ia, ib, ic, sel_op_add, sel_op_sub,
sel_op_mul, sel_op_eq, sel_op_not, sel_op_cast,
sel_op_lt, sel_op_lte, alu_in_tag}
sel_op_lt, sel_op_lte, sel_op_shr, sel_op_shl, alu_in_tag}
is
avm_alu.alu_sel {avm_alu.clk, avm_alu.ia, avm_alu.ib, avm_alu.ic, avm_alu.op_add, avm_alu.op_sub,
avm_alu.op_mul, avm_alu.op_eq, avm_alu.op_not, avm_alu.op_cast,
avm_alu.op_lt, avm_alu.op_lte, avm_alu.in_tag};
avm_alu.op_lt, avm_alu.op_lte, avm_alu.op_shr, avm_alu.op_shl, avm_alu.in_tag};

// Based on the boolean selectors, we derive the binary op id to lookup in the table;
// TODO: Check if having 4 columns (op_id + 3 boolean selectors) is more optimal that just using the op_id
Expand Down Expand Up @@ -379,6 +388,17 @@ namespace avm_main(256);
#[PERM_MAIN_MEM_IND_D]
ind_op_d {clk, ind_d, mem_idx_d} is avm_mem.ind_op_d {avm_mem.clk, avm_mem.addr, avm_mem.val};

//====== Inter-table Shift Constraints (Lookups) ============================================
// Currently only used for shift operations but can be generalised for other uses.

// Lookup for 2**(ib)
#[LOOKUP_POW_2_0]
avm_alu.shift_sel {avm_alu.ib, avm_alu.two_pow_b} in sel_rng_8 {clk, table_pow_2};

// Lookup for 2**(t-ib)
#[LOOKUP_POW_2_1]
avm_alu.shift_sel {avm_alu.t_sub_b_bits , avm_alu.two_pow_t_sub_b} in sel_rng_8 {clk, table_pow_2};

//====== Inter-table Constraints (Range Checks) ============================================
// TODO: Investigate optimising these range checks. Handling non-FF elements should require less range checks.
// One can increase the granularity based on the operation and tag. In the most extreme case,
Expand Down
Loading

0 comments on commit 7478991

Please sign in to comment.