Skip to content

Commit

Permalink
feat: abort ongoing proving jobs
Browse files Browse the repository at this point in the history
  • Loading branch information
alexghr committed May 1, 2024
1 parent 1c305be commit 885cda7
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 65 deletions.
2 changes: 1 addition & 1 deletion yarn-project/circuit-types/src/interfaces/proving-job.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ export type ProvingRequestPublicInputs = {
export type ProvingRequestResult<T extends ProvingRequestType> = ProvingRequestPublicInputs[T];

export interface ProvingJobSource {
getProvingJob(): Promise<ProvingJob<ProvingRequest> | null>;
getProvingJob(): Promise<ProvingJob<ProvingRequest> | undefined>;

resolveProvingJob<T extends ProvingRequestType>(jobId: string, result: ProvingRequestResult<T>): Promise<void>;

Expand Down
5 changes: 5 additions & 0 deletions yarn-project/foundation/src/error/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@ export class InterruptError extends Error {}
* An error thrown when an action times out.
*/
export class TimeoutError extends Error {}

/**
* Represents an error thrown when an operation is aborted.
*/
export class AbortedError extends Error {}
4 changes: 2 additions & 2 deletions yarn-project/prover-client/src/dummy-prover.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ export class DummyProver implements ProverClient {
}

class DummyProvingJobSource implements ProvingJobSource {
getProvingJob(): Promise<ProvingJob<ProvingRequest> | null> {
return Promise.resolve(null);
getProvingJob(): Promise<ProvingJob<ProvingRequest> | undefined> {
return Promise.resolve(undefined);
}

rejectProvingJob(): Promise<void> {
Expand Down
53 changes: 43 additions & 10 deletions yarn-project/prover-client/src/orchestrator/orchestrator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import {
} from '@aztec/circuits.js';
import { makeTuple } from '@aztec/foundation/array';
import { padArrayEnd } from '@aztec/foundation/collection';
import { AbortedError } from '@aztec/foundation/error';
import { createDebugLogger } from '@aztec/foundation/log';
import { promiseWithResolvers } from '@aztec/foundation/promise';
import { type Tuple } from '@aztec/foundation/serialize';
Expand Down Expand Up @@ -82,6 +83,8 @@ const KernelTypesWithoutFunctions: Set<PublicKernelType> = new Set<PublicKernelT
*/
export class ProvingOrchestrator {
private provingState: ProvingState | undefined = undefined;
private pendingProvingJobs: AbortController[] = [];

constructor(private db: MerkleTreeOperations, private prover: CircuitProver) {}

/**
Expand Down Expand Up @@ -211,6 +214,10 @@ export class ProvingOrchestrator {
* Cancel any further proving of the block
*/
public cancelBlock() {
for (const controller of this.pendingProvingJobs) {
controller.abort();
}

this.provingState?.cancel();
}

Expand Down Expand Up @@ -303,30 +310,56 @@ export class ProvingOrchestrator {
*/
private deferredProving<T>(
provingState: ProvingState | undefined,
request: () => Promise<T>,
request: (signal: AbortSignal) => Promise<T>,
callback: (result: T, durationMs: number) => void | Promise<void>,
) {
if (!provingState?.verifyState()) {
logger.debug(`Not enqueuing job, state no longer valid`);
return;
}

const controller = new AbortController();
this.pendingProvingJobs.push(controller);

// We use a 'safeJob'. We don't want promise rejections in the proving pool, we want to capture the error here
// and reject the proving job whilst keeping the event loop free of rejections
const safeJob = async () => {
try {
// there's a delay between enqueueing this job and it actually running
if (controller.signal.aborted) {
return;
}

const timer = new Timer();
const result = await request();
const result = await request(controller.signal);
const duration = timer.ms();

if (!provingState?.verifyState()) {
logger.debug(`State no longer valid, discarding result`);
return;
}

// we could have been cancelled whilst waiting for the result
// and the prover ignored the signal. Drop the result in that case
if (controller.signal.aborted) {
return;
}

await callback(result, duration);
} catch (err) {
if (err instanceof AbortedError) {
// operation was cancelled, probably because the block was cancelled
// drop this result
return;
}

logger.error(`Error thrown when proving job`);
provingState!.reject(`${err}`);
} finally {
const index = this.pendingProvingJobs.indexOf(controller);
if (index > -1) {
this.pendingProvingJobs.splice(index, 1);
}
}
};

Expand Down Expand Up @@ -441,7 +474,7 @@ export class ProvingOrchestrator {

this.deferredProving(
provingState,
() => this.prover.getBaseRollupProof(tx.baseRollupInputs),
signal => this.prover.getBaseRollupProof(tx.baseRollupInputs, signal),
(result, duration) => {
this.emitCircuitSimulationStats(
'base-rollup',
Expand Down Expand Up @@ -472,7 +505,7 @@ export class ProvingOrchestrator {

this.deferredProving(
provingState,
() => this.prover.getMergeRollupProof(inputs),
signal => this.prover.getMergeRollupProof(inputs, signal),
(result, duration) => {
this.emitCircuitSimulationStats(
'merge-rollup',
Expand Down Expand Up @@ -508,7 +541,7 @@ export class ProvingOrchestrator {

this.deferredProving(
provingState,
() => this.prover.getRootRollupProof(inputs),
signal => this.prover.getRootRollupProof(inputs, signal),
(result, duration) => {
this.emitCircuitSimulationStats(
'root-rollup',
Expand All @@ -533,7 +566,7 @@ export class ProvingOrchestrator {
private enqueueBaseParityCircuit(provingState: ProvingState, inputs: BaseParityInputs, index: number) {
this.deferredProving(
provingState,
() => this.prover.getBaseParityProof(inputs),
signal => this.prover.getBaseParityProof(inputs, signal),
(rootInput, duration) => {
this.emitCircuitSimulationStats(
'base-parity',
Expand All @@ -560,7 +593,7 @@ export class ProvingOrchestrator {
private enqueueRootParityCircuit(provingState: ProvingState | undefined, inputs: RootParityInputs) {
this.deferredProving(
provingState,
() => this.prover.getRootParityProof(inputs),
signal => this.prover.getRootParityProof(inputs, signal),
async (rootInput, duration) => {
this.emitCircuitSimulationStats(
'root-parity',
Expand Down Expand Up @@ -674,11 +707,11 @@ export class ProvingOrchestrator {

this.deferredProving(
provingState,
(): Promise<PublicInputsAndProof<KernelCircuitPublicInputs | PublicKernelCircuitPublicInputs>> => {
(signal): Promise<PublicInputsAndProof<KernelCircuitPublicInputs | PublicKernelCircuitPublicInputs>> => {
if (request.type === PublicKernelType.TAIL) {
return this.prover.getPublicTailProof(request);
return this.prover.getPublicTailProof(request, signal);
} else {
return this.prover.getPublicKernelProof(request);
return this.prover.getPublicKernelProof(request, signal);
}
},
(result, duration) => {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import { PROVING_STATUS, type ProvingFailure } from '@aztec/circuit-types';
import { type GlobalVariables, NUMBER_OF_L1_L2_MESSAGES_PER_ROLLUP } from '@aztec/circuits.js';
import { fr } from '@aztec/circuits.js/testing';
import {
type GlobalVariables,
NUMBER_OF_L1_L2_MESSAGES_PER_ROLLUP,
NUM_BASE_PARITY_PER_ROOT_PARITY,
} from '@aztec/circuits.js';
import { fr, makeGlobalVariables } from '@aztec/circuits.js/testing';
import { range } from '@aztec/foundation/array';
import { createDebugLogger } from '@aztec/foundation/log';
import { type PromiseWithResolvers, promiseWithResolvers } from '@aztec/foundation/promise';
import { sleep } from '@aztec/foundation/sleep';

import { jest } from '@jest/globals';

import { makeBloatedProcessedTx, makeEmptyProcessedTestTx, makeGlobals } from '../mocks/fixtures.js';
import { TestContext } from '../mocks/test_context.js';
import { type CircuitProver } from '../prover/interface.js';
import { TestCircuitProver } from '../prover/test_circuit_prover.js';
import { ProvingOrchestrator } from './orchestrator.js';

const logger = createDebugLogger('aztec:orchestrator-lifecycle');

Expand Down Expand Up @@ -124,5 +135,27 @@ describe('prover/orchestrator/lifecycle', () => {

expect(finalisedBlock.block.number).toEqual(101);
}, 60000);

it('cancels proving requests', async () => {
const prover: CircuitProver = new TestCircuitProver();
const orchestrator = new ProvingOrchestrator(context.actualDb, prover);

const spy = jest.spyOn(prover, 'getBaseParityProof');
const deferredPromises: PromiseWithResolvers<any>[] = [];
spy.mockImplementation(() => {
const deferred = promiseWithResolvers<any>();
deferredPromises.push(deferred);
return deferred.promise;
});
await orchestrator.startNewBlock(2, makeGlobalVariables(1), [], await makeEmptyProcessedTestTx(context.actualDb));

await sleep(1);

expect(spy).toHaveBeenCalledTimes(NUM_BASE_PARITY_PER_ROOT_PARITY);
expect(spy.mock.calls.every(([_, signal]) => !signal?.aborted)).toBeTruthy();

orchestrator.cancelBlock();
expect(spy.mock.calls.every(([_, signal]) => signal?.aborted)).toBeTruthy();
});
});
});
Loading

0 comments on commit 885cda7

Please sign in to comment.