Skip to content

Commit

Permalink
Refactor YamuxHandler.SendBuffer (#328)
Browse files Browse the repository at this point in the history
* Introduce ByteBufQueue
* Add ByteBufQueue tests
* Writing exec path is always through fill/drain buffer
* Adopt/fix existing tests
* Add new test checking correct handling of negative sendWindowSize
  • Loading branch information
Nashatyrev authored Oct 10, 2023
1 parent 26efe02 commit 7dc2fa2
Show file tree
Hide file tree
Showing 4 changed files with 328 additions and 87 deletions.
44 changes: 44 additions & 0 deletions libp2p/src/main/kotlin/io/libp2p/etc/util/netty/ByteBufQueue.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package io.libp2p.etc.util.netty

import io.netty.buffer.ByteBuf
import io.netty.buffer.Unpooled

class ByteBufQueue {
private val data: MutableList<ByteBuf> = mutableListOf()

fun push(buf: ByteBuf) {
data += buf
}

fun take(maxLength: Int): ByteBuf {
val wholeBuffers = mutableListOf<ByteBuf>()
var size = 0
while (data.isNotEmpty()) {
val bufLen = data.first().readableBytes()
if (size + bufLen > maxLength) break
size += bufLen
wholeBuffers += data.removeFirst()
if (size == maxLength) break
}

val partialBufferSlice =
when {
data.isEmpty() -> null
size == maxLength -> null
else -> data.first()
}
?.let { buf ->
val remainingBytes = maxLength - size
buf.readRetainedSlice(remainingBytes)
}

val allBuffers = wholeBuffers + listOfNotNull(partialBufferSlice)
return Unpooled.wrappedBuffer(*allBuffers.toTypedArray())
}

fun dispose() {
data.forEach { it.release() }
}

fun readableBytes(): Int = data.sumOf { it.readableBytes() }
}
100 changes: 34 additions & 66 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import io.libp2p.core.StreamHandler
import io.libp2p.core.multistream.MultistreamProtocol
import io.libp2p.core.mux.StreamMuxer
import io.libp2p.etc.types.sliceMaxSize
import io.libp2p.etc.util.netty.ByteBufQueue
import io.libp2p.etc.util.netty.mux.MuxChannel
import io.libp2p.etc.util.netty.mux.MuxId
import io.libp2p.mux.InvalidFrameMuxerException
Expand All @@ -17,6 +18,7 @@ import io.netty.channel.ChannelHandlerContext
import java.util.concurrent.CompletableFuture
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger
import kotlin.math.max

const val INITIAL_WINDOW_SIZE = 256 * 1024
const val DEFAULT_MAX_BUFFERED_CONNECTION_WRITES = 10 * 1024 * 1024 // 10 MiB
Expand All @@ -36,15 +38,21 @@ open class YamuxHandler(
) {
val sendWindowSize = AtomicInteger(initialWindowSize)
val receiveWindowSize = AtomicInteger(initialWindowSize)
val sendBuffer = SendBuffer(id)
val sendBuffer = ByteBufQueue()

fun dispose() {
sendBuffer.dispose()
}

fun handleDataRead(msg: YamuxFrame) {
fun handleFrameRead(msg: YamuxFrame) {
handleFlags(msg)
when (msg.type) {
YamuxType.DATA -> handleDataRead(msg)
YamuxType.WINDOW_UPDATE -> handleWindowUpdate(msg)
}
}

private fun handleDataRead(msg: YamuxFrame) {
val size = msg.length.toInt()
if (size == 0) {
return
Expand All @@ -60,17 +68,11 @@ open class YamuxHandler(
childRead(msg.id, msg.data!!)
}

fun handleWindowUpdate(msg: YamuxFrame) {
handleFlags(msg)

private fun handleWindowUpdate(msg: YamuxFrame) {
val delta = msg.length.toInt()
if (delta == 0) {
return
}

sendWindowSize.addAndGet(delta)
// try to send any buffered messages after the window update
sendBuffer.flush(sendWindowSize)
drainBuffer()
}

private fun handleFlags(msg: YamuxFrame) {
Expand All @@ -85,33 +87,33 @@ open class YamuxHandler(
}
}

fun sendData(
data: ByteBuf
) {
data.sliceMaxSize(maxFrameDataLength)
.forEach { slicedData ->
if (sendWindowSize.get() > 0 && sendBuffer.isEmpty()) {
val length = slicedData.readableBytes()
sendWindowSize.addAndGet(-length)
writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), slicedData))
} else {
// wait until the window is increased to send
addToSendBuffer(data)
}
}
}

private fun addToSendBuffer(data: ByteBuf) {
sendBuffer.add(data)
private fun fillBuffer(data: ByteBuf) {
sendBuffer.push(data)
val totalBufferedWrites = calculateTotalBufferedWrites()
if (totalBufferedWrites > maxBufferedConnectionWrites) {
if (totalBufferedWrites > maxBufferedConnectionWrites + sendWindowSize.get()) {
onLocalClose()
throw WriteBufferOverflowMuxerException(
"Overflowed send buffer ($totalBufferedWrites/$maxBufferedConnectionWrites). Last stream attempting to write: $id"
)
}
}

private fun drainBuffer() {
val maxSendLength = max(0, sendWindowSize.get())
val data = sendBuffer.take(maxSendLength)
sendWindowSize.addAndGet(-data.readableBytes())
data.sliceMaxSize(maxFrameDataLength)
.forEach { slicedData ->
val length = slicedData.readableBytes()
writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), slicedData))
}
}

fun sendData(data: ByteBuf) {
fillBuffer(data)
drainBuffer()
}

fun onLocalOpen() {
writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.SYN, 0))
}
Expand All @@ -122,7 +124,7 @@ open class YamuxHandler(

fun onLocalDisconnect() {
// TODO: this implementation drops remaining data
sendBuffer.flush(sendWindowSize)
drainBuffer()
sendBuffer.dispose()
writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, YamuxFlags.FIN, 0))
}
Expand All @@ -134,36 +136,6 @@ open class YamuxHandler(
}
}

