diff --git a/barretenberg/cpp/src/barretenberg/vm/tests/avm_bitwise.test.cpp b/barretenberg/cpp/src/barretenberg/vm/tests/avm_bitwise.test.cpp index 16789ac3dc6..d0d8e9032f1 100644 --- a/barretenberg/cpp/src/barretenberg/vm/tests/avm_bitwise.test.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/tests/avm_bitwise.test.cpp @@ -216,6 +216,66 @@ enum BIT_FAILURES { IncorrectBinSelector, }; +enum SHIFT_FAILURES { + IncorrectShiftPastBitLength, // Incorrect Setting shift_lt_bit_len + IncorrectInputDecomposition, + ShiftOutputIncorrect, +}; + +std::tuple, std::string> gen_mutated_trace_shift(std::vector trace, + std::function&& select_row, + FF const& c_mutated, + SHIFT_FAILURES fail_mode, + bool shr = true) +{ + auto main_trace_row = std::ranges::find_if(trace.begin(), trace.end(), select_row); + auto main_clk = main_trace_row->avm_main_clk; + // The corresponding row in the alu trace as well as the row where start = 1 + auto alu_row = + std::ranges::find_if(trace.begin(), trace.end(), [main_clk](Row r) { return r.avm_alu_clk == main_clk; }); + + std::string failure; + switch (fail_mode) { + case IncorrectShiftPastBitLength: + alu_row->avm_alu_shift_lt_bit_len = FF(0); + update_slice_registers(*alu_row, uint256_t{ 0 }); + alu_row->avm_alu_a_lo = FF(0); + alu_row->avm_alu_a_hi = FF(0); + failure = "SHIFT_LT_BIT_LEN"; + return std::make_tuple(trace, failure); + case IncorrectInputDecomposition: { + // Add one to b_lo and update b_lo + uint256_t b_lo = alu_row->avm_alu_b_lo + 1; + uint256_t b_hi = alu_row->avm_alu_b_hi; + alu_row->avm_alu_b_lo = b_lo; + + // Update the range checks involving b_lo and b_hi so we dont throw an error about the range checks + if (shr) { + uint256_t a_lo = (uint256_t(1) << alu_row->avm_alu_ib) - b_lo; + uint256_t a_hi = (uint256_t(1) << (32 - uint8_t(alu_row->avm_alu_ib))) - b_hi; + alu_row->avm_alu_a_lo = a_lo & ((uint256_t(1) << 128) - 1); + alu_row->avm_alu_a_hi = a_hi; + // Update slice registers + update_slice_registers(*alu_row, a_lo + (a_hi << 128)); + failure = "CHECK_INPUT_DECOMPOSITION_0"; + return std::make_tuple(trace, failure); + } + uint256_t a_lo = (uint256_t(1) << (32 - uint8_t(alu_row->avm_alu_ib))) - b_lo; + uint256_t a_hi = (uint256_t(1) << alu_row->avm_alu_ib) - b_hi; + alu_row->avm_alu_a_lo = a_lo & ((uint256_t(1) << 128) - 1); + alu_row->avm_alu_a_hi = a_hi; + // Update slice registers + update_slice_registers(*alu_row, a_lo + (a_hi << 128)); + failure = "CHECK_INPUT_DECOMPOSITION_1"; + return std::make_tuple(trace, failure); + } + case ShiftOutputIncorrect: + alu_row->avm_alu_ic = c_mutated; + failure = shr ? "SHR_OUTPUT_0" : "SHL_OUTPUT_1"; + return std::make_tuple(trace, failure); + } + return std::make_tuple(trace, failure); +} std::vector gen_mutated_trace_bit(std::vector trace, std::function&& select_row, FF const& c_mutated, @@ -554,6 +614,10 @@ class AvmBitwiseNegativeTestsOr : public AvmBitwiseTests, public testing::WithParamInterface> {}; class AvmBitwiseNegativeTestsXor : public AvmBitwiseTests, public testing::WithParamInterface> {}; +class AvmBitwiseNegativeTestsShr : public AvmBitwiseTests, + public testing::WithParamInterface> {}; +class AvmBitwiseNegativeTestsShl : public AvmBitwiseTests, + public testing::WithParamInterface> {}; class AvmBitwiseNegativeTestsFF : public AvmBitwiseTests {}; class AvmBitwiseNegativeTestsU8 : public AvmBitwiseTests {}; class AvmBitwiseNegativeTestsU16 : public AvmBitwiseTests {}; @@ -570,12 +634,17 @@ std::vector> bit_failures = { { "OP_ID_REL", BIT_FAILURES::InconsistentOpId }, { "BIN_SEL_CTR_REL", BIT_FAILURES::IncorrectBinSelector }, }; +std::vector shift_failures = { SHIFT_FAILURES::IncorrectShiftPastBitLength, + SHIFT_FAILURES::IncorrectInputDecomposition, + SHIFT_FAILURES::ShiftOutputIncorrect }; // For the negative test the output is set to be incorrect so that we can test the byte lookups. // Picking "simple" inputs such as zero also makes it easier when check the byte length lookups as we dont // need to worry about copying the accmulated a & b registers into the main trace. std::vector neg_test_and = { { { 0, 0, 1 }, AvmMemoryTag::U32 } }; std::vector neg_test_or = { { { 0, 0, 1 }, AvmMemoryTag::U32 } }; std::vector neg_test_xor = { { { 0, 0, 1 }, AvmMemoryTag::U32 } }; + +std::vector neg_test_shr = { { { 7, 2, 0 }, AvmMemoryTag::U32 } }; /****************************************************************************** * Negative Tests - FF ******************************************************************************/ @@ -641,6 +710,48 @@ INSTANTIATE_TEST_SUITE_P(AvmBitwiseNegativeTests, AvmBitwiseNegativeTestsXor, testing::Combine(testing::ValuesIn(bit_failures), testing::ValuesIn(neg_test_xor))); +TEST_P(AvmBitwiseNegativeTestsShr, AllNegativeTests) +{ + const auto [failure, params] = GetParam(); + const auto [operands, mem_tag] = params; + const auto [a, b, output] = operands; + auto trace_builder = avm_trace::AvmTraceBuilder(); + trace_builder.op_set(0, uint128_t{ a }, 0, mem_tag); + trace_builder.op_set(0, uint128_t{ b }, 1, mem_tag); + trace_builder.op_shr(0, 0, 1, 2, mem_tag); + trace_builder.halt(); + auto trace = trace_builder.finalize(); + std::function&& select_row = [](Row r) { return r.avm_main_sel_op_shr == FF(1); }; + + auto [mutated_trace, str] = gen_mutated_trace_shift( + std::move(trace), std::move(select_row), FF(uint256_t::from_uint128(output)), failure, true); + EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(mutated_trace)), str); +} +INSTANTIATE_TEST_SUITE_P(AvmBitwiseNegativeTests, + AvmBitwiseNegativeTestsShr, + testing::Combine(testing::ValuesIn(shift_failures), testing::ValuesIn(neg_test_shr))); + +TEST_P(AvmBitwiseNegativeTestsShl, AllNegativeTests) +{ + const auto [failure, params] = GetParam(); + const auto [operands, mem_tag] = params; + const auto [a, b, output] = operands; + auto trace_builder = avm_trace::AvmTraceBuilder(); + trace_builder.op_set(0, uint128_t{ a }, 0, mem_tag); + trace_builder.op_set(0, uint128_t{ b }, 1, mem_tag); + trace_builder.op_shl(0, 0, 1, 2, mem_tag); + trace_builder.halt(); + auto trace = trace_builder.finalize(); + std::function&& select_row = [](Row r) { return r.avm_main_sel_op_shl == FF(1); }; + + auto [mutated_trace, str] = gen_mutated_trace_shift( + std::move(trace), std::move(select_row), FF(uint256_t::from_uint128(output)), failure, false); + EXPECT_THROW_WITH_MESSAGE(validate_trace_check_circuit(std::move(mutated_trace)), str); +} +INSTANTIATE_TEST_SUITE_P(AvmBitwiseNegativeTests, + AvmBitwiseNegativeTestsShl, + testing::Combine(testing::ValuesIn(shift_failures), testing::ValuesIn(neg_test_shr))); + TEST_F(AvmBitwiseNegativeTestsFF, UndefinedOverFF) { auto trace_builder = avm_trace::AvmTraceBuilder(); diff --git a/barretenberg/cpp/src/barretenberg/vm/tests/helpers.test.cpp b/barretenberg/cpp/src/barretenberg/vm/tests/helpers.test.cpp index c91374202df..290b15585a0 100644 --- a/barretenberg/cpp/src/barretenberg/vm/tests/helpers.test.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/tests/helpers.test.cpp @@ -78,6 +78,44 @@ void mutate_ic_in_trace(std::vector& trace, std::function&& sele mem_row->avm_mem_val = newValue; }; +// TODO: Should be a cleaner way to do this +void update_slice_registers(Row& row, uint256_t a) +{ + row.avm_alu_u8_r0 = static_cast(a); + a >>= 8; + row.avm_alu_u8_r1 = static_cast(a); + a >>= 8; + row.avm_alu_u16_r0 = static_cast(a); + a >>= 16; + row.avm_alu_u16_r1 = static_cast(a); + a >>= 16; + row.avm_alu_u16_r2 = static_cast(a); + a >>= 16; + row.avm_alu_u16_r3 = static_cast(a); + a >>= 16; + row.avm_alu_u16_r4 = static_cast(a); + a >>= 16; + row.avm_alu_u16_r5 = static_cast(a); + a >>= 16; + row.avm_alu_u16_r6 = static_cast(a); + a >>= 16; + row.avm_alu_u16_r7 = static_cast(a); + a >>= 16; + row.avm_alu_u16_r8 = static_cast(a); + a >>= 16; + row.avm_alu_u16_r9 = static_cast(a); + a >>= 16; + row.avm_alu_u16_r10 = static_cast(a); + a >>= 16; + row.avm_alu_u16_r11 = static_cast(a); + a >>= 16; + row.avm_alu_u16_r12 = static_cast(a); + a >>= 16; + row.avm_alu_u16_r13 = static_cast(a); + a >>= 16; + row.avm_alu_u16_r14 = static_cast(a); +} + // TODO: There has to be a better way to do. // This is a helper function to clear the range check counters associated with the alu register decomposition of // "previous_value" so we don't trigger a trivial range_check count error diff --git a/barretenberg/cpp/src/barretenberg/vm/tests/helpers.test.hpp b/barretenberg/cpp/src/barretenberg/vm/tests/helpers.test.hpp index b1f4df3924a..fd1f862404d 100644 --- a/barretenberg/cpp/src/barretenberg/vm/tests/helpers.test.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/tests/helpers.test.hpp @@ -29,5 +29,6 @@ void mutate_ic_in_trace(std::vector& trace, FF const& newValue, bool alu = false); void clear_range_check_counters(std::vector& trace, uint256_t previous_value); +void update_slice_registers(Row& row, uint256_t a); } // namespace tests_avm