Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Point::fromXandSign(...) #7455

Merged
merged 11 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions barretenberg/cpp/src/barretenberg/ecc/curves/bn254/c_bind.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "../bn254/fr.hpp"
#include "barretenberg/common/wasm_export.hpp"

using namespace bb;

WASM_EXPORT void bn254_fr_sqrt(uint8_t const* input, uint8_t* result)
{
using serialize::write;
auto input_fr = from_buffer<bb::fr>(input);
auto [is_sqr, root] = input_fr.sqrt();

uint8_t* is_sqrt_result_ptr = result;
uint8_t* root_result_ptr = result + 1;

write(is_sqrt_result_ptr, is_sqr);
write(root_result_ptr, root);
}

// NOLINTEND(cert-dcl37-c, cert-dcl51-cpp, bugprone-reserved-identifier)
9 changes: 9 additions & 0 deletions yarn-project/foundation/src/crypto/random/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,12 @@ export const randomBigInt = (max: bigint) => {
const randomBigInt = BigInt(`0x${randomBuffer.toString('hex')}`); // Convert buffer to a large integer.
return randomBigInt % max; // Use modulo to ensure the result is less than max.
};

/**
* Generate a random boolean value.
* @returns A random boolean value.
*/
export const randomBoolean = () => {
const randomByte = randomBytes(1)[0]; // Generate a single random byte.
return randomByte % 2 === 0; // Use modulo to determine if the byte is even or odd.
};
26 changes: 25 additions & 1 deletion yarn-project/foundation/src/fields/fields.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ describe('Bn254 arithmetic', () => {
expect(actual).toEqual(expected);
});

