Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(avm-simulator): make shifts take u8 #5905

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/docs/protocol-specs/public-vm/gen/_instruction-set.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ Bitwise leftward shift (a << b)
- **bOffset**: memory offset of the operation's right input
- **dstOffset**: memory offset specifying where to store operation's result
- **Expression**: `M[dstOffset] = M[aOffset] << M[bOffset]`
- **Tag checks**: `T[aOffset] == T[bOffset] == inTag`
- **Tag checks**: `T[aOffset] == inTag`, `T[bOffset] == u8`
- **Tag updates**: `T[dstOffset] = inTag`
- **Bit-size**: 128

Expand All @@ -781,7 +781,7 @@ Bitwise rightward shift (a >> b)
- **bOffset**: memory offset of the operation's right input
- **dstOffset**: memory offset specifying where to store operation's result
- **Expression**: `M[dstOffset] = M[aOffset] >> M[bOffset]`
- **Tag checks**: `T[aOffset] == T[bOffset] == inTag`
- **Tag checks**: `T[aOffset] == inTag`, `T[bOffset] == u8`
- **Tag updates**: `T[dstOffset] = inTag`
- **Bit-size**: 128

Expand Down
6 changes: 4 additions & 2 deletions docs/src/preprocess/InstructionSet/InstructionSet.js
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ const INSTRUCTION_SET_RAW = [
{
name: "bOffset",
description: "memory offset of the operation's right input",
type: "u8",
},
{
name: "dstOffset",
Expand All @@ -438,7 +439,7 @@ const INSTRUCTION_SET_RAW = [
Expression: "`M[dstOffset] = M[aOffset] << M[bOffset]`",
Summary: "Bitwise leftward shift (a << b)",
Details: "",
"Tag checks": "`T[aOffset] == T[bOffset] == inTag`",
"Tag checks": "`T[aOffset] == inTag`, `T[bOffset] == u8`",
"Tag updates": "`T[dstOffset] = inTag`",
},
{
Expand All @@ -457,6 +458,7 @@ const INSTRUCTION_SET_RAW = [
{
name: "bOffset",
description: "memory offset of the operation's right input",
type: "u8",
},
{
name: "dstOffset",
Expand All @@ -467,7 +469,7 @@ const INSTRUCTION_SET_RAW = [
Expression: "`M[dstOffset] = M[aOffset] >> M[bOffset]`",
Summary: "Bitwise rightward shift (a >> b)",
Details: "",
"Tag checks": "`T[aOffset] == T[bOffset] == inTag`",
"Tag checks": "`T[aOffset] == inTag`, `T[bOffset] == u8`",
"Tag updates": "`T[dstOffset] = inTag`",
},
{
Expand Down
2 changes: 1 addition & 1 deletion yarn-project/simulator/src/avm/avm_memory_types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ export class TaggedMemory implements TaggedMemoryInterface {
}

/** Returns a MeteredTaggedMemory instance to track the number of reads and writes if TRACK_MEMORY_ACCESSES is set. */
public track(type: string = 'instruction') {
public track(type: string = 'instruction'): TaggedMemoryInterface {
return TaggedMemory.TRACK_MEMORY_ACCESSES ? new MeteredTaggedMemory(this, type) : this;
}

Expand Down
54 changes: 46 additions & 8 deletions yarn-project/simulator/src/avm/opcodes/bitwise.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { type AvmContext } from '../avm_context.js';
import { TypeTag, Uint16, Uint32 } from '../avm_memory_types.js';
import { TypeTag, Uint8, Uint16, Uint32 } from '../avm_memory_types.js';
import { initContext } from '../fixtures/index.js';
import { And, Not, Or, Shl, Shr, Xor } from './bitwise.js';

Expand Down Expand Up @@ -157,13 +157,32 @@ describe('Bitwise instructions', () => {
expect(inst.serialize()).toEqual(buf);
});

it('Should shift correctly 0 positions over integral types', async () => {
it('Should require shift amount to be U8', async () => {
const a = new Uint32(0b11111110010011100100n);
const b = new Uint32(0n);

context.machineState.memory.set(0, a);
context.machineState.memory.set(1, b);

await expect(
async () =>
await new Shr(
/*indirect=*/ 0,
/*inTag=*/ TypeTag.UINT32,
/*aOffset=*/ 0,
/*bOffset=*/ 1,
/*dstOffset=*/ 2,
).execute(context),
).rejects.toThrow(/got UINT32, expected UINT8/);
});

it('Should shift correctly 0 positions over integral types', async () => {
const a = new Uint32(0b11111110010011100100n);
const b = new Uint8(0n);

context.machineState.memory.set(0, a);
context.machineState.memory.set(1, b);

await new Shr(
/*indirect=*/ 0,
/*inTag=*/ TypeTag.UINT32,
Expand All @@ -179,7 +198,7 @@ describe('Bitwise instructions', () => {

it('Should shift correctly 2 positions over integral types', async () => {
const a = new Uint32(0b11111110010011100100n);
const b = new Uint32(2n);
const b = new Uint8(2n);

context.machineState.memory.set(0, a);
context.machineState.memory.set(1, b);
Expand All @@ -199,7 +218,7 @@ describe('Bitwise instructions', () => {

it('Should shift correctly 19 positions over integral types', async () => {
const a = new Uint32(0b11111110010011100100n);
const b = new Uint32(19n);
const b = new Uint8(19n);

context.machineState.memory.set(0, a);
context.machineState.memory.set(1, b);
Expand Down Expand Up @@ -240,13 +259,32 @@ describe('Bitwise instructions', () => {
expect(inst.serialize()).toEqual(buf);
});

it('Should shift correctly 0 positions over integral types', async () => {
it('Should require shift amount to be U8', async () => {
const a = new Uint32(0b11111110010011100100n);
const b = new Uint32(0n);

context.machineState.memory.set(0, a);
context.machineState.memory.set(1, b);

await expect(
async () =>
await new Shl(
/*indirect=*/ 0,
/*inTag=*/ TypeTag.UINT32,
/*aOffset=*/ 0,
/*bOffset=*/ 1,
/*dstOffset=*/ 2,
).execute(context),
).rejects.toThrow(/got UINT32, expected UINT8/);
});

it('Should shift correctly 0 positions over integral types', async () => {
const a = new Uint32(0b11111110010011100100n);
const b = new Uint8(0n);

context.machineState.memory.set(0, a);
context.machineState.memory.set(1, b);

await new Shl(
/*indirect=*/ 0,
/*inTag=*/ TypeTag.UINT32,
Expand All @@ -262,7 +300,7 @@ describe('Bitwise instructions', () => {

it('Should shift correctly 2 positions over integral types', async () => {
const a = new Uint32(0b11111110010011100100n);
const b = new Uint32(2n);
const b = new Uint8(2n);

context.machineState.memory.set(0, a);
context.machineState.memory.set(1, b);
Expand All @@ -282,7 +320,7 @@ describe('Bitwise instructions', () => {

it('Should shift correctly over bit limit over integral types', async () => {
const a = new Uint16(0b1110010011100111n);
const b = new Uint16(17n);
const b = new Uint8(17n);

context.machineState.memory.set(0, a);
context.machineState.memory.set(1, b);
Expand All @@ -302,7 +340,7 @@ describe('Bitwise instructions', () => {

it('Should truncate when shifting over bit size over integral types', async () => {
const a = new Uint16(0b1110010011100111n);
const b = new Uint16(2n);
const b = new Uint8(2n);

context.machineState.memory.set(0, a);
context.machineState.memory.set(1, b);
Expand Down
25 changes: 18 additions & 7 deletions yarn-project/simulator/src/avm/opcodes/bitwise.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { AvmContext } from '../avm_context.js';
import { type IntegralValue } from '../avm_memory_types.js';
import { type IntegralValue, type TaggedMemoryInterface, TypeTag } from '../avm_memory_types.js';
import { Opcode } from '../serialization/instruction_serialization.js';
import { ThreeOperandInstruction, TwoOperandInstruction } from './instruction_impl.js';

Expand All @@ -9,7 +9,7 @@ abstract class ThreeOperandBitwiseInstruction extends ThreeOperandInstruction {
const memory = context.machineState.memory.track(this.type);
context.machineState.consumeGas(this.gasCost(memoryOperations));

memory.checkTags(this.inTag, this.aOffset, this.bOffset);
this.checkTags(memory, this.inTag, this.aOffset, this.bOffset);

const a = memory.getAs<IntegralValue>(this.aOffset);
const b = memory.getAs<IntegralValue>(this.bOffset);
Expand All @@ -22,13 +22,16 @@ abstract class ThreeOperandBitwiseInstruction extends ThreeOperandInstruction {
}

protected abstract compute(a: IntegralValue, b: IntegralValue): IntegralValue;
protected checkTags(memory: TaggedMemoryInterface, inTag: number, aOffset: number, bOffset: number) {
memory.checkTags(inTag, aOffset, bOffset);
}
}

export class And extends ThreeOperandBitwiseInstruction {
static readonly type: string = 'AND';
static readonly opcode = Opcode.AND;

protected compute(a: IntegralValue, b: IntegralValue): IntegralValue {
protected override compute(a: IntegralValue, b: IntegralValue): IntegralValue {
return a.and(b);
}
}
Expand All @@ -37,7 +40,7 @@ export class Or extends ThreeOperandBitwiseInstruction {
static readonly type: string = 'OR';
static readonly opcode = Opcode.OR;

protected compute(a: IntegralValue, b: IntegralValue): IntegralValue {
protected override compute(a: IntegralValue, b: IntegralValue): IntegralValue {
return a.or(b);
}
}
Expand All @@ -46,7 +49,7 @@ export class Xor extends ThreeOperandBitwiseInstruction {
static readonly type: string = 'XOR';
static readonly opcode = Opcode.XOR;

protected compute(a: IntegralValue, b: IntegralValue): IntegralValue {
protected override compute(a: IntegralValue, b: IntegralValue): IntegralValue {
return a.xor(b);
}
}
Expand All @@ -55,18 +58,26 @@ export class Shl extends ThreeOperandBitwiseInstruction {
static readonly type: string = 'SHL';
static readonly opcode = Opcode.SHL;

protected compute(a: IntegralValue, b: IntegralValue): IntegralValue {
protected override compute(a: IntegralValue, b: IntegralValue): IntegralValue {
return a.shl(b);
}
protected override checkTags(memory: TaggedMemoryInterface, inTag: number, aOffset: number, bOffset: number) {
memory.checkTag(inTag, aOffset);
memory.checkTag(TypeTag.UINT8, bOffset);
}
}

export class Shr extends ThreeOperandBitwiseInstruction {
static readonly type: string = 'SHR';
static readonly opcode = Opcode.SHR;

protected compute(a: IntegralValue, b: IntegralValue): IntegralValue {
protected override compute(a: IntegralValue, b: IntegralValue): IntegralValue {
return a.shr(b);
}
protected override checkTags(memory: TaggedMemoryInterface, inTag: number, aOffset: number, bOffset: number) {
memory.checkTag(inTag, aOffset);
memory.checkTag(TypeTag.UINT8, bOffset);
}
}

export class Not extends TwoOperandInstruction {
Expand Down
Loading