Skip to content

Commit

Permalink
refactor: enqueued calls processor -> public tx simulator (#9919)
Browse files Browse the repository at this point in the history
- Rename EnqueuedCallsProcessor -> PublicTxSimulator
- Introduce PublicTxContext
- Refactor EnqueuedCallsProcessor with functions per phase that use
PublicTxContext
- Split up tracing of enqueued calls and the merging of their state &
traces (too confusing if they're coupled)
- Lots of prep for removal of public kernels 
- Misc cleanup

Peek at the other PRs in this stack to see how things will simplify
further!

I am open to all feedback when it comes to the refactor of
EnqueuedCallsProcessor, this new PublicTxContext, state forking, etc.
  • Loading branch information
dbanks12 authored Nov 15, 2024
1 parent 05e4b27 commit cae7279
Show file tree
Hide file tree
Showing 18 changed files with 908 additions and 767 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { type Tx, TxExecutionPhase, type TxValidator } from '@aztec/circuit-types';
import { type AztecAddress, type Fr, FunctionSelector } from '@aztec/circuits.js';
import { createDebugLogger } from '@aztec/foundation/log';
import { EnqueuedCallsProcessor, computeFeePayerBalanceStorageSlot } from '@aztec/simulator';
import { computeFeePayerBalanceStorageSlot, getExecutionRequestsByPhase } from '@aztec/simulator';

/** Provides a view into public contract state */
export interface PublicStateSource {
Expand Down Expand Up @@ -58,7 +58,7 @@ export class GasTxValidator implements TxValidator<Tx> {
);

// If there is a claim in this tx that increases the fee payer balance in Fee Juice, add it to balance
const setupFns = EnqueuedCallsProcessor.getExecutionRequestsByPhase(tx, TxExecutionPhase.SETUP);
const setupFns = getExecutionRequestsByPhase(tx, TxExecutionPhase.SETUP);
const claimFunctionCall = setupFns.find(
fn =>
fn.callContext.contractAddress.equals(this.#feeJuiceAddress) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
} from '@aztec/circuit-types';
import { type ContractDataSource } from '@aztec/circuits.js';
import { createDebugLogger } from '@aztec/foundation/log';
import { ContractsDataSourcePublicDB, EnqueuedCallsProcessor } from '@aztec/simulator';
import { ContractsDataSourcePublicDB, getExecutionRequestsByPhase } from '@aztec/simulator';

export class PhasesTxValidator implements TxValidator<Tx> {
#log = createDebugLogger('aztec:sequencer:tx_validator:tx_phases');
Expand Down Expand Up @@ -45,7 +45,7 @@ export class PhasesTxValidator implements TxValidator<Tx> {
return true;
}

const setupFns = EnqueuedCallsProcessor.getExecutionRequestsByPhase(tx, TxExecutionPhase.SETUP);
const setupFns = getExecutionRequestsByPhase(tx, TxExecutionPhase.SETUP);
for (const setupFn of setupFns) {
if (!(await this.isOnAllowList(setupFn, this.setupAllowList))) {
this.#log.warn(
Expand Down
4 changes: 2 additions & 2 deletions yarn-project/simulator/src/avm/avm_simulator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ type PcTally = {
export class AvmSimulator {
private log: DebugLogger;
private bytecode: Buffer | undefined;
public opcodeTallies: Map<string, OpcodeTally> = new Map();
public pcTallies: Map<number, PcTally> = new Map();
private opcodeTallies: Map<string, OpcodeTally> = new Map();
private pcTallies: Map<number, PcTally> = new Map();

constructor(private context: AvmContext) {
assert(
Expand Down
63 changes: 13 additions & 50 deletions yarn-project/simulator/src/avm/journal/journal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -430,9 +430,16 @@ export class AvmPersistableStateManager {
/**
* Accept nested world state modifications
*/
public acceptForkedState(forkedState: AvmPersistableStateManager) {
public mergeForkedState(forkedState: AvmPersistableStateManager) {
this.publicStorage.acceptAndMerge(forkedState.publicStorage);
this.nullifiers.acceptAndMerge(forkedState.nullifiers);
this.trace.mergeSuccessfulForkedTrace(forkedState.trace);
}

public rejectForkedState(forkedState: AvmPersistableStateManager) {
this.publicStorage.acceptAndMerge(forkedState.publicStorage);
this.nullifiers.acceptAndMerge(forkedState.nullifiers);
this.trace.mergeRevertedForkedTrace(forkedState.trace);
}

/**
Expand Down Expand Up @@ -474,28 +481,23 @@ export class AvmPersistableStateManager {
return undefined;
}
}
/**
* Accept the nested call's state and trace the nested call
*/
public async processNestedCall(

public async traceNestedCall(
forkedState: AvmPersistableStateManager,
nestedEnvironment: AvmExecutionEnvironment,
startGasLeft: Gas,
endGasLeft: Gas,
bytecode: Buffer,
avmCallResults: AvmContractCallResult,
) {
if (!avmCallResults.reverted) {
this.acceptForkedState(forkedState);
}
const functionName = await getPublicFunctionDebugName(
this.worldStateDB,
nestedEnvironment.address,
nestedEnvironment.functionSelector,
nestedEnvironment.calldata,
);

this.log.verbose(`[AVM] Calling nested function ${functionName}`);
this.log.verbose(`[AVM] Tracing nested external contract call ${functionName}`);

this.trace.traceNestedCall(
forkedState.trace,
Expand All @@ -508,47 +510,8 @@ export class AvmPersistableStateManager {
);
}

public async mergeStateForEnqueuedCall(
forkedState: AvmPersistableStateManager,
/** The call request from private that enqueued this call. */
publicCallRequest: PublicCallRequest,
/** The call's calldata */
calldata: Fr[],
/** Did the call revert? */
reverted: boolean,
) {
if (!reverted) {
this.acceptForkedState(forkedState);
}
const functionName = await getPublicFunctionDebugName(
this.worldStateDB,
publicCallRequest.contractAddress,
publicCallRequest.functionSelector,
calldata,
);

this.log.verbose(`[AVM] Encountered enqueued public call starting with function ${functionName}`);

this.trace.traceEnqueuedCall(forkedState.trace, publicCallRequest, calldata, reverted);
}

public mergeStateForPhase(
/** The forked state manager used by app logic */
forkedState: AvmPersistableStateManager,
/** The call requests for each enqueued call in app logic. */
publicCallRequests: PublicCallRequest[],
/** The calldatas for each enqueued call in app logic */
calldatas: Fr[][],
/** Did the any enqueued call in app logic revert? */
reverted: boolean,
) {
if (!reverted) {
this.acceptForkedState(forkedState);
}

this.log.verbose(`[AVM] Encountered app logic phase`);

this.trace.traceExecutionPhase(forkedState.trace, publicCallRequests, calldatas, reverted);
public traceEnqueuedCall(publicCallRequest: PublicCallRequest, calldata: Fr[], reverted: boolean) {
this.trace.traceEnqueuedCall(publicCallRequest, calldata, reverted);
}
}

Expand Down
5 changes: 4 additions & 1 deletion yarn-project/simulator/src/avm/opcodes/external_calls.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ abstract class ExternalCall extends Instruction {
context.machineState.refundGas(gasLeftToGas(nestedContext.machineState));

// Accept the nested call's state and trace the nested call
await context.persistableState.processNestedCall(
if (success) {
context.persistableState.mergeForkedState(nestedContext.persistableState);
}
await context.persistableState.traceNestedCall(
/*nestedState=*/ nestedContext.persistableState,
/*nestedEnvironment=*/ nestedContext.environment,
/*startGasLeft=*/ Gas.from(allocatedGas),
Expand Down
31 changes: 7 additions & 24 deletions yarn-project/simulator/src/public/dual_side_effect_trace.ts
Original file line number Diff line number Diff line change
Expand Up @@ -222,39 +222,22 @@ export class DualSideEffectTrace implements PublicSideEffectTraceInterface {
}

public traceEnqueuedCall(
/** The trace of the enqueued call. */
enqueuedCallTrace: this,
/** The call request from private that enqueued this call. */
publicCallRequest: PublicCallRequest,
/** The call's calldata */
calldata: Fr[],
/** Did the call revert? */
reverted: boolean,
) {
this.enqueuedCallTrace.traceEnqueuedCall(
enqueuedCallTrace.enqueuedCallTrace,
publicCallRequest,
calldata,
reverted,
);
this.enqueuedCallTrace.traceEnqueuedCall(publicCallRequest, calldata, reverted);
}

public traceExecutionPhase(
/** The trace of the enqueued call. */
appLogicTrace: this,
/** The call request from private that enqueued this call. */
publicCallRequests: PublicCallRequest[],
/** The call's calldata */
calldatas: Fr[][],
/** Did the any enqueued call in app logic revert? */
reverted: boolean,
) {
this.enqueuedCallTrace.traceExecutionPhase(
appLogicTrace.enqueuedCallTrace,
publicCallRequests,
calldatas,
reverted,
);
public mergeSuccessfulForkedTrace(nestedTrace: this) {
this.enqueuedCallTrace.mergeSuccessfulForkedTrace(nestedTrace.enqueuedCallTrace);
}

public mergeRevertedForkedTrace(nestedTrace: this) {
this.enqueuedCallTrace.mergeRevertedForkedTrace(nestedTrace.enqueuedCallTrace);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import {
import { computePublicDataTreeLeafSlot, computeVarArgsHash, siloNullifier } from '@aztec/circuits.js/hash';
import { Fr } from '@aztec/foundation/fields';

import { randomBytes, randomInt } from 'crypto';
import { randomInt } from 'crypto';

import { AvmContractCallResult } from '../avm/avm_contract_call_result.js';
import { type AvmExecutionEnvironment } from '../avm/avm_execution_environment.js';
Expand Down Expand Up @@ -70,7 +70,6 @@ describe('Enqueued-call Side Effect Trace', () => {
const endGasLeft = Gas.fromFields([new Fr(randomInt(10000)), new Fr(randomInt(10000))]);
const transactionFee = Fr.random();
const calldata = [Fr.random(), Fr.random(), Fr.random(), Fr.random()];
const bytecode = randomBytes(100);
const returnValues = [Fr.random(), Fr.random()];

const constants = CombinedConstantData.empty();
Expand All @@ -80,7 +79,6 @@ describe('Enqueued-call Side Effect Trace', () => {
transactionFee,
});
const avmCallResults = new AvmContractCallResult(/*reverted=*/ false, returnValues);
const avmCallRevertedResults = new AvmContractCallResult(/*reverted=*/ true, returnValues);

const emptyValidationRequests = PublicValidationRequests.empty();

Expand Down Expand Up @@ -477,8 +475,8 @@ describe('Enqueued-call Side Effect Trace', () => {
});
});

describe.each([avmCallResults, avmCallRevertedResults])('Should trace & absorb nested calls', callResults => {
it(`${callResults.reverted ? 'Reverted' : 'Successful'} calls should be traced and absorbed properly`, () => {
describe.each([false, true])('Should merge forked traces', reverted => {
it(`${reverted ? 'Reverted' : 'Successful'} forked trace should be merged properly`, () => {
const existsDefault = true;

const nestedTrace = new PublicEnqueuedCallSideEffectTrace(startCounter);
Expand Down Expand Up @@ -510,25 +508,28 @@ describe('Enqueued-call Side Effect Trace', () => {
nestedTrace.traceGetContractInstance(address, /*exists=*/ false, contractInstance);
testCounter++;

trace.traceNestedCall(nestedTrace, avmEnvironment, startGasLeft, endGasLeft, bytecode, callResults);
if (reverted) {
trace.mergeRevertedForkedTrace(nestedTrace);
} else {
trace.mergeSuccessfulForkedTrace(nestedTrace);
}

// parent trace adopts nested call's counter
expect(trace.getCounter()).toBe(testCounter);

// parent absorbs child's side effects
const parentSideEffects = trace.getSideEffects();
const childSideEffects = nestedTrace.getSideEffects();
if (callResults.reverted) {
expect(parentSideEffects.publicDataReads).toEqual(childSideEffects.publicDataReads);
expect(parentSideEffects.publicDataWrites).toEqual(childSideEffects.publicDataWrites);
expect(parentSideEffects.noteHashReadRequests).toEqual(childSideEffects.noteHashReadRequests);
// TODO(dbanks12): confirm that all hints were merged from child
if (reverted) {
expect(parentSideEffects.publicDataReads).toEqual([]);
expect(parentSideEffects.publicDataWrites).toEqual([]);
expect(parentSideEffects.noteHashReadRequests).toEqual([]);
expect(parentSideEffects.noteHashes).toEqual([]);
expect(parentSideEffects.nullifierReadRequests).toEqual(childSideEffects.nullifierReadRequests);
expect(parentSideEffects.nullifierNonExistentReadRequests).toEqual(
childSideEffects.nullifierNonExistentReadRequests,
);
expect(parentSideEffects.nullifiers).toEqual(childSideEffects.nullifiers);
expect(parentSideEffects.l1ToL2MsgReadRequests).toEqual(childSideEffects.l1ToL2MsgReadRequests);
expect(parentSideEffects.nullifierReadRequests).toEqual([]);
expect(parentSideEffects.nullifierNonExistentReadRequests).toEqual([]);
expect(parentSideEffects.nullifiers).toEqual([]);
expect(parentSideEffects.l1ToL2MsgReadRequests).toEqual([]);
expect(parentSideEffects.l2ToL1Msgs).toEqual([]);
expect(parentSideEffects.unencryptedLogs).toEqual([]);
expect(parentSideEffects.unencryptedLogsHashes).toEqual([]);
Expand Down
Loading

0 comments on commit cae7279

Please sign in to comment.