Skip to content

Commit

Permalink
feat(avm): concurrency improvements (#7495)
Browse files Browse the repository at this point in the history
This PR makes some changes to use the more standard `bb::parallel_for` instead of or custom solution.

This is nice because we can now control/experiment with concurrency limit by using the env variable `HARDWARE_CONCURRENCY` (which defaults to the number of cpus).

I also added parallel computation of logderivative inverses **in the prover** which actually made a huge difference: took it from `5.1s` to `0.1`s and has now become negligible (with the # of cpus of the mainframe).

BEFORE
```
prove/check_circuit: 5120
prove/execute_log_derivative_inverse_commitments_round_ms: 532
*** prove/execute_log_derivative_inverse_round_ms: 5199
prove/execute_pcs_rounds_ms: 413
prove/execute_relation_check_rounds_ms: 1328
prove/execute_wire_commitments_round_ms: 1742
prove/gen_trace: 850
```

AFTER
```
prove/check_circuit: 4859
prove/execute_log_derivative_inverse_commitments_round_ms: 543
*** prove/execute_log_derivative_inverse_round_ms: 162
prove/execute_pcs_rounds_ms: 381
prove/execute_relation_check_rounds_ms: 1089
prove/execute_wire_commitments_round_ms: 1608
prove/gen_trace: 755
```

---------

WARNING: I had to update the handling of exception catching in the tests, because things get complicated w/threads. I mostly just changed the helper, but GTest does complain and we have to do sth about it eventually.

> [WARNING] /mnt/user-data/facundo/aztec-packages/barretenberg/cpp/build/_deps/gtest-src/googletest/src/gtest-death-test.cc:1108:: Death tests use fork(), which is unsafe particularly in a threaded context. For this test, Google Test detected 192 threads. See https://github.com/google/googletest/blob/main/docs/advanced.md#death-tests-and-threads for more explanation and suggested solutions, especially if this is the last message you see before your test times out.

Filed [this issue](#7504).
  • Loading branch information
fcarreiro authored Jul 17, 2024
1 parent 5b323a7 commit 0d5c066
Show file tree
Hide file tree
Showing 81 changed files with 647 additions and 906 deletions.
9 changes: 8 additions & 1 deletion barretenberg/cpp/src/barretenberg/common/thread.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ inline size_t get_num_cpus_pow2()
return static_cast<size_t>(1ULL << numeric::get_msb(get_num_cpus()));
}

/**
* Creates a thread pool and runs the function in parallel.
* @param num_iterations Number of iterations
* @param func Function to run in parallel
* Observe that num_iterations is NOT the thread pool size.
* The size will be chosen based on the hardware concurrency (i.e., env or cpus)..
*/
void parallel_for(size_t num_iterations, const std::function<void(size_t)>& func);
void run_loop_in_parallel(size_t num_points,
const std::function<void(size_t, size_t)>& func,
Expand All @@ -30,7 +37,7 @@ template <typename FunctionType>
void run_loop_in_parallel_if_effective_internal(
size_t, const FunctionType&, size_t, size_t, size_t, size_t, size_t, size_t, size_t);
/**
* @brief Runs loop in parallel if parallelization if useful (costs less than the algorith)
* @brief Runs loop in parallel if parallelization if useful (costs less than the algorithm)
*
* @details Please see run_loop_in_parallel_if_effective_internal for detailed description
*
Expand Down
203 changes: 103 additions & 100 deletions barretenberg/cpp/src/barretenberg/relations/generated/avm/alu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,105 +129,6 @@ template <typename FF> struct AluRow {
FF alu_u8_tag{};
};

inline std::string get_relation_label_alu(int index)
{
switch (index) {
case 12:
return "ALU_ADD_SUB_1";
case 13:
return "ALU_ADD_SUB_2";
case 14:
return "ALU_MULTIPLICATION_FF";
case 15:
return "ALU_MUL_COMMON_1";
case 16:
return "ALU_MUL_COMMON_2";
case 19:
return "ALU_MULTIPLICATION_OUT_U128";
case 20:
return "ALU_FF_NOT_XOR";
case 21:
return "ALU_OP_NOT";
case 22:
return "ALU_RES_IS_BOOL";
case 23:
return "ALU_OP_EQ";
case 24:
return "INPUT_DECOMP_1";
case 25:
return "INPUT_DECOMP_2";
case 27:
return "SUB_LO_1";
case 28:
return "SUB_HI_1";
case 30:
return "SUB_LO_2";
case 31:
return "SUB_HI_2";
case 32:
return "RES_LO";
case 33:
return "RES_HI";
case 34:
return "CMP_CTR_REL_1";
case 35:
return "CMP_CTR_REL_2";
case 38:
return "CTR_NON_ZERO_REL";
case 39:
return "RNG_CHK_LOOKUP_SELECTOR";
case 40:
return "LOWER_CMP_RNG_CHK";
case 41:
return "UPPER_CMP_RNG_CHK";
case 42:
return "SHIFT_RELS_0";
case 44:
return "SHIFT_RELS_1";
case 46:
return "SHIFT_RELS_2";
case 48:
return "SHIFT_RELS_3";
case 50:
return "OP_CAST_PREV_LINE";
case 51:
return "ALU_OP_CAST";
case 52:
return "OP_CAST_RNG_CHECK_P_SUB_A_LOW";
case 53:
return "OP_CAST_RNG_CHECK_P_SUB_A_HIGH";
case 54:
return "TWO_LINE_OP_NO_OVERLAP";
case 55:
return "SHR_RANGE_0";
case 56:
return "SHR_RANGE_1";
case 57:
return "SHL_RANGE_0";
case 58:
return "SHL_RANGE_1";
case 60:
return "SHIFT_LT_BIT_LEN";
case 61:
return "SHR_INPUT_DECOMPOSITION";
case 62:
return "SHR_OUTPUT";
case 63:
return "SHL_INPUT_DECOMPOSITION";
case 64:
return "SHL_OUTPUT";
case 74:
return "ALU_PROD_DIV";
case 75:
return "REMAINDER_RANGE_CHK";
case 76:
return "CMP_CTR_REL_3";
case 78:
return "DIVISION_RELATION";
}
return std::to_string(index);
}

template <typename FF_> class aluImpl {
public:
using FF = FF_;
Expand Down Expand Up @@ -1040,6 +941,108 @@ template <typename FF_> class aluImpl {
}
};

template <typename FF> using alu = Relation<aluImpl<FF>>;
template <typename FF> class alu : public Relation<aluImpl<FF>> {
public:
static constexpr const char* NAME = "alu";

static std::string get_subrelation_label(size_t index)
{
switch (index) {
case 12:
return "ALU_ADD_SUB_1";
case 13:
return "ALU_ADD_SUB_2";
case 14:
return "ALU_MULTIPLICATION_FF";
case 15:
return "ALU_MUL_COMMON_1";
case 16:
return "ALU_MUL_COMMON_2";
case 19:
return "ALU_MULTIPLICATION_OUT_U128";
case 20:
return "ALU_FF_NOT_XOR";
case 21:
return "ALU_OP_NOT";
case 22:
return "ALU_RES_IS_BOOL";
case 23:
return "ALU_OP_EQ";
case 24:
return "INPUT_DECOMP_1";
case 25:
return "INPUT_DECOMP_2";
case 27:
return "SUB_LO_1";
case 28:
return "SUB_HI_1";
case 30:
return "SUB_LO_2";
case 31:
return "SUB_HI_2";
case 32:
return "RES_LO";
case 33:
return "RES_HI";
case 34:
return "CMP_CTR_REL_1";
case 35:
return "CMP_CTR_REL_2";
case 38:
return "CTR_NON_ZERO_REL";
case 39:
return "RNG_CHK_LOOKUP_SELECTOR";
case 40:
return "LOWER_CMP_RNG_CHK";
case 41:
return "UPPER_CMP_RNG_CHK";
case 42:
return "SHIFT_RELS_0";
case 44:
return "SHIFT_RELS_1";
case 46:
return "SHIFT_RELS_2";
case 48:
return "SHIFT_RELS_3";
case 50:
return "OP_CAST_PREV_LINE";
case 51:
return "ALU_OP_CAST";
case 52:
return "OP_CAST_RNG_CHECK_P_SUB_A_LOW";
case 53:
return "OP_CAST_RNG_CHECK_P_SUB_A_HIGH";
case 54:
return "TWO_LINE_OP_NO_OVERLAP";
case 55:
return "SHR_RANGE_0";
case 56:
return "SHR_RANGE_1";
case 57:
return "SHL_RANGE_0";
case 58:
return "SHL_RANGE_1";
case 60:
return "SHIFT_LT_BIT_LEN";
case 61:
return "SHR_INPUT_DECOMPOSITION";
case 62:
return "SHR_OUTPUT";
case 63:
return "SHL_INPUT_DECOMPOSITION";
case 64:
return "SHL_OUTPUT";
case 74:
return "ALU_PROD_DIV";
case 75:
return "REMAINDER_RANGE_CHK";
case 76:
return "CMP_CTR_REL_3";
case 78:
return "DIVISION_RELATION";
}
return std::to_string(index);
}
};

} // namespace bb::Avm_vm
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,6 @@ template <typename FF> struct BinaryRow {
FF binary_sel_bin{};
};

inline std::string get_relation_label_binary(int index)
{
switch (index) {
case 1:
return "OP_ID_REL";
case 2:
return "MEM_TAG_REL";
case 3:
return "SEL_BIN_CTR_REL";
case 7:
return "ACC_REL_A";
case 8:
return "ACC_REL_B";
case 9:
return "ACC_REL_C";
}
return std::to_string(index);
}

template <typename FF_> class binaryImpl {
public:
using FF = FF_;
Expand Down Expand Up @@ -128,6 +109,28 @@ template <typename FF_> class binaryImpl {
}
};

template <typename FF> using binary = Relation<binaryImpl<FF>>;
template <typename FF> class binary : public Relation<binaryImpl<FF>> {
public:
static constexpr const char* NAME = "binary";

static std::string get_subrelation_label(size_t index)
{
switch (index) {
case 1:
return "OP_ID_REL";
case 2:
return "MEM_TAG_REL";
case 3:
return "SEL_BIN_CTR_REL";
case 7:
return "ACC_REL_A";
case 8:
return "ACC_REL_B";
case 9:
return "ACC_REL_C";
}
return std::to_string(index);
}
};

} // namespace bb::Avm_vm
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,6 @@ template <typename FF> struct ConversionRow {
FF conversion_sel_to_radix_le{};
};

inline std::string get_relation_label_conversion(int index)
{
switch (index) {}
return std::to_string(index);
}

template <typename FF_> class conversionImpl {
public:
using FF = FF_;
Expand All @@ -37,6 +31,15 @@ template <typename FF_> class conversionImpl {
}
};

template <typename FF> using conversion = Relation<conversionImpl<FF>>;
template <typename FF> class conversion : public Relation<conversionImpl<FF>> {
public:
static constexpr const char* NAME = "conversion";

static std::string get_subrelation_label(size_t index)
{
switch (index) {}
return std::to_string(index);
}
};

} // namespace bb::Avm_vm
17 changes: 10 additions & 7 deletions barretenberg/cpp/src/barretenberg/relations/generated/avm/gas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@ template <typename FF> struct GasRow {
FF gas_sel_gas_cost{};
};

inline std::string get_relation_label_gas(int index)
{
switch (index) {}
return std::to_string(index);
}

template <typename FF_> class gasImpl {
public:
using FF = FF_;
Expand Down Expand Up @@ -51,6 +45,15 @@ template <typename FF_> class gasImpl {
}
};

template <typename FF> using gas = Relation<gasImpl<FF>>;
template <typename FF> class gas : public Relation<gasImpl<FF>> {
public:
static constexpr const char* NAME = "gas";

static std::string get_subrelation_label(size_t index)
{
switch (index) {}
return std::to_string(index);
}
};

} // namespace bb::Avm_vm
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class incl_main_tag_err_lookup_settings {
template <typename FF_>
class incl_main_tag_err_relation : public GenericLookupRelation<incl_main_tag_err_lookup_settings, FF_> {
public:
static constexpr const char* NAME = "incl_main_tag_err";
static constexpr const char* NAME = "INCL_MAIN_TAG_ERR";
};
template <typename FF_> using incl_main_tag_err = GenericLookup<incl_main_tag_err_lookup_settings, FF_>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class incl_mem_tag_err_lookup_settings {
template <typename FF_>
class incl_mem_tag_err_relation : public GenericLookupRelation<incl_mem_tag_err_lookup_settings, FF_> {
public:
static constexpr const char* NAME = "incl_mem_tag_err";
static constexpr const char* NAME = "INCL_MEM_TAG_ERR";
};
template <typename FF_> using incl_mem_tag_err = GenericLookup<incl_mem_tag_err_lookup_settings, FF_>;

Expand Down
Loading

0 comments on commit 0d5c066

Please sign in to comment.