Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(avm-simulator): msm blackbox #7048

Merged
merged 2 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions avm-transpiler/src/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ pub enum AvmOpcode {
SHA256, // temp - may be removed, but alot of contracts rely on it
PEDERSEN, // temp - may be removed, but alot of contracts rely on it
ECADD,
MSM,
IlyasRidhuan marked this conversation as resolved.
Show resolved Hide resolved
// Conversions
TORADIXLE,
}
Expand Down Expand Up @@ -165,6 +166,7 @@ impl AvmOpcode {
AvmOpcode::SHA256 => "SHA256 ",
AvmOpcode::PEDERSEN => "PEDERSEN",
AvmOpcode::ECADD => "ECADD",
AvmOpcode::MSM => "MSM",
// Conversions
AvmOpcode::TORADIXLE => "TORADIXLE",
}
Expand Down
24 changes: 24 additions & 0 deletions avm-transpiler/src/transpile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,30 @@ fn handle_black_box_function(avm_instrs: &mut Vec<AvmInstruction>, operation: &B
],
..Default::default()
}),
// Temporary while we dont have efficient noir implementations
BlackBoxOp::MultiScalarMul { points, scalars, outputs } => {
// The length of the scalars vector is 2x the length of the points vector due to limb
// decomposition
let points_offset = points.pointer.0;
let num_points = points.size.0;
let scalars_offset = scalars.pointer.0;
// Output array is fixed to 3
assert_eq!(outputs.size, 3, "Output array size must be equal to 3");
let outputs_offset = outputs.pointer.0;
IlyasRidhuan marked this conversation as resolved.
Show resolved Hide resolved
avm_instrs.push(AvmInstruction {
opcode: AvmOpcode::MSM,
indirect: Some(
ZEROTH_OPERAND_INDIRECT | FIRST_OPERAND_INDIRECT | SECOND_OPERAND_INDIRECT,
),
operands: vec![
AvmOperand::U32 { value: points_offset as u32 },
AvmOperand::U32 { value: scalars_offset as u32 },
AvmOperand::U32 { value: outputs_offset as u32 },
AvmOperand::U32 { value: num_points as u32 },
],
..Default::default()
});
}
_ => panic!("Transpiler doesn't know how to process {:?}", operation),
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ const std::unordered_map<OpCode, std::vector<OperandType>> OPCODE_WIRE_FORMAT =
OperandType::UINT32, // rhs.y
OperandType::UINT32, // rhs.is_infinite
OperandType::UINT32 } }, // dst_offset
{ OpCode::MSM,
{ OperandType::INDIRECT, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32 } },
// Gadget - Conversion
{ OpCode::TORADIXLE,
{ OperandType::INDIRECT, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32 } },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ enum class OpCode : uint8_t {
SHA256,
PEDERSEN,
ECADD,
MSM,
// Conversions
TORADIXLE,
// Future Gadgets -- pending changes in noir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ contract AvmTest {
global big_field_136_bits: Field = 0x991234567890abcdef1234567890abcdef;

// Libs
use dep::std::embedded_curve_ops::EmbeddedCurvePoint;
use dep::std::embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar, multi_scalar_mul};
use dep::aztec::protocol_types::constants::CONTRACT_INSTANCE_LENGTH;
use dep::aztec::prelude::{Map, Deserialize};
use dep::aztec::state_vars::PublicMutable;
Expand Down Expand Up @@ -144,6 +144,15 @@ contract AvmTest {
added
}

#[aztec(public)]
fn variable_base_msm() -> [Field; 3] {
let g = EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false };
let scalar = EmbeddedCurveScalar { lo: 3, hi: 0 };
let scalar2 = EmbeddedCurveScalar { lo: 20, hi: 0 };
let triple_g = multi_scalar_mul([g, g], [scalar, scalar2]);
triple_g
}

/************************************************************************
* Misc
************************************************************************/
Expand Down
14 changes: 13 additions & 1 deletion yarn-project/foundation/src/fields/point.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,20 @@ export class Point {
return poseidon2Hash(this.toFields());
}

