diff --git a/src/execution/AbortSignalListener.ts b/src/execution/AbortSignalListener.ts index e93a13c1d1..e175046b9e 100644 --- a/src/execution/AbortSignalListener.ts +++ b/src/execution/AbortSignalListener.ts @@ -28,7 +28,10 @@ export class AbortSignalListener { this.abortSignal.removeEventListener('abort', this.abort); } - cancellablePromise(originalPromise: Promise): Promise { + cancellablePromise( + originalPromise: Promise, + onCancel?: (() => Promise) | undefined, + ): Promise { if (this.abortSignal.aborted) { // eslint-disable-next-line @typescript-eslint/prefer-promise-reject-errors return Promise.reject(this.abortSignal.reason); @@ -40,14 +43,42 @@ export class AbortSignalListener { originalPromise.then( (resolved) => { this._aborts.delete(abort); + onCancel?.().catch(() => { + // ignore + }); resolve(resolved); }, (error: unknown) => { this._aborts.delete(abort); + onCancel?.().catch(() => { + // ignore + }); reject(error); }, ); return promise; } + + cancellableIterable(iterable: AsyncIterable): AsyncIterable { + const iterator = iterable[Symbol.asyncIterator](); + + const earlyReturn = + typeof iterator.return === 'function' + ? iterator.return.bind(iterator) + : undefined; + + const cancellableAsyncIterator: AsyncIterator = { + next: () => this.cancellablePromise(iterator.next(), earlyReturn), + }; + + if (earlyReturn) { + cancellableAsyncIterator.return = () => + this.cancellablePromise(earlyReturn()); + } + + return { + [Symbol.asyncIterator]: () => cancellableAsyncIterator, + }; + } } diff --git a/src/execution/execute.ts b/src/execution/execute.ts index b33cd35502..d391110fb1 100644 --- a/src/execution/execute.ts +++ b/src/execution/execute.ts @@ -1387,7 +1387,9 @@ function completeListValue( const itemType = returnType.ofType; if (isAsyncIterable(result)) { - const asyncIterator = result[Symbol.asyncIterator](); + const maybeCancellableIterable = + exeContext.abortSignalListener?.cancellableIterable(result) ?? result; + const asyncIterator = maybeCancellableIterable[Symbol.asyncIterator](); return completeAsyncIteratorValue( exeContext, @@ -2229,7 +2231,7 @@ function executeSubscription( // TODO: add test case /* c8 ignore next */ abortSignalListener?.disconnect(); - return resolved; + return abortSignalListener?.cancellableIterable(resolved) ?? resolved; }, (error: unknown) => { abortSignalListener?.disconnect();