diff --git a/packages/transport-tcp/package.json b/packages/transport-tcp/package.json index 6efe16bb30..aed2b8cca9 100644 --- a/packages/transport-tcp/package.json +++ b/packages/transport-tcp/package.json @@ -66,6 +66,7 @@ "@multiformats/multiaddr": "^12.2.3", "@types/sinon": "^17.0.3", "p-defer": "^4.0.1", + "p-event": "^6.0.1", "progress-events": "^1.0.0", "race-event": "^1.3.0", "stream-to-it": "^1.0.1" diff --git a/packages/transport-tcp/src/listener.ts b/packages/transport-tcp/src/listener.ts index b83b35d843..f9b2c9da92 100644 --- a/packages/transport-tcp/src/listener.ts +++ b/packages/transport-tcp/src/listener.ts @@ -1,5 +1,6 @@ import net from 'net' -import { AbortError, AlreadyStartedError, InvalidParametersError, NotStartedError, TypedEventEmitter } from '@libp2p/interface' +import { AlreadyStartedError, InvalidParametersError, NotStartedError, TypedEventEmitter, setMaxListeners } from '@libp2p/interface' +import { pEvent } from 'p-event' import { CODE_P2P } from './constants.js' import { toMultiaddrConnection } from './socket-to-conn.js' import { @@ -67,12 +68,13 @@ type Status = { code: TCPListenerStatusCode.INACTIVE } | { export class TCPListener extends TypedEventEmitter implements Listener { private readonly server: net.Server - /** Keep track of open connections to destroy in case of timeout */ - private readonly connections = new Set() + /** Keep track of open sockets to destroy in case of timeout */ + private readonly sockets = new Set() private status: Status = { code: TCPListenerStatusCode.INACTIVE } private metrics?: TCPListenerMetrics private addr: string private readonly log: Logger + private readonly shutdownController: AbortController constructor (private readonly context: Context) { super() @@ -80,6 +82,9 @@ export class TCPListener extends TypedEventEmitter implements Li context.keepAlive = context.keepAlive ?? true context.noDelay = context.noDelay ?? true + this.shutdownController = new AbortController() + setMaxListeners(Infinity, this.shutdownController.signal) + this.log = context.logger.forComponent('libp2p:tcp:listener') this.addr = 'unknown' this.server = net.createServer(context, this.onSocket.bind(this)) @@ -119,7 +124,7 @@ export class TCPListener extends TypedEventEmitter implements Li help: 'Current active connections in TCP listener', calculate: () => { return { - [this.addr]: this.connections.size + [this.addr]: this.sockets.size } } }) @@ -195,18 +200,20 @@ export class TCPListener extends TypedEventEmitter implements Li } this.log('new inbound connection %s', maConn.remoteAddr) + this.sockets.add(socket) - this.context.upgrader.upgradeInbound(maConn) + this.context.upgrader.upgradeInbound(maConn, { + signal: this.shutdownController.signal + }) .then((conn) => { this.log('inbound connection upgraded %s', maConn.remoteAddr) - this.connections.add(maConn) socket.once('close', () => { - this.connections.delete(maConn) + this.sockets.delete(socket) if ( this.context.closeServerOnMaxConnections != null && - this.connections.size < this.context.closeServerOnMaxConnections.listenBelow + this.sockets.size < this.context.closeServerOnMaxConnections.listenBelow ) { // The most likely case of error is if the port taken by this // application is bound by another process during the time the @@ -227,11 +234,9 @@ export class TCPListener extends TypedEventEmitter implements Li if ( this.context.closeServerOnMaxConnections != null && - this.connections.size >= this.context.closeServerOnMaxConnections.closeAbove + this.sockets.size >= this.context.closeServerOnMaxConnections.closeAbove ) { - this.pause(false).catch(e => { - this.log.error('error attempting to close server once connection count over limit', e) - }) + this.pause() } this.safeDispatchEvent('connection', { detail: conn }) @@ -239,6 +244,7 @@ export class TCPListener extends TypedEventEmitter implements Li .catch(async err => { this.log.error('inbound connection upgrade failed', err) this.metrics?.errors.increment({ [`${this.addr} inbound_upgrade`]: true }) + this.sockets.delete(socket) maConn.abort(err) }) } @@ -300,15 +306,28 @@ export class TCPListener extends TypedEventEmitter implements Li } async close (): Promise { - const err = new AbortError('Listener is closing') + const events: Array> = [] - // synchronously close each connection - this.connections.forEach(conn => { - conn.abort(err) - }) + if (this.server.listening) { + events.push(pEvent(this.server, 'close')) + } // shut down the server socket, permanently - await this.pause(true) + this.pause(true) + + // stop any in-progress connection upgrades + this.shutdownController.abort() + + // synchronously close any open connections - should be done after closing + // the server socket in case new sockets are opened during the shutdown + this.sockets.forEach(socket => { + if (socket.readable) { + events.push(pEvent(socket, 'close')) + socket.destroy() + } + }) + + await Promise.all(events) } /** @@ -332,7 +351,7 @@ export class TCPListener extends TypedEventEmitter implements Li this.log('listening on %s', this.server.address()) } - private async pause (permanent: boolean): Promise { + private pause (permanent: boolean = false): void { if (!this.server.listening && this.status.code === TCPListenerStatusCode.PAUSED && permanent) { this.status = { code: TCPListenerStatusCode.INACTIVE } return @@ -361,15 +380,10 @@ export class TCPListener extends TypedEventEmitter implements Li // during the time the server is closing this.status = permanent ? { code: TCPListenerStatusCode.INACTIVE } : { ...this.status, code: TCPListenerStatusCode.PAUSED } - await new Promise((resolve, reject) => { - this.server.close(err => { - if (err != null) { - reject(err) - return - } - - resolve() - }) - }) + // stop accepting incoming connections - existing connections are maintained + // - any callback passed here would be invoked after existing connections + // close, we want to maintain them so no callback is passed otherwise his + // method will never return + this.server.close() } } diff --git a/packages/transport-tcp/test/connection-limits.spec.ts b/packages/transport-tcp/test/connection-limits.spec.ts index a889c7cbe4..b43163dade 100644 --- a/packages/transport-tcp/test/connection-limits.spec.ts +++ b/packages/transport-tcp/test/connection-limits.spec.ts @@ -1,6 +1,5 @@ import net from 'node:net' import { promisify } from 'util' -import { TypedEventEmitter } from '@libp2p/interface' import { mockUpgrader } from '@libp2p/interface-compliance-tests/mocks' import { defaultLogger } from '@libp2p/logger' import { multiaddr } from '@multiformats/multiaddr' @@ -64,21 +63,25 @@ async function assertServerConnections (listener: TCPListener, connections: numb // Expect server connections but allow time for sockets to connect or disconnect for (let i = 0; i < 100; i++) { // eslint-disable-next-line @typescript-eslint/dot-notation - if (listener['connections'].size === connections) { + if (listener['sockets'].size === connections) { return } else { await promisify(setTimeout)(10) } } // eslint-disable-next-line @typescript-eslint/dot-notation - expect(listener['connections'].size).equals(connections, 'invalid amount of server connections') + expect(listener['sockets'].size).equals(connections, 'invalid amount of server connections') } describe('closeAbove/listenBelow', () => { - const afterEachCallbacks: Array<() => Promise | any> = [] + let afterEachCallbacks: Array<() => Promise | any> = [] + + beforeEach(() => { + afterEachCallbacks = [] + }) + afterEach(async () => { await Promise.all(afterEachCallbacks.map(fn => fn())) - afterEachCallbacks.length = 0 }) it('reject dial of connection above closeAbove', async () => { @@ -86,16 +89,14 @@ describe('closeAbove/listenBelow', () => { const closeAbove = 3 const port = 9900 - const trasnport = tcp({ closeServerOnMaxConnections: { listenBelow, closeAbove } })({ + const transport = tcp({ closeServerOnMaxConnections: { listenBelow, closeAbove } })({ logger: defaultLogger() }) - const upgrader = mockUpgrader({ - events: new TypedEventEmitter() - }) - const listener = trasnport.createListener({ upgrader }) as TCPListener - // eslint-disable-next-line @typescript-eslint/promise-function-async - afterEachCallbacks.push(() => listener.close()) + const upgrader = mockUpgrader() + const listener = transport.createListener({ upgrader }) as TCPListener + afterEachCallbacks.push(async () => listener.close()) + await listener.listen(multiaddr(`/ip4/127.0.0.1/tcp/${port}`)) const { assertConnectedSocket, assertRefusedSocket } = buildSocketAssertions(port, afterEachCallbacks) @@ -115,16 +116,14 @@ describe('closeAbove/listenBelow', () => { const closeAbove = 3 const port = 9900 - const trasnport = tcp({ closeServerOnMaxConnections: { listenBelow, closeAbove } })({ + const transport = tcp({ closeServerOnMaxConnections: { listenBelow, closeAbove } })({ logger: defaultLogger() }) - const upgrader = mockUpgrader({ - events: new TypedEventEmitter() - }) - const listener = trasnport.createListener({ upgrader }) as TCPListener - // eslint-disable-next-line @typescript-eslint/promise-function-async - afterEachCallbacks.push(() => listener.close()) + const upgrader = mockUpgrader() + const listener = transport.createListener({ upgrader }) as TCPListener + afterEachCallbacks.push(async () => listener.close()) + await listener.listen(multiaddr(`/ip4/127.0.0.1/tcp/${port}`)) const { assertConnectedSocket } = buildSocketAssertions(port, afterEachCallbacks) @@ -152,16 +151,13 @@ describe('closeAbove/listenBelow', () => { const closeAbove = 3 const port = 9900 - const trasnport = tcp({ closeServerOnMaxConnections: { listenBelow, closeAbove } })({ + const transport = tcp({ closeServerOnMaxConnections: { listenBelow, closeAbove } })({ logger: defaultLogger() }) - const upgrader = mockUpgrader({ - events: new TypedEventEmitter() - }) - const listener = trasnport.createListener({ upgrader }) as TCPListener - // eslint-disable-next-line @typescript-eslint/promise-function-async - afterEachCallbacks.push(() => listener.close()) + const upgrader = mockUpgrader() + const listener = transport.createListener({ upgrader }) as TCPListener + afterEachCallbacks.push(async () => listener.close()) let closeEventCallCount = 0 listener.addEventListener('close', () => { @@ -185,16 +181,13 @@ describe('closeAbove/listenBelow', () => { const closeAbove = 3 const port = 9900 - const trasnport = tcp({ closeServerOnMaxConnections: { listenBelow, closeAbove } })({ + const transport = tcp({ closeServerOnMaxConnections: { listenBelow, closeAbove } })({ logger: defaultLogger() }) - const upgrader = mockUpgrader({ - events: new TypedEventEmitter() - }) - const listener = trasnport.createListener({ upgrader }) as TCPListener - // eslint-disable-next-line @typescript-eslint/promise-function-async - afterEachCallbacks.push(() => listener.close()) + const upgrader = mockUpgrader() + const listener = transport.createListener({ upgrader }) as TCPListener + afterEachCallbacks.push(async () => listener.close()) let listeningEventCallCount = 0 listener.addEventListener('listening', () => { diff --git a/packages/transport-tcp/test/listen-dial.spec.ts b/packages/transport-tcp/test/listen-dial.spec.ts index fe90b4eecd..81e47a58dc 100644 --- a/packages/transport-tcp/test/listen-dial.spec.ts +++ b/packages/transport-tcp/test/listen-dial.spec.ts @@ -394,4 +394,86 @@ describe('dial', () => { await listener.close() }) + + it('should close before connection upgrade is completed', async () => { + // create a Promise that resolves when the upgrade starts + const upgradeStarted = pDefer() + + // create a listener with the handler + const listener = transport.createListener({ + upgrader: { + async upgradeInbound () { + upgradeStarted.resolve() + + return new Promise(() => {}) + }, + async upgradeOutbound () { + return new Promise(() => {}) + } + } + }) + + // listen on a multiaddr + await listener.listen(multiaddr('/ip4/127.0.0.1/tcp/0')) + + const localAddrs = listener.getAddrs() + expect(localAddrs.length).to.equal(1) + + // dial the listener address + transport.dial(localAddrs[0], { + upgrader + }).catch(() => {}) + + // wait for the upgrade to start + await upgradeStarted.promise + + // close the listener, process should exit normally + await listener.close() + }) + + it('should abort inbound upgrade on close', async () => { + // create a Promise that resolves when the upgrade starts + const upgradeStarted = pDefer() + const abortedUpgrade = pDefer() + + // create a listener with the handler + const listener = transport.createListener({ + upgrader: { + async upgradeInbound (maConn, opts) { + upgradeStarted.resolve() + + opts?.signal?.addEventListener('abort', () => { + abortedUpgrade.resolve() + }, { + once: true + }) + + return new Promise(() => {}) + }, + async upgradeOutbound () { + return new Promise(() => {}) + } + } + }) + + // listen on a multiaddr + await listener.listen(multiaddr('/ip4/127.0.0.1/tcp/0')) + + const localAddrs = listener.getAddrs() + expect(localAddrs.length).to.equal(1) + + // dial the listener address + transport.dial(localAddrs[0], { + upgrader + }).catch(() => {}) + + // wait for the upgrade to start + await upgradeStarted.promise + + // close the listener + await listener.close() + + // should abort the upgrade + await abortedUpgrade.promise + }) })