/**
* Check if this is point at infinity.
* Check this is consistent with how bb is encoding the point at infinity
*/
public get inf() {
return this.x == Fr.ZERO;
}
public toFieldsWithInf() {
return [this.x, this.y, new Fr(this.inf)];
}

isOnGrumpkin() {
if (this.isZero()) {
// TODO: Check this against how bb handles curve check and infinity point check
if (this.inf) {
return true;
}

Expand Down
1 change: 1 addition & 0 deletions yarn-project/simulator/src/avm/avm_gas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ const BaseGasCosts: Record<Opcode, Gas> = {
[Opcode.SHA256]: DefaultBaseGasCost,
[Opcode.PEDERSEN]: DefaultBaseGasCost,
[Opcode.ECADD]: DefaultBaseGasCost,
[Opcode.MSM]: DefaultBaseGasCost,
// Conversions
[Opcode.TORADIXLE]: DefaultBaseGasCost,
IlyasRidhuan marked this conversation as resolved.
Show resolved Hide resolved
};
Expand Down
14 changes: 14 additions & 0 deletions yarn-project/simulator/src/avm/avm_simulator.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,20 @@ describe('AVM simulator: transpiled Noir contracts', () => {
expect(results.output).toEqual([g3.x, g3.y, Fr.ZERO]);
});

it('variable msm operations', async () => {
const context = initContext();

const bytecode = getAvmTestContractBytecode('variable_base_msm');
const results = await new AvmSimulator(context).executeBytecode(bytecode);

expect(results.reverted).toBe(false);
const grumpkin = new Grumpkin();
const g3 = grumpkin.mul(grumpkin.generator(), new Fq(3));
const g20 = grumpkin.mul(grumpkin.generator(), new Fq(20));
const expectedResult = grumpkin.add(g3, g20);
expect(results.output).toEqual([expectedResult.x, expectedResult.y, Fr.ZERO]);
});

describe('U128 addition and overflows', () => {
it('U128 addition', async () => {
const calldata: Fr[] = [
Expand Down
130 changes: 130 additions & 0 deletions yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import { Fq, Fr } from '@aztec/circuits.js';
import { Grumpkin } from '@aztec/circuits.js/barretenberg';

import { type AvmContext } from '../avm_context.js';
import { Field, type MemoryValue, Uint8, Uint32 } from '../avm_memory_types.js';
import { initContext } from '../fixtures/index.js';
import { MultiScalarMul } from './multi_scalar_mul.js';

describe('MultiScalarMul Opcode', () => {
let context: AvmContext;

beforeEach(async () => {
context = initContext();
});
it('Should (de)serialize correctly', () => {
const buf = Buffer.from([
MultiScalarMul.opcode, // opcode
7, // indirect
...Buffer.from('12345678', 'hex'), // pointsOffset
...Buffer.from('23456789', 'hex'), // scalars Offset
...Buffer.from('3456789a', 'hex'), // outputOffset
...Buffer.from('456789ab', 'hex'), // pointsLengthOffset
]);
const inst = new MultiScalarMul(
/*indirect=*/ 7,
/*pointsOffset=*/ 0x12345678,
/*scalarsOffset=*/ 0x23456789,
/*outputOffset=*/ 0x3456789a,
/*pointsLengthOffset=*/ 0x456789ab,
);

expect(MultiScalarMul.deserialize(buf)).toEqual(inst);
expect(inst.serialize()).toEqual(buf);
});

it('Should perform msm correctly - direct', async () => {
const indirect = 0;
const grumpkin = new Grumpkin();
// We need to ensure points are actually on curve, so we just use the generator
// In future we could use a random point, for now we create an array of [G, 2G, 3G]
const points = Array.from({ length: 3 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1)));

// Pick some big scalars to test the edge cases
const scalars = [new Fq(Fq.MODULUS - 1n), new Fq(Fq.MODULUS - 2n), new Fq(1n)];
const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory
const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory
// Transform the points and scalars into the format that we will write to memory
// We just store the x and y coordinates here, and handle the infinities when we write to memory
const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]);
// Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...]
const storedPoints: MemoryValue[] = points
.map(p => p.toFieldsWithInf())
.flatMap(([x, y, inf]) => [new Field(x), new Field(y), new Uint8(inf.toNumber())]);
const pointsOffset = 0;
context.machineState.memory.setSlice(pointsOffset, storedPoints);
// Store scalars
const scalarsOffset = pointsOffset + pointsReadLength;
context.machineState.memory.setSlice(scalarsOffset, storedScalars);
// Store length of points to read
const pointsLengthOffset = scalarsOffset + scalarsLength;
context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength));
const outputOffset = pointsLengthOffset + 1;

await new MultiScalarMul(indirect, pointsOffset, scalarsOffset, outputOffset, pointsLengthOffset).execute(context);

const result = context.machineState.memory.getSlice(outputOffset, 3).map(r => r.toFr());

// We write it out explicitly here
let expectedResult = grumpkin.mul(points[0], scalars[0]);
expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1]));
expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2]));

