Skip to content

Commit

Permalink
Fixing rebalance handling (#532)
Browse files Browse the repository at this point in the history
  • Loading branch information
LMnet committed Feb 17, 2021
1 parent 653eb3e commit ce4e00b
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 89 deletions.
8 changes: 4 additions & 4 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,8 @@ lazy val scalaSettings = Seq(
"-Ywarn-dead-code",
"-Ywarn-numeric-widen",
"-Ywarn-value-discard",
"-Ywarn-unused",
"-Xfatal-warnings"
"-Ywarn-unused"
// "-Xfatal-warnings"
)
else if (scalaVersion.value.startsWith("2.12"))
Seq(
Expand All @@ -290,8 +290,8 @@ lazy val scalaSettings = Seq(
"-Ywarn-numeric-widen",
"-Ywarn-value-discard",
"-Ywarn-unused",
"-Ypartial-unification",
"-Xfatal-warnings"
"-Ypartial-unification"
// "-Xfatal-warnings"
)
else
Seq(
Expand Down
87 changes: 47 additions & 40 deletions modules/core/src/main/scala/fs2/kafka/KafkaConsumer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ object KafkaConsumer {

def createPartitionStream(
streamId: StreamId,
partitionStreamId: PartitionStreamId,
partition: TopicPartition
partition: TopicPartition,
assignmentRevoked: F[Unit]
): F[Stream[F, CommittableConsumerRecord[F, K, V]]] =
for {
chunks <- chunkQueue
Expand All @@ -169,14 +169,17 @@ object KafkaConsumer {
awaitTermination.attempt,
dequeueDone.get
),
stopConsumingDeferred.get
F.race(
stopConsumingDeferred.get,
assignmentRevoked
)
)
.void
stopReqs <- Deferred.tryable[F, Unit]
} yield Stream.eval {
def fetchPartition(deferred: Deferred[F, PartitionRequest]): F[Unit] = {
val request =
Request.Fetch(partition, streamId, partitionStreamId, deferred.complete)
Request.Fetch(partition, streamId, deferred.complete)
val fetch = requests.enqueue1(request) >> deferred.get
F.race(shutdown, fetch).flatMap {
case Left(()) =>
Expand Down Expand Up @@ -220,33 +223,18 @@ object KafkaConsumer {

def enqueueAssignment(
streamId: StreamId,
partitionStreamIdRef: Ref[F, PartitionStreamId],
assigned: SortedSet[TopicPartition],
partitionsMapQueue: PartitionsMapQueue
partitionsMapQueue: PartitionsMapQueue,
assignmentRevoked: F[Unit]
): F[Unit] = {
val assignment: F[PartitionsMap] = if (assigned.isEmpty) {
F.pure(Map.empty)
} else {
val indexedAssigned = assigned.toVector.zipWithIndex
partitionStreamIdRef
.modify { id =>
val result = indexedAssigned.map {
case (partition, idx) =>
(partition, idx + id)
}
val (_, lastId) = result.last
(lastId + 1, result)
}
.flatMap { (partitions: Vector[(TopicPartition, PartitionStreamId)]) =>
partitions
.traverse {
case (partition, partitionStreamId) =>
createPartitionStream(streamId, partitionStreamId, partition).map { stream =>
partition -> stream
}
}
.map(_.toMap)
assigned.toVector.traverse { partition =>
createPartitionStream(streamId, partition, assignmentRevoked).map { stream =>
partition -> stream
}
}.map(_.toMap)
}

assignment.flatMap { assignment =>
Expand All @@ -261,24 +249,40 @@ object KafkaConsumer {

def onRebalance(
streamId: StreamId,
partitionStreamIdRef: Ref[F, PartitionStreamId],
prevAssignmentFinisherRef: Ref[F, Deferred[F, Unit]],
partitionsMapQueue: PartitionsMapQueue
): OnRebalance[F, K, V] = OnRebalance(
onAssigned = assigned =>
enqueueAssignment(streamId, partitionStreamIdRef, assigned, partitionsMapQueue),
onRevoked = _ => F.unit
)
): OnRebalance[F, K, V] = {
OnRebalance(
onRevoked = _ => {
for {
newFinisher <- Deferred[F, Unit]
prevAssignmentFinisher <- prevAssignmentFinisherRef.getAndSet(newFinisher)
_ <- prevAssignmentFinisher.complete(())
} yield ()
},
onAssigned = assigned => {
prevAssignmentFinisherRef.get.flatMap { prevAssignmentFinisher =>
enqueueAssignment(
streamId = streamId,
assigned = assigned,
partitionsMapQueue = partitionsMapQueue,
assignmentRevoked = prevAssignmentFinisher.get
)
}
}
)
}

def requestAssignment(
streamId: StreamId,
partitionStreamIdRef: Ref[F, PartitionStreamId],
prevAssignmentFinisherRef: Ref[F, Deferred[F, Unit]],
partitionsMapQueue: PartitionsMapQueue
): F[SortedSet[TopicPartition]] =
Deferred[F, Either[Throwable, SortedSet[TopicPartition]]].flatMap { deferred =>
val request =
Request.Assignment[F, K, V](
deferred.complete,
Some(onRebalance(streamId, partitionStreamIdRef, partitionsMapQueue))
Some(onRebalance(streamId, prevAssignmentFinisherRef, partitionsMapQueue))
)
val assignment = requests.enqueue1(request) >> deferred.get.rethrow
F.race(awaitTermination.attempt, assignment).map {
Expand All @@ -290,20 +294,23 @@ object KafkaConsumer {
def initialEnqueue(
streamId: StreamId,
partitionsMapQueue: PartitionsMapQueue,
partitionStreamIdRef: Ref[F, PartitionStreamId]
prevAssignmentFinisherRef: Ref[F, Deferred[F, Unit]]
): F[Unit] =
requestAssignment(streamId, partitionStreamIdRef, partitionsMapQueue).flatMap {
assigned =>
enqueueAssignment(streamId, partitionStreamIdRef, assigned, partitionsMapQueue)
}
for {
prevAssignmentFinisher <- prevAssignmentFinisherRef.get
assigned <- requestAssignment(streamId, prevAssignmentFinisherRef, partitionsMapQueue)
assignmentRevoked = prevAssignmentFinisher.get
_ <- enqueueAssignment(streamId, assigned, partitionsMapQueue, assignmentRevoked)
} yield ()

Stream.eval(stopConsumingDeferred.tryGet).flatMap {
case None =>
for {
partitionsMapQueue <- Stream.eval(Queue.noneTerminated[F, PartitionsMap])
streamId <- Stream.eval(streamIdRef.modify(n => (n + 1, n)))
partitionStreamIdRef <- Stream.eval(Ref.of[F, PartitionStreamId](0))
_ <- Stream.eval(initialEnqueue(streamId, partitionsMapQueue, partitionStreamIdRef))
prevAssignmentFinisher <- Stream.eval(Deferred[F, Unit])
prevAssignmentFinisherRef <- Stream.eval(Ref[F].of(prevAssignmentFinisher))
_ <- Stream.eval(initialEnqueue(streamId, partitionsMapQueue, prevAssignmentFinisherRef))
out <- partitionsMapQueue.dequeue
.interruptWhen(awaitTermination.attempt)
.concurrently(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
private[this] def fetch(
partition: TopicPartition,
streamId: StreamId,
partitionStreamId: PartitionStreamId,
callback: ((Chunk[CommittableConsumerRecord[F, K, V]], FetchCompletedReason)) => F[Unit]
): F[Unit] = {
val assigned =
Expand All @@ -176,7 +175,7 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
ref
.modify { state =>
val (newState, oldFetch) =
state.withFetch(partition, streamId, partitionStreamId, callback)
state.withFetch(partition, streamId, callback)
(newState, (newState, oldFetch))
}
.flatMap {
Expand Down Expand Up @@ -567,8 +566,8 @@ private[kafka] final class KafkaConsumerActor[F[_], K, V](
case Request.Assign(partitions, callback) => assign(partitions, callback)
case Request.SubscribePattern(pattern, callback) => subscribe(pattern, callback)
case Request.Unsubscribe(callback) => unsubscribe(callback)
case Request.Fetch(partition, streamId, partitionStreamId, callback) =>
fetch(partition, streamId, partitionStreamId, callback)
case Request.Fetch(partition, streamId, callback) =>
fetch(partition, streamId, callback)
case request @ Request.Commit(_, _) => commit(request)
case request @ Request.ManualCommitAsync(_, _) => manualCommitAsync(request)
case request @ Request.ManualCommitSync(_, _) => manualCommitSync(request)
Expand Down Expand Up @@ -634,11 +633,9 @@ private[kafka] object KafkaConsumerActor {
}

type StreamId = Int
type PartitionStreamId = Int

final case class State[F[_], K, V](
fetches: Map[TopicPartition, Map[StreamId, FetchRequest[F, K, V]]],
partitionStreamIds: Map[TopicPartition, PartitionStreamId],
records: Map[TopicPartition, NonEmptyVector[CommittableConsumerRecord[F, K, V]]],
pendingCommits: Chain[Request.Commit[F, K, V]],
onRebalances: Chain[OnRebalance[F, K, V]],
Expand All @@ -655,48 +652,29 @@ private[kafka] object KafkaConsumerActor {
def withFetch(
partition: TopicPartition,
streamId: StreamId,
partitionStreamId: PartitionStreamId,
callback: ((Chunk[CommittableConsumerRecord[F, K, V]], FetchCompletedReason)) => F[Unit]
): (State[F, K, V], List[FetchRequest[F, K, V]]) = {
val newFetchRequest =
FetchRequest(callback)

val oldPartitionFetches =
val oldPartitionFetches: Map[StreamId, FetchRequest[F, K, V]] =
fetches.getOrElse(partition, Map.empty)

val oldPartitionFetch =
oldPartitionFetches.get(streamId)
val newFetches: Map[TopicPartition, Map[StreamId, FetchRequest[F, K, V]]] =
fetches.updated(partition, oldPartitionFetches.updated(streamId, newFetchRequest))

val oldPartitionStreamId =
partitionStreamIds.getOrElse(partition, 0)

val hasMoreRecentPartitionStreamIds =
oldPartitionStreamId > partitionStreamId

val newFetches =
fetches.updated(partition, {
if (hasMoreRecentPartitionStreamIds) oldPartitionFetches - streamId
else oldPartitionFetches.updated(streamId, newFetchRequest)
})

val newPartitionStreamIds =
partitionStreamIds.updated(partition, oldPartitionStreamId max partitionStreamId)

val fetchesToRevoke =
if (hasMoreRecentPartitionStreamIds)
newFetchRequest :: oldPartitionFetch.toList
else oldPartitionFetch.toList
val fetchesToRevoke: List[FetchRequest[F, K, V]] =
oldPartitionFetches.get(streamId).toList

(
copy(fetches = newFetches, partitionStreamIds = newPartitionStreamIds),
copy(fetches = newFetches),
fetchesToRevoke
)
}

def withoutFetches(partitions: Set[TopicPartition]): State[F, K, V] =
copy(
fetches = fetches.filterKeysStrict(!partitions.contains(_)),
partitionStreamIds = partitionStreamIds.filterKeysStrict(!partitions.contains(_))
)

def withRecords(
Expand All @@ -707,7 +685,6 @@ private[kafka] object KafkaConsumerActor {
def withoutFetchesAndRecords(partitions: Set[TopicPartition]): State[F, K, V] =
copy(
fetches = fetches.filterKeysStrict(!partitions.contains(_)),
partitionStreamIds = partitionStreamIds.filterKeysStrict(!partitions.contains(_)),
records = records.filterKeysStrict(!partitions.contains(_))
)

Expand Down Expand Up @@ -743,25 +720,14 @@ private[kafka] object KafkaConsumerActor {
append(fs.mkString("[", ", ", "]"))
}("", ", ", "")

val partitionStreamIdsString =
partitionStreamIds.toList
.sortBy { case (tp, _) => tp }
.mkStringAppend {
case (append, (tp, id)) =>
append(tp.toString)
append(" -> ")
append(id.toString)
}("", ", ", "")

s"State(fetches = Map($fetchesString), partitionStreamIds = Map($partitionStreamIdsString), records = Map(${recordsString(records)}), pendingCommits = $pendingCommits, onRebalances = $onRebalances, rebalancing = $rebalancing, subscribed = $subscribed, streaming = $streaming)"
s"State(fetches = Map($fetchesString), records = Map(${recordsString(records)}), pendingCommits = $pendingCommits, onRebalances = $onRebalances, rebalancing = $rebalancing, subscribed = $subscribed, streaming = $streaming)"
}
}

object State {
def empty[F[_], K, V]: State[F, K, V] =
State(
fetches = Map.empty,
partitionStreamIds = Map.empty,
records = Map.empty,
pendingCommits = Chain.empty,
onRebalances = Chain.empty,
Expand Down Expand Up @@ -830,7 +796,6 @@ private[kafka] object KafkaConsumerActor {
final case class Fetch[F[_], K, V](
partition: TopicPartition,
streamId: StreamId,
partitionStreamId: PartitionStreamId,
callback: ((Chunk[CommittableConsumerRecord[F, K, V]], FetchCompletedReason)) => F[Unit]
) extends Request[F, K, V]

Expand Down
2 changes: 2 additions & 0 deletions modules/core/src/test/resources/logback-test.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
</encoder>
</appender>

<logger name="fs2.kafka" level="DEBUG" />

<root level="ERROR" additivity="false">
<appender-ref ref="STDOUT"/>
</root>
Expand Down
51 changes: 51 additions & 0 deletions modules/core/src/test/scala/fs2/kafka/KafkaConsumerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,57 @@ final class KafkaConsumerSpec extends BaseKafkaSpec {
}).unsafeRunSync()
}
}

it("should handle multiple rebalances with multiple instances under load #532") {
withTopic { topic =>
val numPartitions = 3
createCustomTopic(topic, partitions = numPartitions)

val produced = (0 until 10000).map(n => s"key-$n" -> s"value->$n")
publishToKafka(topic, produced)

def run(instance: Int, allAssignments: SignallingRef[IO, Map[Int, Set[Int]]]): IO[Unit] = {
KafkaConsumer
.stream(consumerSettings[IO].withGroupId("test"))
.evalTap(_.subscribeTo(topic))
.flatMap(_.partitionsMapStream)
.flatMap { assignment =>
Stream.eval(allAssignments.update { current =>
current.updated(instance, assignment.keySet.map(_.partition()))
}) >> Stream.emits(assignment.map { case (_, partitionStream) =>
partitionStream.evalMap(_ => IO.sleep(10.millis)) // imitating some work
}.toList).parJoinUnbounded
}.compile.drain
}

def checkAssignments(allAssignments: SignallingRef[IO, Map[Int, Set[Int]]])(instances: Set[Int]) = {
allAssignments.discrete.filter { state =>
state.keySet == instances &&
instances.forall { instance =>
state.get(instance).exists(_.nonEmpty)
} && state.values.toList.flatMap(_.toList).sorted == List(0, 1, 2)
}.take(1).compile.drain
}

(for {
allAssignments <- SignallingRef[IO, Map[Int, Set[Int]]](Map.empty)
check = checkAssignments(allAssignments)(_)
fiber0 <- run(0, allAssignments).start
_ <- check(Set(0))
fiber1 <- run(1, allAssignments).start
_ <- check(Set(0, 1))
fiber2 <- run(2, allAssignments).start
_ <- check(Set(0, 1, 2))
_ <- fiber2.cancel
_ <- allAssignments.update(_.removed(2))
_ <- check(Set(0, 1))
_ <- fiber1.cancel
_ <- allAssignments.update(_.removed(1))
_ <- check(Set(0))
_ <- fiber0.cancel
} yield succeed).unsafeRunSync()
}
}
}

describe("KafkaConsumer#assignmentStream") {
Expand Down

0 comments on commit ce4e00b

Please sign in to comment.