Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Yamux specific unit tests #298

Merged
merged 5 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions libp2p/src/main/kotlin/io/libp2p/etc/util/netty/mux/MuxId.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package io.libp2p.etc.util.netty.mux
import io.netty.channel.ChannelId

data class MuxId(val parentId: ChannelId, val id: Long, val initiator: Boolean) : ChannelId {
override fun asShortText() = "$parentId/$id/$initiator"
override fun asLongText() = asShortText()
override fun asShortText() = "${parentId.asShortText()}/$id/$initiator"
override fun asLongText() = "${parentId.asLongText()}/$id/$initiator"
override fun compareTo(other: ChannelId?) = asShortText().compareTo(other?.asShortText() ?: "")
override fun toString() = asLongText()
}
14 changes: 8 additions & 6 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrame.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@ import io.netty.buffer.Unpooled

/**
* Contains the fields that comprise a yamux frame.
* @param streamId the ID of the stream.
* @param flag the flag value for this frame.
* @param id the ID of the stream.
* @param flags the flags value for this frame.
* @param length the length field for this frame.
* @param data the data segment.
*/
class YamuxFrame(val id: MuxId, val type: Int, val flags: Int, val lenData: Long, val data: ByteBuf? = null) :
class YamuxFrame(val id: MuxId, val type: Int, val flags: Int, val length: Long, val data: ByteBuf? = null) :
DefaultByteBufHolder(data ?: Unpooled.EMPTY_BUFFER) {

override fun toString(): String {
if (data == null)
return "YamuxFrame(id=$id, type=$type, flag=$flags)"
return "YamuxFrame(id=$id, type=$type, flag=$flags, data=${String(data.toByteArray())})"
if (data == null) {
return "YamuxFrame(id=$id, type=$type, flags=$flags, length=$length)"
}
return "YamuxFrame(id=$id, type=$type, flags=$flags, length=$length, data=${String(data.toByteArray())})"
}
}
30 changes: 21 additions & 9 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxFrameCodec.kt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class YamuxFrameCodec(
out.writeByte(msg.type)
out.writeShort(msg.flags)
out.writeInt(msg.id.id.toInt())
out.writeInt(msg.data?.readableBytes() ?: msg.lenData.toInt())
out.writeInt(msg.data?.readableBytes() ?: msg.length.toInt())
out.writeBytes(msg.data ?: Unpooled.EMPTY_BUFFER)
}