expect(result).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]);
});

it('Should perform msm correctly - indirect', async () => {
const indirect = 7;
const grumpkin = new Grumpkin();
// We need to ensure points are actually on curve, so we just use the generator
// In future we could use a random point, for now we create an array of [G, 2G, 3G]
const points = Array.from({ length: 3 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1)));

// Pick some big scalars to test the edge cases
const scalars = [new Fq(Fq.MODULUS - 1n), new Fq(Fq.MODULUS - 2n), new Fq(1n)];
const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory
const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory
// Transform the points and scalars into the format that we will write to memory
// We just store the x and y coordinates here, and handle the infinities when we write to memory
const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]);
// Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...]
const storedPoints: MemoryValue[] = points
.map(p => p.toFieldsWithInf())
.flatMap(([x, y, inf]) => [new Field(x), new Field(y), new Uint8(inf.toNumber())]);
const pointsOffset = 0;
context.machineState.memory.setSlice(pointsOffset, storedPoints);
// Store scalars
const scalarsOffset = pointsOffset + pointsReadLength;
context.machineState.memory.setSlice(scalarsOffset, storedScalars);
// Store length of points to read
const pointsLengthOffset = scalarsOffset + scalarsLength;
context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength));
const outputOffset = pointsLengthOffset + 1;

// Set up the indirect pointers
const pointsIndirectOffset = outputOffset + 3; /* 3 since the output is a triplet */
const scalarsIndirectOffset = pointsIndirectOffset + 1;
const outputIndirectOffset = scalarsIndirectOffset + 1;

context.machineState.memory.set(pointsIndirectOffset, new Uint32(pointsOffset));
context.machineState.memory.set(scalarsIndirectOffset, new Uint32(scalarsOffset));
context.machineState.memory.set(outputIndirectOffset, new Uint32(outputOffset));

await new MultiScalarMul(
indirect,
pointsIndirectOffset,
scalarsIndirectOffset,
outputIndirectOffset,
pointsLengthOffset,
).execute(context);

const result = context.machineState.memory.getSlice(outputOffset, 3).map(r => r.toFr());

// We write it out explicitly here
let expectedResult = grumpkin.mul(points[0], scalars[0]);
expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1]));
expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2]));

expect(result).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]);
});
});
114 changes: 114 additions & 0 deletions yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import { Fq, Point } from '@aztec/circuits.js';
import { Grumpkin } from '@aztec/circuits.js/barretenberg';

import { strict as assert } from 'assert';

import { type AvmContext } from '../avm_context.js';
import { Field, TypeTag } from '../avm_memory_types.js';
import { InstructionExecutionError } from '../errors.js';
import { Opcode, OperandType } from '../serialization/instruction_serialization.js';
import { Addressing } from './addressing_mode.js';
import { Instruction } from './instruction.js';

