Skip to content

Commit

Permalink
feat(avm): msm blackbox
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasRidhuan committed Jun 17, 2024
1 parent fcbd44b commit 8169fb6
Show file tree
Hide file tree
Showing 10 changed files with 317 additions and 2 deletions.
2 changes: 2 additions & 0 deletions avm-transpiler/src/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ pub enum AvmOpcode {
SHA256, // temp - may be removed, but alot of contracts rely on it
PEDERSEN, // temp - may be removed, but alot of contracts rely on it
ECADD,
MSM,
// Conversions
TORADIXLE,
}
Expand Down Expand Up @@ -165,6 +166,7 @@ impl AvmOpcode {
AvmOpcode::SHA256 => "SHA256 ",
AvmOpcode::PEDERSEN => "PEDERSEN",
AvmOpcode::ECADD => "ECADD",
AvmOpcode::MSM => "MSM",
// Conversions
AvmOpcode::TORADIXLE => "TORADIXLE",
}
Expand Down
23 changes: 23 additions & 0 deletions avm-transpiler/src/transpile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,29 @@ fn handle_black_box_function(avm_instrs: &mut Vec<AvmInstruction>, operation: &B
],
..Default::default()
}),
// Temporary while we dont have efficient noir implementations
BlackBoxOp::MultiScalarMul { points, scalars, outputs } => {
// The length of the scalars vector is 2x the length of the points vector due to limb
// decomposition
let points_offset = points.pointer.0;
let num_points = points.size.0;
let scalars_offset = scalars.pointer.0;
// Output array is fixed to 3
let outputs_offset = outputs.pointer.0;
avm_instrs.push(AvmInstruction {
opcode: AvmOpcode::MSM,
indirect: Some(
ZEROTH_OPERAND_INDIRECT | FIRST_OPERAND_INDIRECT | SECOND_OPERAND_INDIRECT,
),
operands: vec![
AvmOperand::U32 { value: points_offset as u32 },
AvmOperand::U32 { value: scalars_offset as u32 },
AvmOperand::U32 { value: outputs_offset as u32 },
AvmOperand::U32 { value: num_points as u32 },
],
..Default::default()
});
}
_ => panic!("Transpiler doesn't know how to process {:?}", operation),
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ contract AvmTest {
global big_field_136_bits: Field = 0x991234567890abcdef1234567890abcdef;

// Libs
use dep::std::embedded_curve_ops::EmbeddedCurvePoint;
use dep::std::embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar, multi_scalar_mul};
use dep::aztec::protocol_types::constants::CONTRACT_INSTANCE_LENGTH;
use dep::aztec::prelude::{Map, Deserialize};
use dep::aztec::state_vars::PublicMutable;
Expand Down Expand Up @@ -144,6 +144,14 @@ contract AvmTest {
added
}

#[aztec(public)]
fn variable_base_msm() -> [Field; 3] {
let g = EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false };
let scalar = EmbeddedCurveScalar { lo: 3, hi: 0 };
let triple_g = multi_scalar_mul([g], [scalar]);
triple_g
}

/************************************************************************
* Misc
************************************************************************/
Expand Down
10 changes: 9 additions & 1 deletion yarn-project/foundation/src/fields/point.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,16 @@ export class Point {
return poseidon2Hash(this.toFields());
}

/**
* Check if this is point at infinity.
*/
isInfPoint() {
// Check this
return this.x.isZero();
}

