Skip to content

Commit

Permalink
More generic, less efficient
Browse files Browse the repository at this point in the history
  • Loading branch information
Rumata888 committed Apr 18, 2024
1 parent d24fec3 commit ce98798
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
31 changes: 15 additions & 16 deletions barretenberg/cpp/src/barretenberg/polynomials/univariate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,10 @@ template <class Fr, size_t domain_end, size_t domain_start = 0> class Univariate
* subtraction: setting Δ = v1-v0, the values of f(X) are f(0)=v0, f(1)= v0 + Δ, v2 = f(1) + Δ, v3 = f(2) + Δ...
*
*/
template <size_t EXTENDED_DOMAIN_END, bool optimised = false> Univariate<Fr, EXTENDED_DOMAIN_END> extend_to() const
template <size_t EXTENDED_DOMAIN_END, size_t NUM_SKIPPED_INDICES = 0>
Univariate<Fr, EXTENDED_DOMAIN_END> extend_to() const
{
const size_t EXTENDED_LENGTH = EXTENDED_DOMAIN_END - domain_start;
const size_t EXTENDED_LENGTH = EXTENDED_DOMAIN_END - domain_start + NUM_SKIPPED_INDICES;
using Data = BarycentricData<Fr, LENGTH, EXTENDED_LENGTH>;
static_assert(EXTENDED_LENGTH >= LENGTH);

Expand All @@ -282,22 +283,13 @@ template <class Fr, size_t domain_end, size_t domain_start = 0> class Univariate
std::copy(evaluations.begin(), evaluations.end(), result.evaluations.begin());

static constexpr Fr inverse_two = Fr(2).invert();
// static_assert(!optimised || (LENGTH <= 2));
static_assert(NUM_SKIPPED_INDICES < LENGTH);
if constexpr (LENGTH == 2) {
Fr delta = value_at(1) - value_at(0);
static_assert(EXTENDED_LENGTH != 0);
if constexpr (optimised) {
Fr current = result.value_at(1);
for (size_t idx = domain_end - 2; idx < EXTENDED_DOMAIN_END - 1; idx++) {
current += delta;
result.value_at(idx + 1) = current;
}
} else {
for (size_t idx = domain_end - 1; idx < EXTENDED_DOMAIN_END - 1; idx++) {
result.value_at(idx + 1) = result.value_at(idx) + delta;
}
for (size_t idx = domain_end - 1; idx < EXTENDED_DOMAIN_END - 1; idx++) {
result.value_at(idx + 1) = result.value_at(idx) + delta;
}
return result;
} else if constexpr (LENGTH == 3) {
// Based off https://hackmd.io/@aztec-network/SyR45cmOq?type=view
// The technique used here is the same as the length == 3 case below.
Expand All @@ -313,7 +305,6 @@ template <class Fr, size_t domain_end, size_t domain_start = 0> class Univariate
result.value_at(idx + 1) = result.value_at(idx) + extra;
extra += a2;
}
return result;
} else if constexpr (LENGTH == 4) {
static constexpr Fr inverse_six = Fr(6).invert(); // computed at compile time for efficiency

Expand Down Expand Up @@ -377,7 +368,6 @@ template <class Fr, size_t domain_end, size_t domain_start = 0> class Univariate

linear_term += three_a_plus_two_b;
}
return result;
} else {
for (size_t k = domain_end; k != EXTENDED_DOMAIN_END; ++k) {
result.value_at(k) = 0;
Expand All @@ -390,8 +380,17 @@ template <class Fr, size_t domain_end, size_t domain_start = 0> class Univariate
// scale the sum by the the value of of B(x)
result.value_at(k) *= Data::full_numerator_values[k];
}
}
if constexpr (NUM_SKIPPED_INDICES == 0) {
return result;
}
Univariate<Fr, EXTENDED_LENGTH - NUM_SKIPPED_INDICES> optimised_result;
optimised_result.value_at(0) = result.value_at(0);

std::copy(std::next(result.begin(), 1 + NUM_SKIPPED_INDICES),
result.evaluations.end(),
std::next(optimised_result.begin(), 1));
return optimised_result;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ template <class ProverInstances_> class ProtoGalaxyProver_ {
{
auto base_univariates = instances.row_to_univariates(row_idx);
for (auto [extended_univariate, base_univariate] : zip_view(extended_univariates.get_all(), base_univariates)) {
extended_univariate = base_univariate.template extend_to<OptimisedExtendedUnivariate::LENGTH, true>();
extended_univariate =
base_univariate.template extend_to<OptimisedExtendedUnivariate::LENGTH, ProverInstances::NUM - 1>();
}
}

Expand Down

0 comments on commit ce98798

Please sign in to comment.