export class MultiScalarMul extends Instruction {
static type: string = 'MultiScalarMul';
static readonly opcode: Opcode = Opcode.MSM;

// Informs (de)serialization. See Instruction.deserialize.
static readonly wireFormat: OperandType[] = [
OperandType.UINT8 /* opcode */,
OperandType.UINT8 /* indirect */,
OperandType.UINT32 /* points vector offset */,
OperandType.UINT32 /* scalars vector offset */,
OperandType.UINT32 /* output offset (fixed triplet) */,
OperandType.UINT32 /* points length offset */,
];

constructor(
private indirect: number,
private pointsOffset: number,
private scalarsOffset: number,
private outputOffset: number,
private pointsLengthOffset: number,
) {
super();
}

public async execute(context: AvmContext): Promise<void> {
const memory = context.machineState.memory.track(this.type);
// Resolve indirects
const [pointsOffset, scalarsOffset, outputOffset] = Addressing.fromWire(this.indirect).resolve(
[this.pointsOffset, this.scalarsOffset, this.outputOffset],
memory,
);

// Length of the points vector should be U32
memory.checkTag(TypeTag.UINT32, this.pointsLengthOffset);
// Get the size of the unrolled (x, y , inf) points vector
IlyasRidhuan marked this conversation as resolved.
Show resolved Hide resolved
const pointsReadLength = memory.get(this.pointsLengthOffset).toNumber();
assert(pointsReadLength % 3 === 0, 'Points vector offset should be a multiple of 3');
// Divide by 3 since each point is represented as a triplet to get the number of points
const numPoints = pointsReadLength / 3;
// The tag for each triplet will be (Field, Field, Uint8)
for (let i = 0; i < numPoints; i++) {
const offset = pointsOffset + i * 3;
// Check (Field, Field)
memory.checkTagsRange(TypeTag.FIELD, offset, 2);
// Check Uint8 (inf flag)
memory.checkTag(TypeTag.UINT8, offset + 2);
}
// Get the unrolled (x, y, inf) representing the points
const pointsVector = memory.getSlice(pointsOffset, pointsReadLength);

// The size of the scalars vector is twice the NUMBER of points because of the scalar limb decomposition
const scalarReadLength = numPoints * 2;
// Consume gas prior to performing work
const memoryOperations = {
reads: 1 + pointsReadLength + scalarReadLength /* points and scalars */,
writes: 3 /* output triplet */,
indirect: this.indirect,
};
context.machineState.consumeGas(this.gasCost(memoryOperations));
// Get the unrolled scalar (lo & hi) representing the scalars
const scalarsVector = memory.getSlice(scalarsOffset, scalarReadLength);
memory.checkTagsRange(TypeTag.FIELD, scalarsOffset, scalarReadLength);

// Now we need to reconstruct the points and scalars into something we can operate on.
const grumpkinPoints: Point[] = [];
for (let i = 0; i < numPoints; i++) {
const p: Point = new Point(pointsVector[3 * i].toFr(), pointsVector[3 * i + 1].toFr());
// Include this later when we have a standard for representing infinity
// const isInf = pointsVector[i + 2].toBoolean();

if (!p.isOnGrumpkin()) {
throw new InstructionExecutionError(`Point ${p.toString()} is not on the curve.`);
}
grumpkinPoints.push(p);
}
// The scalars are read from memory as Fr elements, which are limbs of Fq elements
// So we need to reconstruct them before performing the scalar multiplications
const scalarFqVector: Fq[] = [];
for (let i = 0; i < numPoints; i++) {
const scalarLo = scalarsVector[2 * i].toFr();
const scalarHi = scalarsVector[2 * i + 1].toFr();
const fqScalar = Fq.fromHighLow(scalarHi, scalarLo);
scalarFqVector.push(fqScalar);
}
// TODO: Is there an efficient MSM implementation in ts that we can replace this by?
const grumpkin = new Grumpkin();
IlyasRidhuan marked this conversation as resolved.
Show resolved Hide resolved
// Zip the points and scalars into pairs
const [firstBaseScalarPair, ...rest]: Array<[Point, Fq]> = grumpkinPoints.map((p, idx) => [p, scalarFqVector[idx]]);
// Fold the points and scalars into a single point
// We have to ensure get the first point, since the identity element (point at infinity) isn't quite working in ts
const outputPoint = rest.reduce(
(acc, curr) => grumpkin.add(acc, grumpkin.mul(curr[0], curr[1])),
grumpkin.mul(firstBaseScalarPair[0], firstBaseScalarPair[1]),
);
const output = outputPoint.toFieldsWithInf().map(f => new Field(f));

memory.setSlice(outputOffset, output);

memory.assert(memoryOperations);
context.machineState.incrementPc();
}
}
Loading
Loading