diff --git a/modules/core/src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala b/modules/core/src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala index 6a1f4fe02..c67ef462f 100644 --- a/modules/core/src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala +++ b/modules/core/src/main/scala/fs2/kafka/internal/KafkaConsumerActor.scala @@ -91,17 +91,16 @@ final private[kafka] class KafkaConsumerActor[F[_], K, V]( .handleErrorWith(e => F.delay(callback(Left(e)))) private[this] def commit(request: Request.Commit[F]): F[Unit] = - ref - .modify { state => - if (state.rebalancing) { - val newState = state.withPendingCommit(request) - (newState, Some(StoredPendingCommit(request, newState))) - } else (state, None) - } - .flatMap { - case Some(log) => logging.log(log) - case None => commitAsync(request.offsets, request.callback) - } + ref.flatModify { state => + val commitF = commitAsync(request.offsets, request.callback) + if (state.rebalancing || state.pendingCommits.nonEmpty) { + val newState = state.withPendingCommit( + commitF >> logging.log(CommittedPendingCommit(request)) + ) + (newState, logging.log(StoredPendingCommit(request, newState))) + } else + (state, commitF) + } private[this] def manualCommitSync(request: Request.ManualCommitSync[F]): F[Unit] = { val commit = @@ -302,7 +301,7 @@ final private[kafka] class KafkaConsumerActor[F[_], K, V]( } .flatMap(records) - def handlePoll(newRecords: ConsumerRecords, initialRebalancing: Boolean): F[Unit] = { + def handlePoll(newRecords: ConsumerRecords): F[Unit] = { def handleBatch( state: State[F, K, V], pendingCommits: Option[HandlePollResult.PendingCommits] @@ -381,17 +380,12 @@ final private[kafka] class KafkaConsumerActor[F[_], K, V]( } def handlePendingCommits(state: State[F, K, V]) = { - val currentRebalancing = state.rebalancing - - if (initialRebalancing && !currentRebalancing && state.pendingCommits.nonEmpty) { + if (!state.rebalancing && state.pendingCommits.nonEmpty) { val newState = state.withoutPendingCommits ( newState, Some( - HandlePollResult.PendingCommits( - commits = state.pendingCommits, - log = CommittedPendingCommits(state.pendingCommits, newState) - ) + HandlePollResult.PendingCommits(commits = state.pendingCommits) ) ) } else (state, None) @@ -418,10 +412,9 @@ final private[kafka] class KafkaConsumerActor[F[_], K, V]( ref .get .flatMap { state => - if (state.subscribed && state.streaming) { - val initialRebalancing = state.rebalancing - pollConsumer(state).flatMap(handlePoll(_, initialRebalancing)) - } else F.unit + if (state.subscribed && state.streaming) + pollConsumer(state).flatMap(handlePoll(_)) + else F.unit } } @@ -448,15 +441,9 @@ final private[kafka] class KafkaConsumerActor[F[_], K, V]( private[this] object HandlePollResult { - case class PendingCommits( - commits: Chain[Request.Commit[F]], - log: CommittedPendingCommits[F] - ) { + case class PendingCommits(commits: Chain[F[Unit]]) { - def commit: F[Unit] = - commits.traverse { commitRequest => - commitAsync(commitRequest.offsets, commitRequest.callback) - } >> logging.log(log) + def commit: F[Unit] = commits.sequence_ } @@ -506,7 +493,7 @@ private[kafka] object KafkaConsumerActor { final case class State[F[_], K, V]( fetches: Map[TopicPartition, Map[StreamId, FetchRequest[F, K, V]]], records: Map[TopicPartition, NonEmptyVector[CommittableConsumerRecord[F, K, V]]], - pendingCommits: Chain[Request.Commit[F]], + pendingCommits: Chain[F[Unit]], onRebalances: Chain[OnRebalance[F]], rebalancing: Boolean, subscribed: Boolean, @@ -562,7 +549,7 @@ private[kafka] object KafkaConsumerActor { def withoutRecords(partitions: Set[TopicPartition]): State[F, K, V] = copy(records = records.filterKeysStrict(!partitions.contains(_))) - def withPendingCommit(pendingCommit: Request.Commit[F]): State[F, K, V] = + def withPendingCommit(pendingCommit: F[Unit]): State[F, K, V] = copy(pendingCommits = pendingCommits.append(pendingCommit)) def withoutPendingCommits: State[F, K, V] = diff --git a/modules/core/src/main/scala/fs2/kafka/internal/LogEntry.scala b/modules/core/src/main/scala/fs2/kafka/internal/LogEntry.scala index 782945a1d..adc26ed39 100644 --- a/modules/core/src/main/scala/fs2/kafka/internal/LogEntry.scala +++ b/modules/core/src/main/scala/fs2/kafka/internal/LogEntry.scala @@ -10,7 +10,7 @@ import java.util.regex.Pattern import scala.collection.immutable.SortedSet -import cats.data.{Chain, NonEmptyList, NonEmptySet, NonEmptyVector} +import cats.data.{NonEmptyList, NonEmptySet, NonEmptyVector} import cats.syntax.all.* import fs2.kafka.instances.* import fs2.kafka.internal.syntax.* @@ -211,15 +211,11 @@ private[kafka] object LogEntry { } - final case class CommittedPendingCommits[F[_]]( - pendingCommits: Chain[Request.Commit[F]], - state: State[F, ?, ?] - ) extends LogEntry { + final case class CommittedPendingCommit[F[_]](pendingCommit: Request.Commit[F]) extends LogEntry { override def level: LogLevel = Debug - override def message: String = - s"Committed pending commits [$pendingCommits]. Current state [$state]." + override def message: String = s"Committed pending commit [$pendingCommit]." }