private inner class SendBuffer(val id: MuxId) {
private val bufferedData = ArrayDeque<ByteBuf>()

fun add(data: ByteBuf) {
bufferedData.add(data)
}

fun flush(windowSize: AtomicInteger) {
while (!isEmpty() && windowSize.get() > 0) {
val data = bufferedData.removeFirst()
val length = data.readableBytes()
windowSize.addAndGet(-length)
writeAndFlushFrame(YamuxFrame(id, YamuxType.DATA, 0, length.toLong(), data))
}
}

fun isEmpty(): Boolean {
return bufferedData.isEmpty()
}

fun bufferedBytes(): Int {
return bufferedData.sumOf { it.readableBytes() }
}

fun dispose() {
bufferedData.forEach { releaseMessage(it) }
bufferedData.clear()
}
}

private val idGenerator = YamuxStreamIdGenerator(connectionInitiator)

private val streamHandlers: MutableMap<MuxId, YamuxStreamHandler> = ConcurrentHashMap()
Expand Down Expand Up @@ -206,11 +178,7 @@ open class YamuxHandler(
onRemoteYamuxOpen(msg.id)
}

val streamHandler = getStreamHandlerOrReleaseAndThrow(msg.id, msg.data)
when (msg.type) {
YamuxType.DATA -> streamHandler.handleDataRead(msg)
YamuxType.WINDOW_UPDATE -> streamHandler.handleWindowUpdate(msg)
}
getStreamHandlerOrReleaseAndThrow(msg.id, msg.data).handleFrameRead(msg)
}
}
}
Expand Down Expand Up @@ -263,7 +231,7 @@ open class YamuxHandler(
}

private fun calculateTotalBufferedWrites(): Int {
return streamHandlers.values.sumOf { it.sendBuffer.bufferedBytes() }
return streamHandlers.values.sumOf { it.sendBuffer.readableBytes() }
}

