Skip to content

Commit

Permalink
fix: Fix duplicated WebSocket connection created when acquiring conne…
Browse files Browse the repository at this point in the history
…ction before one established
  • Loading branch information
neet committed Oct 20, 2024
1 parent e309a2e commit dbfd4c8
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 62 deletions.
6 changes: 2 additions & 4 deletions src/adapters/action/dispatcher-ws.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ describe("DispatcherWs", () => {
}).toThrow(MastoUnexpectedError);
});

it("can be disposed", async () => {
it("can be disposed", () => {
const connector = new WebSocketConnectorImpl({
constructorParameters: ["wss://example.com"],
});
Expand All @@ -41,8 +41,6 @@ describe("DispatcherWs", () => {
);

dispatcher[Symbol.dispose]();
await expect(() => connector.acquire()).rejects.toThrow(
MastoWebSocketError,
);
expect(() => connector.acquire()).toThrow(MastoWebSocketError);
});
});
4 changes: 2 additions & 2 deletions src/adapters/action/dispatcher-ws.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export class WebSocketActionDispatcher

dispatch<T>(action: WebSocketAction): T {
if (action.type === "close") {
this.connector.close();
this.connector.kill();

Check warning on line 27 in src/adapters/action/dispatcher-ws.ts

View check run for this annotation

Codecov / codecov/patch

src/adapters/action/dispatcher-ws.ts#L27

Added line #L27 was not covered by tests
return {} as T;
}

Expand All @@ -50,6 +50,6 @@ export class WebSocketActionDispatcher
}

[Symbol.dispose](): void {
this.connector.close();
this.connector.kill();
}
}
4 changes: 2 additions & 2 deletions src/adapters/ws/web-socket-connector.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ describe("WebSocketConnector", () => {
expect(ws1).toBe(ws2);

server.close();
connector.close();
connector.kill();
});

