diff --git a/barretenberg/cpp/src/barretenberg/ecc/curves/bn254/c_bind.cpp b/barretenberg/cpp/src/barretenberg/ecc/curves/bn254/c_bind.cpp new file mode 100644 index 000000000000..85070ee080f1 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/ecc/curves/bn254/c_bind.cpp @@ -0,0 +1,33 @@ +#include "../bn254/fq.hpp" +#include "../bn254/fr.hpp" +#include "barretenberg/common/wasm_export.hpp" + +using namespace bb; + +WASM_EXPORT void bn254_fr_sqrt(uint8_t const* input, uint8_t* result) +{ + using serialize::write; + auto input_fr = from_buffer(input); + auto [is_sqr, root] = input_fr.sqrt(); + + uint8_t* is_sqrt_result_ptr = result; + uint8_t* root_result_ptr = result + 1; + + write(is_sqrt_result_ptr, is_sqr); + write(root_result_ptr, root); +} + +WASM_EXPORT void bn254_fq_sqrt(uint8_t const* input, uint8_t* result) +{ + using serialize::write; + auto input_fq = from_buffer(input); + auto [is_sqr, root] = input_fq.sqrt(); + + uint8_t* is_sqrt_result_ptr = result; + uint8_t* root_result_ptr = result + 1; + + write(is_sqrt_result_ptr, is_sqr); + write(root_result_ptr, root); +} + +// NOLINTEND(cert-dcl37-c, cert-dcl51-cpp, bugprone-reserved-identifier) \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/ecc/curves/grumpkin/c_bind.cpp b/barretenberg/cpp/src/barretenberg/ecc/curves/grumpkin/c_bind.cpp index b206430a3f33..78bb4a04fc54 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/curves/grumpkin/c_bind.cpp +++ b/barretenberg/cpp/src/barretenberg/ecc/curves/grumpkin/c_bind.cpp @@ -63,17 +63,4 @@ WASM_EXPORT void ecc_grumpkin__reduce512_buffer_mod_circuit_modulus(uint8_t* inp write(result, target_output.lo); } -WASM_EXPORT void grumpkin_fr_sqrt(uint8_t const* input, uint8_t* result) -{ - using serialize::write; - auto input_fr = from_buffer(input); - auto [is_sqr, root] = input_fr.sqrt(); - - uint8_t* is_sqrt_result_ptr = result; - uint8_t* root_result_ptr = result + 1; - - write(is_sqrt_result_ptr, is_sqr); - write(root_result_ptr, root); -} - // NOLINTEND(cert-dcl37-c, cert-dcl51-cpp, bugprone-reserved-identifier) \ No newline at end of file diff --git a/yarn-project/foundation/src/fields/fields.test.ts b/yarn-project/foundation/src/fields/fields.test.ts index d1a728e67d77..bd31319bcc83 100644 --- a/yarn-project/foundation/src/fields/fields.test.ts +++ b/yarn-project/foundation/src/fields/fields.test.ts @@ -1,4 +1,4 @@ -import { Fr, GrumpkinScalar } from './fields.js'; +import { Fq, Fr, GrumpkinScalar } from './fields.js'; describe('GrumpkinScalar Serialization', () => { // Test case for GrumpkinScalar.fromHighLow @@ -185,12 +185,20 @@ describe('Bn254 arithmetic', () => { }); describe('Square root', () => { - it('Should return the correct square root', () => { - const a = new Fr(16); - const expected = new Fr(4); + it('Should return the correct square root for Fr', () => { + const a = Fr.random(); + const squared = a.mul(a); - const actual = a.sqrt(); - expect(actual).toEqual(expected); + const actual = squared.sqrt(); + expect(actual!.mul(actual!)).toEqual(squared); + }); + + it('Should return the correct square root for Fq', () => { + const a = Fq.random(); + const squared = a.mul(a); + + const actual = squared.sqrt(); + expect(actual!.mul(actual!)).toEqual(squared); }); }); diff --git a/yarn-project/foundation/src/fields/fields.ts b/yarn-project/foundation/src/fields/fields.ts index 7718d9d0abc4..c84b547cf3ef 100644 --- a/yarn-project/foundation/src/fields/fields.ts +++ b/yarn-project/foundation/src/fields/fields.ts @@ -283,13 +283,13 @@ export class Fr extends BaseField { } /** - * Computes the square root of the field element. - * @returns The square root of the field element (null if it does not exist). + * Computes a square root of the field element. + * @returns A square root of the field element (null if it does not exist). */ sqrt(): Fr | null { const wasm = BarretenbergSync.getSingleton().getWasm(); wasm.writeMemory(0, this.toBuffer()); - wasm.call('grumpkin_fr_sqrt', 0, Fr.SIZE_IN_BYTES); + wasm.call('bn254_fr_sqrt', 0, Fr.SIZE_IN_BYTES); const isSqrtBuf = Buffer.from(wasm.getMemorySlice(Fr.SIZE_IN_BYTES, Fr.SIZE_IN_BYTES + 1)); const isSqrt = isSqrtBuf[0] === 1; if (!isSqrt) { @@ -297,7 +297,7 @@ export class Fr extends BaseField { return null; } - const rootBuf = Buffer.from(wasm.getMemorySlice(Fr.SIZE_IN_BYTES + 1, Fr.SIZE_IN_BYTES * 2 + 1)) + const rootBuf = Buffer.from(wasm.getMemorySlice(Fr.SIZE_IN_BYTES + 1, Fr.SIZE_IN_BYTES * 2 + 1)); return Fr.fromBuffer(rootBuf); } @@ -380,6 +380,29 @@ export class Fq extends BaseField { return new Fq((high.toBigInt() << Fq.HIGH_SHIFT) + low.toBigInt()); } + mul(rhs: Fq) { + return new Fq((this.toBigInt() * rhs.toBigInt()) % Fq.MODULUS); + } + + /** + * Computes a square root of the field element. + * @returns A square root of the field element (null if it does not exist). + */ + sqrt(): Fq | null { + const wasm = BarretenbergSync.getSingleton().getWasm(); + wasm.writeMemory(0, this.toBuffer()); + wasm.call('bn254_fq_sqrt', 0, Fq.SIZE_IN_BYTES); + const isSqrtBuf = Buffer.from(wasm.getMemorySlice(Fq.SIZE_IN_BYTES, Fq.SIZE_IN_BYTES + 1)); + const isSqrt = isSqrtBuf[0] === 1; + if (!isSqrt) { + // Field element is not a quadratic residue mod p so it has no square root. + return null; + } + + const rootBuf = Buffer.from(wasm.getMemorySlice(Fq.SIZE_IN_BYTES + 1, Fq.SIZE_IN_BYTES * 2 + 1)); + return Fq.fromBuffer(rootBuf); + } + toJSON() { return { type: 'Fq',