Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
benesjan committed Jul 15, 2024
1 parent bd0fbbc commit 861fae6
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 23 deletions.
33 changes: 33 additions & 0 deletions barretenberg/cpp/src/barretenberg/ecc/curves/bn254/c_bind.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include "../bn254/fq.hpp"
#include "../bn254/fr.hpp"
#include "barretenberg/common/wasm_export.hpp"

using namespace bb;

WASM_EXPORT void bn254_fr_sqrt(uint8_t const* input, uint8_t* result)
{
using serialize::write;
auto input_fr = from_buffer<bb::fr>(input);
auto [is_sqr, root] = input_fr.sqrt();

uint8_t* is_sqrt_result_ptr = result;
uint8_t* root_result_ptr = result + 1;

write(is_sqrt_result_ptr, is_sqr);
write(root_result_ptr, root);
}

WASM_EXPORT void bn254_fq_sqrt(uint8_t const* input, uint8_t* result)
{
using serialize::write;
auto input_fq = from_buffer<bb::fq>(input);
auto [is_sqr, root] = input_fq.sqrt();

uint8_t* is_sqrt_result_ptr = result;
uint8_t* root_result_ptr = result + 1;

write(is_sqrt_result_ptr, is_sqr);
write(root_result_ptr, root);
}

// NOLINTEND(cert-dcl37-c, cert-dcl51-cpp, bugprone-reserved-identifier)
13 changes: 0 additions & 13 deletions barretenberg/cpp/src/barretenberg/ecc/curves/grumpkin/c_bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,4 @@ WASM_EXPORT void ecc_grumpkin__reduce512_buffer_mod_circuit_modulus(uint8_t* inp
write(result, target_output.lo);
}

WASM_EXPORT void grumpkin_fr_sqrt(uint8_t const* input, uint8_t* result)
{
using serialize::write;
auto input_fr = from_buffer<grumpkin::fr>(input);
auto [is_sqr, root] = input_fr.sqrt();

uint8_t* is_sqrt_result_ptr = result;
uint8_t* root_result_ptr = result + 1;

write(is_sqrt_result_ptr, is_sqr);
write(root_result_ptr, root);
}

// NOLINTEND(cert-dcl37-c, cert-dcl51-cpp, bugprone-reserved-identifier)
20 changes: 14 additions & 6 deletions yarn-project/foundation/src/fields/fields.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Fr, GrumpkinScalar } from './fields.js';
import { Fq, Fr, GrumpkinScalar } from './fields.js';

describe('GrumpkinScalar Serialization', () => {
// Test case for GrumpkinScalar.fromHighLow
Expand Down Expand Up @@ -185,12 +185,20 @@ describe('Bn254 arithmetic', () => {
});

describe('Square root', () => {
it('Should return the correct square root', () => {
const a = new Fr(16);
const expected = new Fr(4);
it('Should return the correct square root for Fr', () => {
const a = Fr.random();
const squared = a.mul(a);

const actual = a.sqrt();
expect(actual).toEqual(expected);
const actual = squared.sqrt();
expect(actual!.mul(actual!)).toEqual(squared);
});

it('Should return the correct square root for Fq', () => {
const a = Fq.random();
const squared = a.mul(a);

const actual = squared.sqrt();
expect(actual!.mul(actual!)).toEqual(squared);
});
});

Expand Down
31 changes: 27 additions & 4 deletions yarn-project/foundation/src/fields/fields.ts
Original file line number Diff line number Diff line change
Expand Up @@ -283,21 +283,21 @@ export class Fr extends BaseField {
}

/**
* Computes the square root of the field element.
* @returns The square root of the field element (null if it does not exist).
* Computes a square root of the field element.
* @returns A square root of the field element (null if it does not exist).
*/
sqrt(): Fr | null {
const wasm = BarretenbergSync.getSingleton().getWasm();
wasm.writeMemory(0, this.toBuffer());
wasm.call('grumpkin_fr_sqrt', 0, Fr.SIZE_IN_BYTES);
wasm.call('bn254_fr_sqrt', 0, Fr.SIZE_IN_BYTES);
const isSqrtBuf = Buffer.from(wasm.getMemorySlice(Fr.SIZE_IN_BYTES, Fr.SIZE_IN_BYTES + 1));
const isSqrt = isSqrtBuf[0] === 1;
if (!isSqrt) {
// Field element is not a quadratic residue mod p so it has no square root.
return null;
}

const rootBuf = Buffer.from(wasm.getMemorySlice(Fr.SIZE_IN_BYTES + 1, Fr.SIZE_IN_BYTES * 2 + 1))
const rootBuf = Buffer.from(wasm.getMemorySlice(Fr.SIZE_IN_BYTES + 1, Fr.SIZE_IN_BYTES * 2 + 1));
return Fr.fromBuffer(rootBuf);
}

Expand Down Expand Up @@ -380,6 +380,29 @@ export class Fq extends BaseField {
return new Fq((high.toBigInt() << Fq.HIGH_SHIFT) + low.toBigInt());
}

mul(rhs: Fq) {
return new Fq((this.toBigInt() * rhs.toBigInt()) % Fq.MODULUS);
}

/**
* Computes a square root of the field element.
* @returns A square root of the field element (null if it does not exist).
*/
sqrt(): Fq | null {
const wasm = BarretenbergSync.getSingleton().getWasm();
wasm.writeMemory(0, this.toBuffer());
wasm.call('bn254_fq_sqrt', 0, Fq.SIZE_IN_BYTES);
const isSqrtBuf = Buffer.from(wasm.getMemorySlice(Fq.SIZE_IN_BYTES, Fq.SIZE_IN_BYTES + 1));
const isSqrt = isSqrtBuf[0] === 1;
if (!isSqrt) {
// Field element is not a quadratic residue mod p so it has no square root.
return null;
}

const rootBuf = Buffer.from(wasm.getMemorySlice(Fq.SIZE_IN_BYTES + 1, Fq.SIZE_IN_BYTES * 2 + 1));
return Fq.fromBuffer(rootBuf);
}

toJSON() {
return {
type: 'Fq',
Expand Down

0 comments on commit 861fae6

Please sign in to comment.