From 14b9736925c6da33133bd24ee283fb4c199082a5 Mon Sep 17 00:00:00 2001 From: Lucas Xia Date: Fri, 1 Dec 2023 15:46:23 -0500 Subject: [PATCH] feat: new Poseidon2 relations (#3406) New Poseidon2 relations for efficient stdlib implementation of Poseidon2. We create new relations for the external and internal rounds of Poseidon2, in order to be able to execute a permutation with t=4 in just 64 rounds. Added new tests, including consistency checks and manual hardcoded tests. Resolves https://github.com/AztecProtocol/barretenberg/issues/775. --- .../crypto/poseidon2/CMakeLists.txt | 2 +- .../poseidon2/poseidon2_permutation.hpp | 4 +- .../relations/poseidon2_external_relation.hpp | 117 +++++++++++ .../relations/poseidon2_internal_relation.hpp | 97 +++++++++ .../relations/relation_manual.test.cpp | 168 +++++++++++++++ .../ultra_relation_consistency.test.cpp | 196 +++++++++++++++--- 6 files changed, 548 insertions(+), 36 deletions(-) create mode 100644 barretenberg/cpp/src/barretenberg/relations/poseidon2_external_relation.hpp create mode 100644 barretenberg/cpp/src/barretenberg/relations/poseidon2_internal_relation.hpp create mode 100644 barretenberg/cpp/src/barretenberg/relations/relation_manual.test.cpp diff --git a/barretenberg/cpp/src/barretenberg/crypto/poseidon2/CMakeLists.txt b/barretenberg/cpp/src/barretenberg/crypto/poseidon2/CMakeLists.txt index dc0157be3a9..e8d8f8e2013 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/poseidon2/CMakeLists.txt +++ b/barretenberg/cpp/src/barretenberg/crypto/poseidon2/CMakeLists.txt @@ -1 +1 @@ -barretenberg_module(crypto_poseidon2 ecc numeric) \ No newline at end of file +barretenberg_module(crypto_poseidon2 ecc) diff --git a/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2_permutation.hpp b/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2_permutation.hpp index 40606b4887b..9e3931cdac3 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2_permutation.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/poseidon2/poseidon2_permutation.hpp @@ -70,8 +70,8 @@ template class Poseidon2Permutation { auto t5 = t0 + t0; t5 += t5; t5 += t2; // 4A + 6B + C + D - auto t6 = t3 + t5; // 5A + 7B + 3C + D - auto t7 = t2 + t4; // A + 3B + 5D + 7C + auto t6 = t3 + t5; // 5A + 7B + C + 3D + auto t7 = t2 + t4; // A + 3B + 5C + 7D input[0] = t6; input[1] = t5; input[2] = t7; diff --git a/barretenberg/cpp/src/barretenberg/relations/poseidon2_external_relation.hpp b/barretenberg/cpp/src/barretenberg/relations/poseidon2_external_relation.hpp new file mode 100644 index 00000000000..aba162e56af --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/relations/poseidon2_external_relation.hpp @@ -0,0 +1,117 @@ +#pragma once +#include "barretenberg/relations/relation_types.hpp" +namespace proof_system { + +template class Poseidon2ExternalRelationImpl { + public: + using FF = FF_; + + static constexpr std::array SUBRELATION_PARTIAL_LENGTHS{ + 7, // external poseidon2 round sub-relation for first value + 7, // external poseidon2 round sub-relation for second value + 7, // external poseidon2 round sub-relation for third value + 7, // external poseidon2 round sub-relation for fourth value + }; + + /** + * @brief Expression for the poseidon2 external round relation, based on E_i in Section 6 of + * https://eprint.iacr.org/2023/323.pdf. + * @details This relation is defined as C(in(X)...) := + * q_poseidon2_external * ( (v1 - w_1_shift) + \alpha * (v2 - w_2_shift) + + * \alpha^2 * (v3 - w_3_shift) + \alpha^3 * (v4 - w_4_shift) ) = 0 where: + * u1 := (w_1 + q_1)^5 + * u2 := (w_2 + q_2)^5 + * u3 := (w_3 + q_3)^5 + * u4 := (w_4 + q_4)^5 + * t0 := u1 + u2 (1, 1, 0, 0) + * t1 := u3 + u4 (0, 0, 1, 1) + * t2 := 2 * u2 + t1 = 2 * u2 + u3 + u4 (0, 2, 1, 1) + * t3 := 2 * u4 + t0 = u1 + u2 + 2 * u4 (1, 1, 0, 2) + * v4 := 4 * t1 + t3 = u1 + u2 + 4 * u3 + 6 * u4 (1, 1, 4, 6) + * v2 := 4 * t0 + t2 = 4 * u1 + 6 * u2 + u3 + u4 (4, 6, 1, 1) + * v1 := t3 + v2 = 5 * u1 + 7 * u2 + 1 * u3 + 3 * u4 (5, 7, 1, 3) + * v3 := t2 + v4 (1, 3, 5, 7) + * + * @param evals transformed to `evals + C(in(X)...)*scaling_factor` + * @param in an std::array containing the fully extended Univariate edges. + * @param parameters contains beta, gamma, and public_input_delta, .... + * @param scaling_factor optional term to scale the evaluation before adding to evals. + */ + template + void static accumulate(ContainerOverSubrelations& evals, + const AllEntities& in, + const Parameters&, + const FF& scaling_factor) + { + using Accumulator = std::tuple_element_t<0, ContainerOverSubrelations>; + using View = typename Accumulator::View; + auto w_l = View(in.w_l); + auto w_r = View(in.w_r); + auto w_o = View(in.w_o); + auto w_4 = View(in.w_4); + auto w_l_shift = View(in.w_l_shift); + auto w_r_shift = View(in.w_r_shift); + auto w_o_shift = View(in.w_o_shift); + auto w_4_shift = View(in.w_4_shift); + auto q_l = View(in.q_l); + auto q_r = View(in.q_r); + auto q_o = View(in.q_o); + auto q_4 = View(in.q_4); + auto q_poseidon2_external = View(in.q_poseidon2_external); + + // add round constants which are loaded in selectors + auto s1 = w_l + q_l; + auto s2 = w_r + q_r; + auto s3 = w_o + q_o; + auto s4 = w_4 + q_4; + + // apply s-box round + auto u1 = s1 * s1; + u1 *= u1; + u1 *= s1; + auto u2 = s2 * s2; + u2 *= u2; + u2 *= s2; + auto u3 = s3 * s3; + u3 *= u3; + u3 *= s3; + auto u4 = s4 * s4; + u4 *= u4; + u4 *= s4; + + // matrix mul v = M_E * u with 14 additions + auto t0 = u1 + u2; // u_1 + u_2 + auto t1 = u3 + u4; // u_3 + u_4 + auto t2 = u2 + u2; // 2u_2 + t2 += t1; // 2u_2 + u_3 + u_4 + auto t3 = u4 + u4; // 2u_4 + t3 += t0; // u_1 + u_2 + 2u_4 + auto v4 = t1 + t1; + v4 += v4; + v4 += t3; // u_1 + u_2 + 4u_3 + 6u_4 + auto v2 = t0 + t0; + v2 += v2; + v2 += t2; // 4u_1 + 6u_2 + u_3 + u_4 + auto v1 = t3 + v2; // 5u_1 + 7u_2 + u_3 + 3u_4 + auto v3 = t2 + v4; // u_1 + 3u_2 + 5u_3 + 7u_4 + + auto tmp = q_poseidon2_external * (v1 - w_l_shift); + tmp *= scaling_factor; + std::get<0>(evals) += tmp; + + tmp = q_poseidon2_external * (v2 - w_r_shift); + tmp *= scaling_factor; + std::get<1>(evals) += tmp; + + tmp = q_poseidon2_external * (v3 - w_o_shift); + tmp *= scaling_factor; + std::get<2>(evals) += tmp; + + tmp = q_poseidon2_external * (v4 - w_4_shift); + tmp *= scaling_factor; + std::get<3>(evals) += tmp; + }; +}; + +template using Poseidon2ExternalRelation = Relation>; +} // namespace proof_system \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/relations/poseidon2_internal_relation.hpp b/barretenberg/cpp/src/barretenberg/relations/poseidon2_internal_relation.hpp new file mode 100644 index 00000000000..1ec20c0956b --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/relations/poseidon2_internal_relation.hpp @@ -0,0 +1,97 @@ +#pragma once +#include "barretenberg/crypto/poseidon2/poseidon2_params.hpp" +#include "relation_types.hpp" + +namespace proof_system { + +template class Poseidon2InternalRelationImpl { + public: + using FF = FF_; + + static constexpr std::array SUBRELATION_PARTIAL_LENGTHS{ + 7, // internal poseidon2 round sub-relation for first value + 7, // internal poseidon2 round sub-relation for second value + 7, // internal poseidon2 round sub-relation for third value + 7, // internal poseidon2 round sub-relation for fourth value + }; + + /** + * @brief Expression for the poseidon2 internal round relation, based on I_i in Section 6 of + * https://eprint.iacr.org/2023/323.pdf. + * @details This relation is defined as C(in(X)...) := + * q_poseidon2_internal * ( (v1 - w_1_shift) + \alpha * (v2 - w_2_shift) + + * \alpha^2 * (v3 - w_3_shift) + \alpha^3 * (v4 - w_4_shift) ) = 0 where: + * u1 := (w_1 + q_1)^5 + * sum := u1 + w_2 + w_3 + w_4 + * v1 := u1 * D1 + sum + * v2 := w_2 * D2 + sum + * v3 := w_3 * D3 + sum + * v4 := w_4 * D4 + sum + * Di is the ith internal diagonal value - 1 of the internal matrix M_I + * + * @param evals transformed to `evals + C(in(X)...)*scaling_factor` + * @param in an std::array containing the fully extended Univariate edges. + * @param parameters contains beta, gamma, and public_input_delta, .... + * @param scaling_factor optional term to scale the evaluation before adding to evals. + */ + template + void static accumulate(ContainerOverSubrelations& evals, + const AllEntities& in, + const Parameters&, + const FF& scaling_factor) + { + using Accumulator = std::tuple_element_t<0, ContainerOverSubrelations>; + using View = typename Accumulator::View; + auto w_l = View(in.w_l); + auto w_r = View(in.w_r); + auto w_o = View(in.w_o); + auto w_4 = View(in.w_4); + auto w_l_shift = View(in.w_l_shift); + auto w_r_shift = View(in.w_r_shift); + auto w_o_shift = View(in.w_o_shift); + auto w_4_shift = View(in.w_4_shift); + auto q_l = View(in.q_l); + auto q_poseidon2_internal = View(in.q_poseidon2_internal); + + // add round constants + auto s1 = w_l + q_l; + + // apply s-box round + auto u1 = s1 * s1; + u1 *= u1; + u1 *= s1; + auto u2 = w_r; + auto u3 = w_o; + auto u4 = w_4; + + // matrix mul with v = M_I * u 4 muls and 7 additions + auto sum = u1 + u2 + u3 + u4; + + auto v1 = u1 * crypto::Poseidon2Bn254ScalarFieldParams::internal_matrix_diagonal[0]; + v1 += sum; + auto tmp = q_poseidon2_internal * (v1 - w_l_shift); + tmp *= scaling_factor; + std::get<0>(evals) += tmp; + + auto v2 = u2 * crypto::Poseidon2Bn254ScalarFieldParams::internal_matrix_diagonal[1]; + v2 += sum; + tmp = q_poseidon2_internal * (v2 - w_r_shift); + tmp *= scaling_factor; + std::get<1>(evals) += tmp; + + auto v3 = u3 * crypto::Poseidon2Bn254ScalarFieldParams::internal_matrix_diagonal[2]; + v3 += sum; + tmp = q_poseidon2_internal * (v3 - w_o_shift); + tmp *= scaling_factor; + std::get<2>(evals) += tmp; + + auto v4 = u4 * crypto::Poseidon2Bn254ScalarFieldParams::internal_matrix_diagonal[3]; + v4 += sum; + tmp = q_poseidon2_internal * (v4 - w_4_shift); + tmp *= scaling_factor; + std::get<3>(evals) += tmp; + }; +}; // namespace proof_system + +template using Poseidon2InternalRelation = Relation>; +} // namespace proof_system \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/relations/relation_manual.test.cpp b/barretenberg/cpp/src/barretenberg/relations/relation_manual.test.cpp new file mode 100644 index 00000000000..752852167a1 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/relations/relation_manual.test.cpp @@ -0,0 +1,168 @@ +#include "barretenberg/flavor/flavor.hpp" +#include "barretenberg/relations/poseidon2_external_relation.hpp" +#include "barretenberg/relations/poseidon2_internal_relation.hpp" +#include "barretenberg/relations/relation_parameters.hpp" +#include + +namespace proof_system::relation_manual_tests { + +using FF = barretenberg::fr; + +class RelationManual : public testing::Test {}; + +TEST_F(RelationManual, Poseidon2ExternalRelationZeros) +{ + using Accumulator = std::array; + using Relation = Poseidon2ExternalRelation; + + Accumulator acc{ 0, 0, 0, 0 }; + struct AllPoseidonValues { + FF q_poseidon2_external; + FF w_l; + FF w_r; + FF w_o; + FF w_4; + FF w_l_shift; + FF w_r_shift; + FF w_o_shift; + FF w_4_shift; + FF q_l; + FF q_r; + FF q_o; + FF q_4; + }; + AllPoseidonValues all_poseidon_values{ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + + const auto parameters = RelationParameters::get_random(); + Relation::accumulate(acc, all_poseidon_values, parameters, 1); + EXPECT_EQ(acc[0], 0); + EXPECT_EQ(acc[1], 0); + EXPECT_EQ(acc[2], 0); + EXPECT_EQ(acc[3], 0); +} + +TEST_F(RelationManual, Poseidon2ExternalRelationRandom) +{ + using Accumulator = std::array; + using Relation = Poseidon2ExternalRelation; + + Accumulator acc{ 0, 0, 0, 0 }; + struct AllPoseidonValues { + FF q_poseidon2_external; + FF w_l; + FF w_r; + FF w_o; + FF w_4; + FF q_l; + FF q_r; + FF q_o; + FF q_4; + FF w_l_shift; + FF w_r_shift; + FF w_o_shift; + FF w_4_shift; + }; + /* + * v1 = w_1 + q_1 = 5 + 6 = 11 + * v2 = w_2 + q_2 = 4 + 9 = 13 + * v3 = w_3 + q_3 = 1 + 8 = 9 + * v4 = w_4 + q_4 = 7 + 3 = 10 + * u1 = v1^5 = 11^5 = 161051 + * u2 = v2^5 = 13^5 = 371293 + * u3 = v3^5 = 9^5 = 59049 + * u4 = v4^5 = 10^5 = 100000 + * matrix mul with calculator: + * 1 3763355 + * 2 3031011 + * 3 2270175 + * 4 1368540 + */ + AllPoseidonValues all_poseidon_values{ 1, 5, 4, 1, 7, 6, 9, 8, 3, 3763355, 3031011, 2270175, 1368540 }; + + const auto parameters = RelationParameters::get_random(); + Relation::accumulate(acc, all_poseidon_values, parameters, 1); + EXPECT_EQ(acc[0], 0); + EXPECT_EQ(acc[1], 0); + EXPECT_EQ(acc[2], 0); + EXPECT_EQ(acc[3], 0); +} + +TEST_F(RelationManual, Poseidon2InternalRelationZeros) +{ + using Accumulator = std::array; + using Relation = Poseidon2InternalRelation; + + Accumulator acc{ 0, 0, 0, 0 }; + struct AllPoseidonValues { + FF q_poseidon2_internal; + FF w_l; + FF w_r; + FF w_o; + FF w_4; + FF w_l_shift; + FF w_r_shift; + FF w_o_shift; + FF w_4_shift; + FF q_l; + FF q_r; + FF q_o; + FF q_4; + }; + AllPoseidonValues all_poseidon_values{ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + + const auto parameters = RelationParameters::get_random(); + Relation::accumulate(acc, all_poseidon_values, parameters, 1); + EXPECT_EQ(acc[0], 0); + EXPECT_EQ(acc[1], 0); + EXPECT_EQ(acc[2], 0); + EXPECT_EQ(acc[3], 0); +} + +TEST_F(RelationManual, Poseidon2InternalRelationRandom) +{ + using Accumulator = std::array; + using Relation = Poseidon2InternalRelation; + + Accumulator acc{ 0, 0, 0, 0 }; + struct AllPoseidonValues { + FF q_poseidon2_internal; + FF w_l; + FF w_r; + FF w_o; + FF w_4; + FF q_l; + + FF w_l_shift; + FF w_r_shift; + FF w_o_shift; + FF w_4_shift; + }; + /* + * u1 = (w_1 + q_1)^5 = (1 + 5)^5 = 7776 + * sum = u1 + w_2 + w_3 + w_4 = 7776 + 2 + 3 + 4 = 7785 + * matrix mul with calculator: + * 1 0x122d9ce41e83c533318954d77a4ebc40eb729f6543ebd5f2e4ecb175ced3bc74 + * 2 0x185028b6d489be7c029367a14616776b33bf2eada9bb370950d6719f68b5067f + * 3 0x00fce289a96b3f4a18562d0ef0ab76ca165e613222aa0c24501377003c5622a8 + * 4 0x27e7677799fda1694819803f459b76d2fb1c45fdf0773375c72d61e8efb92893 + */ + AllPoseidonValues all_poseidon_values{ + 1, + 1, + 2, + 3, + 4, + 5, + FF(std::string("0x122d9ce41e83c533318954d77a4ebc40eb729f6543ebd5f2e4ecb175ced3bc74")), + FF(std::string("0x185028b6d489be7c029367a14616776b33bf2eada9bb370950d6719f68b5067f")), + FF(std::string("0x00fce289a96b3f4a18562d0ef0ab76ca165e613222aa0c24501377003c5622a8")), + FF(std::string("0x27e7677799fda1694819803f459b76d2fb1c45fdf0773375c72d61e8efb92893")) + }; + const auto parameters = RelationParameters::get_random(); + Relation::accumulate(acc, all_poseidon_values, parameters, 1); + EXPECT_EQ(acc[0], 0); + EXPECT_EQ(acc[1], 0); + EXPECT_EQ(acc[2], 0); + EXPECT_EQ(acc[3], 0); +} +}; // namespace proof_system::relation_manual_tests \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/relations/ultra_relation_consistency.test.cpp b/barretenberg/cpp/src/barretenberg/relations/ultra_relation_consistency.test.cpp index 9320053a485..65d33ff8c73 100644 --- a/barretenberg/cpp/src/barretenberg/relations/ultra_relation_consistency.test.cpp +++ b/barretenberg/cpp/src/barretenberg/relations/ultra_relation_consistency.test.cpp @@ -18,6 +18,8 @@ #include "barretenberg/relations/gen_perm_sort_relation.hpp" #include "barretenberg/relations/lookup_relation.hpp" #include "barretenberg/relations/permutation_relation.hpp" +#include "barretenberg/relations/poseidon2_external_relation.hpp" +#include "barretenberg/relations/poseidon2_internal_relation.hpp" #include "barretenberg/relations/relation_parameters.hpp" #include "barretenberg/relations/ultra_arithmetic_relation.hpp" #include @@ -28,7 +30,7 @@ namespace proof_system::ultra_relation_consistency_tests { using FF = barretenberg::fr; struct InputElements { - static constexpr size_t NUM_ELEMENTS = 43; + static constexpr size_t NUM_ELEMENTS = 45; std::array _data; static InputElements get_random() @@ -60,38 +62,40 @@ struct InputElements { FF& q_elliptic = std::get<8>(_data); FF& q_aux = std::get<9>(_data); FF& q_lookup = std::get<10>(_data); - FF& sigma_1 = std::get<11>(_data); - FF& sigma_2 = std::get<12>(_data); - FF& sigma_3 = std::get<13>(_data); - FF& sigma_4 = std::get<14>(_data); - FF& id_1 = std::get<15>(_data); - FF& id_2 = std::get<16>(_data); - FF& id_3 = std::get<17>(_data); - FF& id_4 = std::get<18>(_data); - FF& table_1 = std::get<19>(_data); - FF& table_2 = std::get<20>(_data); - FF& table_3 = std::get<21>(_data); - FF& table_4 = std::get<22>(_data); - FF& lagrange_first = std::get<23>(_data); - FF& lagrange_last = std::get<24>(_data); - FF& w_l = std::get<25>(_data); - FF& w_r = std::get<26>(_data); - FF& w_o = std::get<27>(_data); - FF& w_4 = std::get<28>(_data); - FF& sorted_accum = std::get<29>(_data); - FF& z_perm = std::get<30>(_data); - FF& z_lookup = std::get<31>(_data); - FF& table_1_shift = std::get<32>(_data); - FF& table_2_shift = std::get<33>(_data); - FF& table_3_shift = std::get<34>(_data); - FF& table_4_shift = std::get<35>(_data); - FF& w_l_shift = std::get<36>(_data); - FF& w_r_shift = std::get<37>(_data); - FF& w_o_shift = std::get<38>(_data); - FF& w_4_shift = std::get<39>(_data); - FF& sorted_accum_shift = std::get<40>(_data); - FF& z_perm_shift = std::get<41>(_data); - FF& z_lookup_shift = std::get<42>(_data); + FF& q_poseidon2_external = std::get<11>(_data); + FF& q_poseidon2_internal = std::get<12>(_data); + FF& sigma_1 = std::get<13>(_data); + FF& sigma_2 = std::get<14>(_data); + FF& sigma_3 = std::get<15>(_data); + FF& sigma_4 = std::get<16>(_data); + FF& id_1 = std::get<17>(_data); + FF& id_2 = std::get<18>(_data); + FF& id_3 = std::get<19>(_data); + FF& id_4 = std::get<20>(_data); + FF& table_1 = std::get<21>(_data); + FF& table_2 = std::get<22>(_data); + FF& table_3 = std::get<23>(_data); + FF& table_4 = std::get<24>(_data); + FF& lagrange_first = std::get<25>(_data); + FF& lagrange_last = std::get<26>(_data); + FF& w_l = std::get<27>(_data); + FF& w_r = std::get<28>(_data); + FF& w_o = std::get<29>(_data); + FF& w_4 = std::get<30>(_data); + FF& sorted_accum = std::get<31>(_data); + FF& z_perm = std::get<32>(_data); + FF& z_lookup = std::get<33>(_data); + FF& table_1_shift = std::get<34>(_data); + FF& table_2_shift = std::get<35>(_data); + FF& table_3_shift = std::get<36>(_data); + FF& table_4_shift = std::get<37>(_data); + FF& w_l_shift = std::get<38>(_data); + FF& w_r_shift = std::get<39>(_data); + FF& w_o_shift = std::get<40>(_data); + FF& w_4_shift = std::get<41>(_data); + FF& sorted_accum_shift = std::get<42>(_data); + FF& z_perm_shift = std::get<43>(_data); + FF& z_lookup_shift = std::get<44>(_data); }; class UltraRelationConsistency : public testing::Test { @@ -551,4 +555,130 @@ TEST_F(UltraRelationConsistency, AuxiliaryRelation) run_test(/*random_inputs=*/true); }; +TEST_F(UltraRelationConsistency, Poseidon2ExternalRelation) +{ + const auto run_test = []([[maybe_unused]] bool random_inputs) { + using Relation = Poseidon2ExternalRelation; + using SumcheckArrayOfValuesOverSubrelations = typename Relation::SumcheckArrayOfValuesOverSubrelations; + const InputElements input_elements = random_inputs ? InputElements::get_random() : InputElements::get_special(); + + const auto& w_1 = input_elements.w_l; + const auto& w_2 = input_elements.w_r; + const auto& w_3 = input_elements.w_o; + const auto& w_4 = input_elements.w_4; + const auto& w_1_shift = input_elements.w_l_shift; + const auto& w_2_shift = input_elements.w_r_shift; + const auto& w_3_shift = input_elements.w_o_shift; + const auto& w_4_shift = input_elements.w_4_shift; + const auto& q_1 = input_elements.q_l; + const auto& q_2 = input_elements.q_r; + const auto& q_3 = input_elements.q_o; + const auto& q_4 = input_elements.q_4; + const auto& q_poseidon2_external = input_elements.q_poseidon2_external; + SumcheckArrayOfValuesOverSubrelations expected_values; + + // add round constants + auto s1 = w_1 + q_1; + auto s2 = w_2 + q_2; + auto s3 = w_3 + q_3; + auto s4 = w_4 + q_4; + + // apply s-box round + auto u1 = s1 * s1; + u1 *= u1; + u1 *= s1; + auto u2 = s2 * s2; + u2 *= u2; + u2 *= s2; + auto u3 = s3 * s3; + u3 *= u3; + u3 *= s3; + auto u4 = s4 * s4; + u4 *= u4; + u4 *= s4; + + // matrix mul v = M_E * u with 14 additions + auto t0 = u1 + u2; // u_1 + u_2 + auto t1 = u3 + u4; // u_3 + u_4 + auto t2 = u2 + u2; // 2u_2 + t2 += t1; // 2u_2 + u_3 + u_4 + auto t3 = u4 + u4; // 2u_4 + t3 += t0; // u_1 + u_2 + 2u_4 + auto v4 = t1 + t1; + v4 += v4; + v4 += t3; // u_1 + u_2 + 4u_3 + 6u_4 + auto v2 = t0 + t0; + v2 += v2; + v2 += t2; // 4u_1 + 6u_2 + u_3 + u_4 + auto v1 = t3 + v2; // 5u_1 + 7u_2 + u_3 + 3u_4 + auto v3 = t2 + v4; // u_1 + 3u_2 + 5u_3 + 7u_4 + + // output is { v1, v2, v3, v4 } + + expected_values[0] = q_poseidon2_external * (v1 - w_1_shift); + expected_values[1] = q_poseidon2_external * (v2 - w_2_shift); + expected_values[2] = q_poseidon2_external * (v3 - w_3_shift); + expected_values[3] = q_poseidon2_external * (v4 - w_4_shift); + + const auto parameters = RelationParameters::get_random(); + validate_relation_execution(expected_values, input_elements, parameters); + + // validate_relation_execution(expected_values, input_elements, parameters); + }; + run_test(/*random_inputs=*/false); + run_test(/*random_inputs=*/true); +}; + +TEST_F(UltraRelationConsistency, Poseidon2InternalRelation) +{ + const auto run_test = []([[maybe_unused]] bool random_inputs) { + using Relation = Poseidon2InternalRelation; + using SumcheckArrayOfValuesOverSubrelations = typename Relation::SumcheckArrayOfValuesOverSubrelations; + const InputElements input_elements = random_inputs ? InputElements::get_random() : InputElements::get_special(); + + const auto& w_1 = input_elements.w_l; + const auto& w_2 = input_elements.w_r; + const auto& w_3 = input_elements.w_o; + const auto& w_4 = input_elements.w_4; + const auto& w_1_shift = input_elements.w_l_shift; + const auto& w_2_shift = input_elements.w_r_shift; + const auto& w_3_shift = input_elements.w_o_shift; + const auto& w_4_shift = input_elements.w_4_shift; + const auto& q_1 = input_elements.q_l; + const auto& q_poseidon2_internal = input_elements.q_poseidon2_internal; + SumcheckArrayOfValuesOverSubrelations expected_values; + + // add round constants on only first element + auto v1 = w_1 + q_1; + + // apply s-box to only first element + auto u1 = v1 * v1; + u1 *= u1; + u1 *= v1; + + // multiply with internal matrix + auto sum = u1 + w_2 + w_3 + w_4; + auto t0 = u1 * crypto::Poseidon2Bn254ScalarFieldParams::internal_matrix_diagonal[0]; + t0 += sum; + auto t1 = w_2 * crypto::Poseidon2Bn254ScalarFieldParams::internal_matrix_diagonal[1]; + t1 += sum; + auto t2 = w_3 * crypto::Poseidon2Bn254ScalarFieldParams::internal_matrix_diagonal[2]; + t2 += sum; + auto t3 = w_4 * crypto::Poseidon2Bn254ScalarFieldParams::internal_matrix_diagonal[3]; + t3 += sum; + + expected_values[0] = q_poseidon2_internal * (t0 - w_1_shift); + expected_values[1] = q_poseidon2_internal * (t1 - w_2_shift); + expected_values[2] = q_poseidon2_internal * (t2 - w_3_shift); + expected_values[3] = q_poseidon2_internal * (t3 - w_4_shift); + + const auto parameters = RelationParameters::get_random(); + validate_relation_execution(expected_values, input_elements, parameters); + + // validate_relation_execution(expected_values, input_elements, parameters); + }; + run_test(/*random_inputs=*/false); + run_test(/*random_inputs=*/true); +}; + } // namespace proof_system::ultra_relation_consistency_tests