From 9cd245aab5c12be91eaa9467e777565bdb7d25cf Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Mon, 23 Oct 2023 15:53:06 +1100 Subject: [PATCH] feat: refactored hole punch signalling procedure * general refactor of the signalling protocol. * Added signatures and verification to the signalling requests and relay messages. #148 --- src/nodes/NodeConnectionManager.ts | 245 ++++++---- src/nodes/agent/callers/index.ts | 9 +- .../callers/nodesConnectionSignalFinal.ts | 12 + .../callers/nodesConnectionSignalInitial.ts | 12 + .../callers/nodesHolePunchMessageSend.ts | 12 - src/nodes/agent/errors.ts | 20 +- .../handlers/NodesConnectionSignalFinal.ts | 73 +++ .../handlers/NodesConnectionSignalInitial.ts | 66 +++ .../handlers/NodesHolePunchMessageSend.ts | 113 ----- src/nodes/agent/handlers/index.ts | 9 +- src/nodes/agent/types.ts | 18 +- src/nodes/errors.ts | 18 +- src/rateLimiter/RateLimiter.ts | 151 ++++++ src/rateLimiter/errors.ts | 10 + src/rateLimiter/index.ts | 2 + .../NodeConnectionManager.general.test.ts | 446 +++++++++++------- .../NodeConnectionManager.timeout.test.ts | 39 -- .../nodesConnectionSignalFinal.test.ts | 179 +++++++ .../nodesConnectionSignalInitial.test.ts | 158 +++++++ .../handlers/nodesHolePunchMessage.test.ts | 262 ---------- tests/rateLimiter/RateLimiter.test.ts | 57 +++ 21 files changed, 1207 insertions(+), 704 deletions(-) create mode 100644 src/nodes/agent/callers/nodesConnectionSignalFinal.ts create mode 100644 src/nodes/agent/callers/nodesConnectionSignalInitial.ts delete mode 100644 src/nodes/agent/callers/nodesHolePunchMessageSend.ts create mode 100644 src/nodes/agent/handlers/NodesConnectionSignalFinal.ts create mode 100644 src/nodes/agent/handlers/NodesConnectionSignalInitial.ts delete mode 100644 src/nodes/agent/handlers/NodesHolePunchMessageSend.ts create mode 100644 src/rateLimiter/RateLimiter.ts create mode 100644 src/rateLimiter/errors.ts create mode 100644 src/rateLimiter/index.ts create mode 100644 tests/nodes/agent/handlers/nodesConnectionSignalFinal.test.ts create mode 100644 tests/nodes/agent/handlers/nodesConnectionSignalInitial.test.ts delete mode 100644 tests/nodes/agent/handlers/nodesHolePunchMessage.test.ts create mode 100644 tests/rateLimiter/RateLimiter.test.ts diff --git a/src/nodes/NodeConnectionManager.ts b/src/nodes/NodeConnectionManager.ts index fc0be62e2e..f884af85d7 100644 --- a/src/nodes/NodeConnectionManager.ts +++ b/src/nodes/NodeConnectionManager.ts @@ -1,7 +1,6 @@ import type { LockRequest } from '@matrixai/async-locks'; import type { ResourceAcquire } from '@matrixai/resources'; import type { ContextTimedInput, ContextTimed } from '@matrixai/contexts'; -import type { PromiseCancellable } from '@matrixai/async-cancellable'; import type { ClientCryptoOps, QUICConnection } from '@matrixai/quic'; import type NodeGraph from './NodeGraph'; import type { @@ -21,15 +20,15 @@ import type { TLSConfig, } from '../network/types'; import type { ServerManifest } from '@matrixai/rpc'; -import type { HolePunchRelayMessage } from './agent/types'; import Logger from '@matrixai/logger'; import { withF } from '@matrixai/resources'; import { ready, StartStop } from '@matrixai/async-init/dist/StartStop'; import { IdInternal } from '@matrixai/id'; -import { Lock, LockBox } from '@matrixai/async-locks'; +import { Lock, LockBox, Semaphore } from '@matrixai/async-locks'; import { Timer } from '@matrixai/timer'; import { timedCancellable, context } from '@matrixai/contexts/dist/decorators'; import { AbstractEvent, EventAll } from '@matrixai/events'; +import { PromiseCancellable } from '@matrixai/async-cancellable'; import { QUICSocket, QUICServer, @@ -43,11 +42,11 @@ import * as nodesUtils from './utils'; import * as nodesErrors from './errors'; import * as nodesEvents from './events'; import manifestClientAgent from './agent/callers'; -import * as ids from '../ids'; import * as keysUtils from '../keys/utils'; import * as networkUtils from '../network/utils'; import * as utils from '../utils'; import config from '../config'; +import RateLimiter from '../rateLimiter/RateLimiter'; type ManifestClientAgent = typeof manifestClientAgent; @@ -129,6 +128,22 @@ class NodeConnectionManager { * Default timeout for RPC handlers */ public readonly rpcCallTimeoutTime: number; + /** + * Used to track active hole punching attempts + */ + protected activePunchMap = new Map>(); + /** + * Used to rate limit punch attempts per IP Address + */ + protected activeAddressMap = new Map(); + /** + * Used track active signalling attempts + */ + protected activeSignalSet = new Set>(); + /** + * Used to limit signalling requests on a per requester basis + */ + protected rateLimiter = new RateLimiter(60000, 20, 10, 1); protected logger: Logger; protected keyRing: KeyRing; @@ -456,11 +471,13 @@ class NodeConnectionManager { this.handleEventQUICServerConnection, ); this.quicSocket.addEventListener(EventAll.name, this.handleEventAll); + this.rateLimiter.startRefillInterval(); this.logger.info(`Started ${this.constructor.name}`); } public async stop() { this.logger.info(`Stop ${this.constructor.name}`); + this.rateLimiter.stop(); this.removeEventListener( nodesEvents.EventNodeConnectionManagerError.name, @@ -503,6 +520,16 @@ class NodeConnectionManager { destroyProms.push(destroyProm); } await Promise.all(destroyProms); + const signallingProms: Array> = []; + for (const [, activePunch] of this.activePunchMap) { + signallingProms.push(activePunch); + activePunch.cancel(); + } + for (const activeSignal of this.activeSignalSet) { + signallingProms.push(activeSignal); + activeSignal.cancel(); + } + await Promise.allSettled(signallingProms); await this.quicServer.stop({ force: true }); await this.quicSocket.stop({ force: true }); await this.rpcServer.stop({ force: true }); @@ -904,12 +931,7 @@ class NodeConnectionManager { // 3. if already exists then clean up await connection.destroy({ force: true }); // I can only see this happening as a race condition with creating a forward connection and receiving a reverse. - // FIXME: only here to see if this condition happens. - // this NEEDS to be removed, but I want to know if this branch happens at all. - throw Error( - 'TMP IMP, This should be exceedingly rare, lets see if it happens', - ); - // Return; + return; } // Final setup const newConnAndTimer = this.addConnection(nodeId, connection); @@ -1074,11 +1096,11 @@ class NodeConnectionManager { * @param port port of the target client. * @param ctx */ - public async holePunchReverse( + public holePunchReverse( host: Host, port: Port, - ctx?: Partial, - ): Promise; + ctx?: Partial, + ): PromiseCancellable; @ready(new nodesErrors.ErrorNodeConnectionManagerNotRunning()) @timedCancellable( true, @@ -1378,106 +1400,129 @@ class NodeConnectionManager { } /** - * Performs an RPC request to send a hole-punch message to the target. Used to - * initially establish the NodeConnection from source to target. + * This is used by the `nodesHolePunchRequestHandler` to initiate the hole punch procedure. * - * @param relayNodeId node ID of the relay node (i.e. the seed node) - * @param sourceNodeId node ID of the current node (i.e. the sender) - * @param targetNodeId node ID of the target node to hole punch - * @param address - * @param ctx + * Will validate the message, and initiate hole punching in the background and return immediately */ - public sendSignalingMessage( - relayNodeId: NodeId, - sourceNodeId: NodeId, - targetNodeId: NodeId, - address?: NodeAddress, - ctx?: Partial, - ): PromiseCancellable; - @ready(new nodesErrors.ErrorNodeConnectionManagerNotRunning()) - @timedCancellable( - true, - (nodeConnectionManager: NodeConnectionManager) => - nodeConnectionManager.connectionConnectTimeoutTime, - ) - public async sendSignalingMessage( - relayNodeId: NodeId, + @ready(new nodesErrors.ErrorNodeManagerNotRunning()) + public handleNodesConnectionSignalFinal(host: Host, port: Port) { + const id = `${host}:${port}`; + if (this.activePunchMap.has(id)) return; + // Checking for resource semaphore + let semaphore: Semaphore | undefined = this.activeAddressMap.get(host); + if (semaphore == null) { + semaphore = new Semaphore(3); + this.activeAddressMap.set(host, semaphore); + } + const holePunchAttempt = new PromiseCancellable( + async (res, rej, signal) => { + await semaphore!.withF(async () => { + this.holePunchReverse(host, port, { signal }) + .finally(() => { + this.activePunchMap.delete(id); + if (semaphore!.count === 0) { + this.activeAddressMap.delete(host); + } + }) + .then(res, rej); + }); + }, + ); + this.activePunchMap.set(id, holePunchAttempt); + } + + /** + * The handler used by the RPC to process signalling requests + * @param sourceNodeId - NodeId of the node making the request. Used for rate limiting. + * @param targetNodeId - NodeId of the node that needs to initiate hole punching. + * @param address - Address the target needs to punch to. + * @param requestSignature - `base64url` encoded signature + */ + @ready(new nodesErrors.ErrorNodeManagerNotRunning()) + public handleNodesConnectionSignalInitial( sourceNodeId: NodeId, targetNodeId: NodeId, - address: NodeAddress | undefined, - @context ctx: ContextTimed, - ): Promise { - if ( - this.keyRing.getNodeId().equals(relayNodeId) || - this.keyRing.getNodeId().equals(targetNodeId) - ) { - // Logging and silently dropping operation - this.logger.debug( - 'Attempted to send signaling message to our own NodeId', - ); - return; + address: NodeAddress, + requestSignature: string, + ) { + // Need to get the connection details of the requester and add it to the message. + // Then send the message to the target. + // This would only function with existing connections + if (!this.hasConnection(targetNodeId)) { + throw new nodesErrors.ErrorNodeConnectionManagerConnectionNotFound(); } - const rlyNode = nodesUtils.encodeNodeId(relayNodeId); - const srcNode = nodesUtils.encodeNodeId(sourceNodeId); - const tgtNode = nodesUtils.encodeNodeId(targetNodeId); - const addressString = - address != null ? `, address: ${address.host}:${address.port}` : ''; - this.logger.debug( - `sendSignalingMessage sending Signaling message relay: ${rlyNode}, source: ${srcNode}, target: ${tgtNode}${addressString}`, + // Do other checks. + const sourceNodeIdString = sourceNodeId.toString(); + if (!this.rateLimiter.consume(sourceNodeIdString)) { + throw new nodesErrors.ErrorNodeConnectionManagerRequestRateExceeded(); + } + // Generating relay signature, data is just `
` concatenated + const data = Buffer.concat([ + sourceNodeId, + targetNodeId, + Buffer.from(JSON.stringify(address), 'utf-8'), + Buffer.from(requestSignature, 'base64url'), + ]); + const relaySignature = keysUtils.signWithPrivateKey( + this.keyRing.keyPair, + data, ); - // Send message and ignore any error - await this.withConnF( - relayNodeId, - async (connection) => { - const client = connection.getClient(); - await client.methods.nodesHolePunchMessageSend( - { - srcIdEncoded: srcNode, - dstIdEncoded: tgtNode, - address, - }, - ctx, - ); - }, - ctx, - ).catch(() => {}); + const connProm = this.withConnF(targetNodeId, async (conn) => { + const client = conn.getClient(); + await client.methods.nodesConnectionSignalFinal({ + sourceNodeIdEncoded: nodesUtils.encodeNodeId(sourceNodeId), + targetNodeIdEncoded: nodesUtils.encodeNodeId(targetNodeId), + address, + requestSignature: requestSignature, + relaySignature: relaySignature.toString('base64url'), + }); + }).finally(() => { + this.activeSignalSet.delete(connProm); + }); + this.activeSignalSet.add(connProm); } /** - * Forwards a received hole punch message on to the target. - * If not known, the node ID -> address mapping is attempted to be discovered - * through Kademlia (note, however, this is currently only called by a 'broker' - * node). - * @param message the original relay message (assumed to be created in - * nodeConnection.start()) - * @param sourceAddress + * This till ask a signalling node to signal a target node to hole punch back to this node. + * @param targetNodeId - NodeId of the node that needs to signal back. + * @param signallingNodeId - NodeId of the signalling node. * @param ctx */ - public relaySignalingMessage( - message: HolePunchRelayMessage, - sourceAddress: NodeAddress, - ctx?: Partial, + public holePunchSignalRequest( + targetNodeId: NodeId, + signallingNodeId: NodeId, + ctx?: Partial, ): PromiseCancellable; - @ready(new nodesErrors.ErrorNodeConnectionManagerNotRunning()) + @ready(new nodesErrors.ErrorNodeManagerNotRunning()) @timedCancellable( true, (nodeConnectionManager: NodeConnectionManager) => nodeConnectionManager.connectionConnectTimeoutTime, ) - public async relaySignalingMessage( - message: HolePunchRelayMessage, - sourceAddress: NodeAddress, + public async holePunchSignalRequest( + targetNodeId: NodeId, + signallingNodeId: NodeId, @context ctx: ContextTimed, ): Promise { - // First check if we already have an existing ID -> address record - // If we're relaying then we trust our own node graph records over - // what was provided in the message - const sourceNode = ids.parseNodeId(message.srcIdEncoded); - await this.sendSignalingMessage( - ids.parseNodeId(message.dstIdEncoded), - sourceNode, - ids.parseNodeId(message.dstIdEncoded), - sourceAddress, + await this.withConnF( + signallingNodeId, + async (conn) => { + const client = conn.getClient(); + const sourceNodeId = this.keyRing.getNodeId(); + // Data is just `` concatenated + const data = Buffer.concat([sourceNodeId, targetNodeId]); + const signature = keysUtils.signWithPrivateKey( + this.keyRing.keyPair, + data, + ); + await client.methods.nodesConnectionSignalInitial( + { + targetNodeIdEncoded: nodesUtils.encodeNodeId(targetNodeId), + signature: signature.toString('base64url'), + }, + ctx, + ); + }, ctx, ); } @@ -1657,19 +1702,15 @@ class NodeConnectionManager { const allProms: Array>> = []; for (const targetNodeId of targetNodeIds) { if (!this.isSeedNode(targetNodeId)) { + // Ask seed nodes to signal hole punching for target const holePunchProms = seedNodes.map((seedNodeId) => { return ( - this.sendSignalingMessage( - seedNodeId, - this.keyRing.getNodeId(), - targetNodeId, - undefined, - ctx, - ) + this.holePunchSignalRequest(targetNodeId, seedNodeId, ctx) // Ignore results .then( () => {}, - () => {}, + (e) => + this.logger.debug(`signal request failed with ${e.message}`), ) ); }); diff --git a/src/nodes/agent/callers/index.ts b/src/nodes/agent/callers/index.ts index 32e6c916eb..3bb027081c 100644 --- a/src/nodes/agent/callers/index.ts +++ b/src/nodes/agent/callers/index.ts @@ -1,7 +1,8 @@ import nodesClaimsGet from './nodesClaimsGet'; import nodesClosestLocalNodesGet from './nodesClosestLocalNodesGet'; +import nodesConnectionSignalFinal from './nodesConnectionSignalFinal'; +import nodesConnectionSignalInitial from './nodesConnectionSignalInitial'; import nodesCrossSignClaim from './nodesCrossSignClaim'; -import nodesHolePunchMessageSend from './nodesHolePunchMessageSend'; import notificationsSend from './notificationsSend'; import vaultsGitInfoGet from './vaultsGitInfoGet'; import vaultsGitPackGet from './vaultsGitPackGet'; @@ -13,8 +14,9 @@ import vaultsScan from './vaultsScan'; const manifestClient = { nodesClaimsGet, nodesClosestLocalNodesGet, + nodesConnectionSignalFinal, + nodesConnectionSignalInitial, nodesCrossSignClaim, - nodesHolePunchMessageSend, notificationsSend, vaultsGitInfoGet, vaultsGitPackGet, @@ -26,8 +28,9 @@ export default manifestClient; export { nodesClaimsGet, nodesClosestLocalNodesGet, + nodesConnectionSignalFinal, + nodesConnectionSignalInitial, nodesCrossSignClaim, - nodesHolePunchMessageSend, notificationsSend, vaultsGitInfoGet, vaultsGitPackGet, diff --git a/src/nodes/agent/callers/nodesConnectionSignalFinal.ts b/src/nodes/agent/callers/nodesConnectionSignalFinal.ts new file mode 100644 index 0000000000..2d341f2d46 --- /dev/null +++ b/src/nodes/agent/callers/nodesConnectionSignalFinal.ts @@ -0,0 +1,12 @@ +import type { HandlerTypes } from '@matrixai/rpc'; +import type NodesConnectionSignalFinal from '../handlers/NodesConnectionSignalFinal'; +import { UnaryCaller } from '@matrixai/rpc'; + +type CallerTypes = HandlerTypes; + +const nodesConnectionSignalFinal = new UnaryCaller< + CallerTypes['input'], + CallerTypes['output'] +>(); + +export default nodesConnectionSignalFinal; diff --git a/src/nodes/agent/callers/nodesConnectionSignalInitial.ts b/src/nodes/agent/callers/nodesConnectionSignalInitial.ts new file mode 100644 index 0000000000..eef756e4e3 --- /dev/null +++ b/src/nodes/agent/callers/nodesConnectionSignalInitial.ts @@ -0,0 +1,12 @@ +import type { HandlerTypes } from '@matrixai/rpc'; +import type NodesConnectionSignalInitial from '../handlers/NodesConnectionSignalInitial'; +import { UnaryCaller } from '@matrixai/rpc'; + +type CallerTypes = HandlerTypes; + +const nodesConnectionSignalInitial = new UnaryCaller< + CallerTypes['input'], + CallerTypes['output'] +>(); + +export default nodesConnectionSignalInitial; diff --git a/src/nodes/agent/callers/nodesHolePunchMessageSend.ts b/src/nodes/agent/callers/nodesHolePunchMessageSend.ts deleted file mode 100644 index cea3acc068..0000000000 --- a/src/nodes/agent/callers/nodesHolePunchMessageSend.ts +++ /dev/null @@ -1,12 +0,0 @@ -import type { HandlerTypes } from '@matrixai/rpc'; -import type NodesHolePunchMessageSend from '../handlers/NodesHolePunchMessageSend'; -import { UnaryCaller } from '@matrixai/rpc'; - -type CallerTypes = HandlerTypes; - -const nodesHolePunchMessageSend = new UnaryCaller< - CallerTypes['input'], - CallerTypes['output'] ->(); - -export default nodesHolePunchMessageSend; diff --git a/src/nodes/agent/errors.ts b/src/nodes/agent/errors.ts index b52b6f9ac4..14b8f25dd7 100644 --- a/src/nodes/agent/errors.ts +++ b/src/nodes/agent/errors.ts @@ -8,4 +8,22 @@ class ErrorAgentNodeIdMissing extends ErrorAgent { exitCode = sysexits.UNAVAILABLE; } -export { ErrorAgentNodeIdMissing }; +class ErrorNodesConnectionSignalRequestVerificationFailed< + T, +> extends ErrorAgent { + static description = 'Failed to verify request message signature'; + exitCode = sysexits.TEMPFAIL; +} + +class ErrorNodesConnectionSignalRelayVerificationFailed< + T, +> extends ErrorAgent { + static description = 'Failed to verify relay message signature'; + exitCode = sysexits.TEMPFAIL; +} + +export { + ErrorAgentNodeIdMissing, + ErrorNodesConnectionSignalRequestVerificationFailed, + ErrorNodesConnectionSignalRelayVerificationFailed, +}; diff --git a/src/nodes/agent/handlers/NodesConnectionSignalFinal.ts b/src/nodes/agent/handlers/NodesConnectionSignalFinal.ts new file mode 100644 index 0000000000..4e8b524572 --- /dev/null +++ b/src/nodes/agent/handlers/NodesConnectionSignalFinal.ts @@ -0,0 +1,73 @@ +import type Logger from '@matrixai/logger'; +import type { + AgentRPCRequestParams, + AgentRPCResponseResult, + HolePunchRequestMessage, +} from '../types'; +import type NodeConnectionManager from '../../NodeConnectionManager'; +import type { Host, Port } from '../../../network/types'; +import { UnaryHandler } from '@matrixai/rpc'; +import * as keysUtils from '../../../keys/utils'; +import * as ids from '../../../ids'; +import * as agentErrors from '../errors'; +import * as agentUtils from '../utils'; + +class NodesConnectionSignalFinal extends UnaryHandler< + { + nodeConnectionManager: NodeConnectionManager; + logger: Logger; + }, + AgentRPCRequestParams, + AgentRPCResponseResult +> { + public handle = async ( + input: AgentRPCRequestParams, + _cancel, + meta, + ): Promise => { + const { nodeConnectionManager, logger } = this.container; + // Connections should always be validated + const sourceNodeId = ids.parseNodeId(input.sourceNodeIdEncoded); + const targetNodeId = ids.parseNodeId(input.targetNodeIdEncoded); + const relayingNodeId = agentUtils.nodeIdFromMeta(meta); + if (relayingNodeId == null) { + throw new agentErrors.ErrorAgentNodeIdMissing(); + } + const requestSignature = Buffer.from(input.requestSignature, 'base64url'); + // Checking request requestSignature, requestData is just `` concatenated + const requestData = Buffer.concat([sourceNodeId, targetNodeId]); + const sourcePublicKey = keysUtils.publicKeyFromNodeId(sourceNodeId); + if ( + !keysUtils.verifyWithPublicKey( + sourcePublicKey, + requestData, + requestSignature, + ) + ) { + throw new agentErrors.ErrorNodesConnectionSignalRequestVerificationFailed(); + } + // Checking relay message relaySignature. + // relayData is just `
` concatenated. + const relayData = Buffer.concat([ + sourceNodeId, + targetNodeId, + Buffer.from(JSON.stringify(input.address), 'utf-8'), + requestSignature, + ]); + const relayPublicKey = keysUtils.publicKeyFromNodeId(relayingNodeId); + const relaySignature = Buffer.from(input.relaySignature, 'base64url'); + if ( + !keysUtils.verifyWithPublicKey(relayPublicKey, relayData, relaySignature) + ) { + throw new agentErrors.ErrorNodesConnectionSignalRelayVerificationFailed(); + } + + const host = input.address.host as Host; + const port = input.address.port as Port; + logger.debug(`Received signaling message to target ${host}:${port}`); + nodeConnectionManager.handleNodesConnectionSignalFinal(host, port); + return {}; + }; +} + +export default NodesConnectionSignalFinal; diff --git a/src/nodes/agent/handlers/NodesConnectionSignalInitial.ts b/src/nodes/agent/handlers/NodesConnectionSignalInitial.ts new file mode 100644 index 0000000000..2ee438bbc9 --- /dev/null +++ b/src/nodes/agent/handlers/NodesConnectionSignalInitial.ts @@ -0,0 +1,66 @@ +import type { + AgentRPCRequestParams, + AgentRPCResponseResult, + HolePunchSignalMessage, +} from '../types'; +import type NodeConnectionManager from '../../../nodes/NodeConnectionManager'; +import type { Host, Port } from '../../../network/types'; +import type { NodeAddress } from '../../../nodes/types'; +import type { JSONValue } from '../../../types'; +import { UnaryHandler } from '@matrixai/rpc'; +import * as agentErrors from '../errors'; +import * as agentUtils from '../utils'; +import { never } from '../../../utils'; +import * as keysUtils from '../../../keys/utils'; +import * as ids from '../../../ids'; + +class NodesConnectionSignalInitial extends UnaryHandler< + { + nodeConnectionManager: NodeConnectionManager; + }, + AgentRPCRequestParams, + AgentRPCResponseResult +> { + public handle = async ( + input: AgentRPCRequestParams, + _cancel, + meta: Record | undefined, + ): Promise => { + const { nodeConnectionManager } = this.container; + // Connections should always be validated + const requestingNodeId = agentUtils.nodeIdFromMeta(meta); + if (requestingNodeId == null) { + throw new agentErrors.ErrorAgentNodeIdMissing(); + } + const targetNodeId = ids.parseNodeId(input.targetNodeIdEncoded); + const signature = Buffer.from(input.signature, 'base64url'); + // Checking signature, data is just `` concatenated + const data = Buffer.concat([requestingNodeId, targetNodeId]); + const sourcePublicKey = keysUtils.publicKeyFromNodeId(requestingNodeId); + if (!keysUtils.verifyWithPublicKey(sourcePublicKey, data, signature)) { + throw new agentErrors.ErrorNodesConnectionSignalRelayVerificationFailed(); + } + if (meta == null) never('Missing metadata from stream'); + const remoteHost = meta.remoteHost; + const remotePort = meta.remotePort; + if (remoteHost == null || typeof remoteHost !== 'string') { + never('Missing or invalid remoteHost'); + } + if (remotePort == null || typeof remotePort !== 'number') { + never('Missing or invalid remotePort'); + } + const address: NodeAddress = { + host: remoteHost as Host, + port: remotePort as Port, + }; + nodeConnectionManager.handleNodesConnectionSignalInitial( + requestingNodeId, + targetNodeId, + address, + input.signature, + ); + return {}; + }; +} + +export default NodesConnectionSignalInitial; diff --git a/src/nodes/agent/handlers/NodesHolePunchMessageSend.ts b/src/nodes/agent/handlers/NodesHolePunchMessageSend.ts deleted file mode 100644 index 947fc2b949..0000000000 --- a/src/nodes/agent/handlers/NodesHolePunchMessageSend.ts +++ /dev/null @@ -1,113 +0,0 @@ -import type { DB } from '@matrixai/db'; -import type Logger from '@matrixai/logger'; -import type { - AgentRPCRequestParams, - AgentRPCResponseResult, - HolePunchRelayMessage, -} from '../types'; -import type NodeConnectionManager from '../../NodeConnectionManager'; -import type NodeManager from '../../NodeManager'; -import type KeyRing from '../../../keys/KeyRing'; -import type { Host, Port } from '../../../network/types'; -import type { NodeId } from '../../../ids'; -import { UnaryHandler } from '@matrixai/rpc'; -import * as agentUtils from '../utils'; -import * as agentErrors from '../errors'; -import * as ids from '../../../ids'; -import * as nodesUtils from '../../utils'; -import * as validation from '../../../validation'; -import * as utils from '../../../utils'; - -/** - * Sends a hole punch message to a node - */ -class NodesHolePunchMessageSend extends UnaryHandler< - { - db: DB; - nodeConnectionManager: NodeConnectionManager; - keyRing: KeyRing; - nodeManager: NodeManager; - logger: Logger; - }, - AgentRPCRequestParams, - AgentRPCResponseResult -> { - public handle = async ( - input: AgentRPCRequestParams, - _cancel, - meta, - ): Promise => { - const { db, nodeConnectionManager, keyRing, nodeManager, logger } = - this.container; - const { - targetId, - sourceId, - }: { - targetId: NodeId; - sourceId: NodeId; - } = validation.validateSync( - (keyPath, value) => { - return utils.matchSync(keyPath)( - [['targetId'], ['sourceId'], () => ids.parseNodeId(value)], - () => value, - ); - }, - { - targetId: input.dstIdEncoded, - sourceId: input.srcIdEncoded, - }, - ); - // Connections should always be validated - const requestingNodeId = agentUtils.nodeIdFromMeta(meta); - if (requestingNodeId == null) { - throw new agentErrors.ErrorAgentNodeIdMissing(); - } - const srcNodeId = nodesUtils.encodeNodeId(requestingNodeId); - // Firstly, check if this node is the desired node - // If so, then we want to make this node start sending hole punching packets - // back to the source node. - await db.withTransactionF(async (tran) => { - if (keyRing.getNodeId().equals(targetId)) { - if (input.address != null) { - const host = input.address.host as Host; - const port = input.address.port as Port; - logger.debug( - `Received signaling message to target ${input.srcIdEncoded}@${host}:${port}`, - ); - // Ignore failure - await nodeConnectionManager - .holePunchReverse(host, port) - .catch(() => {}); - } else { - logger.error( - 'Received signaling message, target information was missing, skipping reverse hole punch', - ); - } - } else if (await nodeManager.knowsNode(sourceId, tran)) { - // Otherwise, find if node in table - // If so, ask the nodeManager to relay to the node - const targetNodeId = input.dstIdEncoded; - const agentAddress = { - host: meta.remoteHost, - port: meta.remotePort, - }; - // Checking if the source and destination are the same - if (sourceId?.equals(targetId)) { - // Logging and silently dropping operation - logger.warn('Signaling relay message requested signal to itself'); - return {}; - } - logger.debug( - `Relaying signaling message from ${srcNodeId}@${agentAddress.host}:${agentAddress.port} to ${targetNodeId} with information ${agentAddress}`, - ); - await nodeConnectionManager.relaySignalingMessage(input, { - host: meta.remoteHost, - port: meta.remotePort, - }); - } - }); - return {}; - }; -} - -export default NodesHolePunchMessageSend; diff --git a/src/nodes/agent/handlers/index.ts b/src/nodes/agent/handlers/index.ts index ee8e729149..cb2549c72b 100644 --- a/src/nodes/agent/handlers/index.ts +++ b/src/nodes/agent/handlers/index.ts @@ -10,8 +10,9 @@ import type NotificationsManager from '../../../notifications/NotificationsManag import type VaultManager from '../../../vaults/VaultManager'; import NodesClaimsGet from './NodesClaimsGet'; import NodesClosestLocalNodesGet from './NodesClosestLocalNodesGet'; +import NodesConnectionSignalFinal from './NodesConnectionSignalFinal'; +import NodesConnectionSignalInitial from './NodesConnectionSignalInitial'; import NodesCrossSignClaim from './NodesCrossSignClaim'; -import NodesHolePunchMessageSend from './NodesHolePunchMessageSend'; import NotificationsSend from './NotificationsSend'; import VaultsGitInfoGet from './VaultsGitInfoGet'; import VaultsGitPackGet from './VaultsGitPackGet'; @@ -35,8 +36,9 @@ const manifestServer = (container: { return { nodesClaimsGet: new NodesClaimsGet(container), nodesClosestLocalNodesGet: new NodesClosestLocalNodesGet(container), + nodesConnectionSignalFinal: new NodesConnectionSignalFinal(container), + nodesConnectionSignalInitial: new NodesConnectionSignalInitial(container), nodesCrossSignClaim: new NodesCrossSignClaim(container), - nodesHolePunchMessageSend: new NodesHolePunchMessageSend(container), notificationsSend: new NotificationsSend(container), vaultsGitInfoGet: new VaultsGitInfoGet(container), vaultsGitPackGet: new VaultsGitPackGet(container), @@ -49,8 +51,9 @@ export default manifestServer; export { NodesClaimsGet, NodesClosestLocalNodesGet, + NodesConnectionSignalFinal, + NodesConnectionSignalInitial, NodesCrossSignClaim, - NodesHolePunchMessageSend, NotificationsSend, VaultsGitInfoGet, VaultsGitPackGet, diff --git a/src/nodes/agent/types.ts b/src/nodes/agent/types.ts index 95de729507..a962423964 100644 --- a/src/nodes/agent/types.ts +++ b/src/nodes/agent/types.ts @@ -45,10 +45,17 @@ type AddressMessage = { type NodeAddressMessage = NodeIdMessage & AddressMessage; -type HolePunchRelayMessage = { - srcIdEncoded: NodeIdEncoded; - dstIdEncoded: NodeIdEncoded; - address?: AddressMessage; +type HolePunchRequestMessage = { + sourceNodeIdEncoded: NodeIdEncoded; + targetNodeIdEncoded: NodeIdEncoded; + address: AddressMessage; + requestSignature: string; + relaySignature: string; +}; + +type HolePunchSignalMessage = { + targetNodeIdEncoded: NodeIdEncoded; + signature: string; }; type SignedNotificationEncoded = { @@ -72,7 +79,8 @@ export type { NodeIdMessage, AddressMessage, NodeAddressMessage, - HolePunchRelayMessage, + HolePunchRequestMessage, + HolePunchSignalMessage, SignedNotificationEncoded, VaultInfo, VaultsScanMessage, diff --git a/src/nodes/errors.ts b/src/nodes/errors.ts index faa22173d2..1e06b0b693 100644 --- a/src/nodes/errors.ts +++ b/src/nodes/errors.ts @@ -1,8 +1,6 @@ import ErrorPolykey from '../ErrorPolykey'; import sysexits from '../utils/sysexits'; -// TODO: Some errors may need to be removed here, TBD in stage 2 agent migration - class ErrorNodes extends ErrorPolykey {} class ErrorNodeManager extends ErrorNodes {} @@ -137,6 +135,20 @@ class ErrorNodeConnectionManagerMultiConnectionFailed< exitCode = sysexits.TEMPFAIL; } +class ErrorNodeConnectionManagerConnectionNotFound< + T, +> extends ErrorNodeConnectionManager { + static description = 'No existing connection was found for target NodeId'; + exitCode = sysexits.TEMPFAIL; +} + +class ErrorNodeConnectionManagerRequestRateExceeded< + T, +> extends ErrorNodeConnectionManager { + static description = 'Rate limit exceeded while making request'; + exitCode = sysexits.TEMPFAIL; +} + class ErrorNodePingFailed extends ErrorNodes { static description = 'Failed to ping the node when attempting to authenticate'; @@ -175,6 +187,8 @@ export { ErrorNodeConnectionManagerInternalError, ErrorNodeConnectionManagerNodeIdRequired, ErrorNodeConnectionManagerMultiConnectionFailed, + ErrorNodeConnectionManagerConnectionNotFound, + ErrorNodeConnectionManagerRequestRateExceeded, ErrorNodePingFailed, ErrorNodePermissionDenied, }; diff --git a/src/rateLimiter/RateLimiter.ts b/src/rateLimiter/RateLimiter.ts new file mode 100644 index 0000000000..4e310b5bdd --- /dev/null +++ b/src/rateLimiter/RateLimiter.ts @@ -0,0 +1,151 @@ +import { Timer } from '@matrixai/timer'; +import * as rateLimiterErrors from './errors'; + +/** + * Internal data structure used to track a buckets' information. + * Internal use only so explicitly not exported. + */ +type TokenBucket = { + creationTimestamp: number; + tokens: number; + lastRefillTimestamp: number; + capacity: number; + refillRatePerSecond: number; +}; + +class RateLimiter { + protected tokenBuckets: Map; + protected expirationTimers: Map; + protected refillTimer: Timer | undefined; + + constructor( + protected defaultTTL: number = 60000, + protected defaultCapacity: number = 100, + protected defaultRate: number = 100, + protected defaultConsume: number = 1, + ) { + // Default TTL 1 minute + this.tokenBuckets = new Map(); + this.expirationTimers = new Map(); + } + + /** + * Starts the Refill interval timer + */ + public startRefillInterval(): void { + if (this.refillTimer != null) return; + const handler = () => { + this.refill(); + this.refillTimer = new Timer({ + handler, + delay: 1000, + }); + }; + this.refillTimer = new Timer({ + handler, + delay: 1000, + }); + } + + /** + * Stops the Refill interval timer + */ + public stopRefillInterval(): void { + if (this.refillTimer != null) { + this.refillTimer.cancel(); + delete this.refillTimer; + } + } + + /** + * Refills a second worth of tokens defined by the `refillRatePerSecond`. + */ + public refill(): void { + for (const [, bucket] of this.tokenBuckets) { + bucket.tokens += bucket.refillRatePerSecond; + if (bucket.tokens > bucket.capacity) bucket.tokens = bucket.capacity; + bucket.lastRefillTimestamp = performance.now(); + } + } + + /** + * Consumes an amount of tokens for a given bucket. Will return true if the tokens were available to be consumed. + * Otherwise, returns false if there were insufficient tokens. + * @param key - Key for the given bucket. + * @param tokensToConsume - Number of tokens to consume. + * @returns True if there were sufficient tokens that were consumed. False otherwise with no tokens consumed. + */ + public consume( + key: string, + tokensToConsume: number = this.defaultConsume, + ): boolean { + // Scaled default value for example + const bucket = this.getBucket(key); + if (tokensToConsume <= 0) { + throw new rateLimiterErrors.ErrorRateLimiterInvalidTokens(); + } + // Refreshing TTL + this.expirationTimers.get(key)?.refresh(); + if (bucket.tokens < tokensToConsume) return false; + bucket.tokens -= tokensToConsume; + return true; + } + + /** + * Gets the available tokens for a given bucket. + * @param key - Key for the given bucket. + */ + public tokens(key: string): number { + return this.getBucket(key).tokens; + } + + /** + * Clears all existing `TokenBuckets` . + */ + public clearBuckets(): void { + // Clear timers + for (const [, expirationTimer] of this.expirationTimers) { + expirationTimer.cancel(); + } + this.expirationTimers.clear(); + // Clear buckets + this.tokenBuckets.clear(); + } + + /** + * Stops refreshing and clears all existing `TokenBucket`s + */ + public stop(): void { + this.stopRefillInterval(); + this.clearBuckets(); + } + + protected scheduleExpiration(key: string, ttl: number): void { + const timer = new Timer({ + handler: () => { + this.tokenBuckets.delete(key); + this.expirationTimers.delete(key); + }, + delay: ttl, + }); + this.expirationTimers.set(key, timer); + } + + protected getBucket(key: string, ttl?: number): TokenBucket { + let bucket = this.tokenBuckets.get(key); + if (!bucket) { + bucket = { + capacity: this.defaultCapacity, + creationTimestamp: performance.now(), + lastRefillTimestamp: performance.now(), + refillRatePerSecond: this.defaultRate, + tokens: this.defaultCapacity, + }; + this.tokenBuckets.set(key, bucket); + this.scheduleExpiration(key, ttl || this.defaultTTL); + } + return bucket; + } +} + +export default RateLimiter; diff --git a/src/rateLimiter/errors.ts b/src/rateLimiter/errors.ts new file mode 100644 index 0000000000..a506c40aaa --- /dev/null +++ b/src/rateLimiter/errors.ts @@ -0,0 +1,10 @@ +import { ErrorPolykey, sysexits } from '../errors'; + +class ErrorRateLimiter extends ErrorPolykey {} + +class ErrorRateLimiterInvalidTokens extends ErrorRateLimiter { + static description = 'Consumed tokens must be greater than 0'; + exitCode = sysexits.USAGE; +} + +export { ErrorRateLimiterInvalidTokens }; diff --git a/src/rateLimiter/index.ts b/src/rateLimiter/index.ts new file mode 100644 index 0000000000..12e24d3316 --- /dev/null +++ b/src/rateLimiter/index.ts @@ -0,0 +1,2 @@ +export { default as RateLimiter } from './RateLimiter'; +export * as errors from './errors'; diff --git a/tests/nodes/NodeConnectionManager.general.test.ts b/tests/nodes/NodeConnectionManager.general.test.ts index e18e4ffb9b..84d6a0380b 100644 --- a/tests/nodes/NodeConnectionManager.general.test.ts +++ b/tests/nodes/NodeConnectionManager.general.test.ts @@ -1,5 +1,5 @@ import type { Host, Port, TLSConfig } from '@/network/types'; -import type { NodeId, NodeIdEncoded } from '@/ids'; +import type { NodeId } from '@/ids'; import type { NodeAddress, NodeBucket } from '@/nodes/types'; import fs from 'fs'; import path from 'path'; @@ -16,10 +16,11 @@ import KeyRing from '@/keys/KeyRing'; import ACL from '@/acl/ACL'; import GestaltGraph from '@/gestalts/GestaltGraph'; import NodeGraph from '@/nodes/NodeGraph'; +import * as nodesErrors from '@/nodes/errors'; import Sigchain from '@/sigchain/Sigchain'; import TaskManager from '@/tasks/TaskManager'; -import NodeManager from '@/nodes/NodeManager'; import PolykeyAgent from '@/PolykeyAgent'; +import * as utils from '@/utils'; import * as testNodesUtils from './utils'; import * as tlsTestUtils from '../utils/tls'; @@ -73,11 +74,13 @@ describe(`${NodeConnectionManager.name} general test`, () => { }; let dataDir: string; + let nodePathA: string; + let nodePathB: string; - let remotePolykeyAgent: PolykeyAgent; - let serverAddress: NodeAddress; - let serverNodeId: NodeId; - let serverNodeIdEncoded: NodeIdEncoded; + let remotePolykeyAgentA: PolykeyAgent; + let serverAddressA: NodeAddress; + let serverNodeIdA: NodeId; + let remotePolykeyAgentB: PolykeyAgent; let keyRing: KeyRing; let db: DB; @@ -86,22 +89,30 @@ describe(`${NodeConnectionManager.name} general test`, () => { let nodeGraph: NodeGraph; let sigchain: Sigchain; let taskManager: TaskManager; - let nodeManager: NodeManager; let nodeConnectionManager: NodeConnectionManager; - // Default stream handler, just drop the stream + + // Mocking the relay send + let mockedHolePunchReverse: jest.SpyInstance>; + let mockedPingNode: jest.SpyInstance>; beforeEach(async () => { + mockedHolePunchReverse = jest.spyOn( + NodeConnectionManager.prototype, + 'holePunchReverse', + ); + mockedPingNode = jest.spyOn(NodeConnectionManager.prototype, 'pingNode'); dataDir = await fs.promises.mkdtemp( path.join(os.tmpdir(), 'polykey-test-'), ); // Setting up remote node - const nodePath = path.join(dataDir, 'agentA'); - remotePolykeyAgent = await PolykeyAgent.createPolykeyAgent({ + nodePathA = path.join(dataDir, 'agentA'); + nodePathB = path.join(dataDir, 'agentB'); + remotePolykeyAgentA = await PolykeyAgent.createPolykeyAgent({ password, options: { - nodePath, + nodePath: nodePathA, agentServiceHost: localHost, clientServiceHost: localHost, keys: { @@ -112,8 +123,25 @@ describe(`${NodeConnectionManager.name} general test`, () => { }, logger: logger.getChild('AgentA'), }); - serverNodeId = remotePolykeyAgent.keyRing.getNodeId(); - serverNodeIdEncoded = nodesUtils.encodeNodeId(serverNodeId); + serverNodeIdA = remotePolykeyAgentA.keyRing.getNodeId(); + serverAddressA = { + host: remotePolykeyAgentA.agentServiceHost as Host, + port: remotePolykeyAgentA.agentServicePort as Port, + }; + remotePolykeyAgentB = await PolykeyAgent.createPolykeyAgent({ + password, + options: { + nodePath: nodePathB, + agentServiceHost: localHost, + clientServiceHost: localHost, + keys: { + passwordOpsLimit: keysUtils.passwordOpsLimits.min, + passwordMemLimit: keysUtils.passwordMemLimits.min, + strictMemoryLock: false, + }, + }, + logger: logger.getChild('AgentB'), + }); // Setting up client dependencies const keysPath = path.join(dataDir, 'keys'); @@ -154,17 +182,14 @@ describe(`${NodeConnectionManager.name} general test`, () => { db, logger, }); - serverAddress = { - host: remotePolykeyAgent.agentServiceHost, - port: remotePolykeyAgent.agentServicePort, - }; }); afterEach(async () => { logger.info('AFTER EACH'); + mockedHolePunchReverse.mockRestore(); + mockedPingNode.mockRestore(); await taskManager.stopProcessing(); await taskManager.stopTasks(); - await nodeManager?.stop(); await nodeConnectionManager?.stop(); await sigchain.stop(); await sigchain.destroy(); @@ -180,7 +205,8 @@ describe(`${NodeConnectionManager.name} general test`, () => { await keyRing.destroy(); await taskManager.stop(); - await remotePolykeyAgent.stop(); + await remotePolykeyAgentA?.stop(); + await remotePolykeyAgentB?.stop(); }); test('finds node (local)', async () => { @@ -191,17 +217,6 @@ describe(`${NodeConnectionManager.name} general test`, () => { tlsConfig, seedNodes: undefined, }); - nodeManager = new NodeManager({ - db, - gestaltGraph, - keyRing, - nodeConnectionManager, - nodeGraph, - sigchain, - taskManager, - logger, - }); - await nodeManager.start(); await nodeConnectionManager.start({ host: localHost as Host, }); @@ -231,17 +246,6 @@ describe(`${NodeConnectionManager.name} general test`, () => { tlsConfig, seedNodes: undefined, }); - nodeManager = new NodeManager({ - db, - gestaltGraph, - keyRing, - nodeConnectionManager, - nodeGraph, - sigchain, - taskManager, - logger, - }); - await nodeManager.start(); await nodeConnectionManager.start({ host: localHost as Host, }); @@ -256,14 +260,14 @@ describe(`${NodeConnectionManager.name} general test`, () => { ); logger.info('DOING TEST'); - await nodeGraph.setNode(serverNodeId, serverAddress); + await nodeGraph.setNode(serverNodeIdA, serverAddressA); // Adding node information to remote node const nodeId = testNodesUtils.generateRandomNodeId(); const nodeAddress: NodeAddress = { host: localHost as Host, port: 11111 as Port, }; - await remotePolykeyAgent.nodeGraph.setNode(nodeId, nodeAddress); + await remotePolykeyAgentA.nodeGraph.setNode(nodeId, nodeAddress); // Expect no error thrown const findNodePromise = nodeConnectionManager.findNode(nodeId); @@ -282,17 +286,6 @@ describe(`${NodeConnectionManager.name} general test`, () => { tlsConfig, seedNodes: undefined, }); - nodeManager = new NodeManager({ - db, - gestaltGraph, - keyRing, - nodeConnectionManager, - nodeGraph, - sigchain, - taskManager, - logger, - }); - await nodeManager.start(); await nodeConnectionManager.start({ host: localHost as Host, }); @@ -306,7 +299,7 @@ describe(`${NodeConnectionManager.name} general test`, () => { () => new PromiseCancellable((resolve) => resolve(true)), ); - await nodeGraph.setNode(serverNodeId, serverAddress); + await nodeGraph.setNode(serverNodeIdA, serverAddressA); // Adding node information to remote node const nodeId = testNodesUtils.generateRandomNodeId(); @@ -326,17 +319,6 @@ describe(`${NodeConnectionManager.name} general test`, () => { tlsConfig, seedNodes: undefined, }); - nodeManager = new NodeManager({ - db, - gestaltGraph, - keyRing, - nodeConnectionManager, - nodeGraph, - sigchain, - taskManager, - logger, - }); - await nodeManager.start(); await nodeConnectionManager.start({ host: localHost as Host, }); @@ -350,20 +332,20 @@ describe(`${NodeConnectionManager.name} general test`, () => { () => new PromiseCancellable((resolve) => resolve(true)), ); - await nodeGraph.setNode(serverNodeId, serverAddress); + await nodeGraph.setNode(serverNodeIdA, serverAddressA); // Now generate and add 20 nodes that will be close to this node ID const addedClosestNodes: NodeBucket = []; for (let i = 1; i < 101; i += 5) { const closeNodeId = testNodesUtils.generateNodeIdForBucket( - serverNodeId, + serverNodeIdA, i, ); const nodeAddress = { host: (i + '.' + i + '.' + i + '.' + i) as Host, port: i as Port, }; - await remotePolykeyAgent.nodeGraph.setNode(closeNodeId, nodeAddress); + await remotePolykeyAgentA.nodeGraph.setNode(closeNodeId, nodeAddress); addedClosestNodes.push([ closeNodeId, { @@ -379,23 +361,23 @@ describe(`${NodeConnectionManager.name} general test`, () => { host: `${i}.${i}.${i}.${i}`, port: i, } as NodeAddress; - await remotePolykeyAgent.nodeGraph.setNode(farNodeId, nodeAddress); + await remotePolykeyAgentA.nodeGraph.setNode(farNodeId, nodeAddress); } // Get the closest nodes to the target node const closest = await nodeConnectionManager.getRemoteNodeClosestNodes( - serverNodeId, - serverNodeId, + serverNodeIdA, + serverNodeIdA, ); // Sort the received nodes on distance such that we can check its equality // with addedClosestNodes - nodesUtils.bucketSortByDistance(closest, serverNodeId); + nodesUtils.bucketSortByDistance(closest, serverNodeIdA); expect(closest.length).toBe(20); expect(closest).toEqual(addedClosestNodes); await nodeConnectionManager.stop(); }); - test('sendHolePunchMessage', async () => { + test('holePunchSignalRequest with no target node', async () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, logger: logger.getChild(NodeConnectionManager.name), @@ -405,76 +387,87 @@ describe(`${NodeConnectionManager.name} general test`, () => { tlsConfig, seedNodes: undefined, }); - nodeManager = new NodeManager({ - db, - gestaltGraph, + await nodeConnectionManager.start({ + host: localHost as Host, + }); + await taskManager.startProcessing(); + + mockedHolePunchReverse.mockImplementation(() => { + return new PromiseCancellable((res) => { + res(); + }); + }); + + await nodeGraph.setNode(serverNodeIdA, serverAddressA); + + const targetNodeId = testNodesUtils.generateRandomNodeId(); + const relayNodeId = remotePolykeyAgentA.keyRing.getNodeId(); + + await expect( + nodeConnectionManager.holePunchSignalRequest(targetNodeId, relayNodeId), + ).rejects.toThrow(); + await nodeConnectionManager.stop(); + }); + test('holePunchSignalRequest with target node', async () => { + // Establish connection between remote A and B + expect( + await remotePolykeyAgentA.nodeConnectionManager.pingNode( + remotePolykeyAgentB.keyRing.getNodeId(), + remotePolykeyAgentB.agentServiceHost, + remotePolykeyAgentB.agentServicePort, + ), + ).toBeTrue(); + + nodeConnectionManager = new NodeConnectionManager({ keyRing, - nodeConnectionManager, + logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - sigchain, - taskManager, - logger, + connectionKeepAliveTimeoutTime: 10000, + connectionKeepAliveIntervalTime: 1000, + tlsConfig, + seedNodes: undefined, }); - await nodeManager.start(); await nodeConnectionManager.start({ host: localHost as Host, }); await taskManager.startProcessing(); - // Mocking pinging to always return true - const mockedPingNode = jest.spyOn( - NodeConnectionManager.prototype, - 'pingNode', - ); - mockedPingNode.mockImplementation( - () => new PromiseCancellable((resolve) => resolve(true)), - ); + mockedHolePunchReverse.mockImplementation(() => { + return new PromiseCancellable((res) => { + res(); + }); + }); + + const serverNodeId = remotePolykeyAgentA.keyRing.getNodeId(); + const serverAddress = { + host: remotePolykeyAgentA.agentServiceHost as Host, + port: remotePolykeyAgentA.agentServicePort as Port, + }; await nodeGraph.setNode(serverNodeId, serverAddress); - // Now generate and add 20 nodes that will be close to this node ID - const addedClosestNodes: NodeBucket = []; - for (let i = 1; i < 101; i += 5) { - const closeNodeId = testNodesUtils.generateNodeIdForBucket( - serverNodeId, - i, - ); - const nodeAddress = { - host: (i + '.' + i + '.' + i + '.' + i) as Host, - port: i as Port, - }; - await remotePolykeyAgent.nodeGraph.setNode(closeNodeId, nodeAddress); - addedClosestNodes.push([ - closeNodeId, - { - address: nodeAddress, - lastUpdated: 0, - }, - ]); - } - // Now create and add 10 more nodes that are far away from this node - for (let i = 1; i <= 10; i++) { - const farNodeId = nodeIdGenerator(i); - const nodeAddress = { - host: `${i}.${i}.${i}.${i}`, - port: i, - } as NodeAddress; - await remotePolykeyAgent.nodeGraph.setNode(farNodeId, nodeAddress); - } + const targetNodeId = remotePolykeyAgentB.keyRing.getNodeId(); + const relayNodeId = remotePolykeyAgentA.keyRing.getNodeId(); - // Get the closest nodes to the target node - const closest = await nodeConnectionManager.getRemoteNodeClosestNodes( - serverNodeId, - serverNodeId, + await nodeConnectionManager.holePunchSignalRequest( + targetNodeId, + relayNodeId, ); - // Sort the received nodes on distance such that we can check its equality - // with addedClosestNodes - nodesUtils.bucketSortByDistance(closest, serverNodeId); - expect(closest.length).toBe(20); - expect(closest).toEqual(addedClosestNodes); - + // Await the FAF signalling to finish. + const signalMapA = + // @ts-ignore: kidnap protected property + remotePolykeyAgentA.nodeConnectionManager.activeSignalSet; + for (const p of signalMapA) { + await p; + } + // @ts-ignore: kidnap protected property + const punchMapB = remotePolykeyAgentB.nodeConnectionManager.activePunchMap; + for await (const [, p] of punchMapB) { + await p; + } + expect(mockedHolePunchReverse).toHaveBeenCalled(); await nodeConnectionManager.stop(); }); - test('relayHolePunchMessage', async () => { + test('holePunchSignalRequest is nonblocking', async () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, logger: logger.getChild(NodeConnectionManager.name), @@ -484,54 +477,183 @@ describe(`${NodeConnectionManager.name} general test`, () => { tlsConfig, seedNodes: undefined, }); - nodeManager = new NodeManager({ - db, - gestaltGraph, + await nodeConnectionManager.start({ + host: localHost as Host, + }); + await taskManager.startProcessing(); + + const { p: waitP, resolveP: waitResolveP } = utils.promise(); + mockedHolePunchReverse.mockImplementation(() => { + return new PromiseCancellable(async (res) => { + await waitP; + res(); + }); + }); + + const serverNodeId = remotePolykeyAgentA.keyRing.getNodeId(); + const serverAddress = { + host: remotePolykeyAgentA.agentServiceHost, + port: remotePolykeyAgentA.agentServicePort, + }; + await nodeGraph.setNode(serverNodeId, serverAddress); + // Establish connection between remote A and B + expect( + await remotePolykeyAgentA.nodeConnectionManager.pingNode( + remotePolykeyAgentB.keyRing.getNodeId(), + remotePolykeyAgentB.agentServiceHost, + remotePolykeyAgentB.agentServicePort, + ), + ).toBeTrue(); + + const targetNodeId = remotePolykeyAgentB.keyRing.getNodeId(); + const relayNodeId = remotePolykeyAgentA.keyRing.getNodeId(); + // Creating 5 concurrent attempts + const holePunchSignalRequests = [1, 2, 3, 4, 5].map(() => + nodeConnectionManager.holePunchSignalRequest(targetNodeId, relayNodeId), + ); + // All should resolve immediately and not block + await Promise.all(holePunchSignalRequests); + + // Await the FAF signalling to finish. + const signalMapA = + // @ts-ignore: kidnap protected property + remotePolykeyAgentA.nodeConnectionManager.activeSignalSet; + for (const p of signalMapA) { + await p; + } + // Only one attempt is being made + // @ts-ignore: kidnap protected property + const punchMapB = remotePolykeyAgentB.nodeConnectionManager.activePunchMap; + expect(punchMapB.size).toBe(1); + // Allow the attempt to complete + waitResolveP(); + for await (const [, p] of punchMapB) { + await p; + } + // Only attempted once + expect(mockedHolePunchReverse).toHaveBeenCalledTimes(1); + await nodeConnectionManager.stop(); + }); + test('holePunchRequest single target with multiple ports is rate limited', async () => { + nodeConnectionManager = new NodeConnectionManager({ keyRing, - nodeConnectionManager, + logger: logger.getChild(NodeConnectionManager.name), nodeGraph, - sigchain, - taskManager, - logger, + connectionKeepAliveTimeoutTime: 10000, + connectionKeepAliveIntervalTime: 1000, + tlsConfig, + seedNodes: undefined, }); - await nodeManager.start(); await nodeConnectionManager.start({ host: localHost as Host, }); await taskManager.startProcessing(); - // Mocking the relay send - const mockedHolePunchReverse = jest.spyOn( - NodeConnectionManager.prototype, - 'holePunchReverse', - ); + const { p: waitP, resolveP: waitResolveP } = utils.promise(); mockedHolePunchReverse.mockImplementation(() => { - return new PromiseCancellable((res) => { + return new PromiseCancellable(async (res) => { + await waitP; res(); }); }); - await nodeGraph.setNode(serverNodeId, serverAddress); + nodeConnectionManager.handleNodesConnectionSignalFinal( + '127.0.0.1' as Host, + 55550 as Port, + ); + nodeConnectionManager.handleNodesConnectionSignalFinal( + '127.0.0.1' as Host, + 55551 as Port, + ); + nodeConnectionManager.handleNodesConnectionSignalFinal( + '127.0.0.1' as Host, + 55552 as Port, + ); + nodeConnectionManager.handleNodesConnectionSignalFinal( + '127.0.0.1' as Host, + 55553 as Port, + ); + nodeConnectionManager.handleNodesConnectionSignalFinal( + '127.0.0.1' as Host, + 55554 as Port, + ); + nodeConnectionManager.handleNodesConnectionSignalFinal( + '127.0.0.1' as Host, + 55555 as Port, + ); - const srcNodeId = testNodesUtils.generateRandomNodeId(); - const srcNodeIdEncoded = nodesUtils.encodeNodeId(srcNodeId); + // @ts-ignore: protected property + expect(nodeConnectionManager.activePunchMap.size).toBe(6); + // @ts-ignore: protected property + expect(nodeConnectionManager.activeAddressMap.size).toBe(1); + waitResolveP(); + // @ts-ignore: protected property + for await (const [, p] of nodeConnectionManager.activePunchMap) { + await p; + } - await nodeConnectionManager.relaySignalingMessage( - { - srcIdEncoded: srcNodeIdEncoded, - dstIdEncoded: serverNodeIdEncoded, - address: { - host: '127.0.0.2', - port: 22222, - }, - }, - { - host: '127.0.0.3' as Host, - port: 33333 as Port, - }, - ); + // Only attempted once + expect(mockedHolePunchReverse).toHaveBeenCalledTimes(6); + await nodeConnectionManager.stop(); + }); + test('holePunchSignalRequest rejects excessive requests', async () => { + nodeConnectionManager = new NodeConnectionManager({ + keyRing, + logger: logger.getChild(NodeConnectionManager.name), + nodeGraph, + connectionKeepAliveTimeoutTime: 10000, + connectionKeepAliveIntervalTime: 1000, + tlsConfig, + seedNodes: undefined, + }); + await nodeConnectionManager.start({ + host: localHost as Host, + }); + await taskManager.startProcessing(); - expect(mockedHolePunchReverse).toHaveBeenCalled(); + mockedHolePunchReverse.mockImplementation(() => { + return new PromiseCancellable(async (res) => { + res(); + }); + }); + + expect( + await nodeConnectionManager.pingNode( + remotePolykeyAgentB.keyRing.getNodeId(), + remotePolykeyAgentB.agentServiceHost, + remotePolykeyAgentB.agentServicePort, + ), + ).toBeTrue(); + const keyPair = keysUtils.generateKeyPair(); + const sourceNodeId = keysUtils.publicKeyToNodeId(keyPair.publicKey); + const targetNodeId = remotePolykeyAgentB.keyRing.getNodeId(); + const data = Buffer.concat([sourceNodeId, targetNodeId]); + const signature = keysUtils.signWithPrivateKey(keyPair, data); + expect(() => { + for (let i = 0; i < 30; i++) { + nodeConnectionManager.handleNodesConnectionSignalInitial( + sourceNodeId, + targetNodeId, + { + host: '127.0.0.1' as Host, + port: 55555 as Port, + }, + signature.toString('base64url'), + ); + } + }).toThrow(nodesErrors.ErrorNodeConnectionManagerRequestRateExceeded); + + const signalMapA = + // @ts-ignore: kidnap protected property + nodeConnectionManager.activeSignalSet; + for (const p of signalMapA.values()) { + await p; + } + // @ts-ignore: kidnap protected property + const punchMapB = remotePolykeyAgentB.nodeConnectionManager.activePunchMap; + for (const [, p] of punchMapB) { + await p; + } await nodeConnectionManager.stop(); }); diff --git a/tests/nodes/NodeConnectionManager.timeout.test.ts b/tests/nodes/NodeConnectionManager.timeout.test.ts index 1846677903..44a5d2e3ec 100644 --- a/tests/nodes/NodeConnectionManager.timeout.test.ts +++ b/tests/nodes/NodeConnectionManager.timeout.test.ts @@ -199,45 +199,6 @@ describe(`${NodeConnectionManager.name} timeout test`, () => { await nodeConnectionManager.stop(); }); - test('withConnection should extend timeout', async () => { - nodeConnectionManager = new NodeConnectionManager({ - keyRing, - logger: logger.getChild(NodeConnectionManager.name), - nodeGraph, - tlsConfig, - seedNodes: undefined, - connectionIdleTimeoutTime: 1000, - }); - await nodeConnectionManager.start({ - host: localHost as Host, - }); - - await nodeGraph.setNode(remoteNodeId1, remoteAddress1); - - // @ts-ignore: kidnap connections - const connections = nodeConnectionManager.connections; - // @ts-ignore: kidnap connections - const connectionLocks = nodeConnectionManager.connectionLocks; - await nodeConnectionManager.withConnF(remoteNodeId1, async () => {}); - const midConnAndLock = connections.get( - remoteNodeId1.toString() as NodeIdString, - ); - // Check entry is in map and lock is released - expect(midConnAndLock).toBeDefined(); - expect(connectionLocks.isLocked(remoteNodeId1.toString())).toBeFalsy(); - expect(midConnAndLock?.timer).toBeDefined(); - - // Destroying the connection - // @ts-ignore: private method - await nodeConnectionManager.destroyConnection(remoteNodeId1); - const finalConnAndLock = connections.get( - remoteNodeId1.toString() as NodeIdString, - ); - expect(finalConnAndLock).not.toBeDefined(); - expect(connectionLocks.isLocked(remoteNodeId1.toString())).toBeFalsy(); - - await nodeConnectionManager.stop(); - }); test('Connection can time out', async () => { nodeConnectionManager = new NodeConnectionManager({ keyRing, diff --git a/tests/nodes/agent/handlers/nodesConnectionSignalFinal.test.ts b/tests/nodes/agent/handlers/nodesConnectionSignalFinal.test.ts new file mode 100644 index 0000000000..e334b27e5e --- /dev/null +++ b/tests/nodes/agent/handlers/nodesConnectionSignalFinal.test.ts @@ -0,0 +1,179 @@ +import type { KeyPair } from '@/keys/types'; +import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; +import { QUICClient, QUICServer, events as quicEvents } from '@matrixai/quic'; +import { RPCClient, RPCServer } from '@matrixai/rpc'; +import { nodesConnectionSignalFinal } from '@/nodes/agent/callers'; +import { NodesConnectionSignalFinal } from '@/nodes/agent/handlers'; +import * as keysUtils from '@/keys/utils/index'; +import * as networkUtils from '@/network/utils'; +import * as nodesUtils from '@/nodes/utils'; +import * as tlsTestsUtils from '../../../utils/tls'; +import * as testsNodesUtils from '../../utils'; + +describe('nodesHolePunchRequest', () => { + const logger = new Logger('nodesHolePunchRequest test', LogLevel.WARN, [ + new StreamHandler(), + ]); + const crypto = tlsTestsUtils.createCrypto(); + const localHost = '127.0.0.1'; + + let keyPair: KeyPair; + let rpcServer: RPCServer; + let quicServer: QUICServer; + + const clientManifest = { + nodesConnectionSignalFinal, + }; + type ClientManifest = typeof clientManifest; + let rpcClient: RPCClient; + let quicClient: QUICClient; + const dummyNodeConnectionManager = { + handleNodesConnectionSignalFinal: jest.fn(), + }; + + beforeEach(async () => { + dummyNodeConnectionManager.handleNodesConnectionSignalFinal.mockClear(); + + // Handler dependencies + keyPair = keysUtils.generateKeyPair(); + const tlsConfigClient = await tlsTestsUtils.createTLSConfig(keyPair); + + // Setting up server + const serverManifest = { + nodesConnectionSignalFinal: new NodesConnectionSignalFinal({ + nodeConnectionManager: dummyNodeConnectionManager as any, + logger, + }), + }; + rpcServer = new RPCServer({ + fromError: networkUtils.fromError, + logger, + }); + await rpcServer.start({ manifest: serverManifest }); + const tlsConfig = await tlsTestsUtils.createTLSConfig(keyPair); + quicServer = new QUICServer({ + config: { + key: tlsConfig.keyPrivatePem, + cert: tlsConfig.certChainPem, + verifyPeer: true, + verifyCallback: async () => undefined, + }, + crypto: { + key: keysUtils.generateKey(), + ops: crypto, + }, + logger, + }); + const handleStream = async ( + event: quicEvents.EventQUICConnectionStream, + ) => { + // Streams are handled via the RPCServer. + const stream = event.detail; + logger.info('!!!!Handling new stream!!!!!'); + rpcServer.handleStream(stream); + }; + const handleConnection = async ( + event: quicEvents.EventQUICServerConnection, + ) => { + // Needs to setup stream handler + const conn = event.detail; + logger.info('!!!!Handling new Connection!!!!!'); + conn.addEventListener( + quicEvents.EventQUICConnectionStream.name, + handleStream, + ); + conn.addEventListener( + quicEvents.EventQUICConnectionStopped.name, + () => { + conn.removeEventListener( + quicEvents.EventQUICConnectionStream.name, + handleStream, + ); + }, + { once: true }, + ); + }; + quicServer.addEventListener( + quicEvents.EventQUICServerConnection.name, + handleConnection, + ); + quicServer.addEventListener( + quicEvents.EventQUICServerStopped.name, + () => { + quicServer.removeEventListener( + quicEvents.EventQUICServerConnection.name, + handleConnection, + ); + }, + { once: true }, + ); + await quicServer.start({ + host: localHost, + }); + + // Setting up client + rpcClient = new RPCClient({ + manifest: clientManifest, + streamFactory: async () => { + return quicClient.connection.newStream(); + }, + toError: networkUtils.toError, + logger, + }); + quicClient = await QUICClient.createQUICClient({ + crypto: { + ops: crypto, + }, + config: { + key: tlsConfigClient.keyPrivatePem, + cert: tlsConfigClient.certChainPem, + verifyPeer: true, + verifyCallback: async () => undefined, + }, + host: localHost, + port: quicServer.port, + localHost: localHost, + logger, + }); + }); + afterEach(async () => { + await rpcServer.stop({ force: true }); + await quicServer.stop({ force: true }); + }); + + test('should send hole punch relay', async () => { + const requestKeyPair = keysUtils.generateKeyPair(); + const targetNodeId = testsNodesUtils.generateRandomNodeId(); + const sourceNodeId = keysUtils.publicKeyToNodeId(requestKeyPair.publicKey); + // Data is just `` concatenated + const requestData = Buffer.concat([sourceNodeId, targetNodeId]); + const requestSignature = keysUtils.signWithPrivateKey( + requestKeyPair, + requestData, + ); + + // Generating relay signature, data is just `
` concatenated + const address = { + host: quicClient.host, + port: quicClient.port, + }; + const data = Buffer.concat([ + sourceNodeId, + targetNodeId, + Buffer.from(JSON.stringify(address), 'utf-8'), + requestSignature, + ]); + const relaySignature = keysUtils.signWithPrivateKey(keyPair, data); + + await rpcClient.methods.nodesConnectionSignalFinal({ + sourceNodeIdEncoded: nodesUtils.encodeNodeId(sourceNodeId), + targetNodeIdEncoded: nodesUtils.encodeNodeId(targetNodeId), + address, + requestSignature: requestSignature.toString('base64url'), + relaySignature: relaySignature.toString('base64url'), + }); + expect( + dummyNodeConnectionManager.handleNodesConnectionSignalFinal, + ).toHaveBeenCalled(); + }); +}); diff --git a/tests/nodes/agent/handlers/nodesConnectionSignalInitial.test.ts b/tests/nodes/agent/handlers/nodesConnectionSignalInitial.test.ts new file mode 100644 index 0000000000..7ab902fc8c --- /dev/null +++ b/tests/nodes/agent/handlers/nodesConnectionSignalInitial.test.ts @@ -0,0 +1,158 @@ +import type { KeyPair } from '@/keys/types'; +import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; +import { QUICClient, QUICServer, events as quicEvents } from '@matrixai/quic'; +import { RPCClient, RPCServer } from '@matrixai/rpc'; +import { nodesConnectionSignalInitial } from '@/nodes/agent/callers'; +import { NodesConnectionSignalInitial } from '@/nodes/agent/handlers'; +import * as keysUtils from '@/keys/utils/index'; +import * as nodesUtils from '@/nodes/utils'; +import * as networkUtils from '@/network/utils'; +import * as tlsTestsUtils from '../../../utils/tls'; +import * as testsNodesUtils from '../../../nodes/utils'; + +describe('nodesHolePunchSignal', () => { + const logger = new Logger('nodesHolePunchSignal test', LogLevel.WARN, [ + new StreamHandler(), + ]); + const crypto = tlsTestsUtils.createCrypto(); + const localHost = '127.0.0.1'; + + let keyPair: KeyPair; + let rpcServer: RPCServer; + let quicServer: QUICServer; + + const clientManifest = { + nodesConnectionSignalInitial, + }; + type ClientManifest = typeof clientManifest; + let rpcClient: RPCClient; + let quicClient: QUICClient; + const dummyNodeConnectionManager = { + handleNodesConnectionSignalInitial: jest.fn(), + }; + + beforeEach(async () => { + dummyNodeConnectionManager.handleNodesConnectionSignalInitial.mockClear(); + + // Handler dependencies + keyPair = keysUtils.generateKeyPair(); + const tlsConfigClient = await tlsTestsUtils.createTLSConfig(keyPair); + + // Setting up server + const serverManifest = { + nodesConnectionSignalInitial: new NodesConnectionSignalInitial({ + nodeConnectionManager: dummyNodeConnectionManager as any, + }), + }; + rpcServer = new RPCServer({ + fromError: networkUtils.fromError, + logger, + }); + await rpcServer.start({ manifest: serverManifest }); + const tlsConfig = await tlsTestsUtils.createTLSConfig(keyPair); + quicServer = new QUICServer({ + config: { + key: tlsConfig.keyPrivatePem, + cert: tlsConfig.certChainPem, + verifyPeer: true, + verifyCallback: async () => undefined, + }, + crypto: { + key: keysUtils.generateKey(), + ops: crypto, + }, + logger, + }); + const handleStream = async ( + event: quicEvents.EventQUICConnectionStream, + ) => { + // Streams are handled via the RPCServer. + const stream = event.detail; + logger.info('!!!!Handling new stream!!!!!'); + rpcServer.handleStream(stream); + }; + const handleConnection = async ( + event: quicEvents.EventQUICServerConnection, + ) => { + // Needs to setup stream handler + const conn = event.detail; + logger.info('!!!!Handling new Connection!!!!!'); + conn.addEventListener( + quicEvents.EventQUICConnectionStream.name, + handleStream, + ); + conn.addEventListener( + quicEvents.EventQUICConnectionStopped.name, + () => { + conn.removeEventListener( + quicEvents.EventQUICConnectionStream.name, + handleStream, + ); + }, + { once: true }, + ); + }; + quicServer.addEventListener( + quicEvents.EventQUICServerConnection.name, + handleConnection, + ); + quicServer.addEventListener( + quicEvents.EventQUICServerStopped.name, + () => { + quicServer.removeEventListener( + quicEvents.EventQUICServerConnection.name, + handleConnection, + ); + }, + { once: true }, + ); + await quicServer.start({ + host: localHost, + }); + + // Setting up client + rpcClient = new RPCClient({ + manifest: clientManifest, + streamFactory: async () => { + return quicClient.connection.newStream(); + }, + toError: networkUtils.toError, + logger, + }); + quicClient = await QUICClient.createQUICClient({ + crypto: { + ops: crypto, + }, + config: { + key: tlsConfigClient.keyPrivatePem, + cert: tlsConfigClient.certChainPem, + verifyPeer: true, + verifyCallback: async () => undefined, + }, + host: localHost, + port: quicServer.port, + localHost: localHost, + logger, + }); + }); + afterEach(async () => { + await rpcServer.stop({ force: true }); + await quicServer.stop({ force: true }); + }); + + test('should send hole punch relay', async () => { + const targetNodeId = testsNodesUtils.generateRandomNodeId(); + const targetNodeIdEncoded = nodesUtils.encodeNodeId(targetNodeId); + const sourceNodeId = keysUtils.publicKeyToNodeId(keyPair.publicKey); + // Data is just `` concatenated + const data = Buffer.concat([sourceNodeId, targetNodeId]); + const signature = keysUtils.signWithPrivateKey(keyPair, data); + await rpcClient.methods.nodesConnectionSignalInitial({ + targetNodeIdEncoded, + signature: signature.toString('base64url'), + }); + expect( + dummyNodeConnectionManager.handleNodesConnectionSignalInitial, + ).toHaveBeenCalled(); + }); +}); diff --git a/tests/nodes/agent/handlers/nodesHolePunchMessage.test.ts b/tests/nodes/agent/handlers/nodesHolePunchMessage.test.ts deleted file mode 100644 index 15cab3a52c..0000000000 --- a/tests/nodes/agent/handlers/nodesHolePunchMessage.test.ts +++ /dev/null @@ -1,262 +0,0 @@ -import type GestaltGraph from '@/gestalts/GestaltGraph'; -import type { Host } from '@/network/types'; -import fs from 'fs'; -import path from 'path'; -import os from 'os'; -import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; -import { - QUICClient, - QUICServer, - QUICSocket, - events as quicEvents, -} from '@matrixai/quic'; -import { DB } from '@matrixai/db'; -import { RPCClient, RPCServer } from '@matrixai/rpc'; -import KeyRing from '@/keys/KeyRing'; -import * as nodesUtils from '@/nodes/utils'; -import NodeGraph from '@/nodes/NodeGraph'; -import { nodesHolePunchMessageSend } from '@/nodes/agent/callers'; -import NodesHolePunchMessageSend from '@/nodes/agent/handlers/NodesHolePunchMessageSend'; -import NodeConnectionManager from '@/nodes/NodeConnectionManager'; -import NodeManager from '@/nodes/NodeManager'; -import ACL from '@/acl/ACL'; -import Sigchain from '@/sigchain/Sigchain'; -import TaskManager from '@/tasks/TaskManager'; -import * as keysUtils from '@/keys/utils'; -import * as networkUtils from '@/network/utils'; -import * as tlsTestsUtils from '../../../utils/tls'; - -describe('nodesHolePunchMessage', () => { - const logger = new Logger('nodesHolePunchMessage test', LogLevel.WARN, [ - new StreamHandler(), - ]); - const password = 'password'; - const crypto = tlsTestsUtils.createCrypto(); - const localHost = '127.0.0.1'; - - let dataDir: string; - - let keyRing: KeyRing; - let db: DB; - let acl: ACL; - let sigchain: Sigchain; - let taskManager: TaskManager; - let quicSocket: QUICSocket; - let nodeConnectionManager: NodeConnectionManager; - let nodeManager: NodeManager; - let nodeGraph: NodeGraph; - let rpcServer: RPCServer; - let quicServer: QUICServer; - - const clientManifest = { - nodesHolePunchMessageSend, - }; - type ClientManifest = typeof clientManifest; - let rpcClient: RPCClient; - let quicClient: QUICClient; - - beforeEach(async () => { - dataDir = await fs.promises.mkdtemp( - path.join(os.tmpdir(), 'polykey-test-'), - ); - - // Handler dependencies - const keysPath = path.join(dataDir, 'keys'); - keyRing = await KeyRing.createKeyRing({ - keysPath, - password, - passwordOpsLimit: keysUtils.passwordOpsLimits.min, - passwordMemLimit: keysUtils.passwordMemLimits.min, - strictMemoryLock: false, - logger, - }); - const dbPath = path.join(dataDir, 'db'); - db = await DB.createDB({ - dbPath, - logger, - }); - nodeGraph = await NodeGraph.createNodeGraph({ - db, - keyRing, - logger, - }); - acl = await ACL.createACL({ - db, - logger, - }); - sigchain = await Sigchain.createSigchain({ - db, - keyRing, - logger, - }); - nodeGraph = await NodeGraph.createNodeGraph({ - db, - keyRing, - logger: logger.getChild('NodeGraph'), - }); - taskManager = await TaskManager.createTaskManager({ - db, - logger, - lazy: true, - }); - quicSocket = new QUICSocket({ - logger, - }); - await quicSocket.start({ - host: localHost, - }); - const tlsConfigClient = await tlsTestsUtils.createTLSConfig( - keyRing.keyPair, - ); - nodeConnectionManager = new NodeConnectionManager({ - tlsConfig: tlsConfigClient, - keyRing, - nodeGraph, - connectionConnectTimeoutTime: 2000, - connectionIdleTimeoutTime: 2000, - logger: logger.getChild('NodeConnectionManager'), - }); - nodeManager = new NodeManager({ - db, - keyRing, - nodeGraph, - nodeConnectionManager, - sigchain, - taskManager, - gestaltGraph: {} as GestaltGraph, - logger, - }); - await nodeManager.start(); - await nodeConnectionManager.start({ host: localHost as Host }); - await taskManager.startProcessing(); - - // Setting up server - const serverManifest = { - nodesHolePunchMessageSend: new NodesHolePunchMessageSend({ - db, - keyRing, - nodeConnectionManager, - nodeManager: nodeManager, - logger, - }), - }; - rpcServer = new RPCServer({ - fromError: networkUtils.fromError, - logger, - }); - await rpcServer.start({ manifest: serverManifest }); - const tlsConfig = await tlsTestsUtils.createTLSConfig(keyRing.keyPair); - quicServer = new QUICServer({ - config: { - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - verifyPeer: true, - verifyCallback: async () => { - return undefined; - }, - }, - crypto: { - key: keysUtils.generateKey(), - ops: crypto, - }, - logger, - }); - const handleStream = async ( - event: quicEvents.EventQUICConnectionStream, - ) => { - // Streams are handled via the RPCServer. - const stream = event.detail; - logger.info('!!!!Handling new stream!!!!!'); - rpcServer.handleStream(stream); - }; - const handleConnection = async ( - event: quicEvents.EventQUICServerConnection, - ) => { - // Needs to setup stream handler - const conn = event.detail; - logger.info('!!!!Handling new Connection!!!!!'); - conn.addEventListener( - quicEvents.EventQUICConnectionStream.name, - handleStream, - ); - conn.addEventListener( - quicEvents.EventQUICConnectionStopped.name, - () => { - conn.removeEventListener( - quicEvents.EventQUICConnectionStream.name, - handleStream, - ); - }, - { once: true }, - ); - }; - quicServer.addEventListener('serverConnection', handleConnection); - quicServer.addEventListener( - 'serverStop', - () => { - quicServer.removeEventListener('serverConnection', handleConnection); - }, - { once: true }, - ); - await quicServer.start({ - host: localHost, - }); - - // Setting up client - rpcClient = new RPCClient({ - manifest: clientManifest, - streamFactory: async () => { - return quicClient.connection.newStream(); - }, - toError: networkUtils.toError, - logger, - }); - quicClient = await QUICClient.createQUICClient({ - crypto: { - ops: crypto, - }, - config: { - key: tlsConfigClient.keyPrivatePem, - cert: tlsConfigClient.certChainPem, - verifyPeer: true, - verifyCallback: async () => { - return undefined; - }, - }, - host: localHost, - port: quicServer.port, - localHost: localHost, - logger, - }); - }); - afterEach(async () => { - await rpcServer.stop({ force: true }); - await taskManager.stopProcessing(); - await taskManager.stopTasks(); - await quicServer.stop({ force: true }); - await nodeGraph.stop(); - await nodeManager.stop(); - await nodeConnectionManager.stop(); - await taskManager.stop(); - await sigchain.stop(); - await acl.stop(); - await db.stop(); - await keyRing.stop(); - await quicSocket.stop({ force: true }); - }); - - test('dummy test', async () => {}); - // TODO: holding process open for a short time, subject to change in agent migration stage 2 - test.skip('should send hole punch relay', async () => { - const nodeId = nodesUtils.encodeNodeId(keyRing.getNodeId()); - await rpcClient.methods.nodesHolePunchMessageSend({ - srcIdEncoded: nodeId, - dstIdEncoded: nodeId, - address: { - host: quicClient.host, - port: quicClient.port, - }, - }); - // TODO: check if the ping was sent - }); -}); diff --git a/tests/rateLimiter/RateLimiter.test.ts b/tests/rateLimiter/RateLimiter.test.ts new file mode 100644 index 0000000000..fdb6fedc99 --- /dev/null +++ b/tests/rateLimiter/RateLimiter.test.ts @@ -0,0 +1,57 @@ +import RateLimiter from '@/rateLimiter/RateLimiter'; +import { sleep } from '@/utils'; + +describe(`${RateLimiter.name}`, () => { + let rateLimiter: RateLimiter; + + afterEach(() => { + rateLimiter.stop(); + }); + + test('limits rate', async () => { + rateLimiter = new RateLimiter(undefined, 100); + expect(rateLimiter.consume('', 101)).toBeFalse(); + expect(rateLimiter.consume('', 1)).toBeTrue(); + expect(rateLimiter.consume('', 50)).toBeTrue(); + expect(rateLimiter.consume('', 49)).toBeTrue(); + expect(rateLimiter.tokens('')).toBe(0); + expect(rateLimiter.consume('', 1)).toBeFalse(); + expect(rateLimiter.tokens('')).toBe(0); + }); + test('can refresh rate', async () => { + rateLimiter = new RateLimiter(undefined, 100); + expect(rateLimiter.consume('', 50)).toBeTrue(); + expect(rateLimiter.tokens('')).toBe(50); + rateLimiter.refill(); + expect(rateLimiter.tokens('')).toBe(100); + rateLimiter.refill(); + expect(rateLimiter.tokens('')).toBe(100); + }); + test('independent rates', async () => { + rateLimiter = new RateLimiter(undefined, 100); + expect(rateLimiter.consume('a', 50)).toBeTrue(); + expect(rateLimiter.tokens('a')).toBe(50); + expect(rateLimiter.tokens('b')).toBe(100); + expect(rateLimiter.consume('b', 50)).toBeTrue(); + expect(rateLimiter.tokens('b')).toBe(50); + expect(rateLimiter.consume('a', 50)).toBeTrue(); + expect(rateLimiter.consume('a', 1)).toBeFalse(); + expect(rateLimiter.consume('b', 25)).toBeTrue(); + expect(rateLimiter.consume('b', 26)).toBeFalse(); + }); + test('only positive tokens can be consumed', async () => { + rateLimiter = new RateLimiter(undefined, 100); + expect(() => rateLimiter.consume('', 0)).toThrow(); + expect(() => rateLimiter.consume('', -100)).toThrow(); + expect(() => rateLimiter.consume('', -1)).toThrow(); + expect(() => rateLimiter.consume('', -0)).toThrow(); + }); + test('rates refresh on an interval', async () => { + rateLimiter = new RateLimiter(undefined, 100); + expect(rateLimiter.consume('', 50)).toBeTrue(); + rateLimiter.startRefillInterval(); + expect(rateLimiter.tokens('')).toBe(50); + await sleep(1500); + expect(rateLimiter.tokens('')).toBe(100); + }); +});