Skip to content

Commit

Permalink
feat: abort ongoing proving jobs
Browse files Browse the repository at this point in the history
  • Loading branch information
alexghr committed Apr 26, 2024
1 parent a97ea4e commit e5fe19d
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 62 deletions.
5 changes: 5 additions & 0 deletions yarn-project/foundation/src/error/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@ export class InterruptError extends Error {}
* An error thrown when an action times out.
*/
export class TimeoutError extends Error {}

/**
* Represents an error thrown when an operation is aborted.
*/
export class AbortedError extends Error {}
53 changes: 43 additions & 10 deletions yarn-project/prover-client/src/orchestrator/orchestrator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import {
} from '@aztec/circuits.js';
import { makeTuple } from '@aztec/foundation/array';
import { padArrayEnd } from '@aztec/foundation/collection';
import { AbortedError } from '@aztec/foundation/error';
import { createDebugLogger } from '@aztec/foundation/log';
import { promiseWithResolvers } from '@aztec/foundation/promise';
import { type Tuple } from '@aztec/foundation/serialize';
Expand Down Expand Up @@ -80,6 +81,8 @@ const KernelTypesWithoutFunctions: Set<PublicKernelType> = new Set<PublicKernelT
*/
export class ProvingOrchestrator {
private provingState: ProvingState | undefined = undefined;
private pendingProvingJobs: AbortController[] = [];

constructor(private db: MerkleTreeOperations, private prover: CircuitProver) {}

/**
Expand Down Expand Up @@ -209,6 +212,10 @@ export class ProvingOrchestrator {
* Cancel any further proving of the block
*/
public cancelBlock() {
for (const controller of this.pendingProvingJobs) {
controller.abort();
}

this.provingState?.cancel();
}

Expand Down Expand Up @@ -301,30 +308,56 @@ export class ProvingOrchestrator {
*/
private deferredProving<T>(
provingState: ProvingState | undefined,
request: () => Promise<T>,
request: (signal: AbortSignal) => Promise<T>,
callback: (result: T, durationMs: number) => void | Promise<void>,
) {
if (!provingState?.verifyState()) {
logger.debug(`Not enqueuing job, state no longer valid`);
return;
}

const controller = new AbortController();
this.pendingProvingJobs.push(controller);

// We use a 'safeJob'. We don't want promise rejections in the proving pool, we want to capture the error here
// and reject the proving job whilst keeping the event loop free of rejections
const safeJob = async () => {
try {
// there's a delay between enqueueing this job and it actually running
if (controller.signal.aborted) {
return;
}

const timer = new Timer();
const result = await request();
const result = await request(controller.signal);
const duration = timer.ms();

if (!provingState?.verifyState()) {
logger.debug(`State no longer valid, discarding result`);
return;
}

// we could have been cancelled whilst waiting for the result
// and the prover ignored the signal. Drop the result in that case
if (controller.signal.aborted) {
return;
}

await callback(result, duration);
} catch (err) {
if (err instanceof AbortedError) {
// operation was cancelled, probably because the block was cancelled
// drop this result
return;
}

logger.error(`Error thrown when proving job`);
provingState!.reject(`${err}`);
} finally {
const index = this.pendingProvingJobs.indexOf(controller);
if (index > -1) {
this.pendingProvingJobs.splice(index, 1);
}
}
};

Expand Down Expand Up @@ -439,7 +472,7 @@ export class ProvingOrchestrator {

this.deferredProving(
provingState,
() => this.prover.getBaseRollupProof(tx.baseRollupInputs),
signal => this.prover.getBaseRollupProof(tx.baseRollupInputs, signal),
([publicInputs, proof], duration) => {
this.emitCircuitSimulationStats(
'base-rollup',
Expand Down Expand Up @@ -470,7 +503,7 @@ export class ProvingOrchestrator {

this.deferredProving(
provingState,
() => this.prover.getMergeRollupProof(inputs),
signal => this.prover.getMergeRollupProof(inputs, signal),
([publicInputs, proof], duration) => {
this.emitCircuitSimulationStats(
'merge-rollup',
Expand Down Expand Up @@ -506,7 +539,7 @@ export class ProvingOrchestrator {

this.deferredProving(
provingState,
() => this.prover.getRootRollupProof(inputs),
signal => this.prover.getRootRollupProof(inputs, signal),
([publicInputs, proof], duration) => {
this.emitCircuitSimulationStats(
'root-rollup',
Expand All @@ -531,7 +564,7 @@ export class ProvingOrchestrator {
private enqueueBaseParityCircuit(provingState: ProvingState, inputs: BaseParityInputs, index: number) {
this.deferredProving(
provingState,
() => this.prover.getBaseParityProof(inputs),
signal => this.prover.getBaseParityProof(inputs, signal),
([publicInputs, proof], duration) => {
this.emitCircuitSimulationStats(
'base-parity',
Expand All @@ -554,7 +587,7 @@ export class ProvingOrchestrator {
private enqueueRootParityCircuit(provingState: ProvingState | undefined, inputs: RootParityInputs) {
this.deferredProving(
provingState,
() => this.prover.getRootParityProof(inputs),
signal => this.prover.getRootParityProof(inputs, signal),
async ([publicInputs, proof], duration) => {
this.emitCircuitSimulationStats(
'root-parity',
Expand Down Expand Up @@ -669,11 +702,11 @@ export class ProvingOrchestrator {

this.deferredProving(
provingState,
(): Promise<[KernelCircuitPublicInputs | PublicKernelCircuitPublicInputs, Proof]> => {
(signal): Promise<[KernelCircuitPublicInputs | PublicKernelCircuitPublicInputs, Proof]> => {
if (request.type === PublicKernelType.TAIL) {
return this.prover.getPublicTailProof(request);
return this.prover.getPublicTailProof(request, signal);
} else {
return this.prover.getPublicKernelProof(request);
return this.prover.getPublicKernelProof(request, signal);
}
},
([_, proof], duration) => {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import { PROVING_STATUS, type ProvingFailure } from '@aztec/circuit-types';
import { type GlobalVariables, NUMBER_OF_L1_L2_MESSAGES_PER_ROLLUP } from '@aztec/circuits.js';
import { fr } from '@aztec/circuits.js/testing';
import {
type GlobalVariables,
NUMBER_OF_L1_L2_MESSAGES_PER_ROLLUP,
NUM_BASE_PARITY_PER_ROOT_PARITY,
} from '@aztec/circuits.js';
import { fr, makeGlobalVariables } from '@aztec/circuits.js/testing';
import { range } from '@aztec/foundation/array';
import { createDebugLogger } from '@aztec/foundation/log';
import { type PromiseWithResolvers, promiseWithResolvers } from '@aztec/foundation/promise';
import { sleep } from '@aztec/foundation/sleep';

import { jest } from '@jest/globals';

import { makeBloatedProcessedTx, makeEmptyProcessedTestTx, makeGlobals } from '../mocks/fixtures.js';
import { TestContext } from '../mocks/test_context.js';
import { type CircuitProver } from '../prover/interface.js';
import { TestCircuitProver } from '../prover/test_circuit_prover.js';
import { ProvingOrchestrator } from './orchestrator.js';

const logger = createDebugLogger('aztec:orchestrator-lifecycle');

Expand Down Expand Up @@ -124,5 +135,27 @@ describe('prover/orchestrator/lifecycle', () => {

expect(finalisedBlock.block.number).toEqual(101);
}, 60000);

it('cancels proving requests', async () => {
const prover: CircuitProver = new TestCircuitProver();
const orchestrator = new ProvingOrchestrator(context.actualDb, prover);

const spy = jest.spyOn(prover, 'getBaseParityProof');
const deferredPromises: PromiseWithResolvers<any>[] = [];
spy.mockImplementation(() => {
const deferred = promiseWithResolvers<any>();
deferredPromises.push(deferred);
return deferred.promise;
});
await orchestrator.startNewBlock(2, makeGlobalVariables(1), [], await makeEmptyProcessedTestTx(context.actualDb));

await sleep(1);

expect(spy).toHaveBeenCalledTimes(NUM_BASE_PARITY_PER_ROOT_PARITY);
expect(spy.mock.calls.every(([_, signal]) => !signal?.aborted)).toBeTruthy();

orchestrator.cancelBlock();
expect(spy.mock.calls.every(([_, signal]) => signal?.aborted)).toBeTruthy();
});
});
});
Loading

0 comments on commit e5fe19d

Please sign in to comment.