From dab46880b5f6f175e8e4c73dc2bd76cab6200f20 Mon Sep 17 00:00:00 2001 From: Marcus Pousette Date: Mon, 15 May 2023 14:47:00 +0200 Subject: [PATCH] feat: restrict message sizes and buffered amount --- package.json | 1 + src/muxer.ts | 54 +++++++------- src/private-to-private/handler.ts | 13 ++-- src/private-to-private/transport.ts | 2 + src/private-to-public/transport.ts | 2 +- src/stream.ts | 52 ++++++++++++- test/stream.spec.ts | 112 ++++++++++++++++++++++++++++ 7 files changed, 198 insertions(+), 38 deletions(-) create mode 100644 test/stream.spec.ts diff --git a/package.json b/package.json index a95a1b3..df2e476 100644 --- a/package.json +++ b/package.json @@ -158,6 +158,7 @@ "multiformats": "^11.0.2", "multihashes": "^4.0.3", "p-defer": "^4.0.0", + "p-event": "^5.0.1", "protons-runtime": "^5.0.0", "uint8arraylist": "^2.4.3", "uint8arrays": "^4.0.3" diff --git a/src/muxer.ts b/src/muxer.ts index 5d96125..d6221fd 100644 --- a/src/muxer.ts +++ b/src/muxer.ts @@ -1,4 +1,4 @@ -import { WebRTCStream } from './stream.js' +import { type DataChannelConstraintsOpts, WebRTCStream } from './stream.js' import { nopSink, nopSource } from './util.js' import type { Stream } from '@libp2p/interface-connection' import type { CounterGroup } from '@libp2p/interface-metrics' @@ -7,39 +7,48 @@ import type { Source, Sink } from 'it-stream-types' import type { Uint8ArrayList } from 'uint8arraylist' export interface DataChannelMuxerFactoryInit { + /** + * WebRTC Peer Connection + */ peerConnection: RTCPeerConnection + + /** + * Optional metrics for this data channel muxer + */ metrics?: CounterGroup + + /** + * Options data channel tasks + */ + constraints?: DataChannelConstraintsOpts } export class DataChannelMuxerFactory implements StreamMuxerFactory { /** * WebRTC Peer Connection */ - private readonly peerConnection: RTCPeerConnection private streamBuffer: WebRTCStream[] = [] - private readonly metrics?: CounterGroup - constructor (peerConnection: RTCPeerConnection, metrics?: CounterGroup, readonly protocol = '/webrtc') { - this.peerConnection = peerConnection + constructor (readonly init: DataChannelMuxerFactoryInit, readonly protocol = '/webrtc') { // store any datachannels opened before upgrade has been completed - this.peerConnection.ondatachannel = ({ channel }) => { + this.init.peerConnection.ondatachannel = ({ channel }) => { const stream = new WebRTCStream({ channel, stat: { direction: 'inbound', timeline: { open: 0 } }, + constraints: init.constraints, closeCb: (_stream) => { this.streamBuffer = this.streamBuffer.filter(s => !_stream.eq(s)) } }) this.streamBuffer.push(stream) } - this.metrics = metrics } createStreamMuxer (init?: StreamMuxerInit | undefined): StreamMuxer { - return new DataChannelMuxer(this.peerConnection, this.streamBuffer, this.protocol, init, this.metrics) + return new DataChannelMuxer(this.init, this.streamBuffer, this.protocol, init) } } @@ -47,16 +56,6 @@ export class DataChannelMuxerFactory implements StreamMuxerFactory { * A libp2p data channel stream muxer */ export class DataChannelMuxer implements StreamMuxer { - /** - * WebRTC Peer Connection - */ - private readonly peerConnection: RTCPeerConnection - - /** - * Optional metrics for this data channel muxer - */ - private readonly metrics?: CounterGroup - /** * Array of streams in the data channel */ @@ -82,24 +81,19 @@ export class DataChannelMuxer implements StreamMuxer { */ sink: Sink, Promise> = nopSink - constructor (peerConnection: RTCPeerConnection, streams: Stream[], readonly protocol: string = '/webrtc', init?: StreamMuxerInit, metrics?: CounterGroup) { + constructor (readonly dataChannelMuxer: DataChannelMuxerFactoryInit, streams: Stream[], readonly protocol: string = '/webrtc', init?: StreamMuxerInit) { /** * Initialized stream muxer */ this.init = init - /** - * WebRTC Peer Connection - */ - this.peerConnection = peerConnection - /** * Fired when a data channel has been added to the connection has been * added by the remote peer. * * {@link https://developer.mozilla.org/en-US/docs/Web/API/RTCPeerConnection/datachannel_event} */ - this.peerConnection.ondatachannel = ({ channel }) => { + this.dataChannelMuxer.peerConnection.ondatachannel = ({ channel }) => { const stream = new WebRTCStream({ channel, stat: { @@ -108,12 +102,13 @@ export class DataChannelMuxer implements StreamMuxer { open: 0 } }, + constraints: dataChannelMuxer.constraints, closeCb: this.wrapStreamEnd(init?.onIncomingStream) }) this.streams.push(stream) if ((init?.onIncomingStream) != null) { - this.metrics?.increment({ incoming_stream: true }) + this.dataChannelMuxer.metrics?.increment({ incoming_stream: true }) init.onIncomingStream(stream) } } @@ -133,9 +128,9 @@ export class DataChannelMuxer implements StreamMuxer { newStream (): Stream { // The spec says the label SHOULD be an empty string: https://github.com/libp2p/specs/blob/master/webrtc/README.md#rtcdatachannel-label - const channel = this.peerConnection.createDataChannel('') + const channel = this.dataChannelMuxer.peerConnection.createDataChannel('') const closeCb = (stream: Stream): void => { - this.metrics?.increment({ stream_end: true }) + this.dataChannelMuxer.metrics?.increment({ stream_end: true }) this.init?.onStreamEnd?.(stream) } const stream = new WebRTCStream({ @@ -146,10 +141,11 @@ export class DataChannelMuxer implements StreamMuxer { open: 0 } }, + constraints: this.dataChannelMuxer.constraints, closeCb: this.wrapStreamEnd(closeCb) }) this.streams.push(stream) - this.metrics?.increment({ outgoing_stream: true }) + this.dataChannelMuxer.metrics?.increment({ outgoing_stream: true }) return stream } diff --git a/src/private-to-private/handler.ts b/src/private-to-private/handler.ts index cf444f7..314c33c 100644 --- a/src/private-to-private/handler.ts +++ b/src/private-to-private/handler.ts @@ -5,6 +5,7 @@ import pDefer, { type DeferredPromise } from 'p-defer' import { DataChannelMuxerFactory } from '../muxer.js' import { Message } from './pb/message.js' import { readCandidatesUntilConnected, resolveOnConnected } from './util.js' +import type { DataChannelConstraintsOpts } from '../stream.js' import type { Stream } from '@libp2p/interface-connection' import type { IncomingStreamData } from '@libp2p/interface-registrar' import type { StreamMuxerFactory } from '@libp2p/interface-stream-muxer' @@ -13,14 +14,13 @@ const DEFAULT_TIMEOUT = 30 * 1000 const log = logger('libp2p:webrtc:peer') -export type IncomingStreamOpts = { rtcConfiguration?: RTCConfiguration } & IncomingStreamData +export type IncomingStreamOpts = { rtcConfiguration?: RTCConfiguration, constraints?: DataChannelConstraintsOpts } & IncomingStreamData -export async function handleIncomingStream ({ rtcConfiguration, stream: rawStream }: IncomingStreamOpts): Promise<{ pc: RTCPeerConnection, muxerFactory: StreamMuxerFactory, remoteAddress: string }> { +export async function handleIncomingStream ({ rtcConfiguration, constraints, stream: rawStream }: IncomingStreamOpts): Promise<{ pc: RTCPeerConnection, muxerFactory: StreamMuxerFactory, remoteAddress: string }> { const signal = AbortSignal.timeout(DEFAULT_TIMEOUT) const stream = pbStream(abortableDuplex(rawStream, signal)).pb(Message) const pc = new RTCPeerConnection(rtcConfiguration) - const muxerFactory = new DataChannelMuxerFactory(pc) - + const muxerFactory = new DataChannelMuxerFactory({ peerConnection: pc, constraints }) const connectedPromise: DeferredPromise = pDefer() const answerSentPromise: DeferredPromise = pDefer() @@ -86,13 +86,14 @@ export interface ConnectOptions { stream: Stream signal: AbortSignal rtcConfiguration?: RTCConfiguration + constraints?: DataChannelConstraintsOpts } -export async function initiateConnection ({ rtcConfiguration, signal, stream: rawStream }: ConnectOptions): Promise<{ pc: RTCPeerConnection, muxerFactory: StreamMuxerFactory, remoteAddress: string }> { +export async function initiateConnection ({ rtcConfiguration, constraints, signal, stream: rawStream }: ConnectOptions): Promise<{ pc: RTCPeerConnection, muxerFactory: StreamMuxerFactory, remoteAddress: string }> { const stream = pbStream(abortableDuplex(rawStream, signal)).pb(Message) // setup peer connection const pc = new RTCPeerConnection(rtcConfiguration) - const muxerFactory = new DataChannelMuxerFactory(pc) + const muxerFactory = new DataChannelMuxerFactory({ peerConnection: pc, constraints }) const connectedPromise: DeferredPromise = pDefer() resolveOnConnected(pc, connectedPromise) diff --git a/src/private-to-private/transport.ts b/src/private-to-private/transport.ts index 92c64ab..8a914c5 100644 --- a/src/private-to-private/transport.ts +++ b/src/private-to-private/transport.ts @@ -7,6 +7,7 @@ import { codes } from '../error.js' import { WebRTCMultiaddrConnection } from '../maconn.js' import { initiateConnection, handleIncomingStream } from './handler.js' import { WebRTCPeerListener } from './listener.js' +import type { DataChannelConstraintsOpts } from '../stream.js' import type { Connection } from '@libp2p/interface-connection' import type { Libp2pEvents } from '@libp2p/interface-libp2p' import type { PeerId } from '@libp2p/interface-peer-id' @@ -23,6 +24,7 @@ const WEBRTC_CODE = protocols('webrtc').code export interface WebRTCTransportInit { rtcConfiguration?: RTCConfiguration + constraints?: Partial } export interface WebRTCTransportComponents { diff --git a/src/private-to-public/transport.ts b/src/private-to-public/transport.ts index 30d658d..d73d8d6 100644 --- a/src/private-to-public/transport.ts +++ b/src/private-to-public/transport.ts @@ -231,7 +231,7 @@ export class WebRTCDirectTransport implements Transport { // Track opened peer connection this.metrics?.dialerEvents.increment({ peer_connection: true }) - const muxerFactory = new DataChannelMuxerFactory(peerConnection, this.metrics?.dialerEvents) + const muxerFactory = new DataChannelMuxerFactory({ peerConnection, metrics: this.metrics?.dialerEvents }) // For outbound connections, the remote is expected to start the noise handshake. // Therefore, we need to secure an inbound noise connection from the remote. diff --git a/src/stream.ts b/src/stream.ts index cf35eed..03ac2df 100644 --- a/src/stream.ts +++ b/src/stream.ts @@ -4,6 +4,7 @@ import merge from 'it-merge' import { pipe } from 'it-pipe' import { pushable } from 'it-pushable' import defer, { type DeferredPromise } from 'p-defer' +import { pEvent } from 'p-event' import { Uint8ArrayList } from 'uint8arraylist' import { Message } from './pb/message.js' import type { Stream, StreamStat, Direction } from '@libp2p/interface-connection' @@ -24,6 +25,12 @@ export function defaultStat (dir: Direction): StreamStat { } } +export interface DataChannelConstraintsOpts { + maxDataChannelMessageSize: number + maxDataChannelBufferedAmount: number + dataChannelBufferAmountLowEventTimeout: number +} + interface StreamInitOpts { /** * The network channel used for bidirectional peer-to-peer transfers of @@ -47,6 +54,11 @@ interface StreamInitOpts { * Callback to invoke when the stream is closed. */ closeCb?: (stream: WebRTCStream) => void + + /** + * Constraints options + */ + constraints?: Partial } /* @@ -151,6 +163,15 @@ class StreamState { } } +// Max message size that can be sent to the DataChannel +const MAX_MESSAGE_SIZE = 16 * 1024 + +// How much can be buffered to the DataChannel at once +const MAX_BUFFERED_AMOUNT = 16 * 1024 * 1024 + +// How long time we wait for the 'bufferedamountlow' event to be emitted +const BUFFERED_AMOUNT_LOW_TIMEOUT = 30 * 1000 + export class WebRTCStream implements Stream { /** * Unique identifier for a stream @@ -177,6 +198,12 @@ export class WebRTCStream implements Stream { */ streamState = new StreamState() + /** + * DataChannel contraints + */ + + constraints: DataChannelConstraintsOpts + /** * Read unwrapped protobuf data from the underlying datachannel. * _src is exposed to the user via the `source` getter to . @@ -214,8 +241,14 @@ export class WebRTCStream implements Stream { this.channel = opts.channel this.channel.binaryType = 'arraybuffer' this.id = this.channel.label - this.stat = opts.stat + this.constraints = { + dataChannelBufferAmountLowEventTimeout: opts.constraints?.dataChannelBufferAmountLowEventTimeout ?? BUFFERED_AMOUNT_LOW_TIMEOUT, + maxDataChannelBufferedAmount: opts.constraints?.maxDataChannelBufferedAmount ?? MAX_BUFFERED_AMOUNT, + maxDataChannelMessageSize: opts.constraints?.maxDataChannelMessageSize ?? MAX_MESSAGE_SIZE + } + this.closeCb = opts.closeCb + switch (this.channel.readyState) { case 'open': this.opened.resolve() @@ -313,10 +346,25 @@ export class WebRTCStream implements Stream { if (this.streamState.isWriteClosed()) { return } + + if (this.channel.bufferedAmount > this.constraints.maxDataChannelBufferedAmount) { + await pEvent(this.channel, 'bufferedamountlow', { timeout: this.constraints.dataChannelBufferAmountLowEventTimeout }).catch((e) => { + this.close() + throw new Error('Timed out waiting for DataChannel buffer to clear') + }) + } + const msgbuf = Message.encode({ message: buf.subarray() }) const sendbuf = lengthPrefixed.encode.single(msgbuf) - this.channel.send(sendbuf.subarray()) + while (sendbuf.length > 0) { + if (sendbuf.length <= this.constraints.maxDataChannelMessageSize) { + this.channel.send(sendbuf.subarray()) + break + } + this.channel.send(sendbuf.subarray(0, this.constraints.maxDataChannelMessageSize)) + sendbuf.consume(this.constraints.maxDataChannelMessageSize) + } } } diff --git a/test/stream.spec.ts b/test/stream.spec.ts new file mode 100644 index 0000000..95dc49f --- /dev/null +++ b/test/stream.spec.ts @@ -0,0 +1,112 @@ +/* eslint-disable @typescript-eslint/consistent-type-assertions */ + +import { expect } from 'aegir/chai' +import * as lengthPrefixed from 'it-length-prefixed' +import { pushable } from 'it-pushable' +import { Message } from '../src/pb/message' +import * as underTest from '../src/stream' + +const mockDataChannel = (opts: { send: (bytes: Uint8Array) => void, bufferedAmount?: number }): RTCDataChannel => { + return { + readyState: 'open', + close: () => {}, + addEventListener: (_type: string, _listener: () => void) => {}, + removeEventListener: (_type: string, _listener: () => void) => {}, + ...opts + } as RTCDataChannel +} + +const MAX_MESSAGE_SIZE = 16 * 1024 + +describe('Max message size', () => { + it(`sends messages smaller or equal to ${MAX_MESSAGE_SIZE} bytes in one`, async () => { + const sent: Uint8Array[] = [] + const data = new Uint8Array(MAX_MESSAGE_SIZE - 5) + const p = pushable() + + // Make sure that the data that ought to be sent will result in a message with exactly MAX_MESSAGE_SIZE + const messageLengthEncoded = lengthPrefixed.encode.single(Message.encode({ message: data })).subarray() + expect(messageLengthEncoded.length).eq(MAX_MESSAGE_SIZE) + const webrtcStream = new underTest.WebRTCStream({ + channel: mockDataChannel({ + send: (bytes) => { + sent.push(bytes) + if (p.readableLength === 0) { + webrtcStream.close() + } + } + }), + stat: underTest.defaultStat('outbound') + }) + + p.push(data) + p.end() + await webrtcStream.sink(p) + expect(sent).to.deep.equals([messageLengthEncoded]) + }) + + it(`sends messages greather ${MAX_MESSAGE_SIZE} bytes in parts`, async () => { + const sent: Uint8Array[] = [] + const data = new Uint8Array(MAX_MESSAGE_SIZE - 4) + const p = pushable() + + // Make sure that the data that ought to be sent will result in a message with exactly MAX_MESSAGE_SIZE + 1 + const messageLengthEncoded = lengthPrefixed.encode.single(Message.encode({ message: data })).subarray() + expect(messageLengthEncoded.length).eq(MAX_MESSAGE_SIZE + 1) + + const webrtcStream = new underTest.WebRTCStream({ + channel: mockDataChannel({ + send: (bytes) => { + sent.push(bytes) + if (p.readableLength === 0) { + webrtcStream.close() + } + } + }), + stat: underTest.defaultStat('outbound') + }) + + p.push(data) + p.end() + await webrtcStream.sink(p) + + // Message is sent in two parts + expect(sent).to.deep.equals([messageLengthEncoded.subarray(0, messageLengthEncoded.length - 1), messageLengthEncoded.subarray(messageLengthEncoded.length - 1)]) + }) + + it('closes the stream if bufferamountlow timeout', async () => { + const MAX_BUFFERED_AMOUNT = 16 * 1024 * 1024 + 1 + const timeout = 2000 + let closed = false + const webrtcStream = new underTest.WebRTCStream({ + constraints: { dataChannelBufferAmountLowEventTimeout: timeout }, + channel: mockDataChannel({ + send: () => { + throw new Error('Expected to not send') + }, + bufferedAmount: MAX_BUFFERED_AMOUNT + }), + stat: underTest.defaultStat('outbound'), + closeCb: () => { + closed = true + } + }) + + const p = pushable() + p.push(new Uint8Array(1)) + p.end() + + const t0 = Number(new Date()) + + try { + await webrtcStream.sink(p) + throw new Error('Expected to fail') + } catch (error: any) { + expect(error.message).eq('Timed out waiting for DataChannel buffer to clear') + const t1 = Number(new Date()) + expect(t1 - t0).greaterThan(timeout) + expect(t1 - t0).lessThan(timeout + 1000) // Some upper bound + expect(closed).true() + } + }) +})