it('High Bonudary', () => {
it('High Boundary', () => {
// -1 - (-1) = 0
const a = new Fr(Fr.MODULUS - 1n);
const b = new Fr(Fr.MODULUS - 1n);
Expand Down Expand Up @@ -184,6 +184,30 @@ describe('Bn254 arithmetic', () => {
});
});

describe('Square root', () => {
it.each([
[new Fr(0), 0n],
[new Fr(4), 2n],
[new Fr(9), 3n],
[new Fr(16), 4n],
])('Should return the correct square root for %p', (input, expected) => {
const actual = input.sqrt()!.toBigInt();

// The square root can be either the expected value or the modulus - expected value
const isValid = actual == expected || actual == Fr.MODULUS - expected;

expect(isValid).toBeTruthy();
});

it('Should return the correct square root for random value', () => {
const a = Fr.random();
const squared = a.mul(a);

const actual = squared.sqrt();
expect(actual!.mul(actual!)).toEqual(squared);
});
});

describe('Comparison', () => {
it.each([
[new Fr(5), new Fr(10), -1],
Expand Down
21 changes: 21 additions & 0 deletions yarn-project/foundation/src/fields/fields.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { BarretenbergSync } from '@aztec/bb.js';

import { inspect } from 'util';

import { toBigIntBE, toBufferBE } from '../bigint-buffer/index.js';
Expand Down Expand Up @@ -280,6 +282,25 @@ export class Fr extends BaseField {
return new Fr(this.toBigInt() / rhs.toBigInt());
}

/**
* Computes a square root of the field element.
* @returns A square root of the field element (null if it does not exist).
*/
sqrt(): Fr | null {
const wasm = BarretenbergSync.getSingleton().getWasm();
wasm.writeMemory(0, this.toBuffer());
wasm.call('bn254_fr_sqrt', 0, Fr.SIZE_IN_BYTES);
const isSqrtBuf = Buffer.from(wasm.getMemorySlice(Fr.SIZE_IN_BYTES, Fr.SIZE_IN_BYTES + 1));
const isSqrt = isSqrtBuf[0] === 1;
if (!isSqrt) {
// Field element is not a quadratic residue mod p so it has no square root.
return null;
}

const rootBuf = Buffer.from(wasm.getMemorySlice(Fr.SIZE_IN_BYTES + 1, Fr.SIZE_IN_BYTES * 2 + 1));
return Fr.fromBuffer(rootBuf);
}

toJSON() {
return {
type: 'Fr',
Expand Down
35 changes: 35 additions & 0 deletions yarn-project/foundation/src/fields/point.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import { Fr } from './fields.js';
import { Point } from './point.js';

describe('Point', () => {
it('converts to and from x and sign of y coordinate', () => {
const p = new Point(
new Fr(0x30426e64aee30e998c13c8ceecda3a77807dbead52bc2f3bf0eae851b4b710c1n),
new Fr(0x113156a068f603023240c96b4da5474667db3b8711c521c748212a15bc034ea6n),
false,
);

const [x, sign] = p.toXAndSign();
const p2 = Point.fromXAndSign(x, sign);

expect(p.equals(p2)).toBeTruthy();
});

it('creates a valid random point', () => {
expect(Point.random().isOnGrumpkin()).toBeTruthy();
});

it('converts to and from buffer', () => {
const p = Point.random();
const p2 = Point.fromBuffer(p.toBuffer());

expect(p.equals(p2)).toBeTruthy();
});

it('converts to and from compressed buffer', () => {
const p = Point.random();
const p2 = Point.fromCompressedBuffer(p.toCompressedBuffer());

expect(p.equals(p2)).toBeTruthy();
});
});
83 changes: 80 additions & 3 deletions yarn-project/foundation/src/fields/point.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { poseidon2Hash } from '../crypto/index.js';
import { poseidon2Hash, randomBoolean } from '../crypto/index.js';
import { BufferReader, FieldReader, serializeToBuffer } from '../serialize/index.js';
import { Fr } from './fields.js';

Expand All @@ -10,6 +10,7 @@ import { Fr } from './fields.js';
export class Point {
static ZERO = new Point(Fr.ZERO, Fr.ZERO, false);
static SIZE_IN_BYTES = Fr.SIZE_IN_BYTES * 2;
static COMPRESSED_SIZE_IN_BYTES = Fr.SIZE_IN_BYTES + 1;

/** Used to differentiate this class from AztecAddress */
public readonly kind = 'point';
Expand Down Expand Up @@ -37,8 +38,17 @@ export class Point {
* @returns A randomly generated Point instance.
*/
static random() {
// TODO make this return an actual point on curve.
nventuro marked this conversation as resolved.
Show resolved Hide resolved
return new Point(Fr.random(), Fr.random(), false);
while (true) {
try {
return Point.fromXAndSign(Fr.random(), randomBoolean());
} catch (e: any) {
if (!(e instanceof NotOnCurveError)) {
throw e;
}
// The random point is not on the curve - we try again
continue;
}
}
}
nventuro marked this conversation as resolved.
Show resolved Hide resolved

/**
Expand All @@ -53,6 +63,18 @@ export class Point {
return new this(Fr.fromBuffer(reader), Fr.fromBuffer(reader), false);
}

/**
* Create a Point instance from a compressed buffer.
* The input 'buffer' should have exactly 33 bytes representing the x coordinate and the sign of the y coordinate.
*
* @param buffer - The buffer containing the x coordinate and the sign of the y coordinate.
* @returns A Point instance.
*/
static fromCompressedBuffer(buffer: Buffer | BufferReader) {
const reader = BufferReader.asReader(buffer);
return this.fromXAndSign(Fr.fromBuffer(reader), reader.readBoolean());
}

/**
* Create a Point instance from a hex-encoded string.
* The input 'address' should be prefixed with '0x' or not, and have exactly 128 hex characters representing the x and y coordinates.
Expand All @@ -78,6 +100,46 @@ export class Point {
return new this(reader.readField(), reader.readField(), reader.readBoolean());
}

/**
* Uses the x coordinate and isPositive flag (+/-) to reconstruct the point.
* @dev The y coordinate can be derived from the x coordinate and the "sign" flag by solving the grumpkin curve
* equation for y.
* @param x - The x coordinate of the point
* @param sign - The "sign" of the y coordinate - note that this is not a sign as is known in integer arithmetic.
* Instead it is a boolean flag that determines whether the y coordinate is <= (Fr.MODULUS - 1) / 2
* @returns The point as an array of 2 fields
*/
static fromXAndSign(x: Fr, sign: boolean) {
// Calculate y^2 = x^3 - 17
const ySquared = x.square().mul(x).sub(new Fr(17));

// Calculate the square root of ySquared
const y = ySquared.sqrt();

// If y is null, the x-coordinate is not on the curve
if (y === null) {
throw new NotOnCurveError();
}

const yPositiveBigInt = y.toBigInt() > (Fr.MODULUS - 1n) / 2n ? Fr.MODULUS - y.toBigInt() : y.toBigInt();
const yNegativeBigInt = Fr.MODULUS - yPositiveBigInt;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const yPositiveBigInt = y.toBigInt() > (Fr.MODULUS - 1n) / 2n ? Fr.MODULUS - y.toBigInt() : y.toBigInt();
const yPositiveBigInt = y.toBigInt() <= (Fr.MODULUS - 1n) / 2n ? y.toBigInt() : Fr.MODULUS - y.toBigInt();

Nit, but this way the sign condition is the same in both places we use it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah silly automerge

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will sneak it in a PR up the stack 👍


// Choose the positive or negative root based on isPositive
const finalY = sign ? new Fr(yPositiveBigInt) : new Fr(yNegativeBigInt);

// Create and return the new Point
return new this(x, finalY, false);
}

/**
* Returns the x coordinate and the sign of the y coordinate.
* @dev The y sign can be determined by checking if the y coordinate is greater than half of the modulus.
* @returns The x coordinate and the sign of the y coordinate.
*/
toXAndSign(): [Fr, boolean] {
return [this.x, this.y.toBigInt() <= (Fr.MODULUS - 1n) / 2n];
}

/**
* Returns the contents of the point as BigInts.
* @returns The point as BigInts
Expand Down Expand Up @@ -111,6 +173,14 @@ export class Point {
return buf;
}

/**
* Converts the Point instance to a compressed Buffer representation of the coordinates.
* @returns A Buffer representation of the Point instance
*/
toCompressedBuffer() {
return serializeToBuffer(this.toXAndSign());
}

/**
* Convert the Point instance to a hexadecimal string representation.
* The output string is prefixed with '0x' and consists of exactly 128 hex characters,
Expand Down Expand Up @@ -194,3 +264,10 @@ export function isPoint(obj: object): obj is Point {
const point = obj as Point;
return point.kind === 'point' && point.x !== undefined && point.y !== undefined;
}

class NotOnCurveError extends Error {
constructor() {
super('The given x-coordinate is not on the Grumpkin curve');
this.name = 'NotOnCurveError';
}
}
Loading