diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpLayerDesc.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpLayerDesc.kt index 49f0474c45..90e412f571 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpLayerDesc.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpLayerDesc.kt @@ -28,8 +28,7 @@ import org.jitsi.utils.OrderedJsonObject * * @author George Politis */ -abstract class RtpLayerDesc -constructor( +abstract class RtpLayerDesc( /** * The index of this instance's encoding in the source encoding array. */ @@ -54,7 +53,7 @@ constructor( * represents. The actual frame rate may be less due to bad network or * system load. [NO_FRAME_RATE] for unknown. */ - val frameRate: Double, + var frameRate: Double, ) { abstract fun copy(height: Int = this.height, tid: Int = this.tid, inherit: Boolean = true): RtpLayerDesc @@ -63,6 +62,8 @@ constructor( */ protected var bitrateTracker = BitrateCalculator.createBitrateTracker() + var targetBitrate: Bandwidth? = null + /** * @return the "id" of this layer within this encoding. This is a server-side id and should * not be confused with any encoding id defined in the client (such as the @@ -87,6 +88,7 @@ constructor( */ internal open fun inheritFrom(other: RtpLayerDesc) { inheritStatistics(other.bitrateTracker) + targetBitrate = other.targetBitrate } /** @@ -110,12 +112,6 @@ constructor( */ abstract fun getBitrate(nowMs: Long): Bandwidth - /** - * Expose [getBitrate] as a [Double] in order to make it accessible from java (since [Bandwidth] is an inline - * class). - */ - fun getBitrateBps(nowMs: Long): Double = getBitrate(nowMs).bps - /** * Recursively checks this layer and its dependencies to see if the bitrate is zero. * Note that unlike [calcBitrate] this does not avoid double-visiting layers; the overhead @@ -131,6 +127,7 @@ constructor( addNumber("height", height) addNumber("index", index) addNumber("bitrate_bps", getBitrate(System.currentTimeMillis()).bps) + addNumber("target_bitrate", targetBitrate?.bps ?: 0) } fun debugState(): OrderedJsonObject = getNodeStats().toJson().apply { put("indexString", indexString()) } diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpReceiverImpl.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpReceiverImpl.kt index 079001f481..85b33a5e8a 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpReceiverImpl.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpReceiverImpl.kt @@ -56,6 +56,7 @@ import org.jitsi.nlj.transform.node.incoming.VideoBitrateCalculator import org.jitsi.nlj.transform.node.incoming.VideoMuteNode import org.jitsi.nlj.transform.node.incoming.VideoParser import org.jitsi.nlj.transform.node.incoming.VideoQualityLayerLookup +import org.jitsi.nlj.transform.node.incoming.VlaReaderNode import org.jitsi.nlj.transform.packetPath import org.jitsi.nlj.transform.pipeline import org.jitsi.nlj.util.Bandwidth @@ -248,6 +249,7 @@ class RtpReceiverImpl @JvmOverloads constructor( node(videoParser) node(VideoQualityLayerLookup(logger)) node(videoBitrateCalculator) + node(VlaReaderNode(streamInformationStore, logger)) node(packetHandlerWrapper) } } diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpSender.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpSender.kt index 241fbde39e..c6d943a0d1 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpSender.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpSender.kt @@ -16,6 +16,7 @@ package org.jitsi.nlj import org.jitsi.nlj.rtp.LossListener +import org.jitsi.nlj.rtp.RtpExtensionType import org.jitsi.nlj.rtp.TransportCcEngine import org.jitsi.nlj.rtp.bandwidthestimation.BandwidthEstimator import org.jitsi.nlj.srtp.SrtpTransformers @@ -47,6 +48,7 @@ abstract class RtpSender : abstract fun setFeature(feature: Features, enabled: Boolean) abstract fun isFeatureEnabled(feature: Features): Boolean abstract fun tearDown() + abstract fun addRtpExtensionToRetain(extensionType: RtpExtensionType) /** * An optional function to be executed for each RTP packet, as the first step of the send pipeline. diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpSenderImpl.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpSenderImpl.kt index 6a89f12b99..15895389f2 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpSenderImpl.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpSenderImpl.kt @@ -23,6 +23,7 @@ import org.jitsi.nlj.rtcp.NackHandler import org.jitsi.nlj.rtcp.RtcpEventNotifier import org.jitsi.nlj.rtcp.RtcpSrUpdater import org.jitsi.nlj.rtp.LossListener +import org.jitsi.nlj.rtp.RtpExtensionType import org.jitsi.nlj.rtp.TransportCcEngine import org.jitsi.nlj.rtp.bandwidthestimation.BandwidthEstimator import org.jitsi.nlj.rtp.bandwidthestimation.GoogleCcEstimator @@ -111,6 +112,7 @@ class RtpSenderImpl( private val srtcpEncryptWrapper = SrtcpEncryptNode() private val toggleablePcapWriter = ToggleablePcapWriter(logger, "$id-tx") private val outgoingPacketCache = PacketCacher() + private val headerExtensionStripper = HeaderExtStripper(streamInformationStore) private val absSendTime = AbsSendTime(streamInformationStore) private val statsTracker = OutgoingStatisticsTracker() private val packetStreamStats = PacketStreamStatsNode() @@ -144,7 +146,7 @@ class RtpSenderImpl( outgoingRtpRoot = pipeline { node(PluggableTransformerNode("RTP pre-processor") { preProcesor }) node(AudioRedHandler(streamInformationStore, logger)) - node(HeaderExtStripper(streamInformationStore)) + node(headerExtensionStripper) node(outgoingPacketCache) node(absSendTime) node(statsTracker) @@ -333,6 +335,10 @@ class RtpSenderImpl( toggleablePcapWriter.disable() } + override fun addRtpExtensionToRetain(extensionType: RtpExtensionType) { + headerExtensionStripper.addRtpExtensionToRetain(extensionType) + } + companion object { var queueErrorCounter = CountingErrorHandler() diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/Transceiver.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/Transceiver.kt index 8c4b468490..b0c066e0d2 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/Transceiver.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/Transceiver.kt @@ -18,6 +18,7 @@ package org.jitsi.nlj import org.jitsi.nlj.format.PayloadType import org.jitsi.nlj.rtcp.RtcpEventNotifier import org.jitsi.nlj.rtp.RtpExtension +import org.jitsi.nlj.rtp.RtpExtensionType import org.jitsi.nlj.rtp.bandwidthestimation.BandwidthEstimator import org.jitsi.nlj.srtp.SrtpTransformers import org.jitsi.nlj.srtp.SrtpUtil @@ -211,6 +212,10 @@ class Transceiver( rtpReceiver.handleEvent(localSsrcSetEvent) } + fun addRtpExtensionToRetain(extensionType: RtpExtensionType) { + rtpSender.addRtpExtensionToRetain(extensionType) + } + fun receivesSsrc(ssrc: Long): Boolean = streamInformationStore.receiveSsrcs.contains(ssrc) val receiveSsrcs: Set diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/RtpExtensions.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/RtpExtensions.kt index 90799e11a8..8d86e8af7e 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/RtpExtensions.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/RtpExtensions.kt @@ -103,7 +103,13 @@ enum class RtpExtensionType(val uri: String) { */ AV1_DEPENDENCY_DESCRIPTOR( "https://aomediacodec.github.io/av1-rtp-spec/#dependency-descriptor-rtp-header-extension" - ); + ), + + /** + * Video Layers Allocation + * https://webrtc.googlesource.com/src/+/refs/heads/main/docs/native-code/rtp-hdrext/video-layers-allocation00 + */ + VLA("http://www.webrtc.org/experiments/rtp-hdrext/video-layers-allocation00"); companion object { private val uriMap = RtpExtensionType.values().associateBy(RtpExtensionType::uri) diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/incoming/VlaReaderNode.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/incoming/VlaReaderNode.kt new file mode 100644 index 0000000000..b4375482de --- /dev/null +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/incoming/VlaReaderNode.kt @@ -0,0 +1,105 @@ +/* + * Copyright @ 2024-Present 8x8, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jitsi.nlj.transform.node.incoming + +import org.jitsi.nlj.Event +import org.jitsi.nlj.MediaSourceDesc +import org.jitsi.nlj.PacketInfo +import org.jitsi.nlj.SetMediaSourcesEvent +import org.jitsi.nlj.findRtpSource +import org.jitsi.nlj.rtp.RtpExtensionType.VLA +import org.jitsi.nlj.transform.node.ObserverNode +import org.jitsi.nlj.util.ReadOnlyStreamInformationStore +import org.jitsi.nlj.util.kbps +import org.jitsi.rtp.rtp.RtpPacket +import org.jitsi.rtp.rtp.header_extensions.VlaExtension +import org.jitsi.utils.logging2.Logger +import org.jitsi.utils.logging2.LoggerImpl +import org.jitsi.utils.logging2.cdebug +import org.jitsi.utils.logging2.createChildLogger + +/** + * A node which reads the Video Layers Allocation (VLA) RTP header extension and updates the media sources. + */ +class VlaReaderNode( + streamInformationStore: ReadOnlyStreamInformationStore, + parentLogger: Logger = LoggerImpl(VlaReaderNode::class.simpleName) +) : ObserverNode("Video Layers Allocation reader") { + private val logger = createChildLogger(parentLogger) + private var vlaExtId: Int? = null + private var mediaSourceDescs: Array = arrayOf() + + init { + streamInformationStore.onRtpExtensionMapping(VLA) { + vlaExtId = it + logger.debug("VLA extension ID set to $it") + } + } + + override fun handleEvent(event: Event) { + when (event) { + is SetMediaSourcesEvent -> { + mediaSourceDescs = event.mediaSourceDescs.copyOf() + logger.cdebug { "Media sources changed:\n${mediaSourceDescs.joinToString()}" } + } + } + } + + override fun observe(packetInfo: PacketInfo) { + val rtpPacket = packetInfo.packetAs() + vlaExtId?.let { + rtpPacket.getHeaderExtension(it)?.let { ext -> + val vla = try { + VlaExtension.parse(ext) + } catch (e: Exception) { + logger.warn("Failed to parse VLA extension", e) + return + } + + val sourceDesc = mediaSourceDescs.findRtpSource(rtpPacket) + + logger.debug("Found VLA=$vla for sourceDesc=$sourceDesc") + + vla.forEachIndexed { streamIdx, stream -> + val rtpEncoding = sourceDesc?.rtpEncodings?.get(streamIdx) + stream.spatialLayers.forEach { spatialLayer -> + spatialLayer.targetBitratesKbps.forEachIndexed { tlIdx, targetBitrateKbps -> + rtpEncoding?.layers?.find { + // With VP8 simulcast all layers have sid -1 + (it.sid == spatialLayer.id || it.sid == -1) && it.tid == tlIdx + }?.let { layer -> + logger.debug( + "Setting target bitrate for rtpEncoding=$rtpEncoding layer=$layer to " + + "${targetBitrateKbps.kbps} (res=${spatialLayer.res})" + ) + layer.targetBitrate = targetBitrateKbps.kbps + spatialLayer.res?.let { res -> + if (layer.height > 0 && layer.height != res.height) { + logger.warn("Updating layer height from ${layer.height} to ${res.height}") + } + layer.height = res.height + layer.frameRate = res.maxFramerate.toDouble() + } + } + } + } + } + } + } + } + + override fun trace(f: () -> Unit) {} +} diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/outgoing/HeaderExtStripper.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/outgoing/HeaderExtStripper.kt index b697d9a763..979f271946 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/outgoing/HeaderExtStripper.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/outgoing/HeaderExtStripper.kt @@ -23,34 +23,40 @@ import org.jitsi.nlj.util.ReadOnlyStreamInformationStore import org.jitsi.rtp.rtp.RtpPacket /** - * Strip all hop-by-hop header extensions. Currently this leaves ssrc-audio-level and video-orientation, + * Strip all hop-by-hop header extensions. By default, this leaves ssrc-audio-level and video-orientation, * plus the AV1 dependency descriptor if the packet is an Av1DDPacket. */ class HeaderExtStripper( - streamInformationStore: ReadOnlyStreamInformationStore + streamInformationStore: ReadOnlyStreamInformationStore, ) : ModifierNode("Strip header extensions") { private var retainedExts: Set = emptySet() private var retainedExtsWithAv1DD: Set = emptySet() + private var retainedExtTypes = defaultRetainedExtTypes init { retainedExtTypes.forEach { rtpExtensionType -> streamInformationStore.onRtpExtensionMapping(rtpExtensionType) { it?.let { - retainedExts = retainedExts.plus(it) - retainedExtsWithAv1DD = retainedExtsWithAv1DD.plus(it) + retainedExts += it + retainedExtsWithAv1DD += it } } } streamInformationStore.onRtpExtensionMapping(RtpExtensionType.AV1_DEPENDENCY_DESCRIPTOR) { - it?.let { retainedExtsWithAv1DD = retainedExtsWithAv1DD.plus(it) } + it?.let { retainedExtsWithAv1DD += it } } } + fun addRtpExtensionToRetain(extensionType: RtpExtensionType) { + retainedExtTypes += extensionType + } + override fun modify(packetInfo: PacketInfo): PacketInfo { val rtpPacket = packetInfo.packetAs() val retained = if (rtpPacket is Av1DDPacket) retainedExtsWithAv1DD else retainedExts + // TODO: we should also retain any extensions that were not signaled. rtpPacket.removeHeaderExtensionsExcept(retained) return packetInfo @@ -59,7 +65,7 @@ class HeaderExtStripper( override fun trace(f: () -> Unit) = f.invoke() companion object { - private val retainedExtTypes: Set = setOf( + val defaultRetainedExtTypes: Set = setOf( RtpExtensionType.SSRC_AUDIO_LEVEL, RtpExtensionType.VIDEO_ORIENTATION ) diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/BitrateController.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/BitrateController.kt index bed8c99437..8aaa5d9f64 100644 --- a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/BitrateController.kt +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/BitrateController.kt @@ -19,6 +19,7 @@ import org.jitsi.nlj.MediaSourceDesc import org.jitsi.nlj.PacketInfo import org.jitsi.nlj.format.PayloadType import org.jitsi.nlj.format.PayloadTypeEncoding +import org.jitsi.nlj.util.Bandwidth import org.jitsi.nlj.util.bps import org.jitsi.rtp.rtcp.RtcpSrPacket import org.jitsi.utils.event.SyncEventEmitter @@ -192,13 +193,26 @@ class BitrateController @JvmOverloads constructor( val nowMs = clock.instant().toEpochMilli() val allocation = bandwidthAllocator.allocation - allocation.allocations.forEach { - it.targetLayer?.getBitrate(nowMs)?.let { targetBitrate -> - totalTargetBitrate += targetBitrate - it.mediaSource?.primarySSRC?.let { primarySsrc -> activeSsrcs.add(primarySsrc) } + allocation.allocations.forEach { singleAllocation -> + val allocationTargetBitrate: Bandwidth? = if (config.useVlaTargetBitrate) { + singleAllocation.targetLayer?.targetBitrate ?: singleAllocation.targetLayer?.getBitrate(nowMs) + } else { + singleAllocation.targetLayer?.getBitrate(nowMs) + } + + allocationTargetBitrate?.let { + totalTargetBitrate += it + singleAllocation.mediaSource?.primarySSRC?.let { primarySsrc -> activeSsrcs.add(primarySsrc) } } - it.idealLayer?.getBitrate(nowMs)?.let { idealBitrate -> - totalIdealBitrate += idealBitrate + + val allocationIdealBitrate: Bandwidth? = if (config.useVlaTargetBitrate) { + singleAllocation.idealLayer?.targetBitrate ?: singleAllocation.idealLayer?.getBitrate(nowMs) + } else { + singleAllocation.idealLayer?.getBitrate(nowMs) + } + + allocationIdealBitrate?.let { + totalIdealBitrate += it } } @@ -220,18 +234,24 @@ class BitrateController @JvmOverloads constructor( var totalTargetBps = 0.0 var totalIdealBps = 0.0 + var totalTargetMeasuredBps = 0.0 + var totalIdealMeasuredBps = 0.0 allocation.allocations.forEach { it.targetLayer?.getBitrate(nowMs)?.let { bitrate -> totalTargetBps += bitrate.bps } it.idealLayer?.getBitrate(nowMs)?.let { bitrate -> totalIdealBps += bitrate.bps } + it.targetLayer?.targetBitrate?.let { bitrate -> totalTargetMeasuredBps += bitrate.bps } + it.idealLayer?.targetBitrate?.let { bitrate -> totalIdealMeasuredBps += bitrate.bps } trace( diagnosticContext .makeTimeSeriesPoint("allocation_for_source", nowMs) .addField("remote_endpoint_id", it.endpointId) .addField("target_idx", it.targetLayer?.index ?: -1) .addField("ideal_idx", it.idealLayer?.index ?: -1) - .addField("target_bps", it.targetLayer?.getBitrate(nowMs)?.bps ?: -1) - .addField("ideal_bps", it.idealLayer?.getBitrate(nowMs)?.bps ?: -1) + .addField("target_bps_measured", it.targetLayer?.getBitrate(nowMs)?.bps ?: -1) + .addField("target_bps", it.targetLayer?.targetBitrate?.bps ?: -1) + .addField("ideal_bps_measured", it.idealLayer?.getBitrate(nowMs)?.bps ?: -1) + .addField("ideal_bps", it.idealLayer?.targetBitrate?.bps ?: -1) ) } @@ -240,6 +260,8 @@ class BitrateController @JvmOverloads constructor( .makeTimeSeriesPoint("allocation", nowMs) .addField("total_target_bps", totalTargetBps) .addField("total_ideal_bps", totalIdealBps) + .addField("total_target_measured_bps", totalTargetMeasuredBps) + .addField("total_ideal_measured_bps", totalIdealMeasuredBps) ) } diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/SingleSourceAllocation.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/SingleSourceAllocation.kt index d69a22f465..12f730edaf 100644 --- a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/SingleSourceAllocation.kt +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/allocation/SingleSourceAllocation.kt @@ -269,7 +269,16 @@ internal class SingleSourceAllocation( if (constraints.maxHeight == 0 || !source.hasRtpLayers()) { return Layers.noLayers } - val layers = source.rtpLayers.map { LayerSnapshot(it, it.getBitrateBps(nowMs)) } + val layers = source.rtpLayers.map { + LayerSnapshot( + it, + if (config.useVlaTargetBitrate) { + it.targetBitrate?.bps ?: it.getBitrate(nowMs).bps + } else { + it.getBitrate(nowMs).bps + } + ) + } return when (source.videoType) { VideoType.CAMERA -> selectLayersForCamera(layers, constraints) diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/config/BitrateControllerConfig.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/config/BitrateControllerConfig.kt index 5c44519e43..1f9f3ed82b 100644 --- a/jvb/src/main/kotlin/org/jitsi/videobridge/cc/config/BitrateControllerConfig.kt +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/cc/config/BitrateControllerConfig.kt @@ -110,6 +110,10 @@ class BitrateControllerConfig private constructor() { .convertFrom { Bandwidth.fromString(it) } } + val useVlaTargetBitrate: Boolean by config { + "videobridge.cc.use-vla-target-bitrate".from(JitsiConfig.newConfig) + } + companion object { @JvmField val config = BitrateControllerConfig() diff --git a/jvb/src/main/kotlin/org/jitsi/videobridge/relay/Relay.kt b/jvb/src/main/kotlin/org/jitsi/videobridge/relay/Relay.kt index 6d0beff72e..b5d2b35cd5 100644 --- a/jvb/src/main/kotlin/org/jitsi/videobridge/relay/Relay.kt +++ b/jvb/src/main/kotlin/org/jitsi/videobridge/relay/Relay.kt @@ -285,6 +285,7 @@ class Relay @JvmOverloads constructor( }, external = true ) + addRtpExtensionToRetain(RtpExtensionType.VLA) } /** diff --git a/jvb/src/main/resources/reference.conf b/jvb/src/main/resources/reference.conf index 7eaf35696c..9b90f6cf0b 100644 --- a/jvb/src/main/resources/reference.conf +++ b/jvb/src/main/resources/reference.conf @@ -77,6 +77,10 @@ videobridge { # If set allows receivers to override bandwidth estimation (BWE) with a specific value signaled over the bridge # channel (limited to the configured value). If not set, receivers are not allowed to override BWE. // assumed-bandwidth-limit = 10 Mbps + + # Whether to use the target bitrate signaled in the VLA extension for allocation. When disabled we use the measured + # bitrate instead (preserving previous behavior). + use-vla-target-bitrate = false } # Whether to indicate support for cryptex header extension encryption (RFC 9335) cryptex { diff --git a/rtp/src/main/kotlin/org/jitsi/rtp/rtp/header_extensions/VlaExtension.kt b/rtp/src/main/kotlin/org/jitsi/rtp/rtp/header_extensions/VlaExtension.kt new file mode 100644 index 0000000000..7efbee9ac1 --- /dev/null +++ b/rtp/src/main/kotlin/org/jitsi/rtp/rtp/header_extensions/VlaExtension.kt @@ -0,0 +1,129 @@ +/* + * Copyright @ 2024-Present 8x8, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jitsi.rtp.rtp.header_extensions + +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings +import org.jitsi.rtp.rtp.RtpPacket +import org.jitsi.rtp.rtp.header_extensions.VlaExtension.Stream +import org.jitsi.rtp.util.BitReader + +/** + * A parser for the Video Layers Allocation RTP header extension. + * https://webrtc.googlesource.com/src/+/refs/heads/main/docs/native-code/rtp-hdrext/video-layers-allocation00 + */ +@SuppressFBWarnings("SF_SWITCH_NO_DEFAULT", justification = "False positive") +class VlaExtension { + companion object { + fun parse(ext: RtpPacket.HeaderExtension): ParsedVla { + val empty = ext.dataLengthBytes == 1 && ext.buffer[ext.dataOffset] == 0.toByte() + if (empty) { + return emptyList() + } + + val reader = BitReader(ext.buffer, ext.dataOffset, ext.dataLengthBytes) + reader.skipBits(2) // RID + val ns = reader.bits(2) + 1 + val slBm = reader.bits(4) + val slBms = IntArray(4) { i -> if (i < ns) slBm else 0 } + if (slBm == 0) { + slBms[0] = reader.bits(4) + if (ns > 1) { + slBms[1] = reader.bits(4) + if (ns > 2) { + slBms[2] = reader.bits(4) + if (ns > 3) { + slBms[3] = reader.bits(4) + } + } + } + if (ns == 1 || ns == 3) { + reader.skipBits(4) + } + } + + val slCount = slBms.sumOf { it.countOneBits() } + val tlCountLenBytes = slCount / 4 + if (slCount % 4 != 0) 1 else 0 + + val tlCountReader = reader.clone(tlCountLenBytes) + reader.skipBits(tlCountLenBytes * 8) + + val streams = ArrayList(ns) + (0 until ns).forEach { streamIdx -> + val spatialLayers = ArrayList() + val stream = Stream(streamIdx, spatialLayers) + streams.add(stream) + + (0 until 4).forEach { slIdx -> + if ((slBms[streamIdx] and (1 shl slIdx)) != 0) { + val targetBitrates = buildList { + repeat(tlCountReader.bits(2) + 1) { + add(reader.leb128()) + } + } + spatialLayers.add( + SpatialLayer( + slIdx, + targetBitrates, + null + ) + ) + } + } + } + + (0 until ns).forEach outer@{ streamIdx -> + (0 until 4).forEach { slIdx -> + if ((slBms[streamIdx] and (1 shl slIdx)) != 0) { + if (reader.remainingBits() < 40) { + return@outer + } + val sl = streams[streamIdx].spatialLayers[slIdx] + sl.res = ResolutionAndFrameRate( + reader.bits(16) + 1, + reader.bits(16) + 1, + reader.bits(8) + ) + } + } + } + + return streams + } + } + + data class ResolutionAndFrameRate( + val width: Int, + val height: Int, + val maxFramerate: Int + ) + + data class SpatialLayer( + val id: Int, + // The target bitrates for each temporal layer in this spatial layer + val targetBitratesKbps: List, + var res: ResolutionAndFrameRate? + ) { + override fun toString(): String = "SpatialLayer(id=$id, targetBitratesKbps=$targetBitratesKbps, " + + "width=${res?.width}, height=${res?.height}, maxFramerate=${res?.maxFramerate})" + } + + data class Stream( + val id: Int, + val spatialLayers: List + ) +} + +typealias ParsedVla = List diff --git a/rtp/src/main/kotlin/org/jitsi/rtp/util/BitReader.kt b/rtp/src/main/kotlin/org/jitsi/rtp/util/BitReader.kt index d141b7aed4..9dcbd9bc9f 100644 --- a/rtp/src/main/kotlin/org/jitsi/rtp/util/BitReader.kt +++ b/rtp/src/main/kotlin/org/jitsi/rtp/util/BitReader.kt @@ -26,6 +26,22 @@ class BitReader(val buf: ByteArray, private val byteOffset: Int = 0, private val private var offset = byteOffset * 8 private val byteBound = byteOffset + byteLength + init { + check(byteOffset >= 0) { "byteOffset must be >= 0" } + check(byteBound <= buf.size) { "byteOffset + byteLength must be <= buf.size" } + } + + /** Clone with the current state (offset) and a new length in bytes. */ + fun clone(newByteLength: Int): BitReader { + check(offset % 8 == 0) { "Cannot clone BitReader with unaligned offset" } + check(offset / 8 + newByteLength <= byteBound) { + "newByteLength $newByteLength exceeds buffer length $byteLength after offset $byteOffset" + } + return BitReader(buf, offset / 8, newByteLength) + } + + fun remainingBits(): Int = byteBound * 8 - offset + /** Read a single bit from the buffer, as a boolean, incrementing the offset. */ fun bitAsBoolean(): Boolean { val byteIdx = offset / 8 @@ -98,6 +114,20 @@ class BitReader(val buf: ByteArray, private val byteOffset: Int = 0, private val return (v shl 1) - m + extraBit } + /** + * Read a LEB128-encoded unsigned integer. + * https://aomediacodec.github.io/av1-spec/#leb128 + */ + fun leb128(): Long { + var value = 0L + (0..8).forEach { i -> + val hasNext = bitAsBoolean() + value = value or (bits(7).toLong() shl (i * 7)) + if (!hasNext) return value + } + return value + } + /** Reset the reader to the beginning of the buffer */ fun reset() { offset = byteOffset * 8 diff --git a/rtp/src/test/kotlin/org/jitsi/rtp/extensions/VlaExtensionTest.kt b/rtp/src/test/kotlin/org/jitsi/rtp/extensions/VlaExtensionTest.kt new file mode 100644 index 0000000000..d712ee5e97 --- /dev/null +++ b/rtp/src/test/kotlin/org/jitsi/rtp/extensions/VlaExtensionTest.kt @@ -0,0 +1,372 @@ +/* + * Copyright @ 2024-Present 8x8, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jitsi.rtp.extensions + +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.core.spec.style.ShouldSpec +import io.kotest.matchers.shouldBe +import org.jitsi.rtp.rtp.RtpPacket.HeaderExtension +import org.jitsi.rtp.rtp.header_extensions.ParsedVla +import org.jitsi.rtp.rtp.header_extensions.VlaExtension +import org.jitsi.rtp.rtp.header_extensions.VlaExtension.ResolutionAndFrameRate +import org.jitsi.rtp.rtp.header_extensions.VlaExtension.SpatialLayer +import org.jitsi.rtp.rtp.header_extensions.VlaExtension.Stream + +@Suppress("ktlint:standard:no-multi-spaces") +class VlaExtensionTest : ShouldSpec() { + init { + context("Empty") { + parse(0x00) shouldBe emptyList() + } + context("VP8 single stream (with resolution)") { + parse( + 0b0000_0001, + 0b1000_0000.toByte(), + 0b0101_0000, + 0b0111_1000, + 0b1100_1000.toByte(), + 0b0000_0001, + 0b0000_0001, + 0b0011_1111, + 0b0000_0000, + 0b1011_0011.toByte(), + 0b0010_0001 + ) shouldBe listOf( + Stream( + 0, + listOf( + SpatialLayer( + 0, + listOf(80, 120, 200), + ResolutionAndFrameRate(320, 180, 33) + ) + ) + ) + ) + } + context("VP8 single stream (without resolution)") { + parse( + 0b0000_0001, + 0b1000_0000.toByte(), + 0b0101_0000, + 0b0111_1000, + 0b1100_1000.toByte(), + 0b0000_0001 + ) shouldBe listOf( + Stream( + 0, + listOf( + SpatialLayer( + 0, + listOf(80, 120, 200), + null + ) + ) + ) + ) + } + context("VP8 simulcast stream (with resolutions)") { + parse( + 0b0010_0001, + 0b1010_1000.toByte(), + 0b0011_1100, + 0b0101_1010, + 0b1001_0110.toByte(), + 0b0000_0001, + 0b1100_1000.toByte(), + 0b0000_0001, + 0b1010_1100.toByte(), + 0b0000_0010, + 0b1111_0100.toByte(), + 0b0000_0011, + 0b1110_1110.toByte(), + 0b0000_0011, + 0b1110_0110.toByte(), + 0b0000_0101, + 0b1101_0100.toByte(), + 0b0000_1001, + 0b0000_0001, + 0b0011_1111, + 0b0000_0000, + 0b1011_0011.toByte(), + 0b0001_1111, + 0b0000_0010, + 0b0111_1111, + 0b0000_0001, + 0b0110_0111, + 0b0001_1111, + 0b0000_0100, + 0b1111_1111.toByte(), + 0b0000_0010, + 0b1100_1111.toByte(), + 0b0001_1111 + ) shouldBe listOf( + Stream( + 0, + listOf( + SpatialLayer( + 0, + listOf(60, 90, 150), + ResolutionAndFrameRate(320, 180, 31) + ) + ) + ), + Stream( + 1, + listOf( + SpatialLayer( + 0, + listOf(200, 300, 500), + ResolutionAndFrameRate(640, 360, 31) + ) + ) + ), + Stream( + 2, + listOf( + SpatialLayer( + 0, + listOf(494, 742, 1236), + ResolutionAndFrameRate(1280, 720, 31) + ) + ) + ) + ) + } + context("VP8 simulcast stream (without resolutions)") { + parse( + 0b1010_0001.toByte(), + 0b1010_1000.toByte(), + 0b1001_0101.toByte(), + 0b0000_0010, + 0b1001_1111.toByte(), + 0b0000_0011, + 0b1011_0100.toByte(), + 0b0000_0101, + 0b1000_1101.toByte(), + 0b0000_1001, + 0b1101_0011.toByte(), + 0b0000_1101, + 0b1110_0000.toByte(), + 0b0001_0110, + 0b1101_0000.toByte(), + 0b0000_1111, + 0b1011_1000.toByte(), + 0b0001_0111, + 0b1000_1000.toByte(), + 0b0010_0111 + ) shouldBe listOf( + Stream( + 0, + listOf( + SpatialLayer( + 0, + listOf(277, 415, 692), + null + ) + ) + ), + Stream( + 1, + listOf( + SpatialLayer( + 0, + listOf(1165, 1747, 2912), + null + ) + ) + ), + Stream( + 2, + listOf( + SpatialLayer( + 0, + listOf(2000, 3000, 5000), + null + ) + ) + ) + ) + } + context("VP9 SVC (with resolutions)") { + parse( + 0b0000_0111, + 0b1010_1000.toByte(), + 0b0100_1101, + 0b0110_0100, + 0b1000_1110.toByte(), + 0b0000_0001, + 0b1101_0111.toByte(), + 0b0000_0001, + 0b1001_0111.toByte(), + 0b0000_0010, + 0b1000_1101.toByte(), + 0b0000_0011, + 0b1101_0110.toByte(), + 0b0000_0010, + 0b1011_1101.toByte(), + 0b0000_0011, + 0b1111_1001.toByte(), + 0b0000_0100, + 0b0000_0001, + 0b0011_1111, + 0b0000_0000, + 0b1011_0011.toByte(), + 0b0010_0000, + 0b0000_0010, + 0b0111_1111, + 0b0000_0001, + 0b0110_0111, + 0b0010_0000, + 0b0000_0100, + 0b1111_1111.toByte(), + 0b0000_0010, + 0b1100_1111.toByte(), + 0b0010_0000 + ) shouldBe listOf( + Stream( + 0, + listOf( + SpatialLayer( + 0, + listOf(77, 100, 142), + ResolutionAndFrameRate(320, 180, 32) + ), + SpatialLayer( + 1, + listOf(215, 279, 397), + ResolutionAndFrameRate(640, 360, 32) + ), + SpatialLayer( + 2, + listOf(342, 445, 633), + ResolutionAndFrameRate(1280, 720, 32) + ) + ) + ) + ) + } + context("VP9 SVC (without resolutions or TLs)") { + parse( + 0b0000_0111, + 0b0000_0000, + 0b1001_0110.toByte(), + 0b0000_0001, + 0b1111_0100.toByte(), + 0b0000_0011, + 0b1010_1010.toByte(), + 0b0000_1011 + ) shouldBe listOf( + Stream( + 0, + listOf( + SpatialLayer( + 0, + listOf(150), + null + ), + SpatialLayer( + 1, + listOf(500), + null + ), + SpatialLayer( + 2, + listOf(1450), + null + ) + ) + ) + ) + } + context("Invalid VLAs") { + context("Resolution incomplete") { + // Cuts short with not enough bytes to contain the resolution. We succeed with no resolution. + parse( + 0b0000_0001, + 0b1000_0000.toByte(), // #tls + 0b0101_0000, // targetBitrate 1 + 0b0111_1000, // targetBitrate 2 + 0b1100_1000.toByte(), // targetBitrate 3 + 0b0000_0001, // targetBitrate 3 + 0b0000_0001, // width + 0b0011_1111, // width + 0b0000_0000, // height + 0b1011_0011.toByte(), // height + // 0b0010_0001 // maxFramerate + ) shouldBe listOf( + Stream( + 0, + listOf( + SpatialLayer( + 0, + listOf(80, 120, 200), + null + ) + ) + ) + ) + } + context("Invalid leb128") { + // Cuts short in the middle of one of the leb128 encoding of one of the target bitrates. + shouldThrow { + parse( + 0b0000_0001, + 0b1000_0000.toByte(), // #tls + 0b0101_0000, // targetBitrate 1 + 0b0111_1000, // targetBitrate 2 + 0b1100_1000.toByte(), // targetBitrate 3 + // 0b0000_0001, // targetBitrate 3 + ) + } + } + context("Missing target bitrates") { + // Does not contain the expected number of target bitrates + shouldThrow { + parse( + 0b0000_0001, // 1 spatial layer + 0b1000_0000.toByte(), // 3 temporal layers + 0b0101_0000, // targetBitrate 1 + 0b1100_1000.toByte(), // targetBitrate 2 + 0b0000_0001, // targetBitrate 2 + // 0b0111_1000, // targetBitrate 3 + ) + } + } + context("Missing temporal layer counts") { + // Does not contain the expected temporal layer counts + shouldThrow { + parse( + 0b0001_0000, // spatial layer bitmask in the next byte, 2 streams + 0b1111_1111.toByte(), // 4 spatial layers for each stream + 0b0000_0000.toByte(), // 1 temporal layer for each spatial layer of stream 1, more must follow + ) + } + } + } + } +} + +private fun parse(vararg bytes: Byte): ParsedVla = VlaExtension.parse(RawHeaderExtension(bytes)) + +@SuppressFBWarnings("CN_IMPLEMENTS_CLONE_BUT_NOT_CLONEABLE") +class RawHeaderExtension(override val buffer: ByteArray) : HeaderExtension { + override val dataOffset: Int = 0 + override var id: Int = 0 + override val dataLengthBytes: Int = buffer.size + override val totalLengthBytes: Int = buffer.size +} diff --git a/rtp/src/test/kotlin/org/jitsi/rtp/util/BitReaderTest.kt b/rtp/src/test/kotlin/org/jitsi/rtp/util/BitReaderTest.kt new file mode 100644 index 0000000000..2be8e5100d --- /dev/null +++ b/rtp/src/test/kotlin/org/jitsi/rtp/util/BitReaderTest.kt @@ -0,0 +1,132 @@ +/* + * Copyright @ 2024 - present 8x8, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jitsi.rtp.util + +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.core.spec.IsolationMode +import io.kotest.core.spec.style.ShouldSpec +import io.kotest.matchers.shouldBe + +class BitReaderTest : ShouldSpec() { + override fun isolationMode(): IsolationMode = IsolationMode.InstancePerLeaf + + init { + val bitReader = BitReader( + byteArrayOf( + 0b0101_1010.toByte(), + 0b0000_1111.toByte(), + 0b1010_1010.toByte(), + 0b1010_1010.toByte(), + 0b0111_1111.toByte() + ) + ) + val bits = 40 + context("Reading single bits as boolean") { + bitReader.bitAsBoolean() shouldBe false + bitReader.bitAsBoolean() shouldBe true + bitReader.bitAsBoolean() shouldBe false + bitReader.bitAsBoolean() shouldBe true + bitReader.bitAsBoolean() shouldBe true + bitReader.remainingBits() shouldBe bits - 5 + } + + context("Reading single bits as Integer") { + bitReader.bit() shouldBe 0 + bitReader.bit() shouldBe 1 + bitReader.bit() shouldBe 0 + bitReader.bit() shouldBe 1 + bitReader.bit() shouldBe 1 + bitReader.remainingBits() shouldBe bits - 5 + } + + context("Reading multiple bits as Integer") { + bitReader.bits(4) shouldBe 0b0101 + bitReader.bits(4) shouldBe 0b1010 + bitReader.bits(6) shouldBe 0b000011 + bitReader.bits(2) shouldBe 0b11 + bitReader.remainingBits() shouldBe bits - 16 + } + + context("Reading multiple bits as Long") { + bitReader.bitsLong(4) shouldBe 0b0101L + bitReader.bitsLong(4) shouldBe 0b1010L + bitReader.bits(6) shouldBe 0b000011L + bitReader.bits(2) shouldBe 0b11L + bitReader.remainingBits() shouldBe bits - 16 + } + + context("skip bits correctly") { + bitReader.skipBits(4) + bitReader.remainingBits() shouldBe bits - 4 + bitReader.bit() shouldBe 1 + bitReader.remainingBits() shouldBe bits - 5 + } + + context("Reading LEB128-encoded unsigned integers") { + bitReader.leb128() shouldBe 0b0101_1010 + bitReader.remainingBits() shouldBe bits - 8 + bitReader.leb128() shouldBe 0b00001111 + bitReader.remainingBits() shouldBe bits - 16 + bitReader.leb128() shouldBe 0b111_1111__010_1010__010_1010 + bitReader.remainingBits() shouldBe 0 + } + + context("Cloning") { + bitReader.skipBits(8) + val clone = bitReader.clone(2) + bitReader.remainingBits() shouldBe bits - 8 + clone.remainingBits() shouldBe 16 + + bitReader.bits(8) shouldBe 0b0000_1111 + bitReader.remainingBits() shouldBe bits - 16 + clone.remainingBits() shouldBe 16 + + clone.bits(8) shouldBe 0b0000_1111 + bitReader.remainingBits() shouldBe bits - 16 + clone.remainingBits() shouldBe 8 + + clone.skipBits(8) + bitReader.remainingBits() shouldBe bits - 16 + clone.remainingBits() shouldBe 0 + + bitReader.bits(8) shouldBe 0b1010_1010 + shouldThrow { + clone.bits(8) + } + } + + context("Throwing exception when reading past the bounds") { + bitReader.skipBits(bits) + bitReader.remainingBits() shouldBe 0 + shouldThrow { + bitReader.bitAsBoolean() + } + shouldThrow { + bitReader.bit() + } + shouldThrow { + bitReader.bits(4) + } + shouldThrow { + bitReader.leb128() + } + shouldThrow { + BitReader(byteArrayOf(0b1111_0000.toByte())).leb128() + } + } + } +}