isOnGrumpkin() {
if (this.isZero()) {
if (this.isInfPoint()) {
return true;
}

Expand Down
1 change: 1 addition & 0 deletions yarn-project/simulator/src/avm/avm_gas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ const BaseGasCosts: Record<Opcode, Gas> = {
[Opcode.SHA256]: DefaultBaseGasCost,
[Opcode.PEDERSEN]: DefaultBaseGasCost,
[Opcode.ECADD]: DefaultBaseGasCost,
[Opcode.MSM]: DefaultBaseGasCost,
// Conversions
[Opcode.TORADIXLE]: DefaultBaseGasCost,
};
Expand Down
12 changes: 12 additions & 0 deletions yarn-project/simulator/src/avm/avm_simulator.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ describe('AVM simulator: transpiled Noir contracts', () => {
expect(results.output).toEqual([g3.x, g3.y, Fr.ZERO]);
});

it('variable msm operations', async () => {
const context = initContext();

const bytecode = getAvmTestContractBytecode('variable_base_msm');
const results = await new AvmSimulator(context).executeBytecode(bytecode);

expect(results.reverted).toBe(false);
const grumpkin = new Grumpkin();
const g3 = grumpkin.mul(grumpkin.generator(), new Fq(3));
expect(results.output).toEqual([g3.x, g3.y, Fr.ZERO]);
});

describe('U128 addition and overflows', () => {
it('U128 addition', async () => {
const calldata: Fr[] = [
Expand Down
142 changes: 142 additions & 0 deletions yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import { Fq, Fr } from '@aztec/circuits.js';
import { Grumpkin } from '@aztec/circuits.js/barretenberg';

import { type AvmContext } from '../avm_context.js';
import { Field, Uint8, Uint32 } from '../avm_memory_types.js';
import { initContext } from '../fixtures/index.js';
import { MultiScalarMul } from './multi_scalar_mul.js';

describe('MultiScalarMul Opcode', () => {
let context: AvmContext;

beforeEach(async () => {
context = initContext();
});
it('Should (de)serialize correctly', () => {
const buf = Buffer.from([
MultiScalarMul.opcode, // opcode
7, // indirect
...Buffer.from('12345678', 'hex'), // pointsOffset
...Buffer.from('23456789', 'hex'), // scalars Offset
...Buffer.from('3456789a', 'hex'), // outputOffset
...Buffer.from('456789ab', 'hex'), // pointsLengthOffset
]);
const inst = new MultiScalarMul(
/*indirect=*/ 7,
/*pointsOffset=*/ 0x12345678,
/*scalarsOffset=*/ 0x23456789,
/*outputOffset=*/ 0x3456789a,
/*pointsLengthOffset=*/ 0x456789ab,
);

expect(MultiScalarMul.deserialize(buf)).toEqual(inst);
expect(inst.serialize()).toEqual(buf);
});

it('Should perform msm correctly - direct', async () => {
const indirect = 0;
const grumpkin = new Grumpkin();
// We need to ensure points are actually on curve, so we just use the generator
// In future we could use a random point, for now we create an array of [G, 2G, 3G]
const points = Array.from({ length: 3 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1)));

// Pick some big scalars to test the edge cases
const scalars = [new Fq(Fq.MODULUS - 1n), new Fq(Fq.MODULUS - 2n), new Fq(1n)];
const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory
const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory
// Transform the points and scalars into the format that we will write to memory
// We just store the x and y coordinates here, and handle the infinities when we write to memory
const storedPoints: Field[] = points.flatMap(p => [new Field(p.x), new Field(p.y)]);
const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]);

const pointsOffset = 0;
// Store points...awkwardly (This would be simpler if ts handled the infinities in the Point struct)
// Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...]
for (let i = 0; i < points.length; i++) {
const flatPointsOffset = pointsOffset + 2 * i; // 2 since we only have x and y
const memoryOffset = pointsOffset + 3 * i; // 3 since we store x, y, inf
context.machineState.memory.set(memoryOffset, storedPoints[flatPointsOffset]);
context.machineState.memory.set(memoryOffset + 1, storedPoints[flatPointsOffset + 1]);
context.machineState.memory.set(memoryOffset + 2, new Uint8(points[i].isInfPoint() ? 1 : 0));
}
// Store scalars
const scalarsOffset = pointsOffset + pointsReadLength;
context.machineState.memory.setSlice(scalarsOffset, storedScalars);
// Store length of points to read
const pointsLengthOffset = scalarsOffset + scalarsLength;
context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength));
const outputOffset = pointsLengthOffset + 1;

await new MultiScalarMul(indirect, pointsOffset, scalarsOffset, outputOffset, pointsLengthOffset).execute(context);

const result = context.machineState.memory.getSlice(outputOffset, 3);

// We write it out explicitly here
let expectedResult = grumpkin.mul(points[0], scalars[0]);
expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1]));
expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2]));

expect(result.map(r => r.toFr())).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]);
});

