Skip to content

Commit

Permalink
refactor: updating field conversion code without pointer hack (AztecP…
Browse files Browse the repository at this point in the history
…rotocol#4537)

We currently use a pointer hack for functions like `calc_num_bn254_frs`
and `convert_from_bn254_frs`. This PR aims to clean these up using
traits and template metaprogramming.

Also closes AztecProtocol/barretenberg#846, by
just sending and receiving an std::array instead of an AllValues object.
  • Loading branch information
lucasxia01 authored Feb 12, 2024
1 parent 348455d commit 949c9cb
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 467 deletions.
3 changes: 3 additions & 0 deletions barretenberg/cpp/src/barretenberg/ecc/curves/bn254/fq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class Bn254FqParams {
// used in msgpack schema serialization
static constexpr char schema_name[] = "fq";
static constexpr bool has_high_2adicity = false;

// The modulus is larger than BN254 scalar field modulus, so it maps to two BN254 scalars
static constexpr size_t NUM_BN254_SCALARS = 2;
};

using fq = field<Bn254FqParams>;
Expand Down
3 changes: 3 additions & 0 deletions barretenberg/cpp/src/barretenberg/ecc/curves/bn254/fr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class Bn254FrParams {
// used in msgpack schema serialization
static constexpr char schema_name[] = "fr";
static constexpr bool has_high_2adicity = true;

// This is a BN254 scalar, so it represents one BN254 scalar
static constexpr size_t NUM_BN254_SCALARS = 1;
};

using fr = field<Bn254FrParams>;
Expand Down
59 changes: 2 additions & 57 deletions barretenberg/cpp/src/barretenberg/ecc/fields/field_conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,6 @@ namespace bb::field_conversion {
static constexpr uint64_t NUM_LIMB_BITS = plonk::NUM_LIMB_BITS_IN_FIELD_SIMULATION;
static constexpr uint64_t TOTAL_BITS = 254;

bb::fr convert_from_bn254_frs(std::span<const bb::fr> fr_vec, bb::fr* /*unused*/)
{
ASSERT(fr_vec.size() == 1);
return fr_vec[0];
}

bool convert_from_bn254_frs(std::span<const bb::fr> fr_vec, bool* /*unused*/)
{
ASSERT(fr_vec.size() == 1);
return fr_vec[0] != 0;
}

/**
* @brief Converts 2 bb::fr elements to grumpkin::fr
* @details First, this function must take in 2 bb::fr elements because the grumpkin::fr field has a larger modulus than
Expand All @@ -32,7 +20,7 @@ bool convert_from_bn254_frs(std::span<const bb::fr> fr_vec, bool* /*unused*/)
* @param high_bits_in
* @return grumpkin::fr
*/
grumpkin::fr convert_from_bn254_frs(std::span<const bb::fr> fr_vec, grumpkin::fr* /*unused*/)
grumpkin::fr convert_grumpkin_fr_from_bn254_frs(std::span<const bb::fr> fr_vec)
{
// Combines the two elements into one uint256_t, and then convert that to a grumpkin::fr
ASSERT(uint256_t(fr_vec[0]) < (uint256_t(1) << (NUM_LIMB_BITS * 2))); // lower 136 bits
Expand All @@ -42,25 +30,6 @@ grumpkin::fr convert_from_bn254_frs(std::span<const bb::fr> fr_vec, grumpkin::fr
return result;
}

curve::BN254::AffineElement convert_from_bn254_frs(std::span<const bb::fr> fr_vec,
curve::BN254::AffineElement* /*unused*/)
{
curve::BN254::AffineElement val;
val.x = convert_from_bn254_frs<grumpkin::fr>(fr_vec.subspan(0, 2));
val.y = convert_from_bn254_frs<grumpkin::fr>(fr_vec.subspan(2, 2));
return val;
}

curve::Grumpkin::AffineElement convert_from_bn254_frs(std::span<const bb::fr> fr_vec,
curve::Grumpkin::AffineElement* /*unused*/)
{
ASSERT(fr_vec.size() == 2);
curve::Grumpkin::AffineElement val;
val.x = fr_vec[0];
val.y = fr_vec[1];
return val;
}

/**
* @brief Converts grumpkin::fr to 2 bb::fr elements
* @details First, this function must return 2 bb::fr elements because the grumpkin::fr field has a larger modulus than
Expand All @@ -74,7 +43,7 @@ curve::Grumpkin::AffineElement convert_from_bn254_frs(std::span<const bb::fr> fr
* @param input
* @return std::array<bb::fr, 2>
*/
std::vector<bb::fr> convert_to_bn254_frs(const grumpkin::fr& val)
std::vector<bb::fr> convert_grumpkin_fr_to_bn254_frs(const grumpkin::fr& val)
{
// Goal is to slice up the 64 bit limbs of grumpkin::fr/uint256_t to mirror the 68 bit limbs of bigfield
// We accomplish this by dividing the grumpkin::fr's value into two 68*2=136 bit pieces.
Expand All @@ -89,30 +58,6 @@ std::vector<bb::fr> convert_to_bn254_frs(const grumpkin::fr& val)
return result;
}

std::vector<bb::fr> convert_to_bn254_frs(const bb::fr& val)
{
std::vector<bb::fr> fr_vec{ val };
return fr_vec;
}

std::vector<bb::fr> convert_to_bn254_frs(const curve::BN254::AffineElement& val)
{
auto fr_vec_x = convert_to_bn254_frs(val.x);
auto fr_vec_y = convert_to_bn254_frs(val.y);
std::vector<bb::fr> fr_vec(fr_vec_x.begin(), fr_vec_x.end());
fr_vec.insert(fr_vec.end(), fr_vec_y.begin(), fr_vec_y.end());
return fr_vec;
}

std::vector<bb::fr> convert_to_bn254_frs(const curve::Grumpkin::AffineElement& val)
{
auto fr_vec_x = convert_to_bn254_frs(val.x);
auto fr_vec_y = convert_to_bn254_frs(val.y);
std::vector<bb::fr> fr_vec(fr_vec_x.begin(), fr_vec_x.end());
fr_vec.insert(fr_vec.end(), fr_vec_y.begin(), fr_vec_y.end());
return fr_vec;
}

grumpkin::fr convert_to_grumpkin_fr(const bb::fr& f)
{
const uint64_t NUM_BITS_IN_TWO_LIMBS = 2 * NUM_LIMB_BITS; // the number of bits in 2 bigfield limbs which is 136
Expand Down
225 changes: 64 additions & 161 deletions barretenberg/cpp/src/barretenberg/ecc/fields/field_conversion.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "barretenberg/ecc/curves/bn254/fr.hpp"
#include "barretenberg/ecc/curves/grumpkin/grumpkin.hpp"
#include "barretenberg/polynomials/univariate.hpp"
#include "barretenberg/proof_system/types/circuit_type.hpp"

namespace bb::field_conversion {

Expand All @@ -15,48 +16,22 @@ namespace bb::field_conversion {
* @tparam T
* @return constexpr size_t
*/
template <typename T> constexpr size_t calc_num_bn254_frs();

constexpr size_t calc_num_bn254_frs(bb::fr* /*unused*/)
{
return 1;
}

constexpr size_t calc_num_bn254_frs(grumpkin::fr* /*unused*/)
{
return 2;
}

template <std::integral T> constexpr size_t calc_num_bn254_frs(T* /*unused*/)
{
return 1; // meant for integral types that are less than 254 bits
}

constexpr size_t calc_num_bn254_frs(curve::BN254::AffineElement* /*unused*/)
{
return 2 * calc_num_bn254_frs<curve::BN254::BaseField>();
}

constexpr size_t calc_num_bn254_frs(curve::Grumpkin::AffineElement* /*unused*/)
{
return 2 * calc_num_bn254_frs<curve::Grumpkin::BaseField>();
}

template <typename T, std::size_t N> constexpr size_t calc_num_bn254_frs(std::array<T, N>* /*unused*/)
{
return N * calc_num_bn254_frs<T>();
}

template <typename T, std::size_t N> constexpr size_t calc_num_bn254_frs(bb::Univariate<T, N>* /*unused*/)
{
return N * calc_num_bn254_frs<T>();
}

template <typename T> constexpr size_t calc_num_bn254_frs()
{
return calc_num_bn254_frs(static_cast<T*>(nullptr));
if constexpr (IsAnyOf<T, uint32_t, bool>) {
return 1;
} else if constexpr (IsAnyOf<T, bb::fr, grumpkin::fr>) {
return T::Params::NUM_BN254_SCALARS;
} else if constexpr (IsAnyOf<T, curve::BN254::AffineElement, curve::Grumpkin::AffineElement>) {
return 2 * calc_num_bn254_frs<typename T::Fq>();
} else {
// Array or Univariate
return calc_num_bn254_frs<typename T::value_type>() * (std::tuple_size<T>::value);
}
}

grumpkin::fr convert_grumpkin_fr_from_bn254_frs(std::span<const bb::fr> fr_vec);

/**
* @brief Conversions from vector of bb::fr elements to transcript types.
* @details We want to support the following types: bool, size_t, uint32_t, uint64_t, bb::fr, grumpkin::fr,
Expand All @@ -68,75 +43,40 @@ template <typename T> constexpr size_t calc_num_bn254_frs()
* @param fr_vec
* @return T
*/
template <typename T> T convert_from_bn254_frs(std::span<const bb::fr> fr_vec);

bool convert_from_bn254_frs(std::span<const bb::fr> fr_vec, bool* /*unused*/);

template <std::integral T> inline T convert_from_bn254_frs(std::span<const bb::fr> fr_vec, T* /*unused*/)
{
ASSERT(fr_vec.size() == 1);
return static_cast<T>(fr_vec[0]);
}

bb::fr convert_from_bn254_frs(std::span<const bb::fr> fr_vec, bb::fr* /*unused*/);

grumpkin::fr convert_from_bn254_frs(std::span<const bb::fr> fr_vec, grumpkin::fr* /*unused*/);

curve::BN254::AffineElement convert_from_bn254_frs(std::span<const bb::fr> fr_vec,
curve::BN254::AffineElement* /*unused*/);

curve::Grumpkin::AffineElement convert_from_bn254_frs(std::span<const bb::fr> fr_vec,
curve::Grumpkin::AffineElement* /*unused*/);

template <size_t N>
inline std::array<bb::fr, N> convert_from_bn254_frs(std::span<const bb::fr> fr_vec, std::array<bb::fr, N>* /*unused*/)
{
std::array<bb::fr, N> val;
for (size_t i = 0; i < N; ++i) {
val[i] = fr_vec[i];
}
return val;
}

template <size_t N>
inline std::array<grumpkin::fr, N> convert_from_bn254_frs(std::span<const bb::fr> fr_vec,
std::array<grumpkin::fr, N>* /*unused*/)
{
std::array<grumpkin::fr, N> val;
for (size_t i = 0; i < N; ++i) {
std::vector<bb::fr> fr_vec_tmp{ fr_vec[2 * i],
fr_vec[2 * i + 1] }; // each pair of consecutive elements is a grumpkin::fr
val[i] = convert_from_bn254_frs<grumpkin::fr>(fr_vec_tmp);
}
return val;
}

template <size_t N>
inline Univariate<bb::fr, N> convert_from_bn254_frs(std::span<const bb::fr> fr_vec, Univariate<bb::fr, N>* /*unused*/)
{
Univariate<bb::fr, N> val;
for (size_t i = 0; i < N; ++i) {
val.evaluations[i] = fr_vec[i];
}
return val;
}

template <size_t N>
inline Univariate<grumpkin::fr, N> convert_from_bn254_frs(std::span<const bb::fr> fr_vec,
Univariate<grumpkin::fr, N>* /*unused*/)
template <typename T> T convert_from_bn254_frs(std::span<const bb::fr> fr_vec)
{
Univariate<grumpkin::fr, N> val;
for (size_t i = 0; i < N; ++i) {
std::vector<bb::fr> fr_vec_tmp{ fr_vec[2 * i], fr_vec[2 * i + 1] };
val.evaluations[i] = convert_from_bn254_frs<grumpkin::fr>(fr_vec_tmp);
if constexpr (IsAnyOf<T, bool>) {
ASSERT(fr_vec.size() == 1);
return bool(fr_vec[0]);
} else if constexpr (IsAnyOf<T, uint32_t, bb::fr>) {
ASSERT(fr_vec.size() == 1);
return static_cast<T>(fr_vec[0]);
} else if constexpr (IsAnyOf<T, grumpkin::fr>) {
ASSERT(fr_vec.size() == 2);
return convert_grumpkin_fr_from_bn254_frs(fr_vec);
} else if constexpr (IsAnyOf<T, curve::BN254::AffineElement, curve::Grumpkin::AffineElement>) {
using BaseField = typename T::Fq;
constexpr size_t BASE_FIELD_SCALAR_SIZE = calc_num_bn254_frs<BaseField>();
ASSERT(fr_vec.size() == 2 * BASE_FIELD_SCALAR_SIZE);
T val;
val.x = convert_from_bn254_frs<BaseField>(fr_vec.subspan(0, BASE_FIELD_SCALAR_SIZE));
val.y = convert_from_bn254_frs<BaseField>(fr_vec.subspan(BASE_FIELD_SCALAR_SIZE, BASE_FIELD_SCALAR_SIZE));
return val;
} else {
// Array or Univariate
T val;
constexpr size_t FieldScalarSize = calc_num_bn254_frs<typename T::value_type>();
ASSERT(fr_vec.size() == FieldScalarSize * std::tuple_size<T>::value);
size_t i = 0;
for (auto& x : val) {
x = convert_from_bn254_frs<typename T::value_type>(fr_vec.subspan(FieldScalarSize * i, FieldScalarSize));
++i;
}
return val;
}
return val;
}

template <typename T> T convert_from_bn254_frs(std::span<const bb::fr> fr_vec)
{
return convert_from_bn254_frs(fr_vec, static_cast<T*>(nullptr));
}
std::vector<bb::fr> convert_grumpkin_fr_to_bn254_frs(const grumpkin::fr& val);

/**
* @brief Conversion from transcript values to bb::frs
Expand All @@ -147,65 +87,28 @@ template <typename T> T convert_from_bn254_frs(std::span<const bb::fr> fr_vec)
* @param val
* @return std::vector<bb::fr>
*/
template <std::integral T> std::vector<bb::fr> inline convert_to_bn254_frs(const T& val)
{
std::vector<bb::fr> fr_vec{ val };
return fr_vec;
}

std::vector<bb::fr> convert_to_bn254_frs(const grumpkin::fr& val);

std::vector<bb::fr> convert_to_bn254_frs(const bb::fr& val);

std::vector<bb::fr> convert_to_bn254_frs(const curve::BN254::AffineElement& val);

std::vector<bb::fr> convert_to_bn254_frs(const curve::Grumpkin::AffineElement& val);

template <size_t N> std::vector<bb::fr> inline convert_to_bn254_frs(const std::array<bb::fr, N>& val)
{
std::vector<bb::fr> fr_vec(val.begin(), val.end());
return fr_vec;
}

template <size_t N> std::vector<bb::fr> inline convert_to_bn254_frs(const std::array<grumpkin::fr, N>& val)
{
std::vector<bb::fr> fr_vec;
for (size_t i = 0; i < N; ++i) {
auto tmp_vec = convert_to_bn254_frs(val[i]);
fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end());
}
return fr_vec;
}

template <size_t N> std::vector<bb::fr> inline convert_to_bn254_frs(const bb::Univariate<bb::fr, N>& val)
{
std::vector<bb::fr> fr_vec;
for (size_t i = 0; i < N; ++i) {
auto tmp_vec = convert_to_bn254_frs(val.evaluations[i]);
fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end());
}
return fr_vec;
}

template <size_t N> std::vector<bb::fr> inline convert_to_bn254_frs(const bb::Univariate<grumpkin::fr, N>& val)
{
std::vector<bb::fr> fr_vec;
for (size_t i = 0; i < N; ++i) {
auto tmp_vec = convert_to_bn254_frs(val.evaluations[i]);
fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end());
}
return fr_vec;
}

template <typename AllValues> std::vector<bb::fr> inline convert_to_bn254_frs(const AllValues& val)
{
auto data = val.get_all();
std::vector<bb::fr> fr_vec;
for (auto& item : data) {
auto tmp_vec = convert_to_bn254_frs(item);
fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end());
template <typename T> std::vector<bb::fr> convert_to_bn254_frs(const T& val)
{
if constexpr (IsAnyOf<T, bool, uint32_t, bb::fr>) {
std::vector<bb::fr> fr_vec{ val };
return fr_vec;
} else if constexpr (IsAnyOf<T, grumpkin::fr>) {
return convert_grumpkin_fr_to_bn254_frs(val);
} else if constexpr (IsAnyOf<T, curve::BN254::AffineElement, curve::Grumpkin::AffineElement>) {
auto fr_vec_x = convert_to_bn254_frs(val.x);
auto fr_vec_y = convert_to_bn254_frs(val.y);
std::vector<bb::fr> fr_vec(fr_vec_x.begin(), fr_vec_x.end());
fr_vec.insert(fr_vec.end(), fr_vec_y.begin(), fr_vec_y.end());
return fr_vec;
} else {
// Array or Univariate
std::vector<bb::fr> fr_vec;
for (auto& x : val) {
auto tmp_vec = convert_to_bn254_frs(x);
fr_vec.insert(fr_vec.end(), tmp_vec.begin(), tmp_vec.end());
}
return fr_vec;
}
return fr_vec;
}

grumpkin::fr convert_to_grumpkin_fr(const bb::fr& f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
namespace bb::group_elements {
template <typename T>
concept SupportsHashToCurve = T::can_hash_to_curve;
template <typename Fq, typename Fr, typename Params> class alignas(64) affine_element {
template <typename Fq_, typename Fr_, typename Params> class alignas(64) affine_element {
public:
using Fq = Fq_;
using Fr = Fr_;

using in_buf = const uint8_t*;
using vec_in_buf = const uint8_t*;
using out_buf = uint8_t*;
Expand Down
Loading

0 comments on commit 949c9cb

Please sign in to comment.