Expand All @@ -42,32 +42,44 @@ class YamuxFrameCodec(
*/
override fun decode(ctx: ChannelHandlerContext, msg: ByteBuf, out: MutableList<Any>) {
while (msg.isReadable) {
if (msg.readableBytes() < 12)
if (msg.readableBytes() < 12) {
return
}
val readerIndex = msg.readerIndex()
msg.readByte(); // version always 0
val type = msg.readUnsignedByte()
val flags = msg.readUnsignedShort()
val streamId = msg.readUnsignedInt()
val lenData = msg.readUnsignedInt()
val length = msg.readUnsignedInt()
if (type.toInt() != YamuxType.DATA) {
val yamuxFrame = YamuxFrame(MuxId(ctx.channel().id(), streamId, isInitiator.xor(streamId.mod(2).equals(1)).not()), type.toInt(), flags, lenData)
val yamuxFrame = YamuxFrame(
MuxId(ctx.channel().id(), streamId, isInitiator.xor(streamId.mod(2) == 1).not()),
type.toInt(),
flags,
length
)
out.add(yamuxFrame)
continue
}
if (lenData > maxFrameDataLength) {
if (length > maxFrameDataLength) {
msg.skipBytes(msg.readableBytes())
throw ProtocolViolationException("Yamux frame is too large: $lenData")
throw ProtocolViolationException("Yamux frame is too large: $length")
}
if (msg.readableBytes() < lenData) {
if (msg.readableBytes() < length) {
// not enough data to read the frame content
// will wait for more ...
msg.readerIndex(readerIndex)
return
}
val data = msg.readSlice(lenData.toInt())
val data = msg.readSlice(length.toInt())
data.retain() // MessageToMessageCodec releases original buffer, but it needs to be relayed
val yamuxFrame = YamuxFrame(MuxId(ctx.channel().id(), streamId, isInitiator.xor(streamId.mod(2).equals(1)).not()), type.toInt(), flags, lenData, data)
val yamuxFrame = YamuxFrame(
MuxId(ctx.channel().id(), streamId, isInitiator.xor(streamId.mod(2) == 1).not()),
type.toInt(),
flags,
length,
data
)
out.add(yamuxFrame)
}
}
Expand Down
59 changes: 37 additions & 22 deletions libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@ import io.libp2p.etc.util.netty.mux.MuxId
import io.libp2p.mux.MuxHandler
import io.netty.buffer.ByteBuf
import io.netty.channel.ChannelHandlerContext
import org.slf4j.LoggerFactory
import java.util.concurrent.CompletableFuture
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger

const val INITIAL_WINDOW_SIZE = 256 * 1024
const val MAX_BUFFERED_CONNECTION_WRITES = 1024 * 1024

private val log = LoggerFactory.getLogger(YamuxHandler::class.java)

open class YamuxHandler(
override val multistreamProtocol: MultistreamProtocol,
override val maxFrameDataLength: Int,
Expand All @@ -39,7 +42,7 @@ open class YamuxHandler(

fun flush(sendWindow: AtomicInteger, id: MuxId): Int {
var written = 0
while (! buffered.isEmpty()) {
while (!buffered.isEmpty()) {
val buf = buffered.first()
val readableBytes = buf.readableBytes()
if (readableBytes + written < sendWindow.get()) {
Expand All @@ -65,38 +68,48 @@ open class YamuxHandler(
YamuxType.DATA -> handleDataRead(msg)
YamuxType.WINDOW_UPDATE -> handleWindowUpdate(msg)
YamuxType.PING -> handlePing(msg)
YamuxType.GO_AWAY -> onRemoteClose(msg.id)
YamuxType.GO_AWAY -> handleGoAway(msg)
}
}

fun handlePing(msg: YamuxFrame) {
private fun handlePing(msg: YamuxFrame) {
val ctx = getChannelHandlerContext()
when (msg.flags) {
YamuxFlags.SYN -> ctx.writeAndFlush(YamuxFrame(MuxId(msg.id.parentId, 0, msg.id.initiator), YamuxType.PING, YamuxFlags.ACK, msg.lenData))
YamuxFlags.SYN -> ctx.writeAndFlush(
YamuxFrame(
MuxId(msg.id.parentId, 0, msg.id.initiator),
YamuxType.PING,
YamuxFlags.ACK,
msg.length
)
)

YamuxFlags.ACK -> {}
}
}

fun handleFlags(msg: YamuxFrame) {
private fun handleFlags(msg: YamuxFrame) {
val ctx = getChannelHandlerContext()
when (msg.flags) {
YamuxFlags.SYN -> {
// ACK the new stream
onRemoteYamuxOpen(msg.id)
ctx.writeAndFlush(YamuxFrame(msg.id, YamuxType.WINDOW_UPDATE, YamuxFlags.ACK, 0))
}

YamuxFlags.FIN -> onRemoteDisconnect(msg.id)
YamuxFlags.RST -> onRemoteClose(msg.id)
}
}

fun handleDataRead(msg: YamuxFrame) {
private fun handleDataRead(msg: YamuxFrame) {
val ctx = getChannelHandlerContext()
val size = msg.lenData
val size = msg.length
handleFlags(msg)
if (size.toInt() == 0)
if (size.toInt() == 0) {
return
val recWindow = receiveWindows.get(msg.id)
}
val recWindow = receiveWindows[msg.id]
if (recWindow == null) {
releaseMessage(msg.data!!)
throw Libp2pException("No receive window for " + msg.id)
Expand All @@ -111,36 +124,38 @@ open class YamuxHandler(
childRead(msg.id, msg.data!!)
}

fun handleWindowUpdate(msg: YamuxFrame) {
private fun handleWindowUpdate(msg: YamuxFrame) {
handleFlags(msg)
val size = msg.lenData.toInt()
if (size == 0)
return
val sendWindow = sendWindows.get(msg.id)
if (sendWindow == null) {
val size = msg.length.toInt()
if (size == 0) {
return
}
val sendWindow = sendWindows[msg.id] ?: return
sendWindow.addAndGet(size)
val buffer = sendBuffers.get(msg.id)
val buffer = sendBuffers[msg.id]
if (buffer != null) {
val writtenBytes = buffer.flush(sendWindow, msg.id)
totalBufferedWrites.addAndGet(-writtenBytes)
}
}

private fun handleGoAway(msg: YamuxFrame) {
log.debug("Session will be terminated. Go Away message with with error code ${msg.length} has been received.")
onRemoteClose(msg.id)
}

override fun onChildWrite(child: MuxChannel<ByteBuf>, data: ByteBuf) {
val ctx = getChannelHandlerContext()

val sendWindow = sendWindows.get(child.id)
if (sendWindow == null) {
throw Libp2pException("No send window for " + child.id)
}
val sendWindow = sendWindows[child.id] ?: throw Libp2pException("No send window for " + child.id)

if (sendWindow.get() <= 0) {
// wait until the window is increased to send more data
val buffer = sendBuffers.getOrPut(child.id, { SendBuffer(ctx) })
val buffer = sendBuffers.getOrPut(child.id) { SendBuffer(ctx) }
buffer.add(data)
if (totalBufferedWrites.addAndGet(data.readableBytes()) > MAX_BUFFERED_CONNECTION_WRITES)
if (totalBufferedWrites.addAndGet(data.readableBytes()) > MAX_BUFFERED_CONNECTION_WRITES) {
throw Libp2pException("Overflowed send buffer for connection")
}
return
}
sendBlocks(ctx, data, sendWindow, child.id)
Expand Down
2 changes: 1 addition & 1 deletion libp2p/src/main/kotlin/io/libp2p/mux/yamux/YamuxType.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.libp2p.mux.yamux

/**
* Contains all the permissible values for flags in the <code>yamux</code> protocol.
* Contains all the permissible values for types in the <code>yamux</code> protocol.
*/
object YamuxType {
const val DATA = 0
Expand Down
30 changes: 15 additions & 15 deletions libp2p/src/test/kotlin/io/libp2p/mux/MuxHandlerAbstractTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import io.libp2p.core.StreamHandler
import io.libp2p.etc.types.fromHex
import io.libp2p.etc.types.getX
import io.libp2p.etc.types.toHex
import io.libp2p.etc.util.netty.mux.MuxId
import io.libp2p.etc.util.netty.mux.RemoteWriteClosed
import io.libp2p.etc.util.netty.nettyInitializer
import io.libp2p.mux.MuxHandlerAbstractTest.AbstractTestMuxFrame.Flag.*
import io.libp2p.mux.MuxHandlerAbstractTest.TestEventHandler
import io.libp2p.tools.TestChannel
import io.libp2p.tools.readAllBytesAndRelease
import io.netty.buffer.ByteBuf
Expand All @@ -20,10 +22,7 @@ import io.netty.handler.logging.LoggingHandler
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.data.Index
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertFalse
import org.junit.jupiter.api.Assertions.assertThrows
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Assertions.*
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import java.util.concurrent.CompletableFuture
Expand Down Expand Up @@ -95,10 +94,11 @@ abstract class MuxHandlerAbstractTest {
enum class Flag { Open, Data, Close, Reset }
}

fun Long.toMuxId() = MuxId(parentChannelId, this, true)

abstract fun writeFrame(frame: AbstractTestMuxFrame)
abstract fun readFrame(): AbstractTestMuxFrame?
fun readFrameOrThrow() = readFrame() ?: throw AssertionError("No outbound frames")

fun openStream(id: Long) = writeFrame(AbstractTestMuxFrame(id, Open))
fun writeStream(id: Long, msg: String) = writeFrame(AbstractTestMuxFrame(id, Data, msg))
fun closeStream(id: Long) = writeFrame(AbstractTestMuxFrame(id, Close))
Expand Down Expand Up @@ -478,66 +478,66 @@ abstract class MuxHandlerAbstractTest {
override fun handlerAdded(ctx: ChannelHandlerContext) {
assertFalse(isHandlerAdded)
isHandlerAdded = true
println("MultiplexHandlerTest.handlerAdded")
println("MuxHandlerAbstractTest.handlerAdded")
this.ctx = ctx
}

override fun channelRegistered(ctx: ChannelHandlerContext?) {
assertTrue(isHandlerAdded)
assertFalse(isRegistered)
isRegistered = true
println("MultiplexHandlerTest.channelRegistered")
println("MuxHandlerAbstractTest.channelRegistered")
}

override fun channelActive(ctx: ChannelHandlerContext) {
assertTrue(isRegistered)
assertFalse(isActivated)
isActivated = true
println("MultiplexHandlerTest.channelActive")
println("MuxHandlerAbstractTest.channelActive")
activeEventHandlers.forEach { it.handle(this) }
}

override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
assertTrue(isActivated)
println("MultiplexHandlerTest.channelRead")
println("MuxHandlerAbstractTest.channelRead")
msg as ByteBuf
inboundMessages += msg.readAllBytesAndRelease().toHex()
}

override fun channelReadComplete(ctx: ChannelHandlerContext?) {
readCompleteEventCount++
println("MultiplexHandlerTest.channelReadComplete")
println("MuxHandlerAbstractTest.channelReadComplete")
}

override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) {
userEvents += evt
println("MultiplexHandlerTest.userEventTriggered: $evt")
println("MuxHandlerAbstractTest.userEventTriggered: $evt")
}

override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
exceptions += cause
println("MultiplexHandlerTest.exceptionCaught")
println("MuxHandlerAbstractTest.exceptionCaught")
}

override fun channelInactive(ctx: ChannelHandlerContext) {
assertTrue(isActivated)
assertFalse(isInactivated)
isInactivated = true
println("MultiplexHandlerTest.channelInactive")
println("MuxHandlerAbstractTest.channelInactive")
}

override fun channelUnregistered(ctx: ChannelHandlerContext?) {
assertTrue(isInactivated)
assertFalse(isUnregistered)
isUnregistered = true
println("MultiplexHandlerTest.channelUnregistered")
println("MuxHandlerAbstractTest.channelUnregistered")
}

override fun handlerRemoved(ctx: ChannelHandlerContext?) {
assertTrue(isUnregistered)
assertFalse(isHandlerRemoved)
isHandlerRemoved = true
println("MultiplexHandlerTest.handlerRemoved")
println("MuxHandlerAbstractTest.handlerRemoved")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import io.libp2p.core.StreamHandler
import io.libp2p.core.multistream.MultistreamProtocolV1
import io.libp2p.etc.types.fromHex
import io.libp2p.etc.types.toHex
import io.libp2p.etc.util.netty.mux.MuxId
import io.libp2p.mux.MuxHandler
import io.libp2p.mux.MuxHandlerAbstractTest
import io.libp2p.mux.MuxHandlerAbstractTest.AbstractTestMuxFrame.Flag.*
Expand All @@ -28,6 +27,7 @@ class MplexHandlerTest : MuxHandlerAbstractTest() {
}

override fun writeFrame(frame: AbstractTestMuxFrame) {
val muxId = frame.streamId.toMuxId()
val mplexFlag = when (frame.flag) {
Open -> MplexFlag.Type.OPEN
Data -> MplexFlag.Type.DATA
Expand All @@ -39,7 +39,7 @@ class MplexHandlerTest : MuxHandlerAbstractTest() {
else -> frame.data.fromHex().toByteBuf(allocateBuf())
}
val mplexFrame =
MplexFrame(MuxId(parentChannelId, frame.streamId, true), MplexFlag.getByType(mplexFlag, true), data)
MplexFrame(muxId, MplexFlag.getByType(mplexFlag, true), data)
ech.writeInbound(mplexFrame)
}

Expand All @@ -51,10 +51,9 @@ class MplexHandlerTest : MuxHandlerAbstractTest() {
MplexFlag.Type.DATA -> Data
MplexFlag.Type.CLOSE -> Close
MplexFlag.Type.RESET -> Reset
else -> throw AssertionError("Unknown mplex flag: ${mplexFrame.flag}")
}
val sData = maybeMplexFrame.data.readAllBytesAndRelease().toHex()
AbstractTestMuxFrame(mplexFrame.id.id, flag, sData)
val data = maybeMplexFrame.data.readAllBytesAndRelease().toHex()
AbstractTestMuxFrame(mplexFrame.id.id, flag, data)
}
}
}
Loading