it('Should perform msm correctly - indirect', async () => {
const indirect = 7;
const grumpkin = new Grumpkin();
// We need to ensure points are actually on curve, so we just use the generator
// In future we could use a random point, for now we create an array of [G, 2G, 3G]
const points = Array.from({ length: 3 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1)));

// Pick some big scalars to test the edge cases
const scalars = [new Fq(Fq.MODULUS - 1n), new Fq(Fq.MODULUS - 2n), new Fq(1n)];
const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory
const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory
// Transform the points and scalars into the format that we will write to memory
// We just store the x and y coordinates here, and handle the infinities when we write to memory
const storedPoints: Field[] = points.flatMap(p => [new Field(p.x), new Field(p.y)]);
const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]);

const pointsOffset = 0;
// Store points...awkwardly (This would be simpler if ts handled the infinities in the Point struct)
// Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...]
for (let i = 0; i < points.length; i++) {
const flatPointsOffset = pointsOffset + 2 * i; // 2 since we only have x and y
const memoryOffset = pointsOffset + 3 * i; // 3 since we store x, y, inf
context.machineState.memory.set(memoryOffset, storedPoints[flatPointsOffset]);
context.machineState.memory.set(memoryOffset + 1, storedPoints[flatPointsOffset + 1]);
context.machineState.memory.set(memoryOffset + 2, new Uint8(points[i].isInfPoint() ? 1 : 0));
}
// Store scalars
const scalarsOffset = pointsOffset + pointsReadLength;
context.machineState.memory.setSlice(scalarsOffset, storedScalars);
// Store length of points to read
const pointsLengthOffset = scalarsOffset + scalarsLength;
context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength));
const outputOffset = pointsLengthOffset + 1;

// Set up the indirect pointers
const pointsIndirectOffset = outputOffset + 3; /* 3 since the output is a triplet */
const scalarsIndirectOffset = pointsIndirectOffset + 1;
const outputIndirectOffset = scalarsIndirectOffset + 1;

context.machineState.memory.set(pointsIndirectOffset, new Uint32(pointsOffset));
context.machineState.memory.set(scalarsIndirectOffset, new Uint32(scalarsOffset));
context.machineState.memory.set(outputIndirectOffset, new Uint32(outputOffset));

await new MultiScalarMul(
indirect,
pointsIndirectOffset,
scalarsIndirectOffset,
outputIndirectOffset,
pointsLengthOffset,
).execute(context);

const result = context.machineState.memory.getSlice(outputOffset, 3);

// We write it out explicitly here
let expectedResult = grumpkin.mul(points[0], scalars[0]);
expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1]));
expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2]));

expect(result.map(r => r.toFr())).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]);
});
});
116 changes: 116 additions & 0 deletions yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import { Fq, Fr, Point } from '@aztec/circuits.js';
import { Grumpkin } from '@aztec/circuits.js/barretenberg';

import { type AvmContext } from '../avm_context.js';
import { Field, TypeTag } from '../avm_memory_types.js';
import { InstructionExecutionError } from '../errors.js';
import { Opcode, OperandType } from '../serialization/instruction_serialization.js';
import { Addressing } from './addressing_mode.js';
import { Instruction } from './instruction.js';

