From 8169fb62a285cda4d8d7501393ed0bb0a9a17871 Mon Sep 17 00:00:00 2001 From: IlyasRidhuan Date: Thu, 13 Jun 2024 13:29:21 +0000 Subject: [PATCH 1/2] feat(avm): msm blackbox --- avm-transpiler/src/opcodes.rs | 2 + avm-transpiler/src/transpile.rs | 23 +++ .../contracts/avm_test_contract/src/main.nr | 10 +- yarn-project/foundation/src/fields/point.ts | 10 +- yarn-project/simulator/src/avm/avm_gas.ts | 1 + .../simulator/src/avm/avm_simulator.test.ts | 12 ++ .../src/avm/opcodes/multi_scalar_mul.test.ts | 142 ++++++++++++++++++ .../src/avm/opcodes/multi_scalar_mul.ts | 116 ++++++++++++++ .../serialization/bytecode_serialization.ts | 2 + .../instruction_serialization.ts | 1 + 10 files changed, 317 insertions(+), 2 deletions(-) create mode 100644 yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts create mode 100644 yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts diff --git a/avm-transpiler/src/opcodes.rs b/avm-transpiler/src/opcodes.rs index e4d642dd829..11cd956237d 100644 --- a/avm-transpiler/src/opcodes.rs +++ b/avm-transpiler/src/opcodes.rs @@ -73,6 +73,7 @@ pub enum AvmOpcode { SHA256, // temp - may be removed, but alot of contracts rely on it PEDERSEN, // temp - may be removed, but alot of contracts rely on it ECADD, + MSM, // Conversions TORADIXLE, } @@ -165,6 +166,7 @@ impl AvmOpcode { AvmOpcode::SHA256 => "SHA256 ", AvmOpcode::PEDERSEN => "PEDERSEN", AvmOpcode::ECADD => "ECADD", + AvmOpcode::MSM => "MSM", // Conversions AvmOpcode::TORADIXLE => "TORADIXLE", } diff --git a/avm-transpiler/src/transpile.rs b/avm-transpiler/src/transpile.rs index 67d0f8043a6..a1ae1b71644 100644 --- a/avm-transpiler/src/transpile.rs +++ b/avm-transpiler/src/transpile.rs @@ -855,6 +855,29 @@ fn handle_black_box_function(avm_instrs: &mut Vec, operation: &B ], ..Default::default() }), + // Temporary while we dont have efficient noir implementations + BlackBoxOp::MultiScalarMul { points, scalars, outputs } => { + // The length of the scalars vector is 2x the length of the points vector due to limb + // decomposition + let points_offset = points.pointer.0; + let num_points = points.size.0; + let scalars_offset = scalars.pointer.0; + // Output array is fixed to 3 + let outputs_offset = outputs.pointer.0; + avm_instrs.push(AvmInstruction { + opcode: AvmOpcode::MSM, + indirect: Some( + ZEROTH_OPERAND_INDIRECT | FIRST_OPERAND_INDIRECT | SECOND_OPERAND_INDIRECT, + ), + operands: vec![ + AvmOperand::U32 { value: points_offset as u32 }, + AvmOperand::U32 { value: scalars_offset as u32 }, + AvmOperand::U32 { value: outputs_offset as u32 }, + AvmOperand::U32 { value: num_points as u32 }, + ], + ..Default::default() + }); + } _ => panic!("Transpiler doesn't know how to process {:?}", operation), } } diff --git a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr index c6271fc42ad..c3d827652f1 100644 --- a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr +++ b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr @@ -24,7 +24,7 @@ contract AvmTest { global big_field_136_bits: Field = 0x991234567890abcdef1234567890abcdef; // Libs - use dep::std::embedded_curve_ops::EmbeddedCurvePoint; + use dep::std::embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar, multi_scalar_mul}; use dep::aztec::protocol_types::constants::CONTRACT_INSTANCE_LENGTH; use dep::aztec::prelude::{Map, Deserialize}; use dep::aztec::state_vars::PublicMutable; @@ -144,6 +144,14 @@ contract AvmTest { added } + #[aztec(public)] + fn variable_base_msm() -> [Field; 3] { + let g = EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false }; + let scalar = EmbeddedCurveScalar { lo: 3, hi: 0 }; + let triple_g = multi_scalar_mul([g], [scalar]); + triple_g + } + /************************************************************************ * Misc ************************************************************************/ diff --git a/yarn-project/foundation/src/fields/point.ts b/yarn-project/foundation/src/fields/point.ts index b152bffcc6a..490c1fe2f48 100644 --- a/yarn-project/foundation/src/fields/point.ts +++ b/yarn-project/foundation/src/fields/point.ts @@ -137,8 +137,16 @@ export class Point { return poseidon2Hash(this.toFields()); } + /** + * Check if this is point at infinity. + */ + isInfPoint() { + // Check this + return this.x.isZero(); + } + isOnGrumpkin() { - if (this.isZero()) { + if (this.isInfPoint()) { return true; } diff --git a/yarn-project/simulator/src/avm/avm_gas.ts b/yarn-project/simulator/src/avm/avm_gas.ts index 7802d4177d1..d951f20fa32 100644 --- a/yarn-project/simulator/src/avm/avm_gas.ts +++ b/yarn-project/simulator/src/avm/avm_gas.ts @@ -123,6 +123,7 @@ const BaseGasCosts: Record = { [Opcode.SHA256]: DefaultBaseGasCost, [Opcode.PEDERSEN]: DefaultBaseGasCost, [Opcode.ECADD]: DefaultBaseGasCost, + [Opcode.MSM]: DefaultBaseGasCost, // Conversions [Opcode.TORADIXLE]: DefaultBaseGasCost, }; diff --git a/yarn-project/simulator/src/avm/avm_simulator.test.ts b/yarn-project/simulator/src/avm/avm_simulator.test.ts index f1e8c98fb90..86d1960577b 100644 --- a/yarn-project/simulator/src/avm/avm_simulator.test.ts +++ b/yarn-project/simulator/src/avm/avm_simulator.test.ts @@ -108,6 +108,18 @@ describe('AVM simulator: transpiled Noir contracts', () => { expect(results.output).toEqual([g3.x, g3.y, Fr.ZERO]); }); + it('variable msm operations', async () => { + const context = initContext(); + + const bytecode = getAvmTestContractBytecode('variable_base_msm'); + const results = await new AvmSimulator(context).executeBytecode(bytecode); + + expect(results.reverted).toBe(false); + const grumpkin = new Grumpkin(); + const g3 = grumpkin.mul(grumpkin.generator(), new Fq(3)); + expect(results.output).toEqual([g3.x, g3.y, Fr.ZERO]); + }); + describe('U128 addition and overflows', () => { it('U128 addition', async () => { const calldata: Fr[] = [ diff --git a/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts new file mode 100644 index 00000000000..83a9b79ca31 --- /dev/null +++ b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts @@ -0,0 +1,142 @@ +import { Fq, Fr } from '@aztec/circuits.js'; +import { Grumpkin } from '@aztec/circuits.js/barretenberg'; + +import { type AvmContext } from '../avm_context.js'; +import { Field, Uint8, Uint32 } from '../avm_memory_types.js'; +import { initContext } from '../fixtures/index.js'; +import { MultiScalarMul } from './multi_scalar_mul.js'; + +describe('MultiScalarMul Opcode', () => { + let context: AvmContext; + + beforeEach(async () => { + context = initContext(); + }); + it('Should (de)serialize correctly', () => { + const buf = Buffer.from([ + MultiScalarMul.opcode, // opcode + 7, // indirect + ...Buffer.from('12345678', 'hex'), // pointsOffset + ...Buffer.from('23456789', 'hex'), // scalars Offset + ...Buffer.from('3456789a', 'hex'), // outputOffset + ...Buffer.from('456789ab', 'hex'), // pointsLengthOffset + ]); + const inst = new MultiScalarMul( + /*indirect=*/ 7, + /*pointsOffset=*/ 0x12345678, + /*scalarsOffset=*/ 0x23456789, + /*outputOffset=*/ 0x3456789a, + /*pointsLengthOffset=*/ 0x456789ab, + ); + + expect(MultiScalarMul.deserialize(buf)).toEqual(inst); + expect(inst.serialize()).toEqual(buf); + }); + + it('Should perform msm correctly - direct', async () => { + const indirect = 0; + const grumpkin = new Grumpkin(); + // We need to ensure points are actually on curve, so we just use the generator + // In future we could use a random point, for now we create an array of [G, 2G, 3G] + const points = Array.from({ length: 3 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1))); + + // Pick some big scalars to test the edge cases + const scalars = [new Fq(Fq.MODULUS - 1n), new Fq(Fq.MODULUS - 2n), new Fq(1n)]; + const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory + const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory + // Transform the points and scalars into the format that we will write to memory + // We just store the x and y coordinates here, and handle the infinities when we write to memory + const storedPoints: Field[] = points.flatMap(p => [new Field(p.x), new Field(p.y)]); + const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]); + + const pointsOffset = 0; + // Store points...awkwardly (This would be simpler if ts handled the infinities in the Point struct) + // Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...] + for (let i = 0; i < points.length; i++) { + const flatPointsOffset = pointsOffset + 2 * i; // 2 since we only have x and y + const memoryOffset = pointsOffset + 3 * i; // 3 since we store x, y, inf + context.machineState.memory.set(memoryOffset, storedPoints[flatPointsOffset]); + context.machineState.memory.set(memoryOffset + 1, storedPoints[flatPointsOffset + 1]); + context.machineState.memory.set(memoryOffset + 2, new Uint8(points[i].isInfPoint() ? 1 : 0)); + } + // Store scalars + const scalarsOffset = pointsOffset + pointsReadLength; + context.machineState.memory.setSlice(scalarsOffset, storedScalars); + // Store length of points to read + const pointsLengthOffset = scalarsOffset + scalarsLength; + context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength)); + const outputOffset = pointsLengthOffset + 1; + + await new MultiScalarMul(indirect, pointsOffset, scalarsOffset, outputOffset, pointsLengthOffset).execute(context); + + const result = context.machineState.memory.getSlice(outputOffset, 3); + + // We write it out explicitly here + let expectedResult = grumpkin.mul(points[0], scalars[0]); + expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1])); + expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2])); + + expect(result.map(r => r.toFr())).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); + }); + + it('Should perform msm correctly - indirect', async () => { + const indirect = 7; + const grumpkin = new Grumpkin(); + // We need to ensure points are actually on curve, so we just use the generator + // In future we could use a random point, for now we create an array of [G, 2G, 3G] + const points = Array.from({ length: 3 }, (_, i) => grumpkin.mul(grumpkin.generator(), new Fq(i + 1))); + + // Pick some big scalars to test the edge cases + const scalars = [new Fq(Fq.MODULUS - 1n), new Fq(Fq.MODULUS - 2n), new Fq(1n)]; + const pointsReadLength = points.length * 3; // multiplied by 3 since we will store them as triplet in avm memory + const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory + // Transform the points and scalars into the format that we will write to memory + // We just store the x and y coordinates here, and handle the infinities when we write to memory + const storedPoints: Field[] = points.flatMap(p => [new Field(p.x), new Field(p.y)]); + const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]); + + const pointsOffset = 0; + // Store points...awkwardly (This would be simpler if ts handled the infinities in the Point struct) + // Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...] + for (let i = 0; i < points.length; i++) { + const flatPointsOffset = pointsOffset + 2 * i; // 2 since we only have x and y + const memoryOffset = pointsOffset + 3 * i; // 3 since we store x, y, inf + context.machineState.memory.set(memoryOffset, storedPoints[flatPointsOffset]); + context.machineState.memory.set(memoryOffset + 1, storedPoints[flatPointsOffset + 1]); + context.machineState.memory.set(memoryOffset + 2, new Uint8(points[i].isInfPoint() ? 1 : 0)); + } + // Store scalars + const scalarsOffset = pointsOffset + pointsReadLength; + context.machineState.memory.setSlice(scalarsOffset, storedScalars); + // Store length of points to read + const pointsLengthOffset = scalarsOffset + scalarsLength; + context.machineState.memory.set(pointsLengthOffset, new Uint32(pointsReadLength)); + const outputOffset = pointsLengthOffset + 1; + + // Set up the indirect pointers + const pointsIndirectOffset = outputOffset + 3; /* 3 since the output is a triplet */ + const scalarsIndirectOffset = pointsIndirectOffset + 1; + const outputIndirectOffset = scalarsIndirectOffset + 1; + + context.machineState.memory.set(pointsIndirectOffset, new Uint32(pointsOffset)); + context.machineState.memory.set(scalarsIndirectOffset, new Uint32(scalarsOffset)); + context.machineState.memory.set(outputIndirectOffset, new Uint32(outputOffset)); + + await new MultiScalarMul( + indirect, + pointsIndirectOffset, + scalarsIndirectOffset, + outputIndirectOffset, + pointsLengthOffset, + ).execute(context); + + const result = context.machineState.memory.getSlice(outputOffset, 3); + + // We write it out explicitly here + let expectedResult = grumpkin.mul(points[0], scalars[0]); + expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1])); + expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2])); + + expect(result.map(r => r.toFr())).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); + }); +}); diff --git a/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts new file mode 100644 index 00000000000..70a370b231b --- /dev/null +++ b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts @@ -0,0 +1,116 @@ +import { Fq, Fr, Point } from '@aztec/circuits.js'; +import { Grumpkin } from '@aztec/circuits.js/barretenberg'; + +import { type AvmContext } from '../avm_context.js'; +import { Field, TypeTag } from '../avm_memory_types.js'; +import { InstructionExecutionError } from '../errors.js'; +import { Opcode, OperandType } from '../serialization/instruction_serialization.js'; +import { Addressing } from './addressing_mode.js'; +import { Instruction } from './instruction.js'; + +export class MultiScalarMul extends Instruction { + static type: string = 'MultiScalarMul'; + static readonly opcode: Opcode = Opcode.MSM; + + // Informs (de)serialization. See Instruction.deserialize. + static readonly wireFormat: OperandType[] = [ + OperandType.UINT8 /* opcode */, + OperandType.UINT8 /* indirect */, + OperandType.UINT32 /* points vector offset */, + OperandType.UINT32 /* scalars vector offset */, + OperandType.UINT32 /* output offset (fixed triplet)*/, + OperandType.UINT32 /* points length offset */, + ]; + + constructor( + private indirect: number, + private pointsOffset: number, + private scalarsOffset: number, + private outputOffset: number, + private pointsLengthOffset: number, + ) { + super(); + } + + public async execute(context: AvmContext): Promise { + const memory = context.machineState.memory.track(this.type); + // Resolve indirects + const [pointsOffset, scalarsOffset, outputOffset] = Addressing.fromWire(this.indirect).resolve( + [this.pointsOffset, this.scalarsOffset, this.outputOffset], + memory, + ); + + // Length of the points vector should be U32 + memory.checkTag(TypeTag.UINT32, this.pointsLengthOffset); + + // Get the size of the unrolled (x, y , inf) points vector + // TODO: Do we need to assert that the length is a multiple of 3 (x, y, inf)? + const pointsReadLength = memory.get(this.pointsLengthOffset).toNumber(); + // Divide by 3 since each point is represented as a triplet to get the number of points + const numPoints = pointsReadLength / 3; + // The tag for each triplet will be (Field, Field, Uint8) + for (let i = 0; i < numPoints; i++) { + const offset = pointsOffset + i * 3; + // Check (Field, Field) + memory.checkTagsRange(TypeTag.FIELD, offset, 2); + // Check Uint8 (inf flag) + memory.checkTag(TypeTag.UINT8, offset + 2); + } + // Get the unrolled (x, y, inf) representing the points + const pointsVector = memory.getSlice(pointsOffset, pointsReadLength); + + // The size of the scalars vector is twice the NUMBER of points because of the scalar limb decomposition + const scalarReadLength = numPoints * 2; + // Get the unrolled scalar (lo & hi) representing the scalars + const scalarsVector = memory.getSlice(scalarsOffset, scalarReadLength); + memory.checkTagsRange(TypeTag.FIELD, scalarsOffset, scalarReadLength); + + // Now we need to reconstruct the points and scalars into something we can operate on. + const grumpkinPoints: Point[] = []; + for (let i = 0; i < numPoints; i++) { + const p: Point = new Point(pointsVector[3 * i].toFr(), pointsVector[3 * i + 1].toFr()); + // Include this later when we have a standard for representing infinity + // const isInf = pointsVector[i + 2].toBoolean(); + + if (!p.isOnGrumpkin()) { + throw new InstructionExecutionError(`Point ${p.toString()} is not on the curve.`); + } + grumpkinPoints.push(p); + } + // The scalars are read from memory as Fr elements, which are limbs of Fq elements + // So we need to reconstruct them before performing the scalar multiplications + const scalarFqVector: Fq[] = []; + for (let i = 0; i < numPoints; i++) { + const scalarLo = scalarsVector[2 * i].toFr(); + const scalarHi = scalarsVector[2 * i + 1].toFr(); + const fqScalar = Fq.fromHighLow(scalarHi, scalarLo); + scalarFqVector.push(fqScalar); + } + // TODO: Is there an efficient MSM implementation in ts that we can replace this by? + const grumpkin = new Grumpkin(); + // Zip the points and scalars into pairs + const [firstBaseScalarPair, ...rest]: Array<[Point, Fq]> = grumpkinPoints.map((p, idx) => [p, scalarFqVector[idx]]); + // Fold the points and scalars into a single point + // We have to ensure get the first point, since the identity element (point at infinity) isn't quite working in ts + const outputPoint = rest.reduce( + (acc, curr) => grumpkin.add(acc, grumpkin.mul(curr[0], curr[1])), + grumpkin.mul(firstBaseScalarPair[0], firstBaseScalarPair[1]), + ); + // TODO: Check the Infinity flag here + const output: Fr[] = [outputPoint.x, outputPoint.y, outputPoint.isInfPoint() ? Fr.ONE : Fr.ZERO]; + + memory.setSlice( + outputOffset, + output.map(word => new Field(word)), + ); + + const memoryOperations = { + reads: 1 + pointsReadLength + scalarReadLength /* points and scalars */, + writes: 3 /* output triplet */, + indirect: this.indirect, + }; + context.machineState.consumeGas(this.gasCost(memoryOperations)); + memory.assert(memoryOperations); + context.machineState.incrementPc(); + } +} diff --git a/yarn-project/simulator/src/avm/serialization/bytecode_serialization.ts b/yarn-project/simulator/src/avm/serialization/bytecode_serialization.ts index 0cf22ba0a5c..f3afe05e088 100644 --- a/yarn-project/simulator/src/avm/serialization/bytecode_serialization.ts +++ b/yarn-project/simulator/src/avm/serialization/bytecode_serialization.ts @@ -53,6 +53,7 @@ import { Version, Xor, } from '../opcodes/index.js'; +import { MultiScalarMul } from '../opcodes/multi_scalar_mul.js'; import { BufferCursor } from './buffer_cursor.js'; import { Opcode } from './instruction_serialization.js'; @@ -143,6 +144,7 @@ const INSTRUCTION_SET = () => [Poseidon2.opcode, Poseidon2], [Sha256.opcode, Sha256], [Pedersen.opcode, Pedersen], + [MultiScalarMul.opcode, MultiScalarMul], // Conversions [ToRadixLE.opcode, ToRadixLE], ]); diff --git a/yarn-project/simulator/src/avm/serialization/instruction_serialization.ts b/yarn-project/simulator/src/avm/serialization/instruction_serialization.ts index d8ccbd91840..0a4ee888fcf 100644 --- a/yarn-project/simulator/src/avm/serialization/instruction_serialization.ts +++ b/yarn-project/simulator/src/avm/serialization/instruction_serialization.ts @@ -77,6 +77,7 @@ export enum Opcode { SHA256, // temp - may be removed, but alot of contracts rely on it PEDERSEN, // temp - may be removed, but alot of contracts rely on it ECADD, + MSM, // Conversion TORADIXLE, } From c3a5399a190425bd70929a67838540378055fa9a Mon Sep 17 00:00:00 2001 From: IlyasRidhuan Date: Thu, 13 Jun 2024 17:07:21 +0000 Subject: [PATCH 2/2] fix: add cpp changes --- avm-transpiler/src/transpile.rs | 1 + .../vm/avm_trace/avm_deserialization.cpp | 2 + .../barretenberg/vm/avm_trace/avm_opcode.hpp | 1 + .../contracts/avm_test_contract/src/main.nr | 3 +- yarn-project/foundation/src/fields/point.ts | 12 ++++-- .../simulator/src/avm/avm_simulator.test.ts | 4 +- .../src/avm/opcodes/multi_scalar_mul.test.ts | 42 +++++++------------ .../src/avm/opcodes/multi_scalar_mul.ts | 30 +++++++------ 8 files changed, 46 insertions(+), 49 deletions(-) diff --git a/avm-transpiler/src/transpile.rs b/avm-transpiler/src/transpile.rs index a1ae1b71644..f1fdd201b8b 100644 --- a/avm-transpiler/src/transpile.rs +++ b/avm-transpiler/src/transpile.rs @@ -863,6 +863,7 @@ fn handle_black_box_function(avm_instrs: &mut Vec, operation: &B let num_points = points.size.0; let scalars_offset = scalars.pointer.0; // Output array is fixed to 3 + assert_eq!(outputs.size, 3, "Output array size must be equal to 3"); let outputs_offset = outputs.pointer.0; avm_instrs.push(AvmInstruction { opcode: AvmOpcode::MSM, diff --git a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_deserialization.cpp b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_deserialization.cpp index 6b10ab40afc..891d2af5695 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_deserialization.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_deserialization.cpp @@ -158,6 +158,8 @@ const std::unordered_map> OPCODE_WIRE_FORMAT = OperandType::UINT32, // rhs.y OperandType::UINT32, // rhs.is_infinite OperandType::UINT32 } }, // dst_offset + { OpCode::MSM, + { OperandType::INDIRECT, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32 } }, // Gadget - Conversion { OpCode::TORADIXLE, { OperandType::INDIRECT, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32 } }, diff --git a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_opcode.hpp b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_opcode.hpp index 83de32e6568..e3ced1a03e7 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_opcode.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_opcode.hpp @@ -105,6 +105,7 @@ enum class OpCode : uint8_t { SHA256, PEDERSEN, ECADD, + MSM, // Conversions TORADIXLE, // Future Gadgets -- pending changes in noir diff --git a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr index c3d827652f1..d870e8564f8 100644 --- a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr +++ b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr @@ -148,7 +148,8 @@ contract AvmTest { fn variable_base_msm() -> [Field; 3] { let g = EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false }; let scalar = EmbeddedCurveScalar { lo: 3, hi: 0 }; - let triple_g = multi_scalar_mul([g], [scalar]); + let scalar2 = EmbeddedCurveScalar { lo: 20, hi: 0 }; + let triple_g = multi_scalar_mul([g, g], [scalar, scalar2]); triple_g } diff --git a/yarn-project/foundation/src/fields/point.ts b/yarn-project/foundation/src/fields/point.ts index 490c1fe2f48..bf12faf3a13 100644 --- a/yarn-project/foundation/src/fields/point.ts +++ b/yarn-project/foundation/src/fields/point.ts @@ -139,14 +139,18 @@ export class Point { /** * Check if this is point at infinity. + * Check this is consistent with how bb is encoding the point at infinity */ - isInfPoint() { - // Check this - return this.x.isZero(); + public get inf() { + return this.x == Fr.ZERO; + } + public toFieldsWithInf() { + return [this.x, this.y, new Fr(this.inf)]; } isOnGrumpkin() { - if (this.isInfPoint()) { + // TODO: Check this against how bb handles curve check and infinity point check + if (this.inf) { return true; } diff --git a/yarn-project/simulator/src/avm/avm_simulator.test.ts b/yarn-project/simulator/src/avm/avm_simulator.test.ts index 86d1960577b..1614305b0be 100644 --- a/yarn-project/simulator/src/avm/avm_simulator.test.ts +++ b/yarn-project/simulator/src/avm/avm_simulator.test.ts @@ -117,7 +117,9 @@ describe('AVM simulator: transpiled Noir contracts', () => { expect(results.reverted).toBe(false); const grumpkin = new Grumpkin(); const g3 = grumpkin.mul(grumpkin.generator(), new Fq(3)); - expect(results.output).toEqual([g3.x, g3.y, Fr.ZERO]); + const g20 = grumpkin.mul(grumpkin.generator(), new Fq(20)); + const expectedResult = grumpkin.add(g3, g20); + expect(results.output).toEqual([expectedResult.x, expectedResult.y, Fr.ZERO]); }); describe('U128 addition and overflows', () => { diff --git a/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts index 83a9b79ca31..861d9b4ec71 100644 --- a/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts +++ b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts @@ -2,7 +2,7 @@ import { Fq, Fr } from '@aztec/circuits.js'; import { Grumpkin } from '@aztec/circuits.js/barretenberg'; import { type AvmContext } from '../avm_context.js'; -import { Field, Uint8, Uint32 } from '../avm_memory_types.js'; +import { Field, type MemoryValue, Uint8, Uint32 } from '../avm_memory_types.js'; import { initContext } from '../fixtures/index.js'; import { MultiScalarMul } from './multi_scalar_mul.js'; @@ -46,19 +46,13 @@ describe('MultiScalarMul Opcode', () => { const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory // Transform the points and scalars into the format that we will write to memory // We just store the x and y coordinates here, and handle the infinities when we write to memory - const storedPoints: Field[] = points.flatMap(p => [new Field(p.x), new Field(p.y)]); const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]); - - const pointsOffset = 0; - // Store points...awkwardly (This would be simpler if ts handled the infinities in the Point struct) // Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...] - for (let i = 0; i < points.length; i++) { - const flatPointsOffset = pointsOffset + 2 * i; // 2 since we only have x and y - const memoryOffset = pointsOffset + 3 * i; // 3 since we store x, y, inf - context.machineState.memory.set(memoryOffset, storedPoints[flatPointsOffset]); - context.machineState.memory.set(memoryOffset + 1, storedPoints[flatPointsOffset + 1]); - context.machineState.memory.set(memoryOffset + 2, new Uint8(points[i].isInfPoint() ? 1 : 0)); - } + const storedPoints: MemoryValue[] = points + .map(p => p.toFieldsWithInf()) + .flatMap(([x, y, inf]) => [new Field(x), new Field(y), new Uint8(inf.toNumber())]); + const pointsOffset = 0; + context.machineState.memory.setSlice(pointsOffset, storedPoints); // Store scalars const scalarsOffset = pointsOffset + pointsReadLength; context.machineState.memory.setSlice(scalarsOffset, storedScalars); @@ -69,14 +63,14 @@ describe('MultiScalarMul Opcode', () => { await new MultiScalarMul(indirect, pointsOffset, scalarsOffset, outputOffset, pointsLengthOffset).execute(context); - const result = context.machineState.memory.getSlice(outputOffset, 3); + const result = context.machineState.memory.getSlice(outputOffset, 3).map(r => r.toFr()); // We write it out explicitly here let expectedResult = grumpkin.mul(points[0], scalars[0]); expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1])); expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2])); - expect(result.map(r => r.toFr())).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); + expect(result).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); }); it('Should perform msm correctly - indirect', async () => { @@ -92,19 +86,13 @@ describe('MultiScalarMul Opcode', () => { const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory // Transform the points and scalars into the format that we will write to memory // We just store the x and y coordinates here, and handle the infinities when we write to memory - const storedPoints: Field[] = points.flatMap(p => [new Field(p.x), new Field(p.y)]); const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]); - - const pointsOffset = 0; - // Store points...awkwardly (This would be simpler if ts handled the infinities in the Point struct) // Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...] - for (let i = 0; i < points.length; i++) { - const flatPointsOffset = pointsOffset + 2 * i; // 2 since we only have x and y - const memoryOffset = pointsOffset + 3 * i; // 3 since we store x, y, inf - context.machineState.memory.set(memoryOffset, storedPoints[flatPointsOffset]); - context.machineState.memory.set(memoryOffset + 1, storedPoints[flatPointsOffset + 1]); - context.machineState.memory.set(memoryOffset + 2, new Uint8(points[i].isInfPoint() ? 1 : 0)); - } + const storedPoints: MemoryValue[] = points + .map(p => p.toFieldsWithInf()) + .flatMap(([x, y, inf]) => [new Field(x), new Field(y), new Uint8(inf.toNumber())]); + const pointsOffset = 0; + context.machineState.memory.setSlice(pointsOffset, storedPoints); // Store scalars const scalarsOffset = pointsOffset + pointsReadLength; context.machineState.memory.setSlice(scalarsOffset, storedScalars); @@ -130,13 +118,13 @@ describe('MultiScalarMul Opcode', () => { pointsLengthOffset, ).execute(context); - const result = context.machineState.memory.getSlice(outputOffset, 3); + const result = context.machineState.memory.getSlice(outputOffset, 3).map(r => r.toFr()); // We write it out explicitly here let expectedResult = grumpkin.mul(points[0], scalars[0]); expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1])); expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2])); - expect(result.map(r => r.toFr())).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); + expect(result).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); }); }); diff --git a/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts index 70a370b231b..5c9ffe87e20 100644 --- a/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts +++ b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts @@ -1,6 +1,8 @@ -import { Fq, Fr, Point } from '@aztec/circuits.js'; +import { Fq, Point } from '@aztec/circuits.js'; import { Grumpkin } from '@aztec/circuits.js/barretenberg'; +import { strict as assert } from 'assert'; + import { type AvmContext } from '../avm_context.js'; import { Field, TypeTag } from '../avm_memory_types.js'; import { InstructionExecutionError } from '../errors.js'; @@ -18,7 +20,7 @@ export class MultiScalarMul extends Instruction { OperandType.UINT8 /* indirect */, OperandType.UINT32 /* points vector offset */, OperandType.UINT32 /* scalars vector offset */, - OperandType.UINT32 /* output offset (fixed triplet)*/, + OperandType.UINT32 /* output offset (fixed triplet) */, OperandType.UINT32 /* points length offset */, ]; @@ -42,10 +44,9 @@ export class MultiScalarMul extends Instruction { // Length of the points vector should be U32 memory.checkTag(TypeTag.UINT32, this.pointsLengthOffset); - // Get the size of the unrolled (x, y , inf) points vector - // TODO: Do we need to assert that the length is a multiple of 3 (x, y, inf)? const pointsReadLength = memory.get(this.pointsLengthOffset).toNumber(); + assert(pointsReadLength % 3 === 0, 'Points vector offset should be a multiple of 3'); // Divide by 3 since each point is represented as a triplet to get the number of points const numPoints = pointsReadLength / 3; // The tag for each triplet will be (Field, Field, Uint8) @@ -61,6 +62,13 @@ export class MultiScalarMul extends Instruction { // The size of the scalars vector is twice the NUMBER of points because of the scalar limb decomposition const scalarReadLength = numPoints * 2; + // Consume gas prior to performing work + const memoryOperations = { + reads: 1 + pointsReadLength + scalarReadLength /* points and scalars */, + writes: 3 /* output triplet */, + indirect: this.indirect, + }; + context.machineState.consumeGas(this.gasCost(memoryOperations)); // Get the unrolled scalar (lo & hi) representing the scalars const scalarsVector = memory.getSlice(scalarsOffset, scalarReadLength); memory.checkTagsRange(TypeTag.FIELD, scalarsOffset, scalarReadLength); @@ -96,20 +104,10 @@ export class MultiScalarMul extends Instruction { (acc, curr) => grumpkin.add(acc, grumpkin.mul(curr[0], curr[1])), grumpkin.mul(firstBaseScalarPair[0], firstBaseScalarPair[1]), ); - // TODO: Check the Infinity flag here - const output: Fr[] = [outputPoint.x, outputPoint.y, outputPoint.isInfPoint() ? Fr.ONE : Fr.ZERO]; + const output = outputPoint.toFieldsWithInf().map(f => new Field(f)); - memory.setSlice( - outputOffset, - output.map(word => new Field(word)), - ); + memory.setSlice(outputOffset, output); - const memoryOperations = { - reads: 1 + pointsReadLength + scalarReadLength /* points and scalars */, - writes: 3 /* output triplet */, - indirect: this.indirect, - }; - context.machineState.consumeGas(this.gasCost(memoryOperations)); memory.assert(memoryOperations); context.machineState.incrementPc(); }