From 2325991bf8027efe5fbc48e555571aff84b0d5c2 Mon Sep 17 00:00:00 2001 From: fcarreiro Date: Mon, 22 Apr 2024 09:57:42 +0000 Subject: [PATCH] chore(avm-simulator): make shifts take u8 --- .../public-vm/gen/_instruction-set.mdx | 4 +- .../InstructionSet/InstructionSet.js | 6 ++- .../simulator/src/avm/avm_memory_types.ts | 2 +- .../simulator/src/avm/opcodes/bitwise.test.ts | 54 ++++++++++++++++--- .../simulator/src/avm/opcodes/bitwise.ts | 25 ++++++--- 5 files changed, 71 insertions(+), 20 deletions(-) diff --git a/docs/docs/protocol-specs/public-vm/gen/_instruction-set.mdx b/docs/docs/protocol-specs/public-vm/gen/_instruction-set.mdx index d1f4fa379c3..b0929f294d2 100644 --- a/docs/docs/protocol-specs/public-vm/gen/_instruction-set.mdx +++ b/docs/docs/protocol-specs/public-vm/gen/_instruction-set.mdx @@ -760,7 +760,7 @@ Bitwise leftward shift (a << b) - **bOffset**: memory offset of the operation's right input - **dstOffset**: memory offset specifying where to store operation's result - **Expression**: `M[dstOffset] = M[aOffset] << M[bOffset]` -- **Tag checks**: `T[aOffset] == T[bOffset] == inTag` +- **Tag checks**: `T[aOffset] == inTag`, `T[bOffset] == u8` - **Tag updates**: `T[dstOffset] = inTag` - **Bit-size**: 128 @@ -781,7 +781,7 @@ Bitwise rightward shift (a >> b) - **bOffset**: memory offset of the operation's right input - **dstOffset**: memory offset specifying where to store operation's result - **Expression**: `M[dstOffset] = M[aOffset] >> M[bOffset]` -- **Tag checks**: `T[aOffset] == T[bOffset] == inTag` +- **Tag checks**: `T[aOffset] == inTag`, `T[bOffset] == u8` - **Tag updates**: `T[dstOffset] = inTag` - **Bit-size**: 128 diff --git a/docs/src/preprocess/InstructionSet/InstructionSet.js b/docs/src/preprocess/InstructionSet/InstructionSet.js index 496b2a8128d..5fb7011da18 100644 --- a/docs/src/preprocess/InstructionSet/InstructionSet.js +++ b/docs/src/preprocess/InstructionSet/InstructionSet.js @@ -428,6 +428,7 @@ const INSTRUCTION_SET_RAW = [ { name: "bOffset", description: "memory offset of the operation's right input", + type: "u8", }, { name: "dstOffset", @@ -438,7 +439,7 @@ const INSTRUCTION_SET_RAW = [ Expression: "`M[dstOffset] = M[aOffset] << M[bOffset]`", Summary: "Bitwise leftward shift (a << b)", Details: "", - "Tag checks": "`T[aOffset] == T[bOffset] == inTag`", + "Tag checks": "`T[aOffset] == inTag`, `T[bOffset] == u8`", "Tag updates": "`T[dstOffset] = inTag`", }, { @@ -457,6 +458,7 @@ const INSTRUCTION_SET_RAW = [ { name: "bOffset", description: "memory offset of the operation's right input", + type: "u8", }, { name: "dstOffset", @@ -467,7 +469,7 @@ const INSTRUCTION_SET_RAW = [ Expression: "`M[dstOffset] = M[aOffset] >> M[bOffset]`", Summary: "Bitwise rightward shift (a >> b)", Details: "", - "Tag checks": "`T[aOffset] == T[bOffset] == inTag`", + "Tag checks": "`T[aOffset] == inTag`, `T[bOffset] == u8`", "Tag updates": "`T[dstOffset] = inTag`", }, { diff --git a/yarn-project/simulator/src/avm/avm_memory_types.ts b/yarn-project/simulator/src/avm/avm_memory_types.ts index b1c053e8b01..9509b0297bb 100644 --- a/yarn-project/simulator/src/avm/avm_memory_types.ts +++ b/yarn-project/simulator/src/avm/avm_memory_types.ts @@ -225,7 +225,7 @@ export class TaggedMemory implements TaggedMemoryInterface { } /** Returns a MeteredTaggedMemory instance to track the number of reads and writes if TRACK_MEMORY_ACCESSES is set. */ - public track(type: string = 'instruction') { + public track(type: string = 'instruction'): TaggedMemoryInterface { return TaggedMemory.TRACK_MEMORY_ACCESSES ? new MeteredTaggedMemory(this, type) : this; } diff --git a/yarn-project/simulator/src/avm/opcodes/bitwise.test.ts b/yarn-project/simulator/src/avm/opcodes/bitwise.test.ts index afd6fa26126..2d8e5308cc6 100644 --- a/yarn-project/simulator/src/avm/opcodes/bitwise.test.ts +++ b/yarn-project/simulator/src/avm/opcodes/bitwise.test.ts @@ -1,5 +1,5 @@ import { type AvmContext } from '../avm_context.js'; -import { TypeTag, Uint16, Uint32 } from '../avm_memory_types.js'; +import { TypeTag, Uint8, Uint16, Uint32 } from '../avm_memory_types.js'; import { initContext } from '../fixtures/index.js'; import { And, Not, Or, Shl, Shr, Xor } from './bitwise.js'; @@ -157,13 +157,32 @@ describe('Bitwise instructions', () => { expect(inst.serialize()).toEqual(buf); }); - it('Should shift correctly 0 positions over integral types', async () => { + it('Should require shift amount to be U8', async () => { const a = new Uint32(0b11111110010011100100n); const b = new Uint32(0n); context.machineState.memory.set(0, a); context.machineState.memory.set(1, b); + await expect( + async () => + await new Shr( + /*indirect=*/ 0, + /*inTag=*/ TypeTag.UINT32, + /*aOffset=*/ 0, + /*bOffset=*/ 1, + /*dstOffset=*/ 2, + ).execute(context), + ).rejects.toThrow(/got UINT32, expected UINT8/); + }); + + it('Should shift correctly 0 positions over integral types', async () => { + const a = new Uint32(0b11111110010011100100n); + const b = new Uint8(0n); + + context.machineState.memory.set(0, a); + context.machineState.memory.set(1, b); + await new Shr( /*indirect=*/ 0, /*inTag=*/ TypeTag.UINT32, @@ -179,7 +198,7 @@ describe('Bitwise instructions', () => { it('Should shift correctly 2 positions over integral types', async () => { const a = new Uint32(0b11111110010011100100n); - const b = new Uint32(2n); + const b = new Uint8(2n); context.machineState.memory.set(0, a); context.machineState.memory.set(1, b); @@ -199,7 +218,7 @@ describe('Bitwise instructions', () => { it('Should shift correctly 19 positions over integral types', async () => { const a = new Uint32(0b11111110010011100100n); - const b = new Uint32(19n); + const b = new Uint8(19n); context.machineState.memory.set(0, a); context.machineState.memory.set(1, b); @@ -240,13 +259,32 @@ describe('Bitwise instructions', () => { expect(inst.serialize()).toEqual(buf); }); - it('Should shift correctly 0 positions over integral types', async () => { + it('Should require shift amount to be U8', async () => { const a = new Uint32(0b11111110010011100100n); const b = new Uint32(0n); context.machineState.memory.set(0, a); context.machineState.memory.set(1, b); + await expect( + async () => + await new Shl( + /*indirect=*/ 0, + /*inTag=*/ TypeTag.UINT32, + /*aOffset=*/ 0, + /*bOffset=*/ 1, + /*dstOffset=*/ 2, + ).execute(context), + ).rejects.toThrow(/got UINT32, expected UINT8/); + }); + + it('Should shift correctly 0 positions over integral types', async () => { + const a = new Uint32(0b11111110010011100100n); + const b = new Uint8(0n); + + context.machineState.memory.set(0, a); + context.machineState.memory.set(1, b); + await new Shl( /*indirect=*/ 0, /*inTag=*/ TypeTag.UINT32, @@ -262,7 +300,7 @@ describe('Bitwise instructions', () => { it('Should shift correctly 2 positions over integral types', async () => { const a = new Uint32(0b11111110010011100100n); - const b = new Uint32(2n); + const b = new Uint8(2n); context.machineState.memory.set(0, a); context.machineState.memory.set(1, b); @@ -282,7 +320,7 @@ describe('Bitwise instructions', () => { it('Should shift correctly over bit limit over integral types', async () => { const a = new Uint16(0b1110010011100111n); - const b = new Uint16(17n); + const b = new Uint8(17n); context.machineState.memory.set(0, a); context.machineState.memory.set(1, b); @@ -302,7 +340,7 @@ describe('Bitwise instructions', () => { it('Should truncate when shifting over bit size over integral types', async () => { const a = new Uint16(0b1110010011100111n); - const b = new Uint16(2n); + const b = new Uint8(2n); context.machineState.memory.set(0, a); context.machineState.memory.set(1, b); diff --git a/yarn-project/simulator/src/avm/opcodes/bitwise.ts b/yarn-project/simulator/src/avm/opcodes/bitwise.ts index 5d7e3cf8d67..ae8f69db372 100644 --- a/yarn-project/simulator/src/avm/opcodes/bitwise.ts +++ b/yarn-project/simulator/src/avm/opcodes/bitwise.ts @@ -1,5 +1,5 @@ import type { AvmContext } from '../avm_context.js'; -import { type IntegralValue } from '../avm_memory_types.js'; +import { type IntegralValue, type TaggedMemoryInterface, TypeTag } from '../avm_memory_types.js'; import { Opcode } from '../serialization/instruction_serialization.js'; import { ThreeOperandInstruction, TwoOperandInstruction } from './instruction_impl.js'; @@ -9,7 +9,7 @@ abstract class ThreeOperandBitwiseInstruction extends ThreeOperandInstruction { const memory = context.machineState.memory.track(this.type); context.machineState.consumeGas(this.gasCost(memoryOperations)); - memory.checkTags(this.inTag, this.aOffset, this.bOffset); + this.checkTags(memory, this.inTag, this.aOffset, this.bOffset); const a = memory.getAs(this.aOffset); const b = memory.getAs(this.bOffset); @@ -22,13 +22,16 @@ abstract class ThreeOperandBitwiseInstruction extends ThreeOperandInstruction { } protected abstract compute(a: IntegralValue, b: IntegralValue): IntegralValue; + protected checkTags(memory: TaggedMemoryInterface, inTag: number, aOffset: number, bOffset: number) { + memory.checkTags(inTag, aOffset, bOffset); + } } export class And extends ThreeOperandBitwiseInstruction { static readonly type: string = 'AND'; static readonly opcode = Opcode.AND; - protected compute(a: IntegralValue, b: IntegralValue): IntegralValue { + protected override compute(a: IntegralValue, b: IntegralValue): IntegralValue { return a.and(b); } } @@ -37,7 +40,7 @@ export class Or extends ThreeOperandBitwiseInstruction { static readonly type: string = 'OR'; static readonly opcode = Opcode.OR; - protected compute(a: IntegralValue, b: IntegralValue): IntegralValue { + protected override compute(a: IntegralValue, b: IntegralValue): IntegralValue { return a.or(b); } } @@ -46,7 +49,7 @@ export class Xor extends ThreeOperandBitwiseInstruction { static readonly type: string = 'XOR'; static readonly opcode = Opcode.XOR; - protected compute(a: IntegralValue, b: IntegralValue): IntegralValue { + protected override compute(a: IntegralValue, b: IntegralValue): IntegralValue { return a.xor(b); } } @@ -55,18 +58,26 @@ export class Shl extends ThreeOperandBitwiseInstruction { static readonly type: string = 'SHL'; static readonly opcode = Opcode.SHL; - protected compute(a: IntegralValue, b: IntegralValue): IntegralValue { + protected override compute(a: IntegralValue, b: IntegralValue): IntegralValue { return a.shl(b); } + protected override checkTags(memory: TaggedMemoryInterface, inTag: number, aOffset: number, bOffset: number) { + memory.checkTag(inTag, aOffset); + memory.checkTag(TypeTag.UINT8, bOffset); + } } export class Shr extends ThreeOperandBitwiseInstruction { static readonly type: string = 'SHR'; static readonly opcode = Opcode.SHR; - protected compute(a: IntegralValue, b: IntegralValue): IntegralValue { + protected override compute(a: IntegralValue, b: IntegralValue): IntegralValue { return a.shr(b); } + protected override checkTags(memory: TaggedMemoryInterface, inTag: number, aOffset: number, bOffset: number) { + memory.checkTag(inTag, aOffset); + memory.checkTag(TypeTag.UINT8, bOffset); + } } export class Not extends TwoOperandInstruction {