diff --git a/avm-transpiler/src/transpile.rs b/avm-transpiler/src/transpile.rs index a1ae1b71644..f1fdd201b8b 100644 --- a/avm-transpiler/src/transpile.rs +++ b/avm-transpiler/src/transpile.rs @@ -863,6 +863,7 @@ fn handle_black_box_function(avm_instrs: &mut Vec, operation: &B let num_points = points.size.0; let scalars_offset = scalars.pointer.0; // Output array is fixed to 3 + assert_eq!(outputs.size, 3, "Output array size must be equal to 3"); let outputs_offset = outputs.pointer.0; avm_instrs.push(AvmInstruction { opcode: AvmOpcode::MSM, diff --git a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_deserialization.cpp b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_deserialization.cpp index 6b10ab40afc..891d2af5695 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_deserialization.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_deserialization.cpp @@ -158,6 +158,8 @@ const std::unordered_map> OPCODE_WIRE_FORMAT = OperandType::UINT32, // rhs.y OperandType::UINT32, // rhs.is_infinite OperandType::UINT32 } }, // dst_offset + { OpCode::MSM, + { OperandType::INDIRECT, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32 } }, // Gadget - Conversion { OpCode::TORADIXLE, { OperandType::INDIRECT, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32, OperandType::UINT32 } }, diff --git a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_opcode.hpp b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_opcode.hpp index 83de32e6568..e3ced1a03e7 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_opcode.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_opcode.hpp @@ -105,6 +105,7 @@ enum class OpCode : uint8_t { SHA256, PEDERSEN, ECADD, + MSM, // Conversions TORADIXLE, // Future Gadgets -- pending changes in noir diff --git a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr index c3d827652f1..d870e8564f8 100644 --- a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr +++ b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr @@ -148,7 +148,8 @@ contract AvmTest { fn variable_base_msm() -> [Field; 3] { let g = EmbeddedCurvePoint { x: 1, y: 17631683881184975370165255887551781615748388533673675138860, is_infinite: false }; let scalar = EmbeddedCurveScalar { lo: 3, hi: 0 }; - let triple_g = multi_scalar_mul([g], [scalar]); + let scalar2 = EmbeddedCurveScalar { lo: 20, hi: 0 }; + let triple_g = multi_scalar_mul([g, g], [scalar, scalar2]); triple_g } diff --git a/yarn-project/foundation/src/fields/point.ts b/yarn-project/foundation/src/fields/point.ts index 490c1fe2f48..bf12faf3a13 100644 --- a/yarn-project/foundation/src/fields/point.ts +++ b/yarn-project/foundation/src/fields/point.ts @@ -139,14 +139,18 @@ export class Point { /** * Check if this is point at infinity. + * Check this is consistent with how bb is encoding the point at infinity */ - isInfPoint() { - // Check this - return this.x.isZero(); + public get inf() { + return this.x == Fr.ZERO; + } + public toFieldsWithInf() { + return [this.x, this.y, new Fr(this.inf)]; } isOnGrumpkin() { - if (this.isInfPoint()) { + // TODO: Check this against how bb handles curve check and infinity point check + if (this.inf) { return true; } diff --git a/yarn-project/simulator/src/avm/avm_simulator.test.ts b/yarn-project/simulator/src/avm/avm_simulator.test.ts index 86d1960577b..1614305b0be 100644 --- a/yarn-project/simulator/src/avm/avm_simulator.test.ts +++ b/yarn-project/simulator/src/avm/avm_simulator.test.ts @@ -117,7 +117,9 @@ describe('AVM simulator: transpiled Noir contracts', () => { expect(results.reverted).toBe(false); const grumpkin = new Grumpkin(); const g3 = grumpkin.mul(grumpkin.generator(), new Fq(3)); - expect(results.output).toEqual([g3.x, g3.y, Fr.ZERO]); + const g20 = grumpkin.mul(grumpkin.generator(), new Fq(20)); + const expectedResult = grumpkin.add(g3, g20); + expect(results.output).toEqual([expectedResult.x, expectedResult.y, Fr.ZERO]); }); describe('U128 addition and overflows', () => { diff --git a/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts index 83a9b79ca31..861d9b4ec71 100644 --- a/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts +++ b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.test.ts @@ -2,7 +2,7 @@ import { Fq, Fr } from '@aztec/circuits.js'; import { Grumpkin } from '@aztec/circuits.js/barretenberg'; import { type AvmContext } from '../avm_context.js'; -import { Field, Uint8, Uint32 } from '../avm_memory_types.js'; +import { Field, type MemoryValue, Uint8, Uint32 } from '../avm_memory_types.js'; import { initContext } from '../fixtures/index.js'; import { MultiScalarMul } from './multi_scalar_mul.js'; @@ -46,19 +46,13 @@ describe('MultiScalarMul Opcode', () => { const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory // Transform the points and scalars into the format that we will write to memory // We just store the x and y coordinates here, and handle the infinities when we write to memory - const storedPoints: Field[] = points.flatMap(p => [new Field(p.x), new Field(p.y)]); const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]); - - const pointsOffset = 0; - // Store points...awkwardly (This would be simpler if ts handled the infinities in the Point struct) // Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...] - for (let i = 0; i < points.length; i++) { - const flatPointsOffset = pointsOffset + 2 * i; // 2 since we only have x and y - const memoryOffset = pointsOffset + 3 * i; // 3 since we store x, y, inf - context.machineState.memory.set(memoryOffset, storedPoints[flatPointsOffset]); - context.machineState.memory.set(memoryOffset + 1, storedPoints[flatPointsOffset + 1]); - context.machineState.memory.set(memoryOffset + 2, new Uint8(points[i].isInfPoint() ? 1 : 0)); - } + const storedPoints: MemoryValue[] = points + .map(p => p.toFieldsWithInf()) + .flatMap(([x, y, inf]) => [new Field(x), new Field(y), new Uint8(inf.toNumber())]); + const pointsOffset = 0; + context.machineState.memory.setSlice(pointsOffset, storedPoints); // Store scalars const scalarsOffset = pointsOffset + pointsReadLength; context.machineState.memory.setSlice(scalarsOffset, storedScalars); @@ -69,14 +63,14 @@ describe('MultiScalarMul Opcode', () => { await new MultiScalarMul(indirect, pointsOffset, scalarsOffset, outputOffset, pointsLengthOffset).execute(context); - const result = context.machineState.memory.getSlice(outputOffset, 3); + const result = context.machineState.memory.getSlice(outputOffset, 3).map(r => r.toFr()); // We write it out explicitly here let expectedResult = grumpkin.mul(points[0], scalars[0]); expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1])); expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2])); - expect(result.map(r => r.toFr())).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); + expect(result).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); }); it('Should perform msm correctly - indirect', async () => { @@ -92,19 +86,13 @@ describe('MultiScalarMul Opcode', () => { const scalarsLength = scalars.length * 2; // multiplied by 2 since we will store them as lo and hi limbs in avm memory // Transform the points and scalars into the format that we will write to memory // We just store the x and y coordinates here, and handle the infinities when we write to memory - const storedPoints: Field[] = points.flatMap(p => [new Field(p.x), new Field(p.y)]); const storedScalars: Field[] = scalars.flatMap(s => [new Field(s.low), new Field(s.high)]); - - const pointsOffset = 0; - // Store points...awkwardly (This would be simpler if ts handled the infinities in the Point struct) // Points are stored as [x1, y1, inf1, x2, y2, inf2, ...] where the types are [Field, Field, Uint8, Field, Field, Uint8, ...] - for (let i = 0; i < points.length; i++) { - const flatPointsOffset = pointsOffset + 2 * i; // 2 since we only have x and y - const memoryOffset = pointsOffset + 3 * i; // 3 since we store x, y, inf - context.machineState.memory.set(memoryOffset, storedPoints[flatPointsOffset]); - context.machineState.memory.set(memoryOffset + 1, storedPoints[flatPointsOffset + 1]); - context.machineState.memory.set(memoryOffset + 2, new Uint8(points[i].isInfPoint() ? 1 : 0)); - } + const storedPoints: MemoryValue[] = points + .map(p => p.toFieldsWithInf()) + .flatMap(([x, y, inf]) => [new Field(x), new Field(y), new Uint8(inf.toNumber())]); + const pointsOffset = 0; + context.machineState.memory.setSlice(pointsOffset, storedPoints); // Store scalars const scalarsOffset = pointsOffset + pointsReadLength; context.machineState.memory.setSlice(scalarsOffset, storedScalars); @@ -130,13 +118,13 @@ describe('MultiScalarMul Opcode', () => { pointsLengthOffset, ).execute(context); - const result = context.machineState.memory.getSlice(outputOffset, 3); + const result = context.machineState.memory.getSlice(outputOffset, 3).map(r => r.toFr()); // We write it out explicitly here let expectedResult = grumpkin.mul(points[0], scalars[0]); expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[1], scalars[1])); expectedResult = grumpkin.add(expectedResult, grumpkin.mul(points[2], scalars[2])); - expect(result.map(r => r.toFr())).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); + expect(result).toEqual([expectedResult.x, expectedResult.y, new Fr(0n)]); }); }); diff --git a/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts index 70a370b231b..5c9ffe87e20 100644 --- a/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts +++ b/yarn-project/simulator/src/avm/opcodes/multi_scalar_mul.ts @@ -1,6 +1,8 @@ -import { Fq, Fr, Point } from '@aztec/circuits.js'; +import { Fq, Point } from '@aztec/circuits.js'; import { Grumpkin } from '@aztec/circuits.js/barretenberg'; +import { strict as assert } from 'assert'; + import { type AvmContext } from '../avm_context.js'; import { Field, TypeTag } from '../avm_memory_types.js'; import { InstructionExecutionError } from '../errors.js'; @@ -18,7 +20,7 @@ export class MultiScalarMul extends Instruction { OperandType.UINT8 /* indirect */, OperandType.UINT32 /* points vector offset */, OperandType.UINT32 /* scalars vector offset */, - OperandType.UINT32 /* output offset (fixed triplet)*/, + OperandType.UINT32 /* output offset (fixed triplet) */, OperandType.UINT32 /* points length offset */, ]; @@ -42,10 +44,9 @@ export class MultiScalarMul extends Instruction { // Length of the points vector should be U32 memory.checkTag(TypeTag.UINT32, this.pointsLengthOffset); - // Get the size of the unrolled (x, y , inf) points vector - // TODO: Do we need to assert that the length is a multiple of 3 (x, y, inf)? const pointsReadLength = memory.get(this.pointsLengthOffset).toNumber(); + assert(pointsReadLength % 3 === 0, 'Points vector offset should be a multiple of 3'); // Divide by 3 since each point is represented as a triplet to get the number of points const numPoints = pointsReadLength / 3; // The tag for each triplet will be (Field, Field, Uint8) @@ -61,6 +62,13 @@ export class MultiScalarMul extends Instruction { // The size of the scalars vector is twice the NUMBER of points because of the scalar limb decomposition const scalarReadLength = numPoints * 2; + // Consume gas prior to performing work + const memoryOperations = { + reads: 1 + pointsReadLength + scalarReadLength /* points and scalars */, + writes: 3 /* output triplet */, + indirect: this.indirect, + }; + context.machineState.consumeGas(this.gasCost(memoryOperations)); // Get the unrolled scalar (lo & hi) representing the scalars const scalarsVector = memory.getSlice(scalarsOffset, scalarReadLength); memory.checkTagsRange(TypeTag.FIELD, scalarsOffset, scalarReadLength); @@ -96,20 +104,10 @@ export class MultiScalarMul extends Instruction { (acc, curr) => grumpkin.add(acc, grumpkin.mul(curr[0], curr[1])), grumpkin.mul(firstBaseScalarPair[0], firstBaseScalarPair[1]), ); - // TODO: Check the Infinity flag here - const output: Fr[] = [outputPoint.x, outputPoint.y, outputPoint.isInfPoint() ? Fr.ONE : Fr.ZERO]; + const output = outputPoint.toFieldsWithInf().map(f => new Field(f)); - memory.setSlice( - outputOffset, - output.map(word => new Field(word)), - ); + memory.setSlice(outputOffset, output); - const memoryOperations = { - reads: 1 + pointsReadLength + scalarReadLength /* points and scalars */, - writes: 3 /* output triplet */, - indirect: this.indirect, - }; - context.machineState.consumeGas(this.gasCost(memoryOperations)); memory.assert(memoryOperations); context.machineState.incrementPc(); }