diff --git a/package.json b/package.json index dc5780c810..0b34498f59 100644 --- a/package.json +++ b/package.json @@ -97,6 +97,7 @@ "jose": "^4.3.6", "lexicographic-integer": "^1.1.0", "multiformats": "^9.4.8", + "node-abort-controller": "^3.0.1", "node-forge": "^0.10.0", "pako": "^1.0.11", "prompts": "^2.4.1", diff --git a/src/nodes/NodeConnectionManager.ts b/src/nodes/NodeConnectionManager.ts index 1177afb347..e2949e45dd 100644 --- a/src/nodes/NodeConnectionManager.ts +++ b/src/nodes/NodeConnectionManager.ts @@ -17,6 +17,7 @@ import type { import type { DBTransaction } from '@matrixai/db'; import { withF } from '@matrixai/resources'; import type NodeManager from './NodeManager'; +import type { AbortSignal } from 'node-abort-controller'; import Logger from '@matrixai/logger'; import { ready, StartStop } from '@matrixai/async-init/dist/StartStop'; import { IdInternal } from '@matrixai/id'; @@ -383,15 +384,21 @@ class NodeConnectionManager { * Retrieves the node address. If an entry doesn't exist in the db, then * proceeds to locate it using Kademlia. * @param targetNodeId Id of the node we are tying to find + * @param options */ @ready(new nodesErrors.ErrorNodeConnectionManagerNotRunning()) - public async findNode(targetNodeId: NodeId): Promise { + public async findNode( + targetNodeId: NodeId, + options: { signal?: AbortSignal } = {}, + ): Promise { + const { signal } = { ...options }; // First check if we already have an existing ID -> address record - let address = (await this.nodeGraph.getNode(targetNodeId))?.address; // Otherwise, attempt to locate it by contacting network if (address == null) { - address = await this.getClosestGlobalNodes(targetNodeId); + address = await this.getClosestGlobalNodes(targetNodeId, undefined, { + signal, + }); // TODO: This currently just does one iteration // If not found in this single iteration, we throw an exception if (address == null) { @@ -417,13 +424,16 @@ class NodeConnectionManager { * @param targetNodeId ID of the node attempting to be found (i.e. attempting * to find its IP address and port) * @param timer Connection timeout timer + * @param options * @returns whether the target node was located in the process */ @ready(new nodesErrors.ErrorNodeConnectionManagerNotRunning()) public async getClosestGlobalNodes( targetNodeId: NodeId, timer?: Timer, + options: { signal?: AbortSignal } = {}, ): Promise { + const { signal } = { ...options }; // Let foundTarget: boolean = false; let foundAddress: NodeAddress | undefined = undefined; // Get the closest alpha nodes to the target node (set as shortlist) @@ -443,6 +453,7 @@ class NodeConnectionManager { const contacted: { [nodeId: string]: boolean } = {}; // Iterate until we've found and contacted k nodes while (Object.keys(contacted).length <= this.nodeGraph.nodeBucketLimit) { + if (signal?.aborted) throw new nodesErrors.ErrorNodeAborted(); // While (!foundTarget) { // Remove the node from the front of the array const nextNode = shortlist.shift(); @@ -476,6 +487,7 @@ class NodeConnectionManager { // Check to see if any of these are the target node. At the same time, add // them to the shortlist for (const [nodeId, nodeData] of foundClosest) { + if (signal?.aborted) throw new nodesErrors.ErrorNodeAborted(); // Ignore a`ny nodes that have been contacted if (contacted[nodeId]) { continue; diff --git a/src/nodes/NodeManager.ts b/src/nodes/NodeManager.ts index d82c5fb549..6f9499f828 100644 --- a/src/nodes/NodeManager.ts +++ b/src/nodes/NodeManager.ts @@ -14,8 +14,10 @@ import type { import type { ClaimEncoded } from '../claims/types'; import type { Timer } from '../types'; import type { PromiseType } from '../utils/utils'; +import type { AbortSignal } from 'node-abort-controller'; import Logger from '@matrixai/logger'; import { StartStop, ready } from '@matrixai/async-init/dist/StartStop'; +import { AbortController } from 'node-abort-controller'; import * as nodesErrors from './errors'; import * as nodesUtils from './utils'; import * as networkUtils from '../network/utils'; @@ -57,6 +59,7 @@ class NodeManager { protected refreshBucketQueueRunner: Promise; protected refreshBucketQueuePlug_: PromiseType; protected refreshBucketQueueDrained_: PromiseType; + protected refreshBucketQueueAbortController: AbortController; constructor({ db, @@ -636,8 +639,13 @@ class NodeManager { * Connections during the search will will share node information with other * nodes. * @param bucketIndex + * @param options */ - public async refreshBucket(bucketIndex: NodeBucketIndex) { + public async refreshBucket( + bucketIndex: NodeBucketIndex, + options: { signal?: AbortSignal } = {}, + ) { + const { signal } = { ...options }; // We need to generate a random nodeId for this bucket const nodeId = this.keyManager.getNodeId(); const bucketRandomNodeId = nodesUtils.generateRandomNodeIdForBucket( @@ -645,7 +653,7 @@ class NodeManager { bucketIndex, ); // We then need to start a findNode procedure - await this.nodeConnectionManager.findNode(bucketRandomNodeId); + await this.nodeConnectionManager.findNode(bucketRandomNodeId, { signal }); } // Refresh bucket activity timer methods @@ -741,6 +749,7 @@ class NodeManager { this.refreshBucketQueueRunning = true; this.refreshBucketQueuePlug(); let iterator: IterableIterator | undefined; + this.refreshBucketQueueAbortController = new AbortController(); const pace = async () => { // Wait for plug await this.refreshBucketQueuePlug_.p; @@ -761,7 +770,14 @@ class NodeManager { this.logger.debug( `processing refreshBucket for bucket ${bucketIndex}, ${this.refreshBucketQueue.size} left in queue`, ); - await this.refreshBucket(bucketIndex); + try { + await this.refreshBucket(bucketIndex, { + signal: this.refreshBucketQueueAbortController.signal, + }); + } catch (e) { + if (e instanceof nodesErrors.ErrorNodeAborted) break; + throw e; + } // Remove from queue and update bucket deadline this.refreshBucketQueue.delete(bucketIndex); this.refreshBucketUpdateDeadline(bucketIndex); @@ -771,6 +787,7 @@ class NodeManager { private async stopRefreshBucketQueue(): Promise { // Flag end and await queue finish + this.refreshBucketQueueAbortController.abort(); this.refreshBucketQueueRunning = false; this.refreshBucketQueueUnplug(); } diff --git a/src/nodes/errors.ts b/src/nodes/errors.ts index 42a7145615..4348a21f9e 100644 --- a/src/nodes/errors.ts +++ b/src/nodes/errors.ts @@ -2,6 +2,11 @@ import { ErrorPolykey, sysexits } from '../errors'; class ErrorNodes extends ErrorPolykey {} +class ErrorNodeAborted extends ErrorNodes { + description = 'Operation was aborted'; + exitCode = sysexits.USAGE; +} + class ErrorNodeManagerNotRunning extends ErrorNodes { static description = 'NodeManager is not running'; exitCode = sysexits.USAGE; @@ -79,6 +84,7 @@ class ErrorNodeConnectionHostWildcard extends ErrorNodes { export { ErrorNodes, + ErrorNodeAborted, ErrorNodeManagerNotRunning, ErrorNodeGraphRunning, ErrorNodeGraphNotRunning, diff --git a/tests/nodes/NodeManager.test.ts b/tests/nodes/NodeManager.test.ts index 04880a0ff9..b83be35d84 100644 --- a/tests/nodes/NodeManager.test.ts +++ b/tests/nodes/NodeManager.test.ts @@ -19,6 +19,7 @@ import * as claimsUtils from '@/claims/utils'; import { promise, promisify, sleep } from '@/utils'; import * as nodesUtils from '@/nodes/utils'; import * as utilsPB from '@/proto/js/polykey/v1/utils/utils_pb'; +import * as nodesErrors from '@/nodes/errors'; import * as nodesTestUtils from './utils'; import { generateNodeIdForBucket } from './utils'; @@ -969,4 +970,43 @@ describe(`${NodeManager.name} test`, () => { await nodeManager.stop(); } }); + test('should abort refreshBucket queue when stopping', async () => { + const refreshBucketTimeout = 1000000; + const nodeManager = new NodeManager({ + db, + sigchain: {} as Sigchain, + keyManager, + nodeGraph, + nodeConnectionManager: dummyNodeConnectionManager, + refreshBucketTimerDefault: refreshBucketTimeout, + logger, + }); + const mockRefreshBucket = jest.spyOn( + NodeManager.prototype, + 'refreshBucket', + ); + try { + await nodeManager.start(); + await nodeConnectionManager.start({ nodeManager }); + mockRefreshBucket.mockImplementation( + async (bucket, options: { signal?: AbortSignal } = {}) => { + const { signal } = { ...options }; + const prom = promise(); + signal?.addEventListener('abort', () => + prom.rejectP(new nodesErrors.ErrorNodeAborted()), + ); + await prom.p; + }, + ); + nodeManager.refreshBucketQueueAdd(1); + nodeManager.refreshBucketQueueAdd(2); + nodeManager.refreshBucketQueueAdd(3); + nodeManager.refreshBucketQueueAdd(4); + nodeManager.refreshBucketQueueAdd(5); + await nodeManager.stop(); + } finally { + mockRefreshBucket.mockRestore(); + await nodeManager.stop(); + } + }); });