diff --git a/doc/relay.md b/doc/relay.md index 879c2f7170..7d75542455 100644 --- a/doc/relay.md +++ b/doc/relay.md @@ -2,8 +2,8 @@ ### Relays Relays (aka secure octo) use ICE and DTLS/SRTP between each pair of bridges, so a secure -network is not required. It uses and requires colibri websockets for the -bridge-bridge connections (endpoints can still use SCTP). +network is not required. It uses and requires either SCTP or colibri websockets for the +bridge-bridge connections. ## Jitsi Videobridge configuration diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/Endpoint.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/Endpoint.kt index 91bdddc305..74d5dea168 100644 --- a/jvb/src/main/kotlin/org/jitsi/videobridge/Endpoint.kt +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/Endpoint.kt @@ -32,7 +32,6 @@ import org.jitsi.nlj.rtp.SsrcAssociationType import org.jitsi.nlj.rtp.VideoRtpPacket import org.jitsi.nlj.srtp.TlsRole import org.jitsi.nlj.stats.EndpointConnectionStats -import org.jitsi.nlj.stats.NodeStatsBlock import org.jitsi.nlj.transform.node.ConsumerNode import org.jitsi.nlj.util.Bandwidth import org.jitsi.nlj.util.LocalSsrcAssociation @@ -69,7 +68,8 @@ import org.jitsi.videobridge.message.SenderSourceConstraintsMessage import org.jitsi.videobridge.message.SenderVideoConstraintsMessage import org.jitsi.videobridge.relay.AudioSourceDesc import org.jitsi.videobridge.rest.root.debug.EndpointDebugFeatures -import org.jitsi.videobridge.sctp.SctpConfig +import org.jitsi.videobridge.sctp.DataChannelHandler +import org.jitsi.videobridge.sctp.SctpHandler import org.jitsi.videobridge.sctp.SctpManager import org.jitsi.videobridge.stats.PacketTransitStats import org.jitsi.videobridge.transport.dtls.DtlsTransport @@ -85,13 +85,11 @@ import org.jitsi_modified.sctp4j.SctpDataCallback import org.jitsi_modified.sctp4j.SctpServerSocket import org.jitsi_modified.sctp4j.SctpSocket import org.json.simple.JSONObject -import java.nio.ByteBuffer import java.security.SecureRandom import java.time.Clock import java.time.Duration import java.time.Instant import java.util.Optional -import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong import java.util.function.Supplier @@ -1205,100 +1203,4 @@ class Endpoint @JvmOverloads constructor( bandwidthProbing.bandwidthEstimationChanged(newValue) } } - - /** - * A node which can be placed in the pipeline to cache Data channel packets - * until the DataChannelStack is ready to handle them. - */ - private class DataChannelHandler : ConsumerNode("Data channel handler") { - private val dataChannelStackLock = Any() - private var dataChannelStack: DataChannelStack? = null - private val cachedDataChannelPackets = LinkedBlockingQueue() - - public override fun consume(packetInfo: PacketInfo) { - synchronized(dataChannelStackLock) { - when (val packet = packetInfo.packet) { - is DataChannelPacket -> { - dataChannelStack?.onIncomingDataChannelPacket( - ByteBuffer.wrap(packet.buffer), packet.sid, packet.ppid - ) ?: run { - cachedDataChannelPackets.add(packetInfo) - } - } - else -> Unit - } - } - } - - fun setDataChannelStack(dataChannelStack: DataChannelStack) { - // Submit this to the pool since we wait on the lock and process any - // cached packets here as well - - // Submit this to the pool since we wait on the lock and process any - // cached packets here as well - TaskPools.IO_POOL.execute { - // We grab the lock here so that we can set the SCTP manager and - // process any previously-cached packets as an atomic operation. - // It also prevents another thread from coming in via - // #doProcessPackets and processing packets at the same time in - // another thread, which would be a problem. - synchronized(dataChannelStackLock) { - this.dataChannelStack = dataChannelStack - cachedDataChannelPackets.forEach { - val dcp = it.packet as DataChannelPacket - dataChannelStack.onIncomingDataChannelPacket( - ByteBuffer.wrap(dcp.buffer), dcp.sid, dcp.ppid - ) - } - } - } - } - - override fun trace(f: () -> Unit) = f.invoke() - } - - /** - * A node which can be placed in the pipeline to cache SCTP packets until - * the SCTPManager is ready to handle them. - */ - private class SctpHandler : ConsumerNode("SCTP handler") { - private val sctpManagerLock = Any() - private var sctpManager: SctpManager? = null - private val numCachedSctpPackets = AtomicLong(0) - private val cachedSctpPackets = LinkedBlockingQueue(100) - - override fun consume(packetInfo: PacketInfo) { - synchronized(sctpManagerLock) { - if (SctpConfig.config.enabled) { - sctpManager?.handleIncomingSctp(packetInfo) ?: run { - numCachedSctpPackets.incrementAndGet() - cachedSctpPackets.add(packetInfo) - } - } - } - } - - override fun getNodeStats(): NodeStatsBlock = super.getNodeStats().apply { - addNumber("num_cached_packets", numCachedSctpPackets.get()) - } - - fun setSctpManager(sctpManager: SctpManager) { - // Submit this to the pool since we wait on the lock and process any - // cached packets here as well - TaskPools.IO_POOL.execute { - // We grab the lock here so that we can set the SCTP manager and - // process any previously-cached packets as an atomic operation. - // It also prevents another thread from coming in via - // #doProcessPackets and processing packets at the same time in - // another thread, which would be a problem. - synchronized(sctpManagerLock) { - this.sctpManager = sctpManager - cachedSctpPackets.forEach { sctpManager.handleIncomingSctp(it) } - cachedSctpPackets.clear() - } - } - } - - override fun trace(f: () -> Unit) = f.invoke() - } } diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/colibri2/Colibri2ConferenceHandler.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/colibri2/Colibri2ConferenceHandler.kt index bbf9a354f3..c1dd17603f 100644 --- a/jvb/src/main/kotlin/org/jitsi/videobridge/colibri2/Colibri2ConferenceHandler.kt +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/colibri2/Colibri2ConferenceHandler.kt @@ -74,14 +74,16 @@ class Colibri2ConferenceHandler( } for (r in conferenceModifyIQ.relays) { if (!RelayConfig.config.enabled) { - throw IqProcessingException(Condition.feature_not_implemented, "Octo is disable in configuration.") + throw IqProcessingException(Condition.feature_not_implemented, "Octo is disabled in configuration.") } - if (!WebsocketServiceConfig.config.enabled) { + if (!WebsocketServiceConfig.config.enabled && !SctpConfig.config.enabled) { logger.warn( - "Can not use a colibri2 relay, because colibri web sockets are not enabled. See " + - "https://github.com/jitsi/jitsi-videobridge/blob/master/doc/octo.md" + "Can not use a colibri2 relay, because neither SCTP nor colibri web sockets are enabled. See " + + "https://github.com/jitsi/jitsi-videobridge/blob/master/doc/relay.md" + ) + throw UnsupportedOperationException( + "Colibri websockets or SCTP need to be enabled to use a colibri2 relay." ) - throw UnsupportedOperationException("Colibri websockets need to be enabled to use a colibri2 relay.") } responseBuilder.addRelay(handleColibri2Relay(r)) } @@ -352,16 +354,42 @@ class Colibri2ConferenceHandler( ) } - if (c2relay.transport?.sctp != null) throw IqProcessingException( - Condition.feature_not_implemented, - "SCTP is not supported for relays." - ) + c2relay.transport?.sctp?.let { sctp -> + if (!SctpConfig.config.enabled) { + throw IqProcessingException( + Condition.feature_not_implemented, + "SCTP support is not configured" + ) + } + if (sctp.port != null && sctp.port != SctpManager.DEFAULT_SCTP_PORT) { + throw IqProcessingException( + Condition.bad_request, + "Specific SCTP port requested, not supported." + ) + } + + relay.createSctpConnection(sctp) + } c2relay.transport?.iceUdpTransport?.let { relay.setTransportInfo(it) } if (c2relay.create) { val transBuilder = Transport.getBuilder() transBuilder.setIceUdpExtension(relay.describeTransport()) - respBuilder.setTransport(transBuilder.build()) + c2relay.transport?.sctp?.let { + val role = if (it.role == Sctp.Role.CLIENT) { + Sctp.Role.SERVER + } else { + Sctp.Role.CLIENT + } + transBuilder.setSctp( + Sctp.Builder() + .setPort(SctpManager.DEFAULT_SCTP_PORT) + .setRole(role) + .build() + ) + + respBuilder.setTransport(transBuilder.build()) + } } for (media: Media in c2relay.media) { diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/relay/Relay.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/relay/Relay.kt index 9cea5b165f..5960898b03 100644 --- a/jvb/src/main/kotlin/org/jitsi/videobridge/relay/Relay.kt +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/relay/Relay.kt @@ -71,9 +71,15 @@ import org.jitsi.videobridge.EncodingsManager import org.jitsi.videobridge.Endpoint import org.jitsi.videobridge.PotentialPacketHandler import org.jitsi.videobridge.TransportConfig +import org.jitsi.videobridge.datachannel.DataChannelStack +import org.jitsi.videobridge.datachannel.protocol.DataChannelPacket +import org.jitsi.videobridge.datachannel.protocol.DataChannelProtocolConstants import org.jitsi.videobridge.message.BridgeChannelMessage import org.jitsi.videobridge.message.SourceVideoTypeMessage import org.jitsi.videobridge.rest.root.debug.EndpointDebugFeatures +import org.jitsi.videobridge.sctp.DataChannelHandler +import org.jitsi.videobridge.sctp.SctpHandler +import org.jitsi.videobridge.sctp.SctpManager import org.jitsi.videobridge.stats.PacketTransitStats import org.jitsi.videobridge.transport.dtls.DtlsTransport import org.jitsi.videobridge.transport.ice.IceTransport @@ -82,15 +88,24 @@ import org.jitsi.videobridge.util.TaskPools import org.jitsi.videobridge.util.looksLikeDtls import org.jitsi.videobridge.websocket.colibriWebSocketServiceSupplier import org.jitsi.xmpp.extensions.colibri.WebSocketPacketExtension +import org.jitsi.xmpp.extensions.colibri2.Sctp import org.jitsi.xmpp.extensions.jingle.DtlsFingerprintPacketExtension import org.jitsi.xmpp.extensions.jingle.IceUdpTransportPacketExtension +import org.jitsi_modified.sctp4j.SctpClientSocket +import org.jitsi_modified.sctp4j.SctpDataCallback +import org.jitsi_modified.sctp4j.SctpServerSocket +import org.jitsi_modified.sctp4j.SctpSocket import org.json.simple.JSONObject import java.time.Clock import java.time.Instant +import java.util.* import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong import java.util.function.Supplier +import kotlin.collections.ArrayList +import kotlin.collections.HashMap +import kotlin.collections.HashSet import kotlin.collections.sumOf /** @@ -149,6 +164,9 @@ class Relay @JvmOverloads constructor( */ private var expired = false + private val sctpHandler = SctpHandler() + private val dataChannelHandler = DataChannelHandler() + private val iceTransport = IceTransport(id, iceControlling, useUniquePort, logger, clock) private val dtlsTransport = DtlsTransport(logger).also { it.cryptex = CryptexConfig.relay } @@ -160,6 +178,19 @@ class Relay @JvmOverloads constructor( private val timelineLogger = logger.createChildLogger("timeline.${this.javaClass.name}") + /** + * The [SctpManager] instance we'll use to manage the SCTP connection + */ + private var sctpManager: SctpManager? = null + + private var dataChannelStack: DataChannelStack? = null + + /** + * The [SctpSocket] for this endpoint, if an SCTP connection was + * negotiated. + */ + private var sctpSocket: SctpSocket? = null + private val relayedEndpoints = HashMap() private val endpointsBySsrc = HashMap() private val endpointsLock = Any() @@ -251,6 +282,20 @@ class Relay @JvmOverloads constructor( setErrorHandler(queueErrorCounter) } + /** + * The queue which enforces sequential processing of incoming data channel messages + * to maintain processing order. + */ + private val incomingDataChannelMessagesQueue = PacketInfoQueue( + "${javaClass.simpleName}-incoming-data-channel-queue", + TaskPools.IO_POOL, + { packetInfo -> + dataChannelHandler.consume(packetInfo) + true + }, + TransportConfig.queueSize + ) + val debugState: JSONObject get() = JSONObject().apply { put("iceTransport", iceTransport.getDebugState()) @@ -330,7 +375,7 @@ class Relay @JvmOverloads constructor( private fun setupDtlsTransport() { dtlsTransport.incomingDataHandler = object : DtlsTransport.IncomingDataHandler { override fun dtlsAppDataReceived(buf: ByteArray, off: Int, len: Int) { - // TODO this@Relay.dtlsAppPacketReceived(buf, off, len) + dtlsAppPacketReceived(buf, off, len) } } dtlsTransport.outgoingDataHandler = object : DtlsTransport.OutgoingDataHandler { @@ -346,6 +391,11 @@ class Relay @JvmOverloads constructor( ) { logger.info("DTLS handshake complete") setSrtpInformation(chosenSrtpProtectionProfile, tlsRole, keyingMaterial) + when (val socket = sctpSocket) { + is SctpClientSocket -> connectSctpConnection(socket) + is SctpServerSocket -> acceptSctpConnection(socket) + else -> Unit + } scheduleRelayMessageTransportTimeout() } } @@ -378,6 +428,118 @@ class Relay @JvmOverloads constructor( senders.values.forEach { it.setSrtpInformation(srtpTransformers) } } + /** + * Create an SCTP connection for this Relay. If [sctpDesc.role] is [Sctp.Role.CLIENT], + * we will create the data channel locally, otherwise we will wait for the remote side + * to open it. + */ + fun createSctpConnection(sctpDesc: Sctp) { + val openDataChannelLocally = sctpDesc.role == Sctp.Role.CLIENT + + logger.cdebug { "Creating SCTP manager" } + // Create the SctpManager and provide it a method for sending SCTP data + val sctpManager = SctpManager( + { data, offset, length -> + dtlsTransport.sendDtlsData(data, offset, length) + 0 + }, + logger + ) + this.sctpManager = sctpManager + sctpHandler.setSctpManager(sctpManager) + val socket = if (sctpDesc.role == Sctp.Role.CLIENT) { + sctpManager.createClientSocket() + } else { + sctpManager.createServerSocket() + } + socket.eventHandler = object : SctpSocket.SctpSocketEventHandler { + override fun onReady() { + logger.info("SCTP connection is ready, creating the Data channel stack") + val dataChannelStack = DataChannelStack( + { data, sid, ppid -> socket.send(data, true, sid, ppid) }, + logger + ) + this@Relay.dataChannelStack = dataChannelStack + // This handles if the remote side will be opening the data channel + dataChannelStack.onDataChannelStackEvents { dataChannel -> + logger.info("Remote side opened a data channel.") + messageTransport.setDataChannel(dataChannel) + } + dataChannelHandler.setDataChannelStack(dataChannelStack) + if (openDataChannelLocally) { + // This logic is for opening the data channel locally + logger.info("Will open the data channel.") + val dataChannel = dataChannelStack.createDataChannel( + DataChannelProtocolConstants.RELIABLE, + 0, + 0, + 0, + "default" + ) + messageTransport.setDataChannel(dataChannel) + dataChannel.open() + } else { + logger.info("Will wait for the remote side to open the data channel.") + } + } + + override fun onDisconnected() { + logger.info("SCTP connection is disconnected") + } + } + socket.dataCallback = SctpDataCallback { data, sid, ssn, tsn, ppid, context, flags -> + // We assume all data coming over SCTP will be datachannel data + val dataChannelPacket = DataChannelPacket(data, 0, data.size, sid, ppid.toInt()) + // Post the rest of the task here because the current context is + // holding a lock inside the SctpSocket which can cause a deadlock + // if two endpoints are trying to send datachannel messages to one + // another (with stats broadcasting it can happen often) + incomingDataChannelMessagesQueue.add(PacketInfo(dataChannelPacket)) + } + if (socket is SctpServerSocket) { + socket.listen() + } + sctpSocket = socket + } + + fun connectSctpConnection(sctpClientSocket: SctpClientSocket) { + TaskPools.IO_POOL.execute { + // We don't want to block the thread calling + // onDtlsHandshakeComplete so run the socket acceptance in an IO + // pool thread + logger.info("Attempting to establish SCTP socket connection") + + if (!sctpClientSocket.connect(SctpManager.DEFAULT_SCTP_PORT)) { + logger.error("Failed to establish SCTP connection to remote side") + } + } + } + + fun acceptSctpConnection(sctpServerSocket: SctpServerSocket) { + TaskPools.IO_POOL.execute { + // We don't want to block the thread calling + // onDtlsHandshakeComplete so run the socket acceptance in an IO + // pool thread + // FIXME: This runs forever once the socket is closed ( + // accept never returns true). + logger.info("Attempting to establish SCTP socket connection") + var attempts = 0 + while (!sctpServerSocket.accept()) { + attempts++ + try { + Thread.sleep(100) + } catch (e: InterruptedException) { + break + } + if (attempts > 100) { + logger.error("Timed out waiting for SCTP connection from remote side") + break + } + } + logger.cdebug { "SCTP socket ${sctpServerSocket.hashCode()} accepted connection" } + } + } + /** * Sets the remote transport information (ICE candidates, DTLS fingerprints). * @@ -406,7 +568,7 @@ class Relay @JvmOverloads constructor( iceTransport.startConnectivityEstablishment(transportInfo) val websocketExtension = transportInfo.getFirstChildOfType(WebSocketPacketExtension::class.java) - websocketExtension?.url?.let { messageTransport.connectTo(it) } + websocketExtension?.url?.let { messageTransport.connectToWebsocket(it) } } fun describeTransport(): IceUdpTransportPacketExtension { @@ -414,29 +576,31 @@ class Relay @JvmOverloads constructor( iceTransport.describe(iceUdpTransportPacketExtension) dtlsTransport.describe(iceUdpTransportPacketExtension) - /* TODO: this should be dependent on videobridge.websockets.enabled, if we support that being - * disabled for relay. - */ - if (messageTransport.isActive) { - iceUdpTransportPacketExtension.addChildExtension( - WebSocketPacketExtension().apply { active = true } - ) - } else { - colibriWebSocketServiceSupplier.get()?.let { colibriWebsocketService -> - val urls = colibriWebsocketService.getColibriRelayWebSocketUrls( - conference.id, - id, - iceTransport.icePassword + if (sctpSocket != null) { + /* TODO: this should be dependent on videobridge.websockets.enabled, if we support that being + * disabled for relay. + */ + if (messageTransport.isActive) { + iceUdpTransportPacketExtension.addChildExtension( + WebSocketPacketExtension().apply { active = true } ) - if (urls.isEmpty()) { - logger.warn("No colibri relay URLs configured") - } - urls.forEach { - iceUdpTransportPacketExtension.addChildExtension( - WebSocketPacketExtension().apply { - url = it - } + } else { + colibriWebSocketServiceSupplier.get()?.let { colibriWebsocketService -> + val urls = colibriWebsocketService.getColibriRelayWebSocketUrls( + conference.id, + id, + iceTransport.icePassword ) + if (urls.isEmpty()) { + logger.warn("No colibri relay URLs configured") + } + urls.forEach { + iceUdpTransportPacketExtension.addChildExtension( + WebSocketPacketExtension().apply { + url = it + } + ) + } } } } @@ -561,6 +725,14 @@ class Relay @JvmOverloads constructor( } } + /** + * Handle a DTLS app packet (that is, a packet of some other protocol sent + * over DTLS) which has just been received. + */ + // TODO(brian): change sctp handler to take buf, off, len + fun dtlsAppPacketReceived(data: ByteArray, off: Int, len: Int) = + sctpHandler.processPacket(PacketInfo(UnparsedPacket(data, off, len))) + fun addRemoteEndpoint( id: String, statsId: String?, @@ -923,6 +1095,8 @@ class Relay @JvmOverloads constructor( transceiver.teardown() messageTransport.close() + sctpHandler.stop() + sctpManager?.closeConnection() } catch (t: Throwable) { logger.error("Exception while expiring: ", t) } diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/relay/RelayMessageTransport.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/relay/RelayMessageTransport.kt index bd537a5da6..2754ed9ea5 100644 --- a/jvb/src/main/kotlin/org/jitsi/videobridge/relay/RelayMessageTransport.kt +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/relay/RelayMessageTransport.kt @@ -23,6 +23,10 @@ import org.jitsi.utils.logging2.Logger import org.jitsi.videobridge.AbstractEndpointMessageTransport import org.jitsi.videobridge.VersionConfig import org.jitsi.videobridge.Videobridge +import org.jitsi.videobridge.datachannel.DataChannel +import org.jitsi.videobridge.datachannel.DataChannelStack.DataChannelMessageListener +import org.jitsi.videobridge.datachannel.protocol.DataChannelMessage +import org.jitsi.videobridge.datachannel.protocol.DataChannelStringMessage import org.jitsi.videobridge.message.AddReceiverMessage import org.jitsi.videobridge.message.BridgeChannelMessage import org.jitsi.videobridge.message.ClientHelloMessage @@ -35,6 +39,7 @@ import org.jitsi.videobridge.message.VideoTypeMessage import org.jitsi.videobridge.util.TaskPools import org.jitsi.videobridge.websocket.ColibriWebSocket import org.json.simple.JSONObject +import java.lang.ref.WeakReference import java.net.URI import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicInteger @@ -43,14 +48,15 @@ import java.util.function.Supplier /** * Handles the functionality related to sending and receiving COLIBRI messages - * for a [Relay]. + * for a [Relay]. Supports two underlying transport mechanisms -- + * WebRTC data channels and {@code WebSocket}s. */ class RelayMessageTransport( private val relay: Relay, private val statisticsSupplier: Supplier, private val eventHandler: EndpointMessageTransportEventHandler, parentLogger: Logger -) : AbstractEndpointMessageTransport(parentLogger), ColibriWebSocket.EventHandler { +) : AbstractEndpointMessageTransport(parentLogger), ColibriWebSocket.EventHandler, DataChannelMessageListener { /** * The last connected/accepted web-socket by this instance, if any. */ @@ -70,6 +76,16 @@ class RelayMessageTransport( * Use to synchronize access to [webSocket] */ private val webSocketSyncRoot = Any() + + /** + * Whether the last active transport channel (i.e. the last to receive a + * message from the remote endpoint) was the web socket (if `true`), + * or the WebRTC data channel (if `false`). + */ + private var webSocketLastActive = false + + private var dataChannel = WeakReference(null) + private val numOutgoingMessagesDropped = AtomicInteger(0) /** @@ -82,7 +98,7 @@ class RelayMessageTransport( /** * Connect the bridge channel message to the websocket URL specified */ - fun connectTo(url: String) { + fun connectToWebsocket(url: String) { if (this.url != null && this.url == url) { return } @@ -202,11 +218,23 @@ class RelayMessageTransport( super.sendMessage(dst, message) // Log message if (dst is ColibriWebSocket) { sendMessage(dst, message) + } else if (dst is DataChannel) { + sendMessage(dst, message) } else { throw IllegalArgumentException("unknown transport:$dst") } } + /** + * Sends a string via a particular [DataChannel]. + * @param dst the data channel to send through. + * @param message the message to send. + */ + private fun sendMessage(dst: DataChannel, message: BridgeChannelMessage) { + dst.sendString(message.toJson()) + statisticsSupplier.get().dataChannelMessagesSent.inc() + } + /** * Sends a string via a particular [ColibriWebSocket] instance. * @param dst the [ColibriWebSocket] through which to send the message. @@ -220,21 +248,69 @@ class RelayMessageTransport( statisticsSupplier.get().colibriWebSocketMessagesSent.inc() } + override fun onDataChannelMessage(dataChannelMessage: DataChannelMessage?) { + webSocketLastActive = false + statisticsSupplier.get().dataChannelMessagesReceived.inc() + if (dataChannelMessage is DataChannelStringMessage) { + onMessage(dataChannel.get(), dataChannelMessage.data) + } + } + /** * {@inheritDoc} */ public override fun sendMessage(msg: BridgeChannelMessage) { - if (webSocket == null) { + val dst = getActiveTransportChannel() + if (dst == null) { logger.debug("No available transport channel, can't send a message") numOutgoingMessagesDropped.incrementAndGet() } else { sentMessagesCounts.computeIfAbsent(msg.javaClass.simpleName) { AtomicLong() }.incrementAndGet() - sendMessage(webSocket, msg) + sendMessage(dst, msg) + } + } + + /** + * @return the active transport channel for this + * [RelayMessageTransport] (either the [.webSocket], or + * the WebRTC data channel represented by a [DataChannel]). + * + * The "active" channel is determined based on what channels are available, + * and which one was the last to receive data. That is, if only one channel + * is available, it will be returned. If two channels are available, the + * last one to have received data will be returned. Otherwise, `null` + * will be returned. + */ + // TODO(brian): seems like it'd be nice to have the websocket and datachannel + // share a common parent class (or, at least, have a class that is returned + // here and provides a common API but can wrap either a websocket or + // datachannel) + private fun getActiveTransportChannel(): Any? { + val dataChannel = dataChannel.get() + val webSocket = webSocket + var dst: Any? = null + if (webSocketLastActive) { + dst = webSocket + } + + // Either the socket was not the last active channel, + // or it has been closed. + if (dst == null) { + if (dataChannel != null && dataChannel.isReady) { + dst = dataChannel + } + } + + // Maybe the WebRTC data channel is the last active, but it is not + // currently available. If so, and a web-socket is available -- use it. + if (dst == null && webSocket != null) { + dst = webSocket } + return dst } override val isConnected: Boolean - get() = webSocket != null + get() = getActiveTransportChannel() != null val isActive: Boolean get() = outgoingWebsocket != null @@ -248,6 +324,7 @@ class RelayMessageTransport( if (ws != webSocket) { logger.info("Replacing an existing websocket.") webSocket?.session?.close(CloseStatus.NORMAL, "replaced") + webSocketLastActive = true webSocket = ws sendMessage(ws, createServerHello()) } else { @@ -276,6 +353,7 @@ class RelayMessageTransport( synchronized(webSocketSyncRoot) { if (ws == webSocket) { webSocket = null + webSocketLastActive = false logger.debug { "Web socket closed, statusCode $statusCode ( $reason)." } } } @@ -325,9 +403,35 @@ class RelayMessageTransport( return } statisticsSupplier.get().colibriWebSocketMessagesReceived.inc() + webSocketLastActive = true onMessage(ws, message) } + /** + * Sets the data channel for this endpoint. + * @param dataChannel the [DataChannel] to use for this transport + */ + fun setDataChannel(dataChannel: DataChannel) { + val prevDataChannel = this.dataChannel.get() + if (prevDataChannel == null) { + this.dataChannel = WeakReference(dataChannel) + // We install the handler first, otherwise the 'ready' might fire after we check it but before we + // install the handler + dataChannel.onDataChannelEvents { notifyTransportChannelConnected() } + if (dataChannel.isReady) { + notifyTransportChannelConnected() + } + dataChannel.onDataChannelMessage(this) + } else if (prevDataChannel === dataChannel) { + // TODO: i think we should be able to ensure this doesn't happen, + // so throwing for now. if there's a good + // reason for this, we can make this a no-op + throw Error("Re-setting the same data channel") + } else { + throw Error("Overwriting a previous data channel!") + } + } + override val debugState: JSONObject get() { val debugState = super.debugState diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/sctp/DataChannelHandler.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/sctp/DataChannelHandler.kt new file mode 100644 index 0000000000..df69991226 --- /dev/null +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/sctp/DataChannelHandler.kt @@ -0,0 +1,75 @@ +/* + * Copyright @ 2018 - present 8x8, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jitsi.videobridge.sctp + +import org.jitsi.nlj.PacketInfo +import org.jitsi.nlj.transform.node.ConsumerNode +import org.jitsi.videobridge.datachannel.DataChannelStack +import org.jitsi.videobridge.datachannel.protocol.DataChannelPacket +import org.jitsi.videobridge.util.TaskPools +import java.nio.ByteBuffer +import java.util.concurrent.LinkedBlockingQueue + +/** + * A node which can be placed in the pipeline to cache Data channel packets + * until the DataChannelStack is ready to handle them. + */ +class DataChannelHandler : ConsumerNode("Data channel handler") { + private val dataChannelStackLock = Any() + private var dataChannelStack: DataChannelStack? = null + private val cachedDataChannelPackets = LinkedBlockingQueue() + + public override fun consume(packetInfo: PacketInfo) { + synchronized(dataChannelStackLock) { + when (val packet = packetInfo.packet) { + is DataChannelPacket -> { + dataChannelStack?.onIncomingDataChannelPacket( + ByteBuffer.wrap(packet.buffer), packet.sid, packet.ppid + ) ?: run { + cachedDataChannelPackets.add(packetInfo) + } + } + else -> Unit + } + } + } + + fun setDataChannelStack(dataChannelStack: DataChannelStack) { + // Submit this to the pool since we wait on the lock and process any + // cached packets here as well + + // Submit this to the pool since we wait on the lock and process any + // cached packets here as well + TaskPools.IO_POOL.execute { + // We grab the lock here so that we can set the SCTP manager and + // process any previously-cached packets as an atomic operation. + // It also prevents another thread from coming in via + // #doProcessPackets and processing packets at the same time in + // another thread, which would be a problem. + synchronized(dataChannelStackLock) { + this.dataChannelStack = dataChannelStack + cachedDataChannelPackets.forEach { + val dcp = it.packet as DataChannelPacket + dataChannelStack.onIncomingDataChannelPacket( + ByteBuffer.wrap(dcp.buffer), dcp.sid, dcp.ppid + ) + } + } + } + } + + override fun trace(f: () -> Unit) = f.invoke() +} diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/sctp/SctpHandler.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/sctp/SctpHandler.kt new file mode 100644 index 0000000000..3fc1be6ed0 --- /dev/null +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/sctp/SctpHandler.kt @@ -0,0 +1,68 @@ +/* + * Copyright @ 2018 - present 8x8, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jitsi.videobridge.sctp + +import org.jitsi.nlj.PacketInfo +import org.jitsi.nlj.stats.NodeStatsBlock +import org.jitsi.nlj.transform.node.ConsumerNode +import org.jitsi.videobridge.util.TaskPools +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.atomic.AtomicLong + +/** + * A node which can be placed in the pipeline to cache SCTP packets until + * the SCTPManager is ready to handle them. + */ +class SctpHandler : ConsumerNode("SCTP handler") { + private val sctpManagerLock = Any() + private var sctpManager: SctpManager? = null + private val numCachedSctpPackets = AtomicLong(0) + private val cachedSctpPackets = LinkedBlockingQueue(100) + + override fun consume(packetInfo: PacketInfo) { + synchronized(sctpManagerLock) { + if (SctpConfig.config.enabled) { + sctpManager?.handleIncomingSctp(packetInfo) ?: run { + numCachedSctpPackets.incrementAndGet() + cachedSctpPackets.add(packetInfo) + } + } + } + } + + override fun getNodeStats(): NodeStatsBlock = super.getNodeStats().apply { + addNumber("num_cached_packets", numCachedSctpPackets.get()) + } + + fun setSctpManager(sctpManager: SctpManager) { + // Submit this to the pool since we wait on the lock and process any + // cached packets here as well + TaskPools.IO_POOL.execute { + // We grab the lock here so that we can set the SCTP manager and + // process any previously-cached packets as an atomic operation. + // It also prevents another thread from coming in via + // #doProcessPackets and processing packets at the same time in + // another thread, which would be a problem. + synchronized(sctpManagerLock) { + this.sctpManager = sctpManager + cachedSctpPackets.forEach { sctpManager.handleIncomingSctp(it) } + cachedSctpPackets.clear() + } + } + } + + override fun trace(f: () -> Unit) = f.invoke() +}