it("rejects if WebSocket closes", async () => {
const connector = new WebSocketConnectorImpl({
constructorParameters: [`ws://localhost:0`],
});
const promise = connector.acquire();
connector.close();
connector.kill();

await expect(promise).rejects.toBeInstanceOf(MastoWebSocketError);
});
Expand Down
32 changes: 12 additions & 20 deletions src/adapters/ws/web-socket-connector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,45 +18,42 @@ interface WebSocketConnectorImplProps {
export class WebSocketConnectorImpl implements WebSocketConnector {
private ws?: WebSocket;

private killed = false;
private queue: PromiseWithResolvers<WebSocket>[] = [];
private backoff: ExponentialBackoff;

private closed = false;
private initialized = false;

constructor(
private readonly props: WebSocketConnectorImplProps,
private readonly logger?: Logger,
) {
this.backoff = new ExponentialBackoff({
maxAttempts: this.props.maxAttempts,
});
this.spawn();
}

async acquire(): Promise<WebSocket> {
if (this.closed) {
acquire(): Promise<WebSocket> {
if (this.killed) {
throw new MastoWebSocketError("WebSocket closed");
}

this.init();

if (this.ws != undefined) {
return this.ws;
return Promise.resolve(this.ws);
}

const promiseWithResolvers = createPromiseWithResolvers<WebSocket>();
this.queue.push(promiseWithResolvers);
return await promiseWithResolvers.promise;
return promiseWithResolvers.promise;
}

async *[Symbol.asyncIterator](): AsyncIterableIterator<WebSocket> {
while (!this.closed) {
while (!this.killed) {
yield await this.acquire();
}
}

close(): void {
this.closed = true;
kill(): void {
this.killed = true;
this.ws?.close();
this.backoff.clear();

Expand All @@ -67,14 +64,8 @@ export class WebSocketConnectorImpl implements WebSocketConnector {
this.queue = [];
}

private async init() {
if (this.initialized) {
return;
}

this.initialized = true;

while (!this.closed) {
private async spawn() {
while (!this.killed) {
this.ws?.close();

try {
Expand Down Expand Up @@ -114,6 +105,7 @@ export class WebSocketConnectorImpl implements WebSocketConnector {
),
);
}

this.queue = [];
}
}
24 changes: 3 additions & 21 deletions src/adapters/ws/web-socket-subscription.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,6 @@ import { WebSocketSubscription } from "./web-socket-subscription";
import { WebSocketSubscriptionCounterImpl } from "./web-socket-subscription-counter";

describe("WebSocketSubscription", () => {
it("doesn't do anything if no connection was established", async () => {
const logger = createLogger();

const subscription = new WebSocketSubscription(
new WebSocketConnectorImpl(
{ constructorParameters: ["ws://localhost:0"] },
logger,
),
new WebSocketSubscriptionCounterImpl(),
new SerializerNativeImpl(),
"public",
logger,
);

const res = subscription.unsubscribe();
expect(res).toBeUndefined();
});

it("implements async iterator", async () => {
const logger = createLogger();
const port = await getPort();
Expand All @@ -43,12 +25,12 @@ describe("WebSocketSubscription", () => {
});
});

const connection = new WebSocketConnectorImpl(
const connector = new WebSocketConnectorImpl(
{ constructorParameters: [`ws://localhost:${port}`] },
logger,
);
const subscription = new WebSocketSubscription(
connection,
connector,
new WebSocketSubscriptionCounterImpl(),
new SerializerNativeImpl(),
"public",
Expand All @@ -63,7 +45,7 @@ describe("WebSocketSubscription", () => {

expect(value).toBe("123");

connection.close();
connector.kill();
server.close();
});
});
2 changes: 1 addition & 1 deletion src/interfaces/ws.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { type WebSocket } from "isomorphic-ws";

export interface WebSocketConnector extends AsyncIterable<WebSocket> {
acquire(): Promise<WebSocket>;
close(): void;
kill(): void;
}

export interface WebSocketSubscriptionCounter {
Expand Down
1 change: 0 additions & 1 deletion src/mastodon/streaming/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ export interface SubscribeHashtagParams {

export interface Subscription extends AsyncIterable<Event>, Disposable {
values(): AsyncIterableIterator<Event>;

unsubscribe(): void;
}

Expand Down
26 changes: 15 additions & 11 deletions tests/streaming/connections.spec.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
import assert from "node:assert";
import crypto from "node:crypto";

it("maintains connections for the event even if other handlers closed it", async () => {
const tag = `tag_${crypto.randomBytes(4).toString("hex")}`;
await using alice = await sessions.acquire({ waitForWs: true });

using subscription1 = alice.ws.hashtag.subscribe({ tag: "test" });
using subscription2 = alice.ws.hashtag.subscribe({ tag: "test" });
using subscription1 = alice.ws.hashtag.subscribe({ tag });
using subscription2 = alice.ws.hashtag.subscribe({ tag });

const promise1 = subscription1.values().take(1).toArray();
const promise2 = subscription2.values().take(2).toArray();

// Dispatch event for subscription1 to establish connection
const status1 = await alice.rest.v1.statuses.create({
status: "#test",
status: `#${tag}`,
visibility: "public",
});
await promise1;
subscription1.unsubscribe();

// subscription1 is now closed, so status2 will only be dispatched to subscription2
const status2 = await alice.rest.v1.statuses.create({
status: "#test",
status: `#${tag}`,
visibility: "public",
});

Expand All @@ -37,16 +39,17 @@ it("maintains connections for the event even if other handlers closed it", async
});

it("maintains connections for the event if unsubscribe called twice", async () => {
const tag = `tag_${crypto.randomBytes(4).toString("hex")}`;
await using alice = await sessions.acquire({ waitForWs: true });

using subscription1 = alice.ws.hashtag.subscribe({ tag: "test" });
using subscription2 = alice.ws.hashtag.subscribe({ tag: "test" });
using subscription1 = alice.ws.hashtag.subscribe({ tag });
using subscription2 = alice.ws.hashtag.subscribe({ tag });

const promise1 = subscription1.values().take(1).toArray();
const promise2 = subscription2.values().take(2).toArray();

const status1 = await alice.rest.v1.statuses.create({
status: "#test",
status: `#${tag}`,
visibility: "public",
});
await promise1;
Expand All @@ -56,7 +59,7 @@ it("maintains connections for the event if unsubscribe called twice", async () =
subscription1.unsubscribe();

const status2 = await alice.rest.v1.statuses.create({
status: "#test",
status: `#${tag}`,
visibility: "public",
});

Expand All @@ -74,17 +77,18 @@ it("maintains connections for the event if unsubscribe called twice", async () =
});

it("maintains connections for the event if another handler called unsubscribe before connection established", async () => {
const tag = `tag_${crypto.randomBytes(4).toString("hex")}`;
await using alice = await sessions.acquire({ waitForWs: true });

using subscription1 = alice.ws.hashtag.subscribe({ tag: "test" });
using subscription2 = alice.ws.hashtag.subscribe({ tag: "test" });
using subscription1 = alice.ws.hashtag.subscribe({ tag });
using subscription2 = alice.ws.hashtag.subscribe({ tag });

subscription1.unsubscribe();

const promise2 = subscription2.values().take(1).toArray();

const status1 = await alice.rest.v1.statuses.create({
status: "#test",
status: `#${tag}`,
visibility: "public",
});

Expand Down

0 comments on commit dbfd4c8

Please sign in to comment.