private fun handlePing(msg: YamuxFrame) {
Expand Down
167 changes: 167 additions & 0 deletions libp2p/src/test/kotlin/io/libp2p/etc/util/netty/ByteBufQueueTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package io.libp2p.etc.util.netty

import io.libp2p.tools.readAllBytesAndRelease
import io.netty.buffer.ByteBuf
import io.netty.buffer.Unpooled
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Test

class ByteBufQueueTest {

val queue = ByteBufQueue()

val allocatedBufs = mutableListOf<ByteBuf>()

@AfterEach
fun cleanUpAndCheck() {
allocatedBufs.forEach {
assertThat(it.refCnt()).isEqualTo(1)
}
}

fun allocateBuf(): ByteBuf {
val buf = Unpooled.buffer()
buf.retain() // ref counter to 2 to check that exactly 1 ref remains at the end
allocatedBufs += buf
return buf
}

fun allocateData(data: String): ByteBuf =
allocateBuf().writeBytes(data.toByteArray())

fun ByteBuf.readString() = String(this.readAllBytesAndRelease())

@Test
fun emptyTest() {
assertThat(queue.take(100).readString()).isEqualTo("")
}

@Test
fun zeroTest() {
queue.push(allocateData("abc"))
assertThat(queue.take(0).readString()).isEqualTo("")
assertThat(queue.take(100).readString()).isEqualTo("abc")
}

@Test
fun emptyZeroTest() {
assertThat(queue.take(0).readString()).isEqualTo("")
}

@Test
fun emptyBuffersTest1() {
queue.push(allocateData(""))
assertThat(queue.take(10).readString()).isEqualTo("")
}

@Test
fun emptyBuffersTest2() {
queue.push(allocateData(""))
assertThat(queue.take(0).readString()).isEqualTo("")
}

@Test
fun emptyBuffersTest3() {
queue.push(allocateData(""))
queue.push(allocateData("a"))
queue.push(allocateData(""))
assertThat(queue.take(10).readString()).isEqualTo("a")
}

@Test
fun emptyBuffersTest4() {
queue.push(allocateData("a"))
queue.push(allocateData(""))
assertThat(queue.take(10).readString()).isEqualTo("a")
}

@Test
fun emptyBuffersTest5() {
queue.push(allocateData("a"))
queue.push(allocateData(""))
assertThat(queue.take(1).readString()).isEqualTo("a")
assertThat(queue.take(1).readString()).isEqualTo("")
}

@Test
fun emptyBuffersTest6() {
queue.push(allocateData("a"))
queue.push(allocateData(""))
queue.push(allocateData(""))
queue.push(allocateData("b"))
assertThat(queue.take(10).readString()).isEqualTo("ab")
}

@Test
fun pushTake1() {
queue.push(allocateData("abc"))
queue.push(allocateData("def"))

assertThat(queue.take(4).readString()).isEqualTo("abcd")
assertThat(queue.take(1).readString()).isEqualTo("e")
assertThat(queue.take(100).readString()).isEqualTo("f")
assertThat(queue.take(100).readString()).isEqualTo("")
}

@Test
fun pushTake2() {
queue.push(allocateData("abc"))
queue.push(allocateData("def"))

assertThat(queue.take(2).readString()).isEqualTo("ab")
assertThat(queue.take(2).readString()).isEqualTo("cd")
assertThat(queue.take(2).readString()).isEqualTo("ef")
assertThat(queue.take(2).readString()).isEqualTo("")
}

@Test
fun pushTake3() {
queue.push(allocateData("abc"))
queue.push(allocateData("def"))

assertThat(queue.take(1).readString()).isEqualTo("a")
assertThat(queue.take(1).readString()).isEqualTo("b")
assertThat(queue.take(1).readString()).isEqualTo("c")
assertThat(queue.take(1).readString()).isEqualTo("d")
assertThat(queue.take(1).readString()).isEqualTo("e")
assertThat(queue.take(1).readString()).isEqualTo("f")
assertThat(queue.take(1).readString()).isEqualTo("")
}

@Test
fun pushTakePush1() {
queue.push(allocateData("abc"))
assertThat(queue.take(2).readString()).isEqualTo("ab")
queue.push(allocateData("def"))
assertThat(queue.take(2).readString()).isEqualTo("cd")
assertThat(queue.take(100).readString()).isEqualTo("ef")
}

@Test
fun pushTakePush2() {
queue.push(allocateData("abc"))
assertThat(queue.take(3).readString()).isEqualTo("abc")
queue.push(allocateData("def"))
assertThat(queue.take(2).readString()).isEqualTo("de")
assertThat(queue.take(100).readString()).isEqualTo("f")
}

@Test
fun pushTakePush3() {
queue.push(allocateData("abc"))
queue.push(allocateData("def"))
assertThat(queue.take(1).readString()).isEqualTo("a")
queue.push(allocateData("ghi"))
assertThat(queue.take(100).readString()).isEqualTo("bcdefghi")
}

@Test
fun pushTakePush4() {
queue.push(allocateData("abc"))
assertThat(queue.take(1).readString()).isEqualTo("a")
queue.push(allocateData("def"))
queue.push(allocateData("ghi"))
assertThat(queue.take(100).readString()).isEqualTo("bcdefghi")
}
}
Loading

0 comments on commit 7dc2fa2

Please sign in to comment.