diff --git a/src/cmap/connection.ts b/src/cmap/connection.ts index 3406fed594..7b277794ed 100644 --- a/src/cmap/connection.ts +++ b/src/cmap/connection.ts @@ -29,7 +29,6 @@ import { type CancellationToken, TypedEventEmitter } from '../mongo_types'; import type { ReadPreferenceLike } from '../read_preference'; import { applySession, type ClientSession, updateSessionFromResponse } from '../sessions'; import { - abortable, BufferPool, calculateDurationInMs, type Callback, @@ -37,6 +36,7 @@ import { maxWireVersion, type MongoDBNamespace, now, + promiseWithResolvers, uuidV4 } from '../utils'; import type { WriteConcern } from '../write_concern'; @@ -161,15 +161,14 @@ function streamIdentifier(stream: Stream, options: ConnectionOptions): string { export class Connection extends TypedEventEmitter { public id: number | ''; public address: string; - public lastHelloMS?: number; + public lastHelloMS = -1; public serverApi?: ServerApi; - public helloOk?: boolean; + public helloOk = false; public authContext?: AuthContext; public delayedTimeoutId: NodeJS.Timeout | null = null; public generation: number; public readonly description: Readonly; /** - * @public * Represents if the connection has been established: * - TCP handshake * - TLS negotiated @@ -180,15 +179,16 @@ export class Connection extends TypedEventEmitter { public established: boolean; private lastUseTime: number; - private socketTimeoutMS: number; - private monitorCommands: boolean; - private socket: Stream; - private controller: AbortController; - private messageStream: Readable; - private socketWrite: (buffer: Uint8Array) => Promise; private clusterTime: Document | null = null; - /** @internal */ - override mongoLogger: MongoLogger | undefined; + + private readonly socketTimeoutMS: number; + private readonly monitorCommands: boolean; + private readonly socket: Stream; + private readonly controller: AbortController; + private readonly signal: AbortSignal; + private readonly messageStream: Readable; + private readonly socketWrite: (buffer: Uint8Array) => Promise; + private readonly aborted: Promise; /** @event */ static readonly COMMAND_STARTED = COMMAND_STARTED; @@ -221,7 +221,21 @@ export class Connection extends TypedEventEmitter { this.lastUseTime = now(); this.socket = stream; + + // TODO: Remove signal from connection layer this.controller = new AbortController(); + const { signal } = this.controller; + this.signal = signal; + const { promise: aborted, reject } = promiseWithResolvers(); + aborted.then(undefined, () => null); // Prevent unhandled rejection + this.signal.addEventListener( + 'abort', + function onAbort() { + reject(signal.reason); + }, + { once: true } + ); + this.aborted = aborted; this.messageStream = this.socket .on('error', this.onError.bind(this)) @@ -232,13 +246,13 @@ export class Connection extends TypedEventEmitter { const socketWrite = promisify(this.socket.write.bind(this.socket)); this.socketWrite = async buffer => { - return abortable(socketWrite(buffer), { signal: this.controller.signal }); + return Promise.race([socketWrite(buffer), this.aborted]); }; } /** Indicates that the connection (including underlying TCP socket) has been closed. */ public get closed(): boolean { - return this.controller.signal.aborted; + return this.signal.aborted; } public get hello() { @@ -407,7 +421,7 @@ export class Connection extends TypedEventEmitter { } private async *sendWire(message: WriteProtocolMessageType, options: CommandOptions) { - this.controller.signal.throwIfAborted(); + this.throwIfAborted(); if (typeof options.socketTimeoutMS === 'number') { this.socket.setTimeout(options.socketTimeoutMS); @@ -426,7 +440,7 @@ export class Connection extends TypedEventEmitter { return; } - this.controller.signal.throwIfAborted(); + this.throwIfAborted(); for await (const response of this.readMany()) { this.socket.setTimeout(0); @@ -447,7 +461,7 @@ export class Connection extends TypedEventEmitter { } yield document; - this.controller.signal.throwIfAborted(); + this.throwIfAborted(); if (typeof options.socketTimeoutMS === 'number') { this.socket.setTimeout(options.socketTimeoutMS); @@ -481,7 +495,7 @@ export class Connection extends TypedEventEmitter { let document; try { - this.controller.signal.throwIfAborted(); + this.throwIfAborted(); for await (document of this.sendWire(message, options)) { if (!Buffer.isBuffer(document) && document.writeConcernError) { throw new MongoWriteConcernError(document.writeConcernError, document); @@ -511,7 +525,7 @@ export class Connection extends TypedEventEmitter { } yield document; - this.controller.signal.throwIfAborted(); + this.throwIfAborted(); } } catch (error) { if (this.shouldEmitAndLogCommand) { @@ -554,7 +568,7 @@ export class Connection extends TypedEventEmitter { command: Document, options: CommandOptions = {} ): Promise { - this.controller.signal.throwIfAborted(); + this.throwIfAborted(); for await (const document of this.sendCommand(ns, command, options)) { return document; } @@ -568,16 +582,20 @@ export class Connection extends TypedEventEmitter { replyListener: Callback ) { const exhaustLoop = async () => { - this.controller.signal.throwIfAborted(); + this.throwIfAborted(); for await (const reply of this.sendCommand(ns, command, options)) { replyListener(undefined, reply); - this.controller.signal.throwIfAborted(); + this.throwIfAborted(); } throw new MongoUnexpectedServerResponseError('Server ended moreToCome unexpectedly'); }; exhaustLoop().catch(replyListener); } + private throwIfAborted() { + this.signal.throwIfAborted(); + } + /** * @internal * @@ -611,7 +629,7 @@ export class Connection extends TypedEventEmitter { * Note that `for-await` loops call `return` automatically when the loop is exited. */ private async *readMany(): AsyncGenerator { - for await (const message of onData(this.messageStream, { signal: this.controller.signal })) { + for await (const message of onData(this.messageStream, { signal: this.signal })) { const response = await decompressResponse(message); yield response; diff --git a/src/utils.ts b/src/utils.ts index 719367cad2..173de9053a 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1283,36 +1283,6 @@ export function isHostMatch(match: RegExp, host?: string): boolean { return host && match.test(host.toLowerCase()) ? true : false; } -/** - * Takes a promise and races it with a promise wrapping the abort event of the optionally provided signal. - * The given promise is _always_ ordered before the signal's abort promise. - * When given an already rejected promise and an already aborted signal, the promise's rejection takes precedence. - * - * @see https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Promise/race - * - * @param promise - A promise to discard if the signal aborts - * @param options - An options object carrying an optional signal - */ -export async function abortable( - promise: Promise, - { signal }: { signal: AbortSignal } -): Promise { - const { promise: aborted, reject } = promiseWithResolvers(); - - function rejectOnAbort() { - reject(signal.reason); - } - - if (signal.aborted) rejectOnAbort(); - else signal.addEventListener('abort', rejectOnAbort, { once: true }); - - try { - return await Promise.race([promise, aborted]); - } finally { - signal.removeEventListener('abort', rejectOnAbort); - } -} - export function promiseWithResolvers() { let resolve!: Parameters>[0]>[0]; let reject!: Parameters>[0]>[1]; diff --git a/test/unit/utils.test.ts b/test/unit/utils.test.ts index cf988382a2..b5fcadbffc 100644 --- a/test/unit/utils.test.ts +++ b/test/unit/utils.test.ts @@ -1,9 +1,7 @@ import { expect } from 'chai'; import * as sinon from 'sinon'; -import { setTimeout } from 'timers'; import { - abortable, BufferPool, ByteUtils, compareObjectId, @@ -21,7 +19,6 @@ import { shuffle, TimeoutController } from '../mongodb'; -import { sleep } from '../tools/utils'; import { createTimerSandbox } from './timer_sandbox'; describe('driver utils', function () { @@ -1077,133 +1074,4 @@ describe('driver utils', function () { }); }); }); - - describe('abortable()', () => { - const goodError = new Error('good error'); - const badError = new Error('unexpected bad error!'); - const expectedValue = "don't panic"; - - context('always removes the abort listener it attaches', () => { - let controller; - let removeEventListenerSpy; - let addEventListenerSpy; - - beforeEach(() => { - controller = new AbortController(); - addEventListenerSpy = sinon.spy(controller.signal, 'addEventListener'); - removeEventListenerSpy = sinon.spy(controller.signal, 'removeEventListener'); - }); - - afterEach(() => sinon.restore()); - - const expectListenerCleanup = () => { - expect(addEventListenerSpy).to.have.been.calledOnce; - expect(removeEventListenerSpy).to.have.been.calledOnce; - }; - - it('when promise rejects', async () => { - await abortable(Promise.reject(goodError), { signal: controller.signal }).catch(e => e); - expectListenerCleanup(); - }); - - it('when promise resolves', async () => { - await abortable(Promise.resolve(expectedValue), { signal: controller.signal }); - expectListenerCleanup(); - }); - - it('when signal aborts', async () => { - setTimeout(() => controller.abort(goodError)); - await abortable(new Promise(() => null), { signal: controller.signal }).catch(e => e); - expectListenerCleanup(); - }); - }); - - context('when given already rejected promise with already aborted signal', () => { - it('returns promise rejection', async () => { - const controller = new AbortController(); - const { signal } = controller; - controller.abort(badError); - const result = await abortable(Promise.reject(goodError), { signal }).catch(e => e); - expect(result).to.deep.equal(goodError); - }); - }); - - context('when given already resolved promise with already aborted signal', () => { - it('returns promise resolution', async () => { - const controller = new AbortController(); - const { signal } = controller; - controller.abort(badError); - const result = await abortable(Promise.resolve(expectedValue), { signal }).catch(e => e); - expect(result).to.deep.equal(expectedValue); - }); - }); - - context('when given already rejected promise with not yet aborted signal', () => { - it('returns promise rejection', async () => { - const controller = new AbortController(); - const { signal } = controller; - const result = await abortable(Promise.reject(goodError), { signal }).catch(e => e); - expect(result).to.deep.equal(goodError); - }); - }); - - context('when given already resolved promise with not yet aborted signal', () => { - it('returns promise resolution', async () => { - const controller = new AbortController(); - const { signal } = controller; - const result = await abortable(Promise.resolve(expectedValue), { signal }).catch(e => e); - expect(result).to.deep.equal(expectedValue); - }); - }); - - context('when given unresolved promise with an already aborted signal', () => { - it('returns signal reason', async () => { - const controller = new AbortController(); - const { signal } = controller; - controller.abort(goodError); - const result = await abortable(new Promise(() => null), { signal }).catch(e => e); - expect(result).to.deep.equal(goodError); - }); - }); - - context('when given eventually rejecting promise with not yet aborted signal', () => { - const eventuallyReject = async () => { - await sleep(1); - throw goodError; - }; - - it('returns promise rejection', async () => { - const controller = new AbortController(); - const { signal } = controller; - const result = await abortable(eventuallyReject(), { signal }).catch(e => e); - expect(result).to.deep.equal(goodError); - }); - }); - - context('when given eventually resolving promise with not yet aborted signal', () => { - const eventuallyResolve = async () => { - await sleep(1); - return expectedValue; - }; - - it('returns promise resolution', async () => { - const controller = new AbortController(); - const { signal } = controller; - const result = await abortable(eventuallyResolve(), { signal }).catch(e => e); - expect(result).to.deep.equal(expectedValue); - }); - }); - - context('when given unresolved promise with eventually aborted signal', () => { - it('returns signal reason', async () => { - const controller = new AbortController(); - const { signal } = controller; - - setTimeout(() => controller.abort(goodError), 1); - - const result = await abortable(new Promise(() => null), { signal }).catch(e => e); - expect(result).to.deep.equal(goodError); - }); - }); - }); });