diff --git a/package-lock.json b/package-lock.json index c130defb4..09d004c19 100644 --- a/package-lock.json +++ b/package-lock.json @@ -52,6 +52,7 @@ "tslib": "^2.4.0", "tsyringe": "^4.7.0", "utp-native": "^2.5.3", + "uWebSockets.js": "github:uNetworking/uWebSockets.js#v20.19.0", "ws": "^8.12.0" }, "bin": { @@ -11836,6 +11837,10 @@ "uuid": "dist/bin/uuid" } }, + "node_modules/uWebSockets.js": { + "version": "20.19.0", + "resolved": "git+ssh://git@github.com/uNetworking/uWebSockets.js.git#42c9c0d5d31f46ca4115dc75672b0037ec970f28" + }, "node_modules/v8-compile-cache": { "version": "2.3.0", "resolved": "https://registry.npmjs.org/v8-compile-cache/-/v8-compile-cache-2.3.0.tgz", @@ -20895,6 +20900,10 @@ "resolved": "https://registry.npmjs.org/uuid/-/uuid-8.3.2.tgz", "integrity": "sha512-+NYs2QeMWy+GWFOEm9xnn6HCDp0l7QBD7ml8zLUmJ+93Q5NF0NocErnwkTkXVFNiX3/fpC6afS8Dhb/gz7R7eg==" }, + "uWebSockets.js": { + "version": "git+ssh://git@github.com/uNetworking/uWebSockets.js.git#42c9c0d5d31f46ca4115dc75672b0037ec970f28", + "from": "uWebSockets.js@github:uNetworking/uWebSockets.js#v20.19.0" + }, "v8-compile-cache": { "version": "2.3.0", "resolved": "https://registry.npmjs.org/v8-compile-cache/-/v8-compile-cache-2.3.0.tgz", diff --git a/package.json b/package.json index 929c868c6..3758fd9b1 100644 --- a/package.json +++ b/package.json @@ -122,6 +122,7 @@ "tslib": "^2.4.0", "tsyringe": "^4.7.0", "utp-native": "^2.5.3", + "uWebSockets.js": "github:uNetworking/uWebSockets.js#v20.19.0", "ws": "^8.12.0" }, "devDependencies": { diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index d33d8f053..664811c2f 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -10,10 +10,10 @@ import type { RawHandlerImplementation, ServerHandlerImplementation, UnaryHandlerImplementation, + ConnectionInfo, } from './types'; import type { ReadableWritablePair } from 'stream/web'; import type { JSONValue } from '../types'; -import type { ConnectionInfo } from '../network/types'; import type { RPCErrorEvent } from './utils'; import type { MiddlewareFactory } from './types'; import { ReadableStream } from 'stream/web'; diff --git a/src/RPC/handlers.ts b/src/RPC/handlers.ts index c738c74e8..3ca13ce5b 100644 --- a/src/RPC/handlers.ts +++ b/src/RPC/handlers.ts @@ -2,7 +2,7 @@ import type { JSONValue } from 'types'; import type { ContainerType } from 'RPC/types'; import type { ReadableStream } from 'stream/web'; import type { JsonRpcRequest } from 'RPC/types'; -import type { ConnectionInfo } from '../network/types'; +import type { ConnectionInfo } from './types'; import type { ContextCancellable } from '../contexts/types'; abstract class Handler< diff --git a/src/RPC/types.ts b/src/RPC/types.ts index 4d96fcc0c..e34d0fabc 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -1,5 +1,4 @@ import type { JSONValue } from '../types'; -import type { ConnectionInfo } from '../network/types'; import type { ContextCancellable } from '../contexts/types'; import type { ReadableStream, ReadableWritablePair } from 'stream/web'; import type { Handler } from './handlers'; @@ -11,6 +10,8 @@ import type { ClientCaller, UnaryCaller, } from './callers'; +import type { NodeId } from '../nodes/types'; +import type { Certificate } from '../keys/types'; /** * This is the JSON RPC request object. this is the generic message type used for the RPC. @@ -108,6 +109,24 @@ type JsonRpcMessage = | JsonRpcRequest | JsonRpcResponse; +/** + * Proxy connection information + * @property remoteNodeId - NodeId of the remote connecting node + * @property remoteCertificates - Certificate chain of the remote connecting node + * @property localHost - Proxy host of the local connecting node + * @property localPort - Proxy port of the local connecting node + * @property remoteHost - Proxy host of the remote connecting node + * @property remotePort - Proxy port of the remote connecting node + */ +type ConnectionInfo = Partial<{ + remoteNodeId: NodeId; + remoteCertificates: Array; + localHost: string; + localPort: number; + remoteHost: string; + remotePort: number; +}>; + // Handler types type HandlerImplementation = ( input: I, @@ -218,6 +237,7 @@ export type { JsonRpcRequest, JsonRpcResponse, JsonRpcMessage, + ConnectionInfo, HandlerImplementation, RawHandlerImplementation, DuplexHandlerImplementation, diff --git a/src/clientRPC/errors.ts b/src/clientRPC/errors.ts new file mode 100644 index 000000000..030bad8d5 --- /dev/null +++ b/src/clientRPC/errors.ts @@ -0,0 +1,28 @@ +import { ErrorPolykey, sysexits } from '../errors'; + +class ErrorRPC extends ErrorPolykey {} + +class ErrorRPCClient extends ErrorRPC {} + +class ErrorClientAuthMissing extends ErrorRPCClient { + static description = 'Authorisation metadata is required but missing'; + exitCode = sysexits.NOPERM; +} + +class ErrorClientAuthFormat extends ErrorRPCClient { + static description = 'Authorisation metadata has invalid format'; + exitCode = sysexits.USAGE; +} + +class ErrorClientAuthDenied extends ErrorRPCClient { + static description = 'Authorisation metadata is incorrect or expired'; + exitCode = sysexits.NOPERM; +} + +export { + ErrorRPC, + ErrorRPCClient, + ErrorClientAuthMissing, + ErrorClientAuthFormat, + ErrorClientAuthDenied, +}; diff --git a/src/clientRPC/utils.ts b/src/clientRPC/utils.ts index 9b280d77e..8fea99443 100644 --- a/src/clientRPC/utils.ts +++ b/src/clientRPC/utils.ts @@ -1,20 +1,9 @@ -import type { SessionToken } from '../sessions/types'; -import type KeyRing from '../keys/KeyRing'; -import type SessionManager from '../sessions/SessionManager'; import type { RPCRequestParams } from './types'; -import type { JsonRpcRequest } from '../RPC/types'; -import type { ReadableWritablePair } from 'stream/web'; -import type Logger from '@matrixai/logger'; -import type { ConnectionInfo, Host, Port } from '../network/types'; -import type RPCServer from '../RPC/RPCServer'; -import type { TLSSocket } from 'tls'; -import type { Server } from 'https'; -import type net from 'net'; -import type https from 'https'; -import { ReadableStream, WritableStream } from 'stream/web'; -import WebSocket, { WebSocketServer } from 'ws'; -import * as clientErrors from '../client/errors'; -import { promise } from '../utils'; +import type SessionManager from 'sessions/SessionManager'; +import type KeyRing from 'keys/KeyRing'; +import type { JsonRpcRequest } from 'RPC/types'; +import type { SessionToken } from 'sessions/types'; +import * as clientErrors from './errors'; async function authenticate( sessionManager: SessionManager, @@ -65,201 +54,4 @@ function encodeAuthFromPassword(password: string): string { return `Basic ${encoded}`; } -function readableFromWebSocket( - ws: WebSocket, - logger: Logger, -): ReadableStream { - return new ReadableStream({ - start: (controller) => { - logger.info('starting'); - const messageHandler = (data) => { - logger.debug(`message: ${data.toString()}`); - ws.pause(); - const message = data as Buffer; - if (message.length === 0) { - logger.info('ENDING'); - ws.removeAllListeners('message'); - try { - controller.close(); - } catch { - // Ignore already closed - } - return; - } - controller.enqueue(message); - }; - ws.on('message', messageHandler); - ws.once('close', () => { - logger.info('closed'); - ws.removeListener('message', messageHandler); - try { - controller.close(); - } catch { - // Ignore already closed - } - }); - ws.once('error', (e) => { - controller.error(e); - }); - }, - cancel: () => { - logger.info('cancelled'); - ws.close(); - }, - pull: () => { - logger.debug('resuming'); - ws.resume(); - }, - }); -} - -function writeableFromWebSocket( - ws: WebSocket, - holdOpen: boolean, - logger: Logger, -): WritableStream { - return new WritableStream({ - start: (controller) => { - logger.info('starting'); - ws.once('error', (e) => { - logger.error(`error: ${e}`); - controller.error(e); - }); - ws.once('close', (code, reason) => { - logger.info( - `ws closing early! with code: ${code} and reason: ${reason.toString()}`, - ); - controller.error(Error('TMP WebSocket Closed early')); - }); - }, - close: () => { - logger.info('stream closing'); - ws.send(Buffer.from([])); - if (!holdOpen) ws.terminate(); - }, - abort: () => { - logger.info('aborting'); - ws.close(); - }, - write: async (chunk, controller) => { - logger.debug(`writing: ${chunk?.toString()}`); - const wait = promise(); - ws.send(chunk, (e) => { - if (e != null) { - logger.error(`error: ${e}`); - controller.error(e); - } - wait.resolveP(); - }); - await wait.p; - }, - }); -} - -function webSocketToWebStreamPair( - ws: WebSocket, - holdOpen: boolean, - logger: Logger, -): ReadableWritablePair { - return { - readable: readableFromWebSocket(ws, logger.getChild('readable')), - writable: writeableFromWebSocket(ws, holdOpen, logger.getChild('writable')), - }; -} - -function startConnection( - host: string, - port: number, - logger: Logger, -): Promise> { - const ws = new WebSocket(`wss://${host}:${port}`, { - // CheckServerIdentity: ( - // servername: string, - // cert: WebSocket.CertMeta, - // ): boolean => { - // console.log('CHECKING IDENTITY'); - // console.log(servername); - // console.log(cert); - // return false; - // }, - rejectUnauthorized: false, - // Ca: tlsConfig.certChainPem - }); - ws.once('close', () => logger.info('CLOSED')); - // Ws.once('upgrade', () => { - // // Const tlsSocket = request.socket as TLSSocket; - // // Console.log(tlsSocket.getPeerCertificate()); - // logger.info('Test early cancellation'); - // // Request.destroy(Error('some error')); - // // tlsSocket.destroy(Error('some error')); - // // ws.close(12345, 'some reason'); - // // TODO: Use the existing verify method from the GRPC implementation - // // TODO: Have this emit an error on verification failure. - // // It's fine for the server side to close abruptly without error - // }); - const prom = promise>(); - ws.once('open', () => { - logger.info('starting connection'); - prom.resolveP(webSocketToWebStreamPair(ws, true, logger)); - }); - return prom.p; -} - -function handleConnection(ws: WebSocket, logger: Logger): void { - ws.once('close', () => logger.info('CLOSED')); - const readable = readableFromWebSocket(ws, logger.getChild('readable')); - const writable = writeableFromWebSocket( - ws, - false, - logger.getChild('writable'), - ); - void readable.pipeTo(writable).catch((e) => logger.error(e)); -} - -function createClientServer( - server: Server, - rpcServer: RPCServer, - logger: Logger, -) { - logger.info('created server'); - const wss = new WebSocketServer({ - server, - }); - wss.on('error', (e) => logger.error(e)); - logger.info('created wss'); - wss.on('connection', (ws, req) => { - logger.info('connection!'); - const socket = req.socket as TLSSocket; - const streamPair = webSocketToWebStreamPair(ws, false, logger); - rpcServer.handleStream(streamPair, { - localHost: socket.localAddress! as Host, - localPort: socket.localPort! as Port, - remoteCertificates: socket.getPeerCertificate(), - remoteHost: socket.remoteAddress! as Host, - remotePort: socket.remotePort! as Port, - } as unknown as ConnectionInfo); - }); - wss.once('close', () => { - wss.removeAllListeners('error'); - wss.removeAllListeners('connection'); - }); - return wss; -} - -async function listen(server: https.Server, host?: string, port?: number) { - await new Promise((resolve) => { - server.listen(port, host ?? '127.0.0.1', undefined, () => resolve()); - }); - const addressInfo = server.address() as net.AddressInfo; - return addressInfo.port; -} - -export { - authenticate, - decodeAuth, - encodeAuthFromPassword, - startConnection, - handleConnection, - createClientServer, - listen, -}; +export { authenticate, decodeAuth, encodeAuthFromPassword }; diff --git a/src/types.ts b/src/types.ts index 9a5289884..2f937bc51 100644 --- a/src/types.ts +++ b/src/types.ts @@ -110,6 +110,7 @@ interface FileSystem { readdir: typeof fs.promises.readdir; rename: typeof fs.promises.rename; open: typeof fs.promises.open; + mkdtemp: typeof fs.promises.mkdtemp; }; constants: typeof fs.constants; } diff --git a/src/websockets/WebSocketClient.ts b/src/websockets/WebSocketClient.ts new file mode 100644 index 000000000..cc7be82c6 --- /dev/null +++ b/src/websockets/WebSocketClient.ts @@ -0,0 +1,363 @@ +import type { TLSSocket } from 'tls'; +import type { NodeId } from 'ids/index'; +import { WritableStream, ReadableStream } from 'stream/web'; +import { createDestroy } from '@matrixai/async-init'; +import Logger from '@matrixai/logger'; +import WebSocket from 'ws'; +import { Timer } from '@matrixai/timer'; +import { Validator } from 'ip-num'; +import WebSocketStream from './WebSocketStream'; +import * as clientRpcUtils from './utils'; +import * as clientRPCErrors from './errors'; +import { promise } from '../utils'; + +const timeoutSymbol = Symbol('TimedOutSymbol'); + +interface WebSocketClient extends createDestroy.CreateDestroy {} +@createDestroy.CreateDestroy() +class WebSocketClient { + static async createWebSocketClient({ + host, + port, + expectedNodeIds, + connectionTimeout, + pingInterval = 1000, + pingTimeout = 10000, + maxReadableStreamBytes = 1000, // About 1kB + logger = new Logger(this.name), + }: { + host: string; + port: number; + expectedNodeIds: Array; + connectionTimeout?: number; + pingInterval?: number; + pingTimeout?: number; + maxReadableStreamBytes?: number; + logger?: Logger; + }): Promise { + logger.info(`Creating ${this.name}`); + const clientClient = new this( + logger, + host, + port, + maxReadableStreamBytes, + expectedNodeIds, + connectionTimeout, + pingInterval, + pingTimeout, + ); + logger.info(`Created ${this.name}`); + return clientClient; + } + + protected host: string; + protected activeConnections: Set = new Set(); + + constructor( + protected logger: Logger, + host: string, + protected port: number, + protected maxReadableStreamBytes: number, + protected expectedNodeIds: Array, + protected connectionTimeout: number | undefined, + protected pingInterval: number, + protected pingTimeout: number, + ) { + if (Validator.isValidIPv4String(host)[0]) { + this.host = host; + } else if (Validator.isValidIPv6String(host)[0]) { + this.host = `[${host}]`; + } else { + throw new clientRPCErrors.ErrorClientInvalidHost(); + } + } + + public async destroy(force: boolean = false) { + this.logger.info(`Destroying ${this.constructor.name}`); + if (force) { + for (const activeConnection of this.activeConnections) { + activeConnection.end(); + } + } + for (const activeConnection of this.activeConnections) { + await activeConnection.endedProm.catch(() => {}); // Ignore errors here + } + this.logger.info(`Destroyed ${this.constructor.name}`); + } + + @createDestroy.ready(new clientRPCErrors.ErrorClientDestroyed()) + public async startConnection({ + timeoutTimer, + }: { + timeoutTimer?: Timer; + } = {}): Promise { + // Use provided timer + let timer: Timer | undefined = timeoutTimer; + // If no timer provided use provided default timeout + if (timeoutTimer == null && this.connectionTimeout != null) { + timer = new Timer({ + delay: this.connectionTimeout, + }); + } + const address = `wss://${this.host}:${this.port}`; + this.logger.info(`Connecting to ${address}`); + const connectProm = promise(); + const authenticateProm = promise(); + const ws = new WebSocket(address, { + rejectUnauthorized: false, + }); + // Handle connection failure + const openErrorHandler = (e) => { + connectProm.rejectP( + new clientRPCErrors.ErrorClientConnectionFailed(undefined, { + cause: e, + }), + ); + }; + ws.once('error', openErrorHandler); + // Authenticate server's certificates + ws.once('upgrade', async (request) => { + const tlsSocket = request.socket as TLSSocket; + const peerCert = tlsSocket.getPeerCertificate(true); + clientRpcUtils + .verifyServerCertificateChain( + this.expectedNodeIds, + clientRpcUtils.detailedToCertChain(peerCert), + ) + .then(authenticateProm.resolveP, authenticateProm.rejectP); + }); + ws.once('open', () => { + this.logger.info('starting connection'); + connectProm.resolveP(); + }); + const earlyCloseProm = promise(); + ws.once('close', () => { + earlyCloseProm.resolveP(); + }); + // There are 3 resolve conditions here. + // 1. Connection established and authenticated + // 2. connection error or authentication failure + // 3. connection timed out + try { + const result = await Promise.race([ + timer?.then(() => timeoutSymbol) ?? new Promise(() => {}), + await Promise.all([authenticateProm.p, connectProm.p]), + ]); + if (result === timeoutSymbol) { + throw new clientRPCErrors.ErrorClientConnectionTimedOut(); + } + } catch (e) { + // Clean up + // unregister handlers + ws.removeAllListeners('error'); + ws.removeAllListeners('upgrade'); + ws.removeAllListeners('open'); + // Close the ws if it's open at this stage + ws.terminate(); + // Ensure the connection is removed from the active connection set before + // returning. + await earlyCloseProm.p; + throw e; + } + // Cleaning up connection error + ws.removeEventListener('error', openErrorHandler); + + // Constructing the `ReadableWritablePair`, the lifecycle is handed off to + // the webSocketStream at this point. + const webSocketStreamClient = new WebSocketStreamClientInternal( + ws, + this.maxReadableStreamBytes, + this.pingInterval, + this.pingTimeout, + this.logger, + ); + // Setting up activeStream map lifecycle + this.activeConnections.add(webSocketStreamClient); + void webSocketStreamClient.endedProm + .catch(() => {}) // Ignore errors + .finally(() => { + this.activeConnections.delete(webSocketStreamClient); + }); + return webSocketStreamClient; + } +} + +// This is the internal implementation of the client's stream pair. +class WebSocketStreamClientInternal extends WebSocketStream { + constructor( + protected ws: WebSocket, + maxReadableStreamBytes: number, + pingInterval: number, + pingTimeout: number, + logger: Logger, + ) { + super(); + const readableLogger = logger.getChild('readable'); + const writableLogger = logger.getChild('writable'); + this.readable = new ReadableStream( + { + start: (controller) => { + readableLogger.info('Starting'); + const messageHandler = (data) => { + readableLogger.debug(`Received ${data.toString()}`); + if (controller.desiredSize == null) { + controller.error(Error('NEVER')); + return; + } + if (controller.desiredSize < 0) { + readableLogger.debug('Applying readable backpressure'); + ws.pause(); + } + const message = data as Buffer; + if (message.length === 0) { + readableLogger.debug('Null message received'); + ws.removeListener('message', messageHandler); + if (!this.readableEnded_) { + this.signalReadableEnd(); + readableLogger.debug('Closing'); + controller.close(); + } + if (this.writableEnded_) { + logger.debug('Closing socket'); + ws.close(); + } + return; + } + controller.enqueue(message); + }; + readableLogger.debug('Registering socket message handler'); + ws.on('message', messageHandler); + ws.once('close', (code, reason) => { + logger.info('Socket closed'); + ws.removeListener('message', messageHandler); + if (!this.readableEnded_) { + readableLogger.debug( + `Closed early, ${code}, ${reason.toString()}`, + ); + const e = new clientRPCErrors.ErrorClientConnectionEndedEarly(); + this.signalReadableEnd(e); + controller.error(e); + } + }); + ws.once('error', (e) => { + if (!this.readableEnded_) { + readableLogger.error(e); + this.signalReadableEnd(e); + controller.error(e); + } + }); + }, + cancel: () => { + readableLogger.debug('Cancelled'); + if (!this.writableEnded_) { + readableLogger.debug('Closing socket'); + this.signalReadableEnd(); + ws.close(); + } + }, + pull: () => { + readableLogger.debug('Releasing backpressure'); + ws.resume(); + }, + }, + { + highWaterMark: maxReadableStreamBytes, + size: (chunk) => chunk?.byteLength ?? 0, + }, + ); + this.writable = new WritableStream({ + start: (controller) => { + writableLogger.info('Starting'); + ws.once('error', (e) => { + if (!this.writableEnded_) { + writableLogger.error(e); + this.signalWritableEnd(e); + controller.error(e); + } + }); + ws.once('close', (code, reason) => { + if (!this.writableEnded_) { + writableLogger.debug(`Closed early, ${code}, ${reason.toString()}`); + const e = new clientRPCErrors.ErrorClientConnectionEndedEarly(); + this.signalWritableEnd(e); + controller.error(e); + } + }); + }, + close: () => { + writableLogger.debug('Closing, sending null message'); + ws.send(Buffer.from([])); + this.signalWritableEnd(); + if (this.readableEnded_) { + writableLogger.debug('Closing socket'); + ws.close(); + } + }, + abort: () => { + writableLogger.debug('Aborted'); + this.signalWritableEnd(Error('TMP ABORTED')); + if (this.readableEnded_) { + writableLogger.debug('Closing socket'); + ws.close(); + } + }, + write: async (chunk, controller) => { + if (this.writableEnded_) return; + writableLogger.debug(`Sending ${chunk?.toString()}`); + const wait = promise(); + ws.send(chunk, (e) => { + if (e != null && !this.writableEnded_) { + // Opting to debug message here and not log an error, sending + // failure is common if we send before the close event. + writableLogger.debug('failed to send'); + const err = new clientRPCErrors.ErrorClientConnectionEndedEarly( + undefined, + { + cause: e, + }, + ); + this.signalWritableEnd(err); + controller.error(err); + } + wait.resolveP(); + }); + await wait.p; + }, + }); + + // Setting up heartbeat + const pingTimer = setInterval(() => { + ws.ping(); + }, pingInterval); + const pingTimeoutTimer = setTimeout(() => { + logger.debug('Ping timed out'); + ws.close(4002, 'Timed out'); + }, pingTimeout); + ws.on('ping', () => { + logger.debug('Received ping'); + ws.pong(); + }); + ws.on('pong', () => { + logger.debug('Received pong'); + pingTimeoutTimer.refresh(); + }); + ws.once('close', (code, reason) => { + logger.debug('WebSocket closed'); + const err = + code !== 1000 + ? Error(`TMP WebSocket ended with code ${code}, ${reason.toString()}`) + : undefined; + this.signalWebSocketEnd(err); + logger.debug('Cleaning up timers'); + // Clean up timers + clearTimeout(pingTimer); + clearTimeout(pingTimeoutTimer); + }); + } + + end(): void { + this.ws.close(4001, 'TMP ENDING CONNECTION'); + } +} + +export default WebSocketClient; diff --git a/src/websockets/WebSocketServer.ts b/src/websockets/WebSocketServer.ts new file mode 100644 index 000000000..5ab3a80c0 --- /dev/null +++ b/src/websockets/WebSocketServer.ts @@ -0,0 +1,499 @@ +import type { + ReadableStreamController, + ReadableWritablePair, + WritableStreamDefaultController, +} from 'stream/web'; +import type { FileSystem, PromiseDeconstructed } from 'types'; +import type { TLSConfig } from 'network/types'; +import type { + HttpRequest, + HttpResponse, + us_socket_context_t, + WebSocket, +} from 'uWebSockets.js'; +import type { ConnectionInfo } from '../RPC/types'; +import { WritableStream, ReadableStream } from 'stream/web'; +import path from 'path'; +import os from 'os'; +import { startStop } from '@matrixai/async-init'; +import Logger from '@matrixai/logger'; +import uWebsocket from 'uWebSockets.js'; +import WebSocketStream from './WebSocketStream'; +import * as clientRPCErrors from './errors'; +import * as webSocketEvents from './events'; +import { promise } from '../utils'; + +type ConnectionCallback = ( + streamPair: ReadableWritablePair, + connectionInfo: ConnectionInfo, +) => void; + +type Context = { + message: ( + ws: WebSocket, + message: ArrayBuffer, + isBinary: boolean, + ) => void; + drain: (ws: WebSocket) => void; + close: (ws: WebSocket, code: number, message: ArrayBuffer) => void; + pong: (ws: WebSocket, message: ArrayBuffer) => void; + logger: Logger; +}; + +/** + * Events: + * - start + * - stop + * - connection + */ +interface WebSocketServer extends startStop.StartStop {} +@startStop.StartStop() +class WebSocketServer extends EventTarget { + static async createWebSocketServer({ + connectionCallback, + tlsConfig, + basePath, + host, + port, + idleTimeout, + pingInterval = 1000, + pingTimeout = 10000, + fs = require('fs'), + maxReadBufferBytes = 1_000_000_000, // About 1 GB + logger = new Logger(this.name), + }: { + connectionCallback: ConnectionCallback; + tlsConfig: TLSConfig; + basePath?: string; + host?: string; + port?: number; + idleTimeout?: number; + pingInterval?: number; + pingTimeout?: number; + fs?: FileSystem; + maxReadBufferBytes?: number; + logger?: Logger; + }) { + logger.info(`Creating ${this.name}`); + const wsServer = new this( + logger, + fs, + maxReadBufferBytes, + idleTimeout, + pingInterval, + pingTimeout, + ); + await wsServer.start({ + connectionCallback, + tlsConfig, + basePath, + host, + port, + }); + logger.info(`Created ${this.name}`); + return wsServer; + } + + protected server: uWebsocket.TemplatedApp; + protected listenSocket: uWebsocket.us_listen_socket; + protected host: string; + protected connectionEventHandler: ( + event: webSocketEvents.ConnectionEvent, + ) => void; + protected activeSockets: Set = new Set(); + protected connectionIndex: number = 0; + + /** + * + * @param logger + * @param fs + * @param maxReadBufferBytes Max number of bytes stored in read buffer before error + * @param idleTimeout + * @param pingInterval + * @param pingTimeout + */ + constructor( + protected logger: Logger, + protected fs: FileSystem, + protected maxReadBufferBytes, + protected idleTimeout: number | undefined, + protected pingInterval: number, + protected pingTimeout: number, + ) { + super(); + } + + public async start({ + tlsConfig, + basePath = os.tmpdir(), + host, + port = 0, + connectionCallback, + }: { + tlsConfig: TLSConfig; + basePath?: string; + host?: string; + port?: number; + connectionCallback?: ConnectionCallback; + }): Promise { + this.logger.info(`Starting ${this.constructor.name}`); + if (connectionCallback != null) { + this.connectionEventHandler = ( + event: webSocketEvents.ConnectionEvent, + ) => { + connectionCallback( + event.detail.webSocketStream, + event.detail.connectionInfo, + ); + }; + this.addEventListener('connection', this.connectionEventHandler); + } + await this.setupServer(basePath, tlsConfig); + this.server.ws('/*', { + sendPingsAutomatically: true, + idleTimeout: this.idleTimeout, + upgrade: this.upgrade, + open: this.open, + message: this.message, + close: this.close, + drain: this.drain, + pong: this.pong, + // Ping uses default behaviour. + // We don't use subscriptions. + }); + this.server.any('/*', (res, _) => { + // Reject normal requests with an upgrade code + res + .writeStatus('426') + .writeHeader('connection', 'Upgrade') + .writeHeader('upgrade', 'websocket') + .end('426 Upgrade Required', true); + }); + const listenProm = promise(); + const listenCallback = (listenSocket) => { + if (listenSocket) { + this.listenSocket = listenSocket; + listenProm.resolveP(); + } else { + listenProm.rejectP(new clientRPCErrors.ErrorServerPortUnavailable()); + } + }; + if (host != null) { + // With custom host + this.server.listen(host, port ?? 0, listenCallback); + } else { + // With default host + this.server.listen(port, listenCallback); + } + await listenProm.p; + this.logger.debug( + `Listening on port ${uWebsocket.us_socket_local_port(this.listenSocket)}`, + ); + this.host = host ?? '127.0.0.1'; + this.dispatchEvent( + new webSocketEvents.StartEvent({ + detail: { + host: this.host, + port: this.port, + }, + }), + ); + this.logger.info(`Started ${this.constructor.name}`); + } + + public async stop(force: boolean = false): Promise { + this.logger.info(`Stopping ${this.constructor.name}`); + // Close the server by closing the underlying socket + uWebsocket.us_listen_socket_close(this.listenSocket); + // Shutting down active websockets + if (force) { + for (const webSocketStream of this.activeSockets) { + webSocketStream.end(); + } + } + // Wait for all active websockets to close + for (const webSocketStream of this.activeSockets) { + webSocketStream.endedProm.catch(() => {}); // Ignore errors + } + if (this.connectionEventHandler != null) { + this.removeEventListener('connection', this.connectionEventHandler); + } + this.dispatchEvent(new webSocketEvents.StopEvent()); + this.logger.info(`Stopped ${this.constructor.name}`); + } + + get port() { + return uWebsocket.us_socket_local_port(this.listenSocket); + } + + /** + * This creates the pem files and starts the server with them. It ensures that + * files are cleaned up to the best of its ability. + */ + protected async setupServer(basePath: string, tlsConfig: TLSConfig) { + const tmpDir = await this.fs.promises.mkdtemp( + path.join(basePath, 'polykey-'), + ); + // TODO: The key file needs to be in the encrypted format + const keyFile = path.join(tmpDir, 'keyFile.pem'); + const certFile = path.join(tmpDir, 'certFile.pem'); + try { + await this.fs.promises.writeFile(keyFile, tlsConfig.keyPrivatePem); + await this.fs.promises.writeFile(certFile, tlsConfig.certChainPem); + this.server = uWebsocket.SSLApp({ + key_file_name: keyFile, + cert_file_name: certFile, + }); + } finally { + await this.fs.promises.rm(keyFile); + await this.fs.promises.rm(certFile); + } + } + + /** + * Applies default upgrade behaviour and creates a UserData object we can + * mutate for the Context + */ + protected upgrade = ( + res: HttpResponse, + req: HttpRequest, + context: us_socket_context_t, + ) => { + const logger = this.logger.getChild(`Connection ${this.connectionIndex}`); + res.upgrade>( + { + logger, + }, + req.getHeader('sec-websocket-key'), + req.getHeader('sec-websocket-protocol'), + req.getHeader('sec-websocket-extensions'), + context, + ); + this.connectionIndex += 1; + }; + + /** + * Handles the creation of the `ReadableWritablePair` and provides it to the + * StreamPair handler. + */ + protected open = (ws: WebSocket) => { + const webSocketStream = new WebSocketStreamServerInternal( + ws, + this.maxReadBufferBytes, + this.pingInterval, + this.pingTimeout, + ); + // Adding socket to the active sockets map + this.activeSockets.add(webSocketStream); + webSocketStream.endedProm + .catch(() => {}) // Ignore errors here + .finally(() => { + this.activeSockets.delete(webSocketStream); + }); + + // There is not nodeId or certs for the client, and we can't get the remote + // port from the `uWebsocket` library. + const connectionInfo: ConnectionInfo = { + remoteHost: Buffer.from(ws.getRemoteAddressAsText()).toString(), + localHost: this.host, + localPort: this.port, + }; + this.dispatchEvent( + new webSocketEvents.ConnectionEvent({ + detail: { + webSocketStream, + connectionInfo, + }, + }), + ); + }; + + /** + * Routes incoming messages to each stream using the `Context` message + * callback. + */ + protected message = ( + ws: WebSocket, + message: ArrayBuffer, + isBinary: boolean, + ) => { + ws.getUserData().message(ws, message, isBinary); + }; + + protected drain = (ws: WebSocket) => { + ws.getUserData().drain(ws); + }; + + protected close = ( + ws: WebSocket, + code: number, + message: ArrayBuffer, + ) => { + ws.getUserData().close(ws, code, message); + }; + + protected pong = (ws: WebSocket, message: ArrayBuffer) => { + ws.getUserData().pong(ws, message); + }; +} + +class WebSocketStreamServerInternal extends WebSocketStream { + protected backPressure: PromiseDeconstructed | null = null; + protected writeBackpressure: boolean = false; + + constructor( + protected ws: WebSocket, + maxReadBufferBytes: number, + pingInterval: number, + pingTimeout: number, + ) { + super(); + const context = ws.getUserData(); + const logger = context.logger; + logger.info('WS opened'); + let writableController: WritableStreamDefaultController | undefined; + let readableController: ReadableStreamController | undefined; + const writableLogger = logger.getChild('Writable'); + const readableLogger = logger.getChild('Readable'); + // Setting up the writable stream + this.writable = new WritableStream({ + start: (controller) => { + writableController = controller; + }, + write: async (chunk, controller) => { + await this.backPressure?.p; + const writeResult = ws.send(chunk, true); + switch (writeResult) { + default: + case 2: + // Write failure, emit error + writableLogger.error('Send error'); + controller.error(new clientRPCErrors.ErrorServerSendFailed()); + break; + case 0: + writableLogger.info('Write backpressure'); + // Signal backpressure + this.backPressure = promise(); + this.writeBackpressure = true; + this.backPressure.p.finally(() => { + this.writeBackpressure = false; + }); + break; + case 1: + // Success + writableLogger.debug(`Sending ${Buffer.from(chunk).toString()}`); + break; + } + }, + close: () => { + writableLogger.info('Closed, sending null message'); + if (!this.webSocketEnded_) ws.send(Buffer.from([]), true); + this.signalWritableEnd(); + if (this.readableEnded_ && !this.webSocketEnded_) { + writableLogger.debug('Ending socket'); + this.signalWebSocketEnd(); + ws.end(); + } + }, + abort: () => { + writableLogger.info('Aborted'); + if (this.readableEnded_ && !this.webSocketEnded_) { + writableLogger.debug('Ending socket'); + this.signalWebSocketEnd(Error('TMP ERROR ABORTED')); + ws.end(4001, 'ABORTED'); + } + }, + }); + // Setting up the readable stream + this.readable = new ReadableStream( + { + start: (controller) => { + readableController = controller; + context.message = (ws, message, _) => { + const messageBuffer = Buffer.from(message); + readableLogger.debug(`Received ${messageBuffer.toString()}`); + if (message.byteLength === 0) { + readableLogger.debug('Null message received'); + if (!this.readableEnded_) { + readableLogger.debug('Closing'); + this.signalReadableEnd(); + controller.close(); + if (this.writableEnded_ && !this.webSocketEnded_) { + readableLogger.debug('Ending socket'); + this.signalWebSocketEnd(); + ws.end(); + } + } + return; + } + controller.enqueue(messageBuffer); + if (controller.desiredSize != null && controller.desiredSize < 0) { + readableLogger.error('Read stream buffer full'); + const err = new clientRPCErrors.ErrorServerReadableBufferLimit(); + if (!this.webSocketEnded_) { + this.signalWebSocketEnd(err); + ws.end(4001, 'Read stream buffer full'); + } + controller.error(err); + } + }; + }, + cancel: () => { + this.signalReadableEnd(Error('TMP READABLE CANCELLED')); + if (this.writableEnded_ && !this.webSocketEnded_) { + readableLogger.debug('Ending socket'); + this.signalWebSocketEnd(); + ws.end(); + } + }, + }, + { + highWaterMark: maxReadBufferBytes, + size: (chunk) => chunk?.byteLength ?? 0, + }, + ); + + const pingTimer = setInterval(() => { + ws.ping(); + }, pingInterval); + const pingTimeoutTimer = setTimeout(() => { + logger.debug('Ping timed out'); + ws.end(); + }, pingTimeout); + context.pong = () => { + logger.debug('Received pong'); + pingTimeoutTimer.refresh(); + }; + context.close = () => { + logger.debug('Closing'); + this.signalWebSocketEnd(); + // Cleaning up timers + logger.debug('Cleaning up timers'); + clearTimeout(pingTimer); + clearTimeout(pingTimeoutTimer); + // Closing streams + logger.debug('Cleaning streams'); + const err = new clientRPCErrors.ErrorServerConnectionEndedEarly(); + if (!this.readableEnded_) { + readableLogger.debug('Closing'); + this.signalReadableEnd(err); + readableController?.error(err); + } + if (!this.writableEnded_) { + writableLogger.debug('Closing'); + this.signalWritableEnd(err); + writableController?.error(err); + } + }; + context.drain = () => { + logger.debug('Drained'); + this.backPressure?.resolveP(); + }; + } + + end(): void { + this.ws.end(4001, 'TMP ENDING CONNECTION'); + } +} + +export default WebSocketServer; diff --git a/src/websockets/WebSocketStream.ts b/src/websockets/WebSocketStream.ts new file mode 100644 index 000000000..a91d63a81 --- /dev/null +++ b/src/websockets/WebSocketStream.ts @@ -0,0 +1,130 @@ +import type { + ReadableStream, + ReadableWritablePair, + WritableStream, +} from 'stream/web'; +import { promise } from '../utils'; + +abstract class WebSocketStream + implements ReadableWritablePair +{ + public readable: ReadableStream; + public writable: WritableStream; + + protected readableEnded_ = false; + protected readableEndedProm_ = promise(); + protected writableEnded_ = false; + protected writableEndedProm_ = promise(); + protected webSocketEnded_ = false; + protected webSocketEndedProm_ = promise(); + protected endedProm_: Promise; + + protected constructor() { + // Sanitise promises so they don't result in unhandled rejections + this.readableEndedProm_.p.catch(() => {}); + this.writableEndedProm_.p.catch(() => {}); + this.webSocketEndedProm_.p.catch(() => {}); + // Creating the endedPromise + this.endedProm_ = Promise.allSettled([ + this.readableEndedProm_.p, + this.writableEndedProm_.p, + this.webSocketEndedProm_.p, + ]).then((result) => { + if ( + result[0].status === 'rejected' || + result[1].status === 'rejected' || + result[2].status === 'rejected' + ) { + // Throw a compound error + throw Error('TMP Stream failed', { cause: result }); + } + // Otherwise return nothing + }); + // Ignore errors if it's never used + this.endedProm_.catch(() => {}); + } + + get readableEnded() { + return this.readableEnded_; + } + + /** + * Resolves when the readable has ended and rejects with any errors. + */ + get readableEndedProm() { + return this.readableEndedProm_.p; + } + + get writableEnded() { + return this.writableEnded_; + } + + /** + * Resolves when the writable has ended and rejects with any errors. + */ + get writableEndedProm() { + return this.writableEndedProm_.p; + } + + get webSocketEnded() { + return this.webSocketEnded_; + } + + /** + * Resolves when the webSocket has ended and rejects with any errors. + */ + get webSocketEndedProm() { + return this.webSocketEndedProm_.p; + } + + get ended() { + return this.readableEnded_ && this.writableEnded_; + } + + /** + * Resolves when the stream has fully closed + */ + get endedProm(): Promise { + return this.endedProm_; + } + + /** + * Forces the active stream to end early + */ + abstract end(): void; + + /** + * Signals the end of the ReadableStream. to be used with the extended class + * to track the streams state. + */ + protected signalReadableEnd(e?: Error) { + if (this.readableEnded_) return; + this.readableEnded_ = true; + if (e == null) this.readableEndedProm_.resolveP(); + else this.readableEndedProm_.rejectP(e); + } + + /** + * Signals the end of the WritableStream. to be used with the extended class + * to track the streams state. + */ + protected signalWritableEnd(e?: Error) { + if (this.writableEnded_) return; + this.writableEnded_ = true; + if (e == null) this.writableEndedProm_.resolveP(); + else this.writableEndedProm_.rejectP(e); + } + + /** + * Signals the end of the WebSocket. to be used with the extended class + * to track the streams state. + */ + protected signalWebSocketEnd(e?: Error) { + if (this.webSocketEnded_) return; + this.webSocketEnded_ = true; + if (e == null) this.webSocketEndedProm_.resolveP(); + else this.webSocketEndedProm_.rejectP(e); + } +} + +export default WebSocketStream; diff --git a/src/websockets/errors.ts b/src/websockets/errors.ts new file mode 100644 index 000000000..002282359 --- /dev/null +++ b/src/websockets/errors.ts @@ -0,0 +1,119 @@ +import { ErrorPolykey, sysexits } from '../errors'; + +class ErrorWebSocket extends ErrorPolykey {} + +class ErrorWebSocketClient extends ErrorWebSocket {} + +class ErrorClientDestroyed extends ErrorWebSocketClient { + static description = 'ClientClient has been destroyed'; + exitCode = sysexits.USAGE; +} + +class ErrorClientInvalidHost extends ErrorWebSocketClient { + static description = 'Host must be a valid IPv4 or IPv6 address string'; + exitCode = sysexits.USAGE; +} + +class ErrorClientConnectionFailed extends ErrorWebSocketClient { + static description = 'Failed to establish connection to server'; + exitCode = sysexits.UNAVAILABLE; +} + +class ErrorClientConnectionTimedOut extends ErrorWebSocketClient { + static description = 'Connection timed out'; + exitCode = sysexits.UNAVAILABLE; +} + +class ErrorClientConnectionEndedEarly extends ErrorWebSocketClient { + static description = 'Connection ended before stream ended'; + exitCode = sysexits.UNAVAILABLE; +} + +class ErrorWebSocketServer extends ErrorWebSocket {} + +class ErrorServerPortUnavailable extends ErrorWebSocketServer { + static description = 'Failed to bind a free port'; + exitCode = sysexits.UNAVAILABLE; +} + +class ErrorServerSendFailed extends ErrorWebSocketServer { + static description = 'Failed to send message'; + exitCode = sysexits.UNAVAILABLE; +} + +class ErrorServerReadableBufferLimit extends ErrorWebSocketServer { + static description = 'Readable buffer is full, messages received too quickly'; + exitCode = sysexits.USAGE; +} + +class ErrorServerConnectionEndedEarly extends ErrorWebSocketServer { + static description = 'Connection ended before stream ended'; + exitCode = sysexits.UNAVAILABLE; +} + +/** + * Used for certificate verification + */ +class ErrorCertChain extends ErrorWebSocket {} + +class ErrorCertChainEmpty extends ErrorCertChain { + static description = 'Certificate chain is empty'; + exitCode = sysexits.PROTOCOL; +} + +class ErrorCertChainUnclaimed extends ErrorCertChain { + static description = 'The target node id is not claimed by any certificate'; + exitCode = sysexits.PROTOCOL; +} + +class ErrorCertChainBroken extends ErrorCertChain { + static description = 'The signature chain is broken'; + exitCode = sysexits.PROTOCOL; +} + +class ErrorCertChainDateInvalid extends ErrorCertChain { + static description = 'Certificate in the chain is expired'; + exitCode = sysexits.PROTOCOL; +} + +class ErrorCertChainNameInvalid extends ErrorCertChain { + static description = 'Certificate is missing the common name'; + exitCode = sysexits.PROTOCOL; +} + +class ErrorCertChainKeyInvalid extends ErrorCertChain { + static description = 'Certificate public key does not generate the Node ID'; + exitCode = sysexits.PROTOCOL; +} + +class ErrorCertChainSignatureInvalid extends ErrorCertChain { + static description = 'Certificate self-signed signature is invalid'; + exitCode = sysexits.PROTOCOL; +} + +class ErrorConnectionNodesEmpty extends ErrorWebSocket { + static description = 'Nodes list to verify against was empty'; + exitCode = sysexits.USAGE; +} + +export { + ErrorWebSocketClient, + ErrorClientDestroyed, + ErrorClientInvalidHost, + ErrorClientConnectionFailed, + ErrorClientConnectionTimedOut, + ErrorClientConnectionEndedEarly, + ErrorWebSocketServer, + ErrorServerPortUnavailable, + ErrorServerSendFailed, + ErrorServerReadableBufferLimit, + ErrorServerConnectionEndedEarly, + ErrorCertChainEmpty, + ErrorCertChainUnclaimed, + ErrorCertChainBroken, + ErrorCertChainDateInvalid, + ErrorCertChainNameInvalid, + ErrorCertChainKeyInvalid, + ErrorCertChainSignatureInvalid, + ErrorConnectionNodesEmpty, +}; diff --git a/src/websockets/events.ts b/src/websockets/events.ts new file mode 100644 index 000000000..aaabb5842 --- /dev/null +++ b/src/websockets/events.ts @@ -0,0 +1,46 @@ +import type WebSocketStream from 'websockets/WebSocketStream'; +import type { ConnectionInfo } from 'RPC/types'; + +class StartEvent extends Event { + public detail: { + host: string; + port: number; + }; + constructor( + options: EventInit & { + detail: { + host: string; + port: number; + }; + }, + ) { + super('start', options); + this.detail = options.detail; + } +} + +class StopEvent extends Event { + constructor(options?: EventInit) { + super('stop', options); + } +} + +class ConnectionEvent extends Event { + public detail: { + webSocketStream: WebSocketStream; + connectionInfo: ConnectionInfo; + }; + constructor( + options: EventInit & { + detail: { + webSocketStream: WebSocketStream; + connectionInfo: ConnectionInfo; + }; + }, + ) { + super('connection', options); + this.detail = options.detail; + } +} + +export { StartEvent, StopEvent, ConnectionEvent }; diff --git a/src/websockets/utils.ts b/src/websockets/utils.ts new file mode 100644 index 000000000..638bdb181 --- /dev/null +++ b/src/websockets/utils.ts @@ -0,0 +1,147 @@ +import type { Certificate } from 'keys/types'; +import type { DetailedPeerCertificate } from 'tls'; +import type { NodeId } from 'ids/index'; +import * as x509 from '@peculiar/x509'; +import * as webSocketErrors from './errors'; +import * as keysUtils from '../keys/utils/index'; + +function detailedToCertChain( + cert: DetailedPeerCertificate, +): Array { + const certChain: Array = []; + let currentCert = cert; + while (true) { + certChain.unshift(new x509.X509Certificate(currentCert.raw)); + if (currentCert === currentCert.issuerCertificate) break; + currentCert = currentCert.issuerCertificate; + } + return certChain; +} + +/** + * Verify the server certificate chain when connecting to it from a client + * This is a custom verification intended to verify that the server owned + * the relevant NodeId. + * It is possible that the server has a new NodeId. In that case we will + * verify that the new NodeId is the true descendant of the target NodeId. + */ +async function verifyServerCertificateChain( + nodeIds: Array, + certChain: Array, +): Promise { + if (!certChain.length) { + throw new webSocketErrors.ErrorCertChainEmpty( + 'No certificates available to verify', + ); + } + if (!nodeIds.length) { + throw new webSocketErrors.ErrorConnectionNodesEmpty( + 'No nodes were provided to verify against', + ); + } + const now = new Date(); + let certClaim: Certificate | null = null; + let certClaimIndex: number | null = null; + let verifiedNodeId: NodeId | null = null; + for (let certIndex = 0; certIndex < certChain.length; certIndex++) { + const cert = certChain[certIndex]; + if (now < cert.notBefore || now > cert.notAfter) { + throw new webSocketErrors.ErrorCertChainDateInvalid( + 'Chain certificate date is invalid', + { + data: { + cert, + certIndex, + notBefore: cert.notBefore, + notAfter: cert.notAfter, + now, + }, + }, + ); + } + const certNodeId = keysUtils.certNodeId(cert); + if (certNodeId == null) { + throw new webSocketErrors.ErrorCertChainNameInvalid( + 'Chain certificate common name attribute is missing', + { + data: { + cert, + certIndex, + }, + }, + ); + } + const certPublicKey = keysUtils.certPublicKey(cert); + if (certPublicKey == null) { + throw new webSocketErrors.ErrorCertChainKeyInvalid( + 'Chain certificate public key is missing', + { + data: { + cert, + certIndex, + }, + }, + ); + } + if (!(await keysUtils.certNodeSigned(cert))) { + throw new webSocketErrors.ErrorCertChainSignatureInvalid( + 'Chain certificate does not have a valid node-signature', + { + data: { + cert, + certIndex, + nodeId: keysUtils.publicKeyToNodeId(certPublicKey), + commonName: certNodeId, + }, + }, + ); + } + for (const nodeId of nodeIds) { + if (certNodeId.equals(nodeId)) { + // Found the certificate claiming the nodeId + certClaim = cert; + certClaimIndex = certIndex; + verifiedNodeId = nodeId; + } + } + // If cert is found then break out of loop + if (verifiedNodeId != null) break; + } + if (certClaimIndex == null || certClaim == null || verifiedNodeId == null) { + throw new webSocketErrors.ErrorCertChainUnclaimed( + 'Node IDs is not claimed by any certificate', + { + data: { nodeIds }, + }, + ); + } + if (certClaimIndex > 0) { + let certParent: Certificate; + let certChild: Certificate; + for (let certIndex = certClaimIndex; certIndex > 0; certIndex--) { + certParent = certChain[certIndex]; + certChild = certChain[certIndex - 1]; + if ( + !keysUtils.certIssuedBy(certParent, certChild) || + !(await keysUtils.certSignedBy( + certParent, + keysUtils.certPublicKey(certChild)!, + )) + ) { + throw new webSocketErrors.ErrorCertChainBroken( + 'Chain certificate is not signed by parent certificate', + { + data: { + cert: certChild, + certIndex: certIndex - 1, + certParent, + }, + }, + ); + } + } + } + return verifiedNodeId; +} + +export { detailedToCertChain, verifyServerCertificateChain }; diff --git a/tests/clientRPC/authenticationMiddleware.test.ts b/tests/clientRPC/authenticationMiddleware.test.ts index 3e5d778c2..e204dda20 100644 --- a/tests/clientRPC/authenticationMiddleware.test.ts +++ b/tests/clientRPC/authenticationMiddleware.test.ts @@ -1,10 +1,8 @@ -import type { Server } from 'https'; -import type { WebSocketServer } from 'ws'; import type { RPCRequestParams, RPCResponseResult } from '@/clientRPC/types'; +import type { TLSConfig } from '../../src/network/types'; import fs from 'fs'; import path from 'path'; import os from 'os'; -import { createServer } from 'https'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import { DB } from '@matrixai/db'; import KeyRing from '@/keys/KeyRing'; @@ -19,6 +17,8 @@ import * as authMiddleware from '@/clientRPC/authenticationMiddleware'; import { UnaryCaller } from '@/RPC/callers'; import { UnaryHandler } from '@/RPC/handlers'; import * as middlewareUtils from '@/RPC/middleware'; +import WebSocketServer from '@/websockets/WebSocketServer'; +import WebSocketClient from '@/websockets/WebSocketClient'; import * as testsUtils from '../utils'; describe('agentUnlock', () => { @@ -26,6 +26,7 @@ describe('agentUnlock', () => { new StreamHandler(), ]); const password = 'helloworld'; + const host = '127.0.0.1'; let dataDir: string; let db: DB; let keyRing: KeyRing; @@ -33,9 +34,9 @@ describe('agentUnlock', () => { let certManager: CertManager; let session: Session; let sessionManager: SessionManager; - let server: Server; - let wss: WebSocketServer; - let port: number; + let clientServer: WebSocketServer; + let clientClient: WebSocketClient; + let tlsConfig: TLSConfig; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -72,16 +73,11 @@ describe('agentUnlock', () => { keyRing, logger, }); - const tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); - server = createServer({ - cert: tlsConfig.certChainPem, - key: tlsConfig.keyPrivatePem, - }); - port = await clientRPCUtils.listen(server, '127.0.0.1'); + tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); }); afterEach(async () => { - wss?.close(); - server.close(); + await clientServer?.stop(true); + await clientClient?.destroy(true); await certManager.stop(); await taskManager.stop(); await keyRing.stop(); @@ -111,22 +107,25 @@ describe('agentUnlock', () => { ), logger, }); - wss = clientRPCUtils.createClientServer( - server, - rpcServer, - logger.getChild('server'), - ); + clientServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair, connectionInfo) => { + rpcServer.handleStream(streamPair, connectionInfo); + }, + host, + tlsConfig, + logger, + }); + clientClient = await WebSocketClient.createWebSocketClient({ + expectedNodeIds: [keyRing.getNodeId()], + host, + port: clientServer.port, + logger, + }); const rpcClient = await RPCClient.createRPCClient({ manifest: { agentUnlock: new UnaryCaller(), }, - streamPairCreateCallback: async () => { - return clientRPCUtils.startConnection( - '127.0.0.1', - port, - logger.getChild('client'), - ); - }, + streamPairCreateCallback: async () => clientClient.startConnection(), middleware: middlewareUtils.defaultClientMiddlewareWrapper( authMiddleware.authenticationMiddlewareClient(session), ), diff --git a/tests/clientRPC/handlers/agentStatus.test.ts b/tests/clientRPC/handlers/agentStatus.test.ts index 0ec84f1f2..b40216b99 100644 --- a/tests/clientRPC/handlers/agentStatus.test.ts +++ b/tests/clientRPC/handlers/agentStatus.test.ts @@ -1,9 +1,7 @@ -import type { Server } from 'https'; -import type { WebSocketServer } from 'ws'; +import type { TLSConfig } from '@/network/types'; import fs from 'fs'; import path from 'path'; import os from 'os'; -import { createServer } from 'https'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import { DB } from '@matrixai/db'; import KeyRing from '@/keys/KeyRing'; @@ -16,8 +14,9 @@ import { AgentStatusHandler, } from '@/clientRPC/handlers/agentStatus'; import RPCClient from '@/RPC/RPCClient'; -import * as clientRPCUtils from '@/clientRPC/utils'; import * as nodesUtils from '@/nodes/utils'; +import WebSocketClient from '@/websockets/WebSocketClient'; +import WebSocketServer from '@/websockets/WebSocketServer'; import * as testsUtils from '../../utils'; describe('agentStatus', () => { @@ -25,15 +24,15 @@ describe('agentStatus', () => { new StreamHandler(), ]); const password = 'helloworld'; + const host = '127.0.0.1'; let dataDir: string; let db: DB; let keyRing: KeyRing; let taskManager: TaskManager; let certManager: CertManager; - let server: Server; - let wss: WebSocketServer; - const host = '127.0.0.1'; - let port: number; + let clientServer: WebSocketServer; + let clientClient: WebSocketClient; + let tlsConfig: TLSConfig; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -60,16 +59,11 @@ describe('agentStatus', () => { taskManager, logger, }); - const tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); - server = createServer({ - cert: tlsConfig.certChainPem, - key: tlsConfig.keyPrivatePem, - }); - port = await clientRPCUtils.listen(server, host); + tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); }); afterEach(async () => { - wss?.close(); - server?.close(); + await clientServer?.stop(true); + await clientClient?.destroy(true); await certManager.stop(); await taskManager.stop(); await keyRing.stop(); @@ -91,22 +85,25 @@ describe('agentStatus', () => { }, logger: logger.getChild('RPCServer'), }); - wss = clientRPCUtils.createClientServer( - server, - rpcServer, - logger.getChild('server'), - ); + clientServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair, connectionInfo) => { + rpcServer.handleStream(streamPair, connectionInfo); + }, + host, + tlsConfig, + logger, + }); + clientClient = await WebSocketClient.createWebSocketClient({ + expectedNodeIds: [keyRing.getNodeId()], + host, + port: clientServer.port, + logger, + }); const rpcClient = await RPCClient.createRPCClient({ manifest: { agentStatus: agentStatusCaller, }, - streamPairCreateCallback: async () => { - return clientRPCUtils.startConnection( - host, - port, - logger.getChild('client'), - ); - }, + streamPairCreateCallback: async () => clientClient.startConnection(), logger: logger.getChild('RPCClient'), }); // Doing the test diff --git a/tests/clientRPC/handlers/agentUnlock.test.ts b/tests/clientRPC/handlers/agentUnlock.test.ts index 1b592af3b..41d8c9b44 100644 --- a/tests/clientRPC/handlers/agentUnlock.test.ts +++ b/tests/clientRPC/handlers/agentUnlock.test.ts @@ -1,9 +1,7 @@ -import type { Server } from 'https'; -import type { WebSocketServer } from 'ws'; +import type { TLSConfig } from '@/network/types'; import fs from 'fs'; import path from 'path'; import os from 'os'; -import { createServer } from 'https'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import { DB } from '@matrixai/db'; import KeyRing from '@/keys/KeyRing'; @@ -20,6 +18,8 @@ import { Session, SessionManager } from '@/sessions'; import * as clientRPCUtils from '@/clientRPC/utils'; import * as authMiddleware from '@/clientRPC/authenticationMiddleware'; import * as middlewareUtils from '@/RPC/middleware'; +import WebSocketServer from '@/websockets/WebSocketServer'; +import WebSocketClient from '@/websockets/WebSocketClient'; import * as testsUtils from '../../utils'; describe('agentUnlock', () => { @@ -27,6 +27,7 @@ describe('agentUnlock', () => { new StreamHandler(), ]); const password = 'helloworld'; + const host = '127.0.0.1'; let dataDir: string; let db: DB; let keyRing: KeyRing; @@ -34,9 +35,9 @@ describe('agentUnlock', () => { let certManager: CertManager; let session: Session; let sessionManager: SessionManager; - let server: Server; - let wss: WebSocketServer; - let port: number; + let clientClient: WebSocketClient; + let clientServer: WebSocketServer; + let tlsConfig: TLSConfig; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -73,16 +74,11 @@ describe('agentUnlock', () => { keyRing, logger, }); - const tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); - server = createServer({ - cert: tlsConfig.certChainPem, - key: tlsConfig.keyPrivatePem, - }); - port = await clientRPCUtils.listen(server, '127.0.0.1'); + tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); }); afterEach(async () => { - wss?.close(); - server.close(); + await clientServer.stop(true); + await clientClient.destroy(true); await certManager.stop(); await taskManager.stop(); await keyRing.stop(); @@ -103,22 +99,24 @@ describe('agentUnlock', () => { ), logger, }); - wss = clientRPCUtils.createClientServer( - server, - rpcServer, - logger.getChild('server'), - ); + clientServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair, connectionInfo) => + rpcServer.handleStream(streamPair, connectionInfo), + host, + tlsConfig, + logger, + }); + clientClient = await WebSocketClient.createWebSocketClient({ + expectedNodeIds: [keyRing.getNodeId()], + host, + logger, + port: clientServer.port, + }); const rpcClient = await RPCClient.createRPCClient({ manifest: { agentUnlock: agentUnlockCaller, }, - streamPairCreateCallback: async () => { - return clientRPCUtils.startConnection( - '127.0.0.1', - port, - logger.getChild('client'), - ); - }, + streamPairCreateCallback: async () => clientClient.startConnection(), middleware: middlewareUtils.defaultClientMiddlewareWrapper( authMiddleware.authenticationMiddlewareClient(session), ), diff --git a/tests/clientRPC/websocket.test.ts b/tests/clientRPC/websocket.test.ts deleted file mode 100644 index 813935ab7..000000000 --- a/tests/clientRPC/websocket.test.ts +++ /dev/null @@ -1,114 +0,0 @@ -import type { TLSConfig } from '@/network/types'; -import type { Server } from 'https'; -import type { WebSocketServer } from 'ws'; -import type { ClientManifest } from '@/RPC/types'; -import type { JSONValue } from '@/types'; -import fs from 'fs'; -import path from 'path'; -import os from 'os'; -import { createServer } from 'https'; -import Logger, { LogLevel, StreamHandler, formatting } from '@matrixai/logger'; -import RPCServer from '@/RPC/RPCServer'; -import RPCClient from '@/RPC/RPCClient'; -import { KeyRing } from '@/keys/index'; -import * as clientRPCUtils from '@/clientRPC/utils'; -import { UnaryHandler } from '@/RPC/handlers'; -import { UnaryCaller } from '@/RPC/callers'; -import * as testsUtils from '../utils/index'; - -describe('websocket', () => { - const logger = new Logger('websocket test', LogLevel.WARN, [ - new StreamHandler( - formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, - ), - ]); - let dataDir: string; - let keyRing: KeyRing; - let tlsConfig: TLSConfig; - let server: Server; - let wss: WebSocketServer; - const host = '127.0.0.1'; - let port: number; - let rpcServer: RPCServer; - let rpcClient_: RPCClient; - - beforeEach(async () => { - dataDir = await fs.promises.mkdtemp( - path.join(os.tmpdir(), 'polykey-test-'), - ); - const keysPath = path.join(dataDir, 'keys'); - keyRing = await KeyRing.createKeyRing({ - keysPath: keysPath, - password: 'password', - logger: logger.getChild('keyRing'), - }); - tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); - server = createServer({ - cert: tlsConfig.certChainPem, - key: tlsConfig.keyPrivatePem, - }); - port = await clientRPCUtils.listen(server, host); - }); - afterEach(async () => { - await rpcClient_?.destroy(); - await rpcServer?.destroy(); - wss?.close(); - server.close(); - await keyRing.stop(); - await fs.promises.rm(dataDir, { force: true, recursive: true }); - }); - - test('websocket should work with RPC', async () => { - // Setting up server - class Test1 extends UnaryHandler { - public async handle(input: JSONValue): Promise { - return input; - } - } - class Test2 extends UnaryHandler { - public async handle(): Promise { - return { hello: 'not world' }; - } - } - rpcServer = await RPCServer.createRPCServer({ - manifest: { - test1: new Test1({}), - test2: new Test2({}), - }, - logger: logger.getChild('RPCServer'), - }); - wss = clientRPCUtils.createClientServer( - server, - rpcServer, - logger.getChild('client'), - ); - - // Setting up client - const rpcClient = await RPCClient.createRPCClient({ - manifest: { - test1: new UnaryCaller(), - test2: new UnaryCaller(), - }, - logger: logger.getChild('RPCClient'), - streamPairCreateCallback: async () => { - return clientRPCUtils.startConnection( - host, - port, - logger.getChild('Connection'), - ); - }, - }); - rpcClient_ = rpcClient; - - // Making the call - await expect( - rpcClient.methods.test1({ hello: 'world2' }), - ).resolves.toStrictEqual({ hello: 'world2' }); - await expect( - rpcClient.methods.test2({ hello: 'world2' }), - ).resolves.toStrictEqual({ hello: 'not world' }); - await expect( - rpcClient.unaryCaller('test3', { hello: 'world2' }), - ).toReject(); - }); -}); diff --git a/tests/utils/utils.ts b/tests/utils/utils.ts index 79c529794..a3d7e8ac0 100644 --- a/tests/utils/utils.ts +++ b/tests/utils/utils.ts @@ -3,6 +3,7 @@ import type { NodeId, CertId } from '@/ids/types'; import type { StatusLive } from '@/status/types'; import type { TLSConfig } from '@/network/types'; import type { CertificatePEMChain, KeyPair } from '@/keys/types'; +import type { Certificate } from '@/keys/types'; import path from 'path'; import fs from 'fs'; import readline from 'readline'; @@ -126,6 +127,38 @@ async function createTLSConfig( }; } +async function createTLSConfigWithChain( + keyPairs: Array, + generateCertId?: () => CertId, +): Promise { + if (keyPairs.length === 0) throw Error('Must have at least 1 keypair'); + generateCertId = generateCertId ?? keysUtils.createCertIdGenerator(); + let previousCert: Certificate | null = null; + let previousKeyPair: KeyPair | null = null; + const certChain: Array = []; + for (const keyPair of keyPairs) { + const newCert = await keysUtils.generateCertificate({ + certId: generateCertId(), + duration: 31536000, + issuerPrivateKey: previousKeyPair?.privateKey ?? keyPair.privateKey, + subjectKeyPair: keyPair, + issuerAttrsExtra: previousCert?.subjectName.toJSON(), + }); + certChain.unshift(newCert); + previousCert = newCert; + previousKeyPair = keyPair; + } + let certChainPEM = ''; + for (const certificate of certChain) { + certChainPEM += keysUtils.certToPEM(certificate); + } + + return { + keyPrivatePem: keysUtils.privateKeyToPEM(previousKeyPair!.privateKey), + certChainPem: certChainPEM as CertificatePEMChain, + }; +} + export { setupTestAgent, generateRandomNodeId, @@ -133,4 +166,5 @@ export { testIf, describeIf, createTLSConfig, + createTLSConfigWithChain, }; diff --git a/tests/websockets/WebSocket.test.ts b/tests/websockets/WebSocket.test.ts new file mode 100644 index 000000000..d098d60e8 --- /dev/null +++ b/tests/websockets/WebSocket.test.ts @@ -0,0 +1,859 @@ +import type { ReadableWritablePair } from 'stream/web'; +import type { TLSConfig } from '@/network/types'; +import type { KeyPair } from '@/keys/types'; +import type http from 'http'; +import type WebSocketStream from '@/websockets/WebSocketStream'; +import fs from 'fs'; +import path from 'path'; +import os from 'os'; +import https from 'https'; +import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; +import { testProp, fc } from '@fast-check/jest'; +import { Timer } from '@matrixai/timer'; +import { KeyRing } from '@/keys/index'; +import WebSocketServer from '@/websockets/WebSocketServer'; +import WebSocketClient from '@/websockets/WebSocketClient'; +import { promise } from '@/utils'; +import * as keysUtils from '@/keys/utils'; +import * as webSocketErrors from '@/websockets/errors'; +import * as nodesUtils from '@/nodes/utils'; +import * as testNodeUtils from '../nodes/utils'; +import * as testsUtils from '../utils'; + +// This file tests both the client and server together. They're too interlinked +// to be separate. +describe('WebSocket', () => { + const logger = new Logger('websocket test', LogLevel.WARN, [ + new StreamHandler( + formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, + ), + ]); + let dataDir: string; + let keyRing: KeyRing; + let tlsConfig: TLSConfig; + const host = '127.0.0.2'; + let webSocketServer: WebSocketServer; + let webSocketClient: WebSocketClient; + + const messagesArb = fc.array( + fc.uint8Array({ minLength: 1 }).map((d) => Buffer.from(d)), + ); + const streamsArb = fc.array(messagesArb, { minLength: 1 }).noShrink(); + const asyncReadWrite = async ( + messages: Array, + streamPair: ReadableWritablePair, + ) => { + await Promise.allSettled([ + (async () => { + const writer = streamPair.writable.getWriter(); + for (const message of messages) { + await writer.write(message); + } + await writer.close(); + })(), + (async () => { + for await (const _ of streamPair.readable) { + // No touch, only consume + } + })(), + ]); + }; + + beforeEach(async () => { + dataDir = await fs.promises.mkdtemp( + path.join(os.tmpdir(), 'polykey-test-'), + ); + const keysPath = path.join(dataDir, 'keys'); + keyRing = await KeyRing.createKeyRing({ + keysPath: keysPath, + password: 'password', + logger: logger.getChild('keyRing'), + }); + tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); + }); + afterEach(async () => { + logger.info('AFTEREACH'); + await webSocketServer?.stop(true); + await webSocketClient?.destroy(true); + await keyRing.stop(); + await fs.promises.rm(dataDir, { force: true, recursive: true }); + }); + + // These tests are share between client and server + test('makes a connection', async () => { + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ + host, + port: webSocketServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await webSocketClient.startConnection(); + + const writer = websocket.writable.getWriter(); + const reader = websocket.readable.getReader(); + const message1 = Buffer.from('1request1'); + await writer.write(message1); + expect((await reader.read()).value).toStrictEqual(message1); + const message2 = Buffer.from('1request2'); + await writer.write(message2); + expect((await reader.read()).value).toStrictEqual(message2); + await writer.close(); + expect((await reader.read()).done).toBeTrue(); + logger.info('ending'); + }); + test('makes a connection over IPv6', async () => { + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host: '::1', + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ + host: '::1', + port: webSocketServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await webSocketClient.startConnection(); + + const writer = websocket.writable.getWriter(); + const reader = websocket.readable.getReader(); + const message1 = Buffer.from('1request1'); + await writer.write(message1); + expect((await reader.read()).value).toStrictEqual(message1); + const message2 = Buffer.from('1request2'); + await writer.write(message2); + expect((await reader.read()).value).toStrictEqual(message2); + await writer.close(); + expect((await reader.read()).done).toBeTrue(); + logger.info('ending'); + }); + test('Handles a connection and closes before message', async () => { + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ + host, + port: webSocketServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await webSocketClient.startConnection(); + await websocket.writable.close(); + const reader = websocket.readable.getReader(); + expect((await reader.read()).done).toBeTrue(); + logger.info('ending'); + }); + testProp( + 'Handles multiple connections', + [streamsArb], + async (streamsData) => { + try { + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ + host, + port: webSocketServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + + const testStream = async (messages: Array) => { + const websocket = await webSocketClient.startConnection(); + const writer = websocket.writable.getWriter(); + const reader = websocket.readable.getReader(); + for (const message of messages) { + await writer.write(message); + const response = await reader.read(); + expect(response.done).toBeFalse(); + expect(response.value?.toString()).toStrictEqual( + message.toString(), + ); + } + await writer.close(); + expect((await reader.read()).done).toBeTrue(); + }; + const streams = streamsData.map((messages) => testStream(messages)); + await Promise.all(streams); + + logger.info('ending'); + } finally { + await webSocketServer.stop(true); + } + }, + ); + test('reverse backpressure', async () => { + const backpressure = promise(); + const resumeWriting = promise(); + let webSocketStream: WebSocketStream | null = null; + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void Promise.allSettled([ + (async () => { + for await (const _ of streamPair.readable) { + // No touch, only consume + } + })(), + (async () => { + // Kidnap the context + // @ts-ignore: kidnap protected property + for (const websocket of webSocketServer.activeSockets.values()) { + webSocketStream = websocket; + } + if (webSocketStream == null) { + await streamPair.writable.close(); + return; + } + // Write until backPressured + const message = Buffer.alloc(128, 0xf0); + const writer = streamPair.writable.getWriter(); + // @ts-ignore: kidnap protected property + while (!webSocketStream.writeBackpressure) { + await writer.write(message); + } + logger.info('BACK PRESSURED'); + backpressure.resolveP(); + await resumeWriting.p; + for (let i = 0; i < 100; i++) { + await writer.write(message); + } + await writer.close(); + logger.info('WRITING ENDED'); + })(), + ]).catch((e) => logger.error(e.toString())); + }, + basePath: dataDir, + tlsConfig, + host, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ + host, + port: webSocketServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await webSocketClient.startConnection(); + await websocket.writable.close(); + + await backpressure.p; + // @ts-ignore: kidnap protected property + expect(webSocketStream.writeBackpressure).toBeTrue(); + resumeWriting.resolveP(); + // Consume all the back-pressured data + for await (const _ of websocket.readable) { + // No touch, only consume + } + // @ts-ignore: kidnap protected property + expect(webSocketStream.writeBackpressure).toBeFalse(); + logger.info('ending'); + }); + // Readable backpressure is not actually supported. We're dealing with it by + // using a buffer with a provided limit that can be very large. + test('Exceeding readable buffer limit causes error', async () => { + const startReading = promise(); + const handlingProm = promise(); + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + Promise.all([ + (async () => { + await startReading.p; + logger.info('Starting consumption'); + for await (const _ of streamPair.readable) { + // No touch, only consume + } + logger.info('Reads ended'); + })(), + (async () => { + await streamPair.writable.close(); + })(), + ]) + .catch(() => {}) + .finally(() => handlingProm.resolveP()); + }, + basePath: dataDir, + tlsConfig, + host, + // Setting a really low buffer limit + maxReadBufferBytes: 1500, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ + host, + port: webSocketServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await webSocketClient.startConnection(); + const message = Buffer.alloc(1_000, 0xf0); + const writer = websocket.writable.getWriter(); + logger.info('Starting writes'); + await expect(async () => { + for (let i = 0; i < 100; i++) { + await writer.write(message); + } + }).rejects.toThrow(); + startReading.resolveP(); + logger.info('writes ended'); + await expect(async () => { + for await (const _ of websocket.readable) { + // No touch, only consume + } + }).rejects.toThrow(); + await handlingProm.p; + logger.info('ending'); + }); + test('client ends connection abruptly', async () => { + const streamPairProm = + promise>(); + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + streamPairProm.resolveP(streamPair); + }, + basePath: dataDir, + tlsConfig, + host, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.port}`); + + const testProcess = await testsUtils.spawn( + 'ts-node', + [ + '--project', + testsUtils.tsConfigPath, + `${globalThis.testDir}/websockets/testClient.ts`, + ], + { + env: { + PK_TEST_HOST: host, + PK_TEST_PORT: `${webSocketServer.port}`, + PK_TEST_NODE_ID: nodesUtils.encodeNodeId(keyRing.getNodeId()), + }, + }, + logger, + ); + const startedProm = promise(); + testProcess.stdout!.on('data', (data) => { + startedProm.resolveP(data.toString()); + }); + testProcess.stderr!.on('data', (data) => + startedProm.rejectP(data.toString()), + ); + const exitedProm = promise(); + testProcess.once('exit', () => exitedProm.resolveP()); + await startedProm.p; + + // Killing the client + testProcess.kill('SIGTERM'); + await exitedProm.p; + + const streamPair = await streamPairProm.p; + // Everything should throw after websocket ends early + await expect(async () => { + for await (const _ of streamPair.readable) { + // No touch, only consume + } + }).rejects.toThrow(); + const serverWritable = streamPair.writable.getWriter(); + await expect(serverWritable.write(Buffer.from('test'))).rejects.toThrow(); + logger.info('ending'); + }); + test('Server ends connection abruptly', async () => { + const testProcess = await testsUtils.spawn( + 'ts-node', + [ + '--project', + testsUtils.tsConfigPath, + `${globalThis.testDir}/websockets/testServer.ts`, + ], + { + env: { + PK_TEST_KEY_PRIVATE_PEM: tlsConfig.keyPrivatePem, + PK_TEST_CERT_CHAIN_PEM: tlsConfig.certChainPem, + PK_TEST_HOST: host, + }, + }, + logger, + ); + const startedProm = promise(); + testProcess.stdout!.on('data', (data) => { + startedProm.resolveP(parseInt(data.toString())); + }); + testProcess.stderr!.on('data', (data) => + startedProm.rejectP(data.toString()), + ); + const exitedProm = promise(); + testProcess.once('exit', () => exitedProm.resolveP()); + + logger.info(`Server started on port ${await startedProm.p}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ + host, + port: await startedProm.p, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await webSocketClient.startConnection(); + + // Killing the server + testProcess.kill('SIGTERM'); + await exitedProm.p; + + // Waiting for connections to end + await webSocketClient.destroy(); + // Checking client's response to connection dropping + await expect(async () => { + for await (const _ of websocket.readable) { + // No touch, only consume + } + }).rejects.toThrow(); + const clientWritable = websocket.writable.getWriter(); + await expect(clientWritable.write(Buffer.from('test'))).rejects.toThrow(); + logger.info('ending'); + }); + + // These describe blocks contains tests specific to either the client or server + describe('WebSocketServer', () => { + testProp( + 'allows half closed writable closes first', + [messagesArb, messagesArb], + async (messages1, messages2) => { + try { + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void (async () => { + const writer = streamPair.writable.getWriter(); + for await (const val of messages2) { + await writer.write(val); + } + await writer.close(); + for await (const _ of streamPair.readable) { + // No touch, only consume + } + })().catch((e) => logger.error(e)); + }, + basePath: dataDir, + tlsConfig, + host, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ + host, + port: webSocketServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await webSocketClient.startConnection(); + await asyncReadWrite(messages1, websocket); + logger.info('ending'); + } finally { + await webSocketServer.stop(true); + } + }, + ); + testProp( + 'allows half closed readable closes first', + [messagesArb, messagesArb], + async (messages1, messages2) => { + try { + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void (async () => { + for await (const _ of streamPair.readable) { + // No touch, only consume + } + const writer = streamPair.writable.getWriter(); + for await (const val of messages2) { + await writer.write(val); + } + await writer.close(); + })().catch((e) => logger.error(e)); + }, + basePath: dataDir, + tlsConfig, + host, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ + host, + port: webSocketServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await webSocketClient.startConnection(); + await asyncReadWrite(messages1, websocket); + logger.info('ending'); + } finally { + await webSocketServer.stop(true); + } + }, + ); + testProp( + 'handles early close of readable', + [messagesArb, messagesArb], + async (messages1, messages2) => { + try { + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void (async () => { + await streamPair.readable.cancel(); + const writer = streamPair.writable.getWriter(); + for await (const val of messages2) { + await writer.write(val); + } + await writer.close(); + })().catch((e) => logger.error(e)); + }, + basePath: dataDir, + tlsConfig, + host, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ + host, + port: webSocketServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await webSocketClient.startConnection(); + await asyncReadWrite(messages1, websocket); + logger.info('ending'); + } finally { + await webSocketServer.stop(true); + } + }, + ); + test('Destroying ClientServer stops all connections', async () => { + const streamPairProm = + promise>(); + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + streamPairProm.resolveP(streamPair); + }, + basePath: dataDir, + tlsConfig, + host, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ + host, + port: webSocketServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await webSocketClient.startConnection(); + await webSocketServer.stop(true); + const streamPair = await streamPairProm.p; + // Everything should throw after websocket ends early + await expect(async () => { + for await (const _ of websocket.readable) { + // No touch, only consume + } + }).rejects.toThrow(); + await expect(async () => { + for await (const _ of streamPair.readable) { + // No touch, only consume + } + }).rejects.toThrow(); + const clientWritable = websocket.writable.getWriter(); + const serverWritable = streamPair.writable.getWriter(); + await expect(clientWritable.write(Buffer.from('test'))).rejects.toThrow(); + await expect(serverWritable.write(Buffer.from('test'))).rejects.toThrow(); + logger.info('ending'); + }); + test('Server rejects normal HTTPS requests', async () => { + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.port}`); + const getResProm = promise(); + https.get( + `https://${host}:${webSocketServer.port}/`, + { rejectUnauthorized: false }, + getResProm.resolveP, + ); + const res = await getResProm.p; + const contentProm = promise(); + res.once('data', (d) => contentProm.resolveP(d.toString())); + const endProm = promise(); + res.on('error', endProm.rejectP); + res.on('close', endProm.resolveP); + + expect(res.statusCode).toBe(426); + await expect(contentProm.p).resolves.toBe('426 Upgrade Required'); + expect(res.headers['connection']).toBe('Upgrade'); + expect(res.headers['upgrade']).toBe('websocket'); + }); + test('ping timeout', async () => { + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (_) => { + logger.info('inside callback'); + // Hang connection + }, + basePath: dataDir, + tlsConfig, + host, + pingTimeout: 100, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ + host, + port: webSocketServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + await webSocketClient.startConnection(); + await webSocketClient.destroy(); + logger.info('ending'); + }); + }); + describe('WebSocketClient', () => { + test('Destroying ClientClient stops all connections', async () => { + const streamPairProm = + promise>(); + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + streamPairProm.resolveP(streamPair); + }, + basePath: dataDir, + tlsConfig, + host, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ + host, + port: webSocketServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await webSocketClient.startConnection(); + // Destroying the client, force close connections + await webSocketClient.destroy(true); + const streamPair = await streamPairProm.p; + // Everything should throw after websocket ends early + await expect(async () => { + for await (const _ of websocket.readable) { + // No touch, only consume + } + }).rejects.toThrow(); + await expect(async () => { + for await (const _ of streamPair.readable) { + // No touch, only consume + } + }).rejects.toThrow(); + const clientWritable = websocket.writable.getWriter(); + const serverWritable = streamPair.writable.getWriter(); + await expect(clientWritable.write(Buffer.from('test'))).rejects.toThrow(); + await expect(serverWritable.write(Buffer.from('test'))).rejects.toThrow(); + await webSocketServer.stop(); + logger.info('ending'); + }); + test('Authentication rejects bad server certificate', async () => { + const invalidNodeId = testNodeUtils.generateRandomNodeId(); + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ + host, + port: webSocketServer.port, + expectedNodeIds: [invalidNodeId], + logger: logger.getChild('clientClient'), + }); + await expect(webSocketClient.startConnection()).rejects.toThrow( + webSocketErrors.ErrorCertChainUnclaimed, + ); + // @ts-ignore: kidnap protected property + const activeConnections = webSocketClient.activeConnections; + expect(activeConnections.size).toBe(0); + await webSocketServer.stop(); + logger.info('ending'); + }); + test('Authenticates with multiple certs in chain', async () => { + const keyPairs: Array = [ + keyRing.keyPair, + keysUtils.generateKeyPair(), + keysUtils.generateKeyPair(), + keysUtils.generateKeyPair(), + keysUtils.generateKeyPair(), + ]; + const tlsConfig = await testsUtils.createTLSConfigWithChain(keyPairs); + const nodeId = keysUtils.publicKeyToNodeId(keyPairs[1].publicKey); + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ + host, + port: webSocketServer.port, + expectedNodeIds: [nodeId], + logger: logger.getChild('clientClient'), + }); + const connProm = webSocketClient.startConnection(); + await connProm; + await expect(connProm).toResolve(); + // @ts-ignore: kidnap protected property + const activeConnections = webSocketClient.activeConnections; + expect(activeConnections.size).toBe(1); + logger.info('ending'); + }); + test('Authenticates with multiple expected nodes', async () => { + const alternativeNodeId = testNodeUtils.generateRandomNodeId(); + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ + host, + port: webSocketServer.port, + expectedNodeIds: [keyRing.getNodeId(), alternativeNodeId], + logger: logger.getChild('clientClient'), + }); + await expect(webSocketClient.startConnection()).toResolve(); + // @ts-ignore: kidnap protected property + const activeConnections = webSocketClient.activeConnections; + expect(activeConnections.size).toBe(1); + logger.info('ending'); + }); + test('Connection times out', async () => { + webSocketClient = await WebSocketClient.createWebSocketClient({ + host, + port: 12345, + expectedNodeIds: [keyRing.getNodeId()], + connectionTimeout: 0, + logger: logger.getChild('clientClient'), + }); + await expect(webSocketClient.startConnection({})).rejects.toThrow(); + await expect( + webSocketClient.startConnection({ + timeoutTimer: new Timer({ delay: 0 }), + }), + ).rejects.toThrow(); + logger.info('ending'); + }); + test('ping timeout', async () => { + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (_) => { + logger.info('inside callback'); + // Hang connection + }, + basePath: dataDir, + tlsConfig, + host, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ + host, + port: webSocketServer.port, + expectedNodeIds: [keyRing.getNodeId()], + pingTimeout: 100, + logger: logger.getChild('clientClient'), + }); + await webSocketClient.startConnection(); + await webSocketClient.destroy(); + logger.info('ending'); + }); + }); +}); diff --git a/tests/websockets/testClient.ts b/tests/websockets/testClient.ts new file mode 100644 index 000000000..52179d0c3 --- /dev/null +++ b/tests/websockets/testClient.ts @@ -0,0 +1,31 @@ +/** + * This is spawned as a background process for use in some NodeConnection.test.ts tests + * This process will not preserve jest testing environment, + * any usage of jest globals will result in an error + * Beware of propagated usage of jest globals through the script dependencies + * @module + */ +import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; +import WebSocketClient from '@/websockets/WebSocketClient'; +import * as nodesUtils from '@/nodes/utils'; + +async function main() { + const logger = new Logger('websocket test', LogLevel.WARN, [ + new StreamHandler(), + ]); + const clientClient = await WebSocketClient.createWebSocketClient({ + expectedNodeIds: [nodesUtils.decodeNodeId(process.env.PK_TEST_NODE_ID!)!], + host: process.env.PK_TEST_HOST ?? '127.0.0.1', + port: parseInt(process.env.PK_TEST_PORT!), + logger, + }); + // Ignore streams, make connection hang + await clientClient.startConnection(); + process.stdout.write(`ready`); +} + +if (require.main === module) { + void main(); +} + +export default main; diff --git a/tests/websockets/testServer.ts b/tests/websockets/testServer.ts new file mode 100644 index 000000000..0a7aac880 --- /dev/null +++ b/tests/websockets/testServer.ts @@ -0,0 +1,36 @@ +/** + * This is spawned as a background process for use in some NodeConnection.test.ts tests + * This process will not preserve jest testing environment, + * any usage of jest globals will result in an error + * Beware of propagated usage of jest globals through the script dependencies + * @module + */ +import type { CertificatePEMChain, PrivateKeyPEM } from '@/keys/types'; +import type { TLSConfig } from '@/network/types'; +import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; +import WebSocketServer from '@/websockets/WebSocketServer'; + +async function main() { + const logger = new Logger('websocket test', LogLevel.WARN, [ + new StreamHandler(), + ]); + const tlsConfig: TLSConfig = { + keyPrivatePem: process.env.PK_TEST_KEY_PRIVATE_PEM as PrivateKeyPEM, + certChainPem: process.env.PK_TEST_CERT_CHAIN_PEM as CertificatePEMChain, + }; + const clientServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (_) => { + // Ignore streams and hang connections + }, + host: process.env.PK_TEST_HOST ?? '127.0.0.1', + tlsConfig, + logger, + }); + process.stdout.write(`${clientServer.port}`); +} + +if (require.main === module) { + void main(); +} + +export default main;