Skip to content

Commit

Permalink
KafkaIO SDF: Fetch end position for each topic-partition tuple in a b…
Browse files Browse the repository at this point in the history
…ackground thread, reusing consumers.
  • Loading branch information
bzablocki committed Oct 2, 2024
1 parent dcc0bd2 commit 6d9857b
Show file tree
Hide file tree
Showing 3 changed files with 265 additions and 47 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.kafka;

import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;

import java.io.IOException;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Closeables;
import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.common.TopicPartition;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KafkaOffsetConsumerPollThread {

private static final Logger LOG = LoggerFactory.getLogger(KafkaOffsetConsumerPollThread.class);
private @Nullable Consumer<byte[], byte[]> consumer;
private @Nullable TopicPartition topicPartition;
private final AtomicLong collectedEndOffset = new AtomicLong();
private final AtomicBoolean closed = new AtomicBoolean(false);
private @Nullable Future<?> offsetTrackerFuture;

KafkaOffsetConsumerPollThread() {
consumer = null;
offsetTrackerFuture = null;
topicPartition = null;
}

void startOnExecutor(
ScheduledExecutorService executorService,
Consumer<byte[], byte[]> consumer,
TopicPartition topicPartition) {
this.consumer = consumer;
this.topicPartition = topicPartition;
this.offsetTrackerFuture =
executorService.scheduleAtFixedRate(
this::fetchEndOffsetFromKafka, 0, 500, TimeUnit.MILLISECONDS);
}

private void fetchEndOffsetFromKafka() {
Consumer<byte[], byte[]> consumer = checkStateNotNull(this.consumer);
TopicPartition topicPartition = checkStateNotNull(this.topicPartition);

if (closed.get()) {
return;
}

Long currentEndOffset =
consumer.endOffsets(ImmutableList.of(topicPartition)).get(topicPartition);
if (currentEndOffset != null) {
collectedEndOffset.set(currentEndOffset);
} else {
LOG.warn("Unable to get an end offset for {}", topicPartition);
}
}

void close() throws IOException {
if (consumer == null) {
LOG.debug("Closing consumer poll thread that was never started.");
return;
}
closed.set(true);
Closeables.close(consumer, true);
checkStateNotNull(offsetTrackerFuture).cancel(true);
}

long readEndOffset() {
return collectedEndOffset.get();
}

boolean isClosed() {
return closed.get();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.kafka;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;

import java.io.IOException;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.Cache;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.CacheBuilder;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.cache.RemovalNotification;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.kafka.clients.consumer.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KafkaOffsetConsumerPollThreadCache {

private static final Logger LOG =
LoggerFactory.getLogger(KafkaOffsetConsumerPollThreadCache.class);
private static final Duration OFFSET_THREAD_ALLOWED_NOT_ACCESSED_DURATION =
Duration.ofMinutes(10);

private final ExecutorService invalidationExecutor =
Executors.newCachedThreadPool(
new ThreadFactoryBuilder()
.setDaemon(true)
.setNameFormat("KafkaOffsetConsumerPollCache-invalidation-%d")
.build());

private final ScheduledExecutorService backgroundThread =
Executors.newScheduledThreadPool(
0,
new ThreadFactoryBuilder()
.setDaemon(true)
.setNameFormat("KafkaOffsetConsumerPollCache-poll-%d")
.build());

private final Cache<KafkaSourceDescriptor, KafkaOffsetConsumerPollThread> offsetThreadCache;

@SuppressWarnings("method.invocation")
KafkaOffsetConsumerPollThreadCache() {
this.offsetThreadCache =
CacheBuilder.newBuilder()
.expireAfterAccess(
OFFSET_THREAD_ALLOWED_NOT_ACCESSED_DURATION.toMillis(), TimeUnit.MILLISECONDS)
.removalListener(
(RemovalNotification<KafkaSourceDescriptor, KafkaOffsetConsumerPollThread>
notification) -> {
LOG.info(
"Asynchronously closing offset reader for {}. Reason: {}",
notification.getKey(),
notification.getCause());
asyncCloseOffsetConsumer(
checkNotNull(notification.getKey()), checkNotNull(notification.getValue()));
})
.build();
}

void invalidate(KafkaSourceDescriptor kafkaSourceDescriptor) {
this.offsetThreadCache.invalidate(kafkaSourceDescriptor);
}

KafkaOffsetConsumerPollThread acquireOffsetTrackerThread(
Map<String, Object> consumerConfig,
SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>> consumerFactoryFn,
KafkaSourceDescriptor kafkaSourceDescriptor) {
try {
return offsetThreadCache.get(
kafkaSourceDescriptor,
() -> {
Consumer<byte[], byte[]> consumer =
createAndSetupKafkaEndOffsetConsumer(
consumerConfig, consumerFactoryFn, kafkaSourceDescriptor);
KafkaOffsetConsumerPollThread pollThread = new KafkaOffsetConsumerPollThread();
pollThread.startOnExecutor(
backgroundThread, consumer, kafkaSourceDescriptor.getTopicPartition());
return pollThread;
});
} catch (ExecutionException e) {
throw new RuntimeException(e);
}
}

private static Consumer<byte[], byte[]> createAndSetupKafkaEndOffsetConsumer(
Map<String, Object> consumerConfig,
SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>> consumerFactoryFn,
KafkaSourceDescriptor kafkaSourceDescriptor) {
LOG.info(
"Creating a new kafka consumer for tracking backlog for {}",
kafkaSourceDescriptor.getTopicPartition());
Consumer<byte[], byte[]> consumer = consumerFactoryFn.apply(consumerConfig);
ConsumerSpEL.evaluateAssign(
consumer, ImmutableList.of(kafkaSourceDescriptor.getTopicPartition()));
return consumer;
}

/** Close the reader and log a warning if close fails. */
private void asyncCloseOffsetConsumer(
KafkaSourceDescriptor kafkaSourceDescriptor, KafkaOffsetConsumerPollThread offsetPollThread) {
invalidationExecutor.execute(
() -> {
try {
offsetPollThread.close();
LOG.info("Finished closing consumer for {}", kafkaSourceDescriptor);
} catch (IOException e) {
LOG.warn("Failed to close consumer for {}", kafkaSourceDescriptor, e);
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.io.kafka;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;

import java.util.HashMap;
Expand All @@ -25,7 +26,6 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.io.kafka.KafkaIO.ReadSourceDescriptors;
import org.apache.beam.sdk.io.kafka.KafkaIOUtils.MovingAvg;
Expand Down Expand Up @@ -223,10 +223,13 @@ private ReadFromKafkaDoFn(
private transient @Nullable Deserializer<V> valueDeserializerInstance = null;
private transient @Nullable Map<TopicPartition, KafkaLatestOffsetEstimator> offsetEstimatorCache;

private static final Supplier<KafkaOffsetConsumerPollThreadCache> offsetConsumerPollThreadCache =
Suppliers.memoize(KafkaOffsetConsumerPollThreadCache::new);

private transient @Nullable LoadingCache<TopicPartition, AverageRecordSize> avgRecordSize;
private static final long DEFAULT_KAFKA_POLL_TIMEOUT = 2L;

private HashMap<String, Long> perPartitionBacklogMetrics = new HashMap<String, Long>();;
private final HashMap<String, Long> perPartitionBacklogMetrics = new HashMap<>();

@VisibleForTesting final long consumerPollingTimeout;
@VisibleForTesting final DeserializerProvider<K> keyDeserializerProvider;
Expand All @@ -244,46 +247,26 @@ private ReadFromKafkaDoFn(
private static class KafkaLatestOffsetEstimator
implements GrowableOffsetRangeTracker.RangeEndEstimator {

private final Consumer<byte[], byte[]> offsetConsumer;
private final TopicPartition topicPartition;
private final Supplier<Long> memoizedBacklog;
private boolean closed;

KafkaLatestOffsetEstimator(
Consumer<byte[], byte[]> offsetConsumer, TopicPartition topicPartition) {
this.offsetConsumer = offsetConsumer;
this.topicPartition = topicPartition;
ConsumerSpEL.evaluateAssign(this.offsetConsumer, ImmutableList.of(this.topicPartition));
memoizedBacklog =
Suppliers.memoizeWithExpiration(
() -> {
synchronized (offsetConsumer) {
ConsumerSpEL.evaluateSeek2End(offsetConsumer, topicPartition);
return offsetConsumer.position(topicPartition);
}
},
1,
TimeUnit.SECONDS);
}

@Override
protected void finalize() {
try {
Closeables.close(offsetConsumer, true);
closed = true;
LOG.info("Offset Estimator consumer was closed for {}", topicPartition);
} catch (Exception anyException) {
LOG.warn("Failed to close offset consumer for {}", topicPartition);
}
private final KafkaOffsetConsumerPollThread offsetTrackerThread;

public KafkaLatestOffsetEstimator(
Map<String, Object> fullOffsetConsumerConfig,
SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>> consumerFactoryFn,
KafkaSourceDescriptor kafkaSourceDescriptor) {
this.offsetTrackerThread =
offsetConsumerPollThreadCache
.get()
.acquireOffsetTrackerThread(
fullOffsetConsumerConfig, consumerFactoryFn, kafkaSourceDescriptor);
}

@Override
public long estimate() {
return memoizedBacklog.get();
return offsetTrackerThread.readEndOffset();
}

public boolean isClosed() {
return closed;
return offsetTrackerThread.isClosed();
}
}

Expand Down Expand Up @@ -312,7 +295,6 @@ public OffsetRange initialRestriction(@Element KafkaSourceDescriptor kafkaSource
} else if (stopReadTime != null) {
endOffset = ConsumerSpEL.offsetForTime(offsetConsumer, partition, stopReadTime);
}
new OffsetRange(startOffset, endOffset);
Lineage.getSources()
.add(
"kafka",
Expand Down Expand Up @@ -369,21 +351,23 @@ public OffsetRangeTracker restrictionTracker(
// so we want to minimize the amount of connections that we start and track with Kafka. Another
// point is that it has a memoized backlog, and this should make that more reusable estimations.
final Map<TopicPartition, KafkaLatestOffsetEstimator> offsetEstimatorCacheInstance =
Preconditions.checkStateNotNull(this.offsetEstimatorCache);
Preconditions.checkStateNotNull(offsetEstimatorCache);

TopicPartition topicPartition = kafkaSourceDescriptor.getTopicPartition();
KafkaLatestOffsetEstimator offsetEstimator = offsetEstimatorCacheInstance.get(topicPartition);
if (offsetEstimator == null || offsetEstimator.isClosed()) {
Map<String, Object> updatedConsumerConfig =
Map<String, Object> consumerConfigWithBootstrapServers =
overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor);

LOG.info("Creating Kafka consumer for offset estimation for {}", topicPartition);
Map<String, Object> combinedConsumerConfigs =
KafkaIOUtils.getOffsetConsumerConfig(
"tracker-" + topicPartition,
offsetConsumerConfig,
consumerConfigWithBootstrapServers);

Consumer<byte[], byte[]> offsetConsumer =
consumerFactoryFn.apply(
KafkaIOUtils.getOffsetConsumerConfig(
"tracker-" + topicPartition, offsetConsumerConfig, updatedConsumerConfig));
offsetEstimator = new KafkaLatestOffsetEstimator(offsetConsumer, topicPartition);
offsetEstimator =
new KafkaLatestOffsetEstimator(
combinedConsumerConfigs, consumerFactoryFn, kafkaSourceDescriptor);
offsetEstimatorCacheInstance.put(topicPartition, offsetEstimator);
}

Expand Down Expand Up @@ -454,6 +438,10 @@ public ProcessContinuation processElement(
if (rawRecords.isEmpty()) {
if (!topicPartitionExists(
kafkaSourceDescriptor.getTopicPartition(), consumer.listTopics())) {
LOG.info(
"Closing the offset thread for {} because the partition doesn't exist anymore",
kafkaSourceDescriptor.getTopicPartition());
offsetConsumerPollThreadCache.get().invalidate(kafkaSourceDescriptor);
return ProcessContinuation.stop();
}
if (timestampPolicy != null) {
Expand Down Expand Up @@ -644,7 +632,7 @@ public void teardown() throws Exception {
}

if (offsetEstimatorCache != null) {
offsetEstimatorCache.clear();
checkNotNull(offsetEstimatorCache).clear();
}
if (checkStopReadingFn != null) {
checkStopReadingFn.teardown();
Expand All @@ -666,8 +654,8 @@ private Map<String, Object> overrideBootstrapServersConfig(
}

private static class AverageRecordSize {
private MovingAvg avgRecordSize;
private MovingAvg avgRecordGap;
private final MovingAvg avgRecordSize;
private final MovingAvg avgRecordGap;

public AverageRecordSize() {
this.avgRecordSize = new MovingAvg();
Expand Down

0 comments on commit 6d9857b

Please sign in to comment.