Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for the VLA RTP header extension #2263

Merged
merged 11 commits into from
Dec 12, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ import org.jitsi.utils.OrderedJsonObject
*
* @author George Politis
*/
abstract class RtpLayerDesc
constructor(
abstract class RtpLayerDesc(
/**
* The index of this instance's encoding in the source encoding array.
*/
Expand All @@ -54,7 +53,7 @@ constructor(
* represents. The actual frame rate may be less due to bad network or
* system load. [NO_FRAME_RATE] for unknown.
*/
val frameRate: Double,
var frameRate: Double,
) {
abstract fun copy(height: Int = this.height, tid: Int = this.tid, inherit: Boolean = true): RtpLayerDesc

Expand All @@ -63,6 +62,8 @@ constructor(
*/
protected var bitrateTracker = BitrateCalculator.createBitrateTracker()

var targetBitrate: Bandwidth? = null

/**
* @return the "id" of this layer within this encoding. This is a server-side id and should
* not be confused with any encoding id defined in the client (such as the
Expand All @@ -87,6 +88,7 @@ constructor(
*/
internal open fun inheritFrom(other: RtpLayerDesc) {
inheritStatistics(other.bitrateTracker)
targetBitrate = other.targetBitrate
}

/**
Expand All @@ -110,12 +112,6 @@ constructor(
*/
abstract fun getBitrate(nowMs: Long): Bandwidth

/**
* Expose [getBitrate] as a [Double] in order to make it accessible from java (since [Bandwidth] is an inline
* class).
*/
fun getBitrateBps(nowMs: Long): Double = getBitrate(nowMs).bps

/**
* Recursively checks this layer and its dependencies to see if the bitrate is zero.
* Note that unlike [calcBitrate] this does not avoid double-visiting layers; the overhead
Expand All @@ -131,6 +127,7 @@ constructor(
addNumber("height", height)
addNumber("index", index)
addNumber("bitrate_bps", getBitrate(System.currentTimeMillis()).bps)
addNumber("target_bitrate", targetBitrate?.bps ?: 0)
}

fun debugState(): OrderedJsonObject = getNodeStats().toJson().apply { put("indexString", indexString()) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import org.jitsi.nlj.transform.node.incoming.VideoBitrateCalculator
import org.jitsi.nlj.transform.node.incoming.VideoMuteNode
import org.jitsi.nlj.transform.node.incoming.VideoParser
import org.jitsi.nlj.transform.node.incoming.VideoQualityLayerLookup
import org.jitsi.nlj.transform.node.incoming.VlaReaderNode
import org.jitsi.nlj.transform.packetPath
import org.jitsi.nlj.transform.pipeline
import org.jitsi.nlj.util.Bandwidth
Expand Down Expand Up @@ -248,6 +249,7 @@ class RtpReceiverImpl @JvmOverloads constructor(
node(videoParser)
node(VideoQualityLayerLookup(logger))
node(videoBitrateCalculator)
node(VlaReaderNode(streamInformationStore, logger))
node(packetHandlerWrapper)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.jitsi.nlj

import org.jitsi.nlj.rtp.LossListener
import org.jitsi.nlj.rtp.RtpExtensionType
import org.jitsi.nlj.rtp.TransportCcEngine
import org.jitsi.nlj.rtp.bandwidthestimation.BandwidthEstimator
import org.jitsi.nlj.srtp.SrtpTransformers
Expand Down Expand Up @@ -47,6 +48,7 @@ abstract class RtpSender :
abstract fun setFeature(feature: Features, enabled: Boolean)
abstract fun isFeatureEnabled(feature: Features): Boolean
abstract fun tearDown()
abstract fun addRtpExtensionToRetain(extensionType: RtpExtensionType)

/**
* An optional function to be executed for each RTP packet, as the first step of the send pipeline.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.jitsi.nlj.rtcp.NackHandler
import org.jitsi.nlj.rtcp.RtcpEventNotifier
import org.jitsi.nlj.rtcp.RtcpSrUpdater
import org.jitsi.nlj.rtp.LossListener
import org.jitsi.nlj.rtp.RtpExtensionType
import org.jitsi.nlj.rtp.TransportCcEngine
import org.jitsi.nlj.rtp.bandwidthestimation.BandwidthEstimator
import org.jitsi.nlj.rtp.bandwidthestimation.GoogleCcEstimator
Expand Down Expand Up @@ -111,6 +112,7 @@ class RtpSenderImpl(
private val srtcpEncryptWrapper = SrtcpEncryptNode()
private val toggleablePcapWriter = ToggleablePcapWriter(logger, "$id-tx")
private val outgoingPacketCache = PacketCacher()
private val headerExtensionStripper = HeaderExtStripper(streamInformationStore)
private val absSendTime = AbsSendTime(streamInformationStore)
private val statsTracker = OutgoingStatisticsTracker()
private val packetStreamStats = PacketStreamStatsNode()
Expand Down Expand Up @@ -144,7 +146,7 @@ class RtpSenderImpl(
outgoingRtpRoot = pipeline {
node(PluggableTransformerNode("RTP pre-processor") { preProcesor })
node(AudioRedHandler(streamInformationStore, logger))
node(HeaderExtStripper(streamInformationStore))
node(headerExtensionStripper)
node(outgoingPacketCache)
node(absSendTime)
node(statsTracker)
Expand Down Expand Up @@ -333,6 +335,10 @@ class RtpSenderImpl(
toggleablePcapWriter.disable()
}

override fun addRtpExtensionToRetain(extensionType: RtpExtensionType) {
headerExtensionStripper.addRtpExtensionToRetain(extensionType)
}

companion object {
var queueErrorCounter = CountingErrorHandler()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.jitsi.nlj
import org.jitsi.nlj.format.PayloadType
import org.jitsi.nlj.rtcp.RtcpEventNotifier
import org.jitsi.nlj.rtp.RtpExtension
import org.jitsi.nlj.rtp.RtpExtensionType
import org.jitsi.nlj.rtp.bandwidthestimation.BandwidthEstimator
import org.jitsi.nlj.srtp.SrtpTransformers
import org.jitsi.nlj.srtp.SrtpUtil
Expand Down Expand Up @@ -211,6 +212,10 @@ class Transceiver(
rtpReceiver.handleEvent(localSsrcSetEvent)
}

fun addRtpExtensionToRetain(extensionType: RtpExtensionType) {
rtpSender.addRtpExtensionToRetain(extensionType)
}

fun receivesSsrc(ssrc: Long): Boolean = streamInformationStore.receiveSsrcs.contains(ssrc)

val receiveSsrcs: Set<Long>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,13 @@ enum class RtpExtensionType(val uri: String) {
*/
AV1_DEPENDENCY_DESCRIPTOR(
"https://aomediacodec.github.io/av1-rtp-spec/#dependency-descriptor-rtp-header-extension"
);
),

/**
* Video Layers Allocation
* https://webrtc.googlesource.com/src/+/refs/heads/main/docs/native-code/rtp-hdrext/video-layers-allocation00
*/
VLA("http://www.webrtc.org/experiments/rtp-hdrext/video-layers-allocation00");

companion object {
private val uriMap = RtpExtensionType.values().associateBy(RtpExtensionType::uri)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright @ 2024-Present 8x8, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jitsi.nlj.transform.node.incoming

import org.jitsi.nlj.Event
import org.jitsi.nlj.MediaSourceDesc
import org.jitsi.nlj.PacketInfo
import org.jitsi.nlj.SetMediaSourcesEvent
import org.jitsi.nlj.findRtpSource
import org.jitsi.nlj.rtp.RtpExtensionType.VLA
import org.jitsi.nlj.transform.node.ObserverNode
import org.jitsi.nlj.util.ReadOnlyStreamInformationStore
import org.jitsi.nlj.util.kbps
import org.jitsi.rtp.rtp.RtpPacket
import org.jitsi.rtp.rtp.header_extensions.VlaExtension
import org.jitsi.utils.logging2.Logger
import org.jitsi.utils.logging2.LoggerImpl
import org.jitsi.utils.logging2.cdebug
import org.jitsi.utils.logging2.createChildLogger

/**
* A node which reads the Video Layers Allocation (VLA) RTP header extension and updates the media sources.
*/
class VlaReaderNode(
streamInformationStore: ReadOnlyStreamInformationStore,
parentLogger: Logger = LoggerImpl(VlaReaderNode::class.simpleName)
) : ObserverNode("Video Layers Allocation reader") {
private val logger = createChildLogger(parentLogger)
private var vlaExtId: Int? = null
private var mediaSourceDescs: Array<MediaSourceDesc> = arrayOf()

init {
streamInformationStore.onRtpExtensionMapping(VLA) {
vlaExtId = it
logger.debug("VLA extension ID set to $it")
}
}

override fun handleEvent(event: Event) {
when (event) {
is SetMediaSourcesEvent -> {
mediaSourceDescs = event.mediaSourceDescs.copyOf()
logger.cdebug { "Media sources changed:\n${mediaSourceDescs.joinToString()}" }
}
}
}

override fun observe(packetInfo: PacketInfo) {
val rtpPacket = packetInfo.packetAs<RtpPacket>()
vlaExtId?.let {
rtpPacket.getHeaderExtension(it)?.let { ext ->
val vla = try {
VlaExtension.parse(ext)
} catch (e: Exception) {
logger.warn("Failed to parse VLA extension", e)
return
}

val sourceDesc = mediaSourceDescs.findRtpSource(rtpPacket)

logger.debug("Found VLA=$vla for sourceDesc=$sourceDesc")

vla.forEachIndexed { streamIdx, stream ->
val rtpEncoding = sourceDesc?.rtpEncodings?.get(streamIdx)
stream.spatialLayers.forEach { spatialLayer ->
spatialLayer.targetBitratesKbps.forEachIndexed { tlIdx, targetBitrateKbps ->
rtpEncoding?.layers?.find {
// With VP8 simulcast all layers have sid -1
(it.sid == spatialLayer.id || it.sid == -1) && it.tid == tlIdx
}?.let { layer ->
logger.debug(
"Setting target bitrate for rtpEncoding=$rtpEncoding layer=$layer to " +
"${targetBitrateKbps.kbps} (res=${spatialLayer.res})"
)
layer.targetBitrate = targetBitrateKbps.kbps
spatialLayer.res?.let { res ->
if (layer.height > 0 && layer.height != res.height) {
logger.warn("Updating layer height from ${layer.height} to ${res.height}")
}
layer.height = res.height
layer.frameRate = res.maxFramerate.toDouble()
}
}
}
}
}
}
}
}

override fun trace(f: () -> Unit) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,40 @@ import org.jitsi.nlj.util.ReadOnlyStreamInformationStore
import org.jitsi.rtp.rtp.RtpPacket

/**
* Strip all hop-by-hop header extensions. Currently this leaves ssrc-audio-level and video-orientation,
* Strip all hop-by-hop header extensions. By default, this leaves ssrc-audio-level and video-orientation,
* plus the AV1 dependency descriptor if the packet is an Av1DDPacket.
*/
class HeaderExtStripper(
streamInformationStore: ReadOnlyStreamInformationStore
streamInformationStore: ReadOnlyStreamInformationStore,
) : ModifierNode("Strip header extensions") {
private var retainedExts: Set<Int> = emptySet()
private var retainedExtsWithAv1DD: Set<Int> = emptySet()
private var retainedExtTypes = defaultRetainedExtTypes

init {
retainedExtTypes.forEach { rtpExtensionType ->
streamInformationStore.onRtpExtensionMapping(rtpExtensionType) {
it?.let {
retainedExts = retainedExts.plus(it)
retainedExtsWithAv1DD = retainedExtsWithAv1DD.plus(it)
retainedExts += it
retainedExtsWithAv1DD += it
}
}
}
streamInformationStore.onRtpExtensionMapping(RtpExtensionType.AV1_DEPENDENCY_DESCRIPTOR) {
it?.let { retainedExtsWithAv1DD = retainedExtsWithAv1DD.plus(it) }
it?.let { retainedExtsWithAv1DD += it }
}
}

fun addRtpExtensionToRetain(extensionType: RtpExtensionType) {
retainedExtTypes += extensionType
}

override fun modify(packetInfo: PacketInfo): PacketInfo {
val rtpPacket = packetInfo.packetAs<RtpPacket>()

val retained = if (rtpPacket is Av1DDPacket) retainedExtsWithAv1DD else retainedExts

// TODO: we should also retain any extensions that were not signaled.
rtpPacket.removeHeaderExtensionsExcept(retained)

return packetInfo
Expand All @@ -59,7 +65,7 @@ class HeaderExtStripper(
override fun trace(f: () -> Unit) = f.invoke()

companion object {
private val retainedExtTypes: Set<RtpExtensionType> = setOf(
val defaultRetainedExtTypes: Set<RtpExtensionType> = setOf(
RtpExtensionType.SSRC_AUDIO_LEVEL,
RtpExtensionType.VIDEO_ORIENTATION
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import org.jitsi.nlj.MediaSourceDesc
import org.jitsi.nlj.PacketInfo
import org.jitsi.nlj.format.PayloadType
import org.jitsi.nlj.format.PayloadTypeEncoding
import org.jitsi.nlj.util.Bandwidth
import org.jitsi.nlj.util.bps
import org.jitsi.rtp.rtcp.RtcpSrPacket
import org.jitsi.utils.event.SyncEventEmitter
Expand Down Expand Up @@ -192,13 +193,26 @@ class BitrateController<T : MediaSourceContainer> @JvmOverloads constructor(

val nowMs = clock.instant().toEpochMilli()
val allocation = bandwidthAllocator.allocation
allocation.allocations.forEach {
it.targetLayer?.getBitrate(nowMs)?.let { targetBitrate ->
totalTargetBitrate += targetBitrate
it.mediaSource?.primarySSRC?.let { primarySsrc -> activeSsrcs.add(primarySsrc) }
allocation.allocations.forEach { singleAllocation ->
val allocationTargetBitrate: Bandwidth? = if (config.useVlaTargetBitrate) {
singleAllocation.targetLayer?.targetBitrate ?: singleAllocation.targetLayer?.getBitrate(nowMs)
} else {
singleAllocation.targetLayer?.getBitrate(nowMs)
}

allocationTargetBitrate?.let {
totalTargetBitrate += it
singleAllocation.mediaSource?.primarySSRC?.let { primarySsrc -> activeSsrcs.add(primarySsrc) }
}
it.idealLayer?.getBitrate(nowMs)?.let { idealBitrate ->
totalIdealBitrate += idealBitrate

val allocationIdealBitrate: Bandwidth? = if (config.useVlaTargetBitrate) {
singleAllocation.idealLayer?.targetBitrate ?: singleAllocation.idealLayer?.getBitrate(nowMs)
} else {
singleAllocation.idealLayer?.getBitrate(nowMs)
}

allocationIdealBitrate?.let {
totalIdealBitrate += it
}
}

Expand All @@ -220,18 +234,24 @@ class BitrateController<T : MediaSourceContainer> @JvmOverloads constructor(

var totalTargetBps = 0.0
var totalIdealBps = 0.0
var totalTargetMeasuredBps = 0.0
var totalIdealMeasuredBps = 0.0

allocation.allocations.forEach {
it.targetLayer?.getBitrate(nowMs)?.let { bitrate -> totalTargetBps += bitrate.bps }
it.idealLayer?.getBitrate(nowMs)?.let { bitrate -> totalIdealBps += bitrate.bps }
it.targetLayer?.targetBitrate?.let { bitrate -> totalTargetMeasuredBps += bitrate.bps }
it.idealLayer?.targetBitrate?.let { bitrate -> totalIdealMeasuredBps += bitrate.bps }
trace(
diagnosticContext
.makeTimeSeriesPoint("allocation_for_source", nowMs)
.addField("remote_endpoint_id", it.endpointId)
.addField("target_idx", it.targetLayer?.index ?: -1)
.addField("ideal_idx", it.idealLayer?.index ?: -1)
.addField("target_bps", it.targetLayer?.getBitrate(nowMs)?.bps ?: -1)
.addField("ideal_bps", it.idealLayer?.getBitrate(nowMs)?.bps ?: -1)
.addField("target_bps_measured", it.targetLayer?.getBitrate(nowMs)?.bps ?: -1)
.addField("target_bps", it.targetLayer?.targetBitrate?.bps ?: -1)
.addField("ideal_bps_measured", it.idealLayer?.getBitrate(nowMs)?.bps ?: -1)
.addField("ideal_bps", it.idealLayer?.targetBitrate?.bps ?: -1)
)
}

Expand All @@ -240,6 +260,8 @@ class BitrateController<T : MediaSourceContainer> @JvmOverloads constructor(
.makeTimeSeriesPoint("allocation", nowMs)
.addField("total_target_bps", totalTargetBps)
.addField("total_ideal_bps", totalIdealBps)
.addField("total_target_measured_bps", totalTargetMeasuredBps)
.addField("total_ideal_measured_bps", totalIdealMeasuredBps)
)
}

Expand Down
Loading
Loading