Skip to content

Commit

Permalink
Pass PeerId along with message when subscribed to a pubsub topic
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanBratanov committed Mar 5, 2024
1 parent 562ce10 commit d9e13b4
Show file tree
Hide file tree
Showing 11 changed files with 67 additions and 51 deletions.
12 changes: 6 additions & 6 deletions libp2p/src/main/kotlin/io/libp2p/core/pubsub/PubsubApi.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import io.libp2p.pubsub.PubsubRouter
import io.netty.buffer.ByteBuf
import java.util.concurrent.CompletableFuture
import java.util.concurrent.atomic.AtomicLong
import java.util.function.Consumer
import java.util.function.Function
import java.util.function.BiConsumer
import java.util.function.BiFunction
import kotlin.random.Random.Default.nextLong

fun createPubsubApi(router: PubsubRouter): PubsubApi =
Expand All @@ -36,8 +36,8 @@ enum class ValidationResult {
Ignore
}

typealias Subscriber = Consumer<MessageApi>
typealias Validator = Function<MessageApi, CompletableFuture<ValidationResult>>
typealias Subscriber = BiConsumer<PeerId, MessageApi>
typealias Validator = BiFunction<PeerId, MessageApi, CompletableFuture<ValidationResult>>

val RESULT_VALID = CompletableFuture.completedFuture(ValidationResult.Valid)
val RESULT_INVALID = CompletableFuture.completedFuture(ValidationResult.Invalid)
Expand Down Expand Up @@ -76,8 +76,8 @@ interface PubsubSubscriberApi {
*/
fun subscribe(receiver: Subscriber, vararg topics: Topic): PubsubSubscription {
return subscribe(
Validator {
receiver.accept(it)
Validator { peerId, messageApi ->
receiver.accept(peerId, messageApi)
RESULT_VALID
},
*topics
Expand Down
22 changes: 14 additions & 8 deletions libp2p/src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,16 @@ import io.netty.handler.codec.protobuf.ProtobufEncoder
import io.netty.handler.codec.protobuf.ProtobufVarint32LengthFieldPrepender
import org.slf4j.LoggerFactory
import pubsub.pb.Rpc
import java.util.*
import java.util.Collections.singletonList
import java.util.Optional
import java.util.concurrent.CompletableFuture
import java.util.concurrent.ScheduledExecutorService
import java.util.function.BiConsumer
import java.util.function.Consumer

// 1 MB default max message size
const val DEFAULT_MAX_PUBSUB_MESSAGE_SIZE = 1 shl 20

typealias PubsubMessageHandler = (PubsubMessage) -> CompletableFuture<ValidationResult>
typealias PubsubMessageHandler = (PeerId, PubsubMessage) -> CompletableFuture<ValidationResult>

open class DefaultPubsubMessage(override val protobufMessage: Rpc.Message) : AbstractPubsubMessage() {
override val messageId: MessageId = protobufMessage.from.toWBytes() + protobufMessage.seqno.toWBytes()
Expand All @@ -44,7 +43,8 @@ abstract class AbstractRouter(
protected val messageValidator: PubsubRouterMessageValidator
) : P2PServiceSemiDuplex(executor), PubsubRouter, PubsubRouterDebug {

protected var msgHandler: PubsubMessageHandler = { throw IllegalStateException("Message handler is not initialized for PubsubRouter") }
protected var msgHandler: PubsubMessageHandler =
{ _, _ -> throw IllegalStateException("Message handler is not initialized for PubsubRouter") }

protected open val peersTopics = mutableMultiBiMap<PeerHandler, Topic>()
protected open val subscribedTopics = linkedSetOf<Topic>()
Expand Down Expand Up @@ -153,7 +153,13 @@ abstract class AbstractRouter(
protected open fun notifyMalformedMessage(peer: PeerHandler) {}
protected open fun notifyUnseenMessage(peer: PeerHandler, msg: PubsubMessage) {}
protected open fun notifyNonSubscribedMessage(peer: PeerHandler, msg: Rpc.Message) {}
protected open fun notifySeenMessage(peer: PeerHandler, msg: PubsubMessage, validationResult: Optional<ValidationResult>) {}
protected open fun notifySeenMessage(
peer: PeerHandler,
msg: PubsubMessage,
validationResult: Optional<ValidationResult>
) {
}

protected open fun notifyUnseenInvalidMessage(peer: PeerHandler, msg: PubsubMessage) {}
protected open fun notifyUnseenValidMessage(peer: PeerHandler, msg: PubsubMessage) {}
protected open fun acceptRequestsFrom(peer: PeerHandler) = true
Expand Down Expand Up @@ -216,7 +222,7 @@ abstract class AbstractRouter(
}
}

val validFuts = msgValid.map { it to msgHandler(it) }
val validFuts = msgValid.map { it to msgHandler(peer.peerId, it) }
val doneUndone = validFuts.groupBy { it.second.isDone }
val done = doneUndone.getOrDefault(true, emptyList())
val undone = doneUndone.getOrDefault(false, emptyList())
Expand Down Expand Up @@ -247,7 +253,7 @@ abstract class AbstractRouter(
// broadcast others on completion
undone.forEach {
it.second.whenCompleteAsync(
BiConsumer { res, err ->
{ res, err ->
when {
err != null -> logger.warn("Exception while handling message from peer $peer: ${it.first}", err)
res == ValidationResult.Invalid -> logger.debug("Invalid pubsub message from peer $peer: ${it.first}")
Expand Down Expand Up @@ -331,7 +337,7 @@ abstract class AbstractRouter(
return peer.writeAndFlush(msg)
}

override fun initHandler(handler: (PubsubMessage) -> CompletableFuture<ValidationResult>) {
override fun initHandler(handler: (PeerId, PubsubMessage) -> CompletableFuture<ValidationResult>) {
msgHandler = handler
}
}
8 changes: 4 additions & 4 deletions libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubApiImpl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ open class PubsubApiImpl(val router: PubsubRouter) : PubsubApi {
}

init {
router.initHandler { onNewMessage(it) }
router.initHandler { peerId, msg -> onNewMessage(peerId, msg) }
}

val subscriptions: MutableMap<Topic, MutableList<SubscriptionImpl>> = mutableMapOf()
Expand All @@ -74,11 +74,11 @@ open class PubsubApiImpl(val router: PubsubRouter) : PubsubApi {
}
}

private fun onNewMessage(msg: PubsubMessage): CompletableFuture<ValidationResult> {
private fun onNewMessage(peerId: PeerId, msg: PubsubMessage): CompletableFuture<ValidationResult> {
val validationFuts = synchronized(this) {
msg.topics.mapNotNull { subscriptions[Topic(it)] }.flatten().distinct()
}.map {
it.receiver.apply(rpc2Msg(msg))
it.receiver.apply(peerId, rpc2Msg(msg))
}
return validationFuts.thenApplyAll {
if (it.isEmpty()) {
Expand All @@ -97,7 +97,7 @@ open class PubsubApiImpl(val router: PubsubRouter) : PubsubApi {

synchronized(this) {
for (topic in topics) {
val list = subscriptions.getOrPut(topic, { mutableListOf() })
val list = subscriptions.getOrPut(topic) { mutableListOf() }
if (list.isEmpty()) {
routerToSubscribe += topic.topic
}
Expand Down
5 changes: 3 additions & 2 deletions libp2p/src/main/kotlin/io/libp2p/pubsub/PubsubRouter.kt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ interface PubsubMessage {
}

abstract class AbstractPubsubMessage : PubsubMessage {
@Volatile private var sha256: ByteArray? = null
@Volatile
private var sha256: ByteArray? = null

override fun messageSha256(): ByteArray {
val cached = sha256
Expand Down Expand Up @@ -77,7 +78,7 @@ interface PubsubMessageRouter {
* All the messages received by the router are forwarded to the [handler] independently
* of any client subscriptions. Is it up to the client API to sort out subscriptions
*/
fun initHandler(handler: (PubsubMessage) -> CompletableFuture<ValidationResult>)
fun initHandler(handler: (PeerId, PubsubMessage) -> CompletableFuture<ValidationResult>)

/**
* Notifies the router that a client wants to receive messages on the following topics
Expand Down
4 changes: 2 additions & 2 deletions libp2p/src/test/java/io/libp2p/pubsub/GossipApiTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ public void testFastMessageId() throws Exception {

BlockingQueue<PubsubMessage> messages = new LinkedBlockingQueue<>();
router.initHandler(
m -> {
messages.add(m);
(__, msg) -> {
messages.add(msg);
return CompletableFuture.completedFuture(ValidationResult.Valid);
});

Expand Down
8 changes: 4 additions & 4 deletions libp2p/src/test/kotlin/io/libp2p/pubsub/PubsubApiTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class PubsubApiTest {
router1.connectSemiDuplex(router2)

val receivedMessages2 = LinkedBlockingQueue<MessageApi>()
api1.subscribe(Subscriber { println(it) }, Topic("myTopic"))
api2.subscribe(Subscriber { receivedMessages2 += it }, Topic("myTopic"))
api1.subscribe(Subscriber { _, msg -> println(msg) }, Topic("myTopic"))
api2.subscribe(Subscriber { _, msg -> receivedMessages2 += msg }, Topic("myTopic"))

fuzz.timeController.addTime(Duration.ofSeconds(10))

Expand Down Expand Up @@ -61,7 +61,7 @@ class PubsubApiTest {

router1.connectSemiDuplex(router2)

api1.subscribe(Subscriber { println(it) }, Topic("myTopic"))
api1.subscribe(Subscriber { _, msg -> println(msg) }, Topic("myTopic"))
router2.router.subscribe("myTopic")

fuzz.timeController.addTime(Duration.ofSeconds(10))
Expand All @@ -88,7 +88,7 @@ class PubsubApiTest {

router1.connectSemiDuplex(router2)

api1.subscribe(Subscriber { println(it) }, Topic("myTopic"))
api1.subscribe(Subscriber { _, msg -> println(msg) }, Topic("myTopic"))
router2.router.subscribe("myTopic")

fuzz.timeController.addTime(Duration.ofSeconds(10))
Expand Down
17 changes: 11 additions & 6 deletions libp2p/src/test/kotlin/io/libp2p/pubsub/PubsubRouterTest.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.libp2p.pubsub

import io.libp2p.core.PeerId
import io.libp2p.core.pubsub.*
import io.libp2p.core.pubsub.Topic
import io.libp2p.etc.types.seconds
Expand Down Expand Up @@ -279,7 +280,10 @@ abstract class PubsubRouterTest(val routerFactory: DeterministicFuzzRouterFactor
doTenNeighborsTopology()
}

fun doTenNeighborsTopology(randomSeed: Int = 0, routerFactory: DeterministicFuzzRouterFactory = this.routerFactory) {
fun doTenNeighborsTopology(
randomSeed: Int = 0,
routerFactory: DeterministicFuzzRouterFactory = this.routerFactory
) {
val fuzz = DeterministicFuzz().also {
it.randomSeed = randomSeed.toLong()
}
Expand Down Expand Up @@ -398,9 +402,10 @@ abstract class PubsubRouterTest(val routerFactory: DeterministicFuzzRouterFactor
routers[1].connectSemiDuplex(routers[2], pubsubLogs = LogLevel.ERROR)

val apis = routers.map { createPubsubApi(it.router) }

class RecordingSubscriber : Subscriber {
var count = 0
override fun accept(t: MessageApi) {
override fun accept(p: PeerId, t: MessageApi) {
count++
}
}
Expand All @@ -420,10 +425,10 @@ abstract class PubsubRouterTest(val routerFactory: DeterministicFuzzRouterFactor
scheduler.schedule({ it.complete(result) }, delayMs, TimeUnit.MILLISECONDS)
}
}
apis[1].subscribe(Validator { RESULT_VALID }, topics[0])
apis[1].subscribe(Validator { RESULT_INVALID }, topics[1])
apis[1].subscribe(Validator { delayed(ValidationResult.Valid, 500) }, topics[2])
apis[1].subscribe(Validator { delayed(ValidationResult.Invalid, 500) }, topics[3])
apis[1].subscribe(Validator { _, _ -> RESULT_VALID }, topics[0])
apis[1].subscribe(Validator { _, _ -> RESULT_INVALID }, topics[1])
apis[1].subscribe(Validator { _, _ -> delayed(ValidationResult.Valid, 500) }, topics[2])
apis[1].subscribe(Validator { _, _ -> delayed(ValidationResult.Invalid, 500) }, topics[3])

// 2 heartbeats for all
fuzz.timeController.addTime(Duration.ofSeconds(2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class GossipTwoHostTest : TwoGossipHostTestBase() {
val topic = Topic("topic")

val messages = mutableListOf<MessageApi>()
gossip2.subscribe(Subscriber { messages += it }, topic)
gossip2.subscribe(Subscriber { _, msg -> messages += msg }, topic)

waitForSubscribed(router1, topic.topic)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class GossipV1_1Tests {
.setSeqno(seqNo.toBytesBigEndian().toProtobuf())
.setData(data.toProtobuf())
.build()

private fun newMessage(topic: Topic, seqNo: Long, data: ByteArray) =
DefaultPubsubMessage(newProtoMessage(topic, seqNo, data))

Expand Down Expand Up @@ -147,7 +148,7 @@ class GossipV1_1Tests {

val api = createPubsubApi(test.gossipRouter)
val apiMessages = mutableListOf<MessageApi>()
api.subscribe(Subscriber { apiMessages += it }, io.libp2p.core.pubsub.Topic("topic2"))
api.subscribe(Subscriber { _, msg -> apiMessages += msg }, io.libp2p.core.pubsub.Topic("topic2"))

val msg1 = Rpc.RPC.newBuilder()
.addPublish(newProtoMessage("topic2", 0L, "Hello-1".toByteArray()))
Expand Down Expand Up @@ -181,12 +182,13 @@ class GossipV1_1Tests {
super.initChannelWithHandler(streamHandler, handler)
}
}

val test = TwoRoutersTest(mockRouterFactory = { exec, _, _ -> MalformedMockRouter(exec) })
val mockRouter = test.router2.router as MalformedMockRouter

val api = createPubsubApi(test.gossipRouter)
val apiMessages = mutableListOf<MessageApi>()
api.subscribe(Subscriber { apiMessages += it }, io.libp2p.core.pubsub.Topic("topic1"))
api.subscribe(Subscriber { _, msg -> apiMessages += msg }, io.libp2p.core.pubsub.Topic("topic1"))

val msg1 = Rpc.RPC.newBuilder()
.addPublish(newProtoMessage("topic1", 0L, "Hello-1".toByteArray()))
Expand Down Expand Up @@ -422,7 +424,7 @@ class GossipV1_1Tests {
fun testAppValidatorScore() {
val test = TwoRoutersTest()
val validator = AtomicReference<CompletableFuture<ValidationResult>>(RESULT_VALID)
test.gossipRouter.initHandler { validator.get() }
test.gossipRouter.initHandler { _, _ -> validator.get() }

test.mockRouter.subscribe("topic1")
test.gossipRouter.subscribe("topic1")
Expand Down Expand Up @@ -860,7 +862,7 @@ class GossipV1_1Tests {
fun testValidatorIgnoreResult() {
val test = ManyRoutersTest(mockRouterCount = 2)
val validator = AtomicReference<CompletableFuture<ValidationResult>>(RESULT_VALID)
test.gossipRouter.initHandler { validator.get() }
test.gossipRouter.initHandler { _, _ -> validator.get() }
test.connectAll()
test.gossipRouter.subscribe("topic1")
test.routers.forEach { it.router.subscribe("topic1") }
Expand Down Expand Up @@ -985,8 +987,8 @@ class GossipV1_1Tests {

val validationResult = CompletableFuture<ValidationResult>()
val receivedMessages = LinkedBlockingQueue<MessageApi>()
val slowValidator = Validator {
receivedMessages += it
val slowValidator = Validator { _, msg ->
receivedMessages += msg
validationResult
}
api.subscribe(slowValidator, io.libp2p.core.pubsub.Topic("topic1"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ class SubscriptionsLimitTest : TwoGossipHostTestBase() {
@Test
fun `new peer subscribed to many topics`() {
val topics = (0..13).map { Topic("topic-$it") }.toTypedArray()
gossip1.subscribe(Subscriber {}, *topics)
gossip1.subscribe(Subscriber { _, _ -> }, *topics)
val messages2 = mutableListOf<MessageApi>()
gossip2.subscribe(Subscriber { messages2 += it }, *topics)
gossip2.subscribe(Subscriber { _, msg -> messages2 += msg }, *topics)

connect()
waitForSubscribed(router1, "topic-13")
Expand All @@ -44,8 +44,8 @@ class SubscriptionsLimitTest : TwoGossipHostTestBase() {
@Test
fun `new peer subscribed to few topics`() {
val topics = (0..4).map { Topic("topic-$it") }.toTypedArray()
gossip1.subscribe(Subscriber { }, *topics)
gossip2.subscribe(Subscriber { }, *topics)
gossip1.subscribe(Subscriber { _, _ -> }, *topics)
gossip2.subscribe(Subscriber { _, _ -> }, *topics)

connect()
waitForSubscribed(router1, "topic-4")
Expand All @@ -59,16 +59,16 @@ class SubscriptionsLimitTest : TwoGossipHostTestBase() {

@Test
fun `existing peer subscribed to many topics`() {
gossip1.subscribe(Subscriber { }, Topic("test-topic"))
gossip2.subscribe(Subscriber { }, Topic("test-topic"))
gossip1.subscribe(Subscriber { _, _ -> }, Topic("test-topic"))
gossip2.subscribe(Subscriber { _, _ -> }, Topic("test-topic"))

connect()
waitForSubscribed(router1, "test-topic")
waitForSubscribed(router2, "test-topic")

val topics = (0..13).map { Topic("topic-$it") }.toTypedArray()
gossip1.subscribe(Subscriber { }, *topics)
gossip2.subscribe(Subscriber { }, *topics)
gossip1.subscribe(Subscriber { _, _ -> }, *topics)
gossip2.subscribe(Subscriber { _, _ -> }, *topics)

waitForSubscribed(router1, "topic-13")
waitForSubscribed(router2, "topic-13")
Expand Down
10 changes: 6 additions & 4 deletions libp2p/src/testFixtures/kotlin/io/libp2p/pubsub/TestRouter.kt
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class TestRouter(

val inboundMessages = LinkedBlockingQueue<PubsubMessage>()
var handlerValidationResult = RESULT_VALID
val routerHandler: (PubsubMessage) -> CompletableFuture<ValidationResult> = {
inboundMessages += it
val routerHandler: (PeerId, PubsubMessage) -> CompletableFuture<ValidationResult> = { _, msg ->
inboundMessages += msg
handlerValidationResult
}

Expand Down Expand Up @@ -92,8 +92,10 @@ class TestRouter(
wireLogs: LogLevel? = null,
pubsubLogs: LogLevel? = null
): TestConnection {
val thisChannel = newChannel("[${idCnt.incrementAndGet()}]$name=>${another.name}", another, wireLogs, pubsubLogs, true)
val anotherChannel = another.newChannel("[${idCnt.incrementAndGet()}]${another.name}=>$name", this, wireLogs, pubsubLogs, false)
val thisChannel =
newChannel("[${idCnt.incrementAndGet()}]$name=>${another.name}", another, wireLogs, pubsubLogs, true)
val anotherChannel =
another.newChannel("[${idCnt.incrementAndGet()}]${another.name}=>$name", this, wireLogs, pubsubLogs, false)
listOf(thisChannel, anotherChannel).forEach {
it.attr(PROTOCOL).get().complete(this.protocol)
}
Expand Down

0 comments on commit d9e13b4

Please sign in to comment.