export class MultiScalarMul extends Instruction {
static type: string = 'MultiScalarMul';
static readonly opcode: Opcode = Opcode.MSM;

// Informs (de)serialization. See Instruction.deserialize.
static readonly wireFormat: OperandType[] = [
OperandType.UINT8 /* opcode */,
OperandType.UINT8 /* indirect */,
OperandType.UINT32 /* points vector offset */,
OperandType.UINT32 /* scalars vector offset */,
OperandType.UINT32 /* output offset (fixed triplet)*/,
OperandType.UINT32 /* points length offset */,
];

constructor(
private indirect: number,
private pointsOffset: number,
private scalarsOffset: number,
private outputOffset: number,
private pointsLengthOffset: number,
) {
super();
}

public async execute(context: AvmContext): Promise<void> {
const memory = context.machineState.memory.track(this.type);
// Resolve indirects
const [pointsOffset, scalarsOffset, outputOffset] = Addressing.fromWire(this.indirect).resolve(
[this.pointsOffset, this.scalarsOffset, this.outputOffset],
memory,
);

// Length of the points vector should be U32
memory.checkTag(TypeTag.UINT32, this.pointsLengthOffset);

// Get the size of the unrolled (x, y , inf) points vector
// TODO: Do we need to assert that the length is a multiple of 3 (x, y, inf)?
const pointsReadLength = memory.get(this.pointsLengthOffset).toNumber();
// Divide by 3 since each point is represented as a triplet to get the number of points
const numPoints = pointsReadLength / 3;
// The tag for each triplet will be (Field, Field, Uint8)
for (let i = 0; i < numPoints; i++) {
const offset = pointsOffset + i * 3;
// Check (Field, Field)
memory.checkTagsRange(TypeTag.FIELD, offset, 2);
// Check Uint8 (inf flag)
memory.checkTag(TypeTag.UINT8, offset + 2);
}
// Get the unrolled (x, y, inf) representing the points
const pointsVector = memory.getSlice(pointsOffset, pointsReadLength);

// The size of the scalars vector is twice the NUMBER of points because of the scalar limb decomposition
const scalarReadLength = numPoints * 2;
// Get the unrolled scalar (lo & hi) representing the scalars
const scalarsVector = memory.getSlice(scalarsOffset, scalarReadLength);
memory.checkTagsRange(TypeTag.FIELD, scalarsOffset, scalarReadLength);

// Now we need to reconstruct the points and scalars into something we can operate on.
const grumpkinPoints: Point[] = [];
for (let i = 0; i < numPoints; i++) {
const p: Point = new Point(pointsVector[3 * i].toFr(), pointsVector[3 * i + 1].toFr());
// Include this later when we have a standard for representing infinity
// const isInf = pointsVector[i + 2].toBoolean();

if (!p.isOnGrumpkin()) {
throw new InstructionExecutionError(`Point ${p.toString()} is not on the curve.`);
}
grumpkinPoints.push(p);
}
// The scalars are read from memory as Fr elements, which are limbs of Fq elements
// So we need to reconstruct them before performing the scalar multiplications
const scalarFqVector: Fq[] = [];
for (let i = 0; i < numPoints; i++) {
const scalarLo = scalarsVector[2 * i].toFr();
const scalarHi = scalarsVector[2 * i + 1].toFr();
const fqScalar = Fq.fromHighLow(scalarHi, scalarLo);
scalarFqVector.push(fqScalar);
}
// TODO: Is there an efficient MSM implementation in ts that we can replace this by?
const grumpkin = new Grumpkin();
// Zip the points and scalars into pairs
const [firstBaseScalarPair, ...rest]: Array<[Point, Fq]> = grumpkinPoints.map((p, idx) => [p, scalarFqVector[idx]]);
// Fold the points and scalars into a single point
// We have to ensure get the first point, since the identity element (point at infinity) isn't quite working in ts
const outputPoint = rest.reduce(
(acc, curr) => grumpkin.add(acc, grumpkin.mul(curr[0], curr[1])),
grumpkin.mul(firstBaseScalarPair[0], firstBaseScalarPair[1]),
);
// TODO: Check the Infinity flag here
const output: Fr[] = [outputPoint.x, outputPoint.y, outputPoint.isInfPoint() ? Fr.ONE : Fr.ZERO];

memory.setSlice(
outputOffset,
output.map(word => new Field(word)),
);

const memoryOperations = {
reads: 1 + pointsReadLength + scalarReadLength /* points and scalars */,
writes: 3 /* output triplet */,
indirect: this.indirect,
};
context.machineState.consumeGas(this.gasCost(memoryOperations));
memory.assert(memoryOperations);
context.machineState.incrementPc();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import {
Version,
Xor,
} from '../opcodes/index.js';
import { MultiScalarMul } from '../opcodes/multi_scalar_mul.js';
import { BufferCursor } from './buffer_cursor.js';
import { Opcode } from './instruction_serialization.js';

Expand Down Expand Up @@ -143,6 +144,7 @@ const INSTRUCTION_SET = () =>
[Poseidon2.opcode, Poseidon2],
[Sha256.opcode, Sha256],
[Pedersen.opcode, Pedersen],
[MultiScalarMul.opcode, MultiScalarMul],
// Conversions
[ToRadixLE.opcode, ToRadixLE],
]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ export enum Opcode {
SHA256, // temp - may be removed, but alot of contracts rely on it
PEDERSEN, // temp - may be removed, but alot of contracts rely on it
ECADD,
MSM,
// Conversion
TORADIXLE,
}
Expand Down

0 comments on commit 8169fb6

Please sign in to comment.