Skip to content

Commit

Permalink
Swallow malformed messages on trigger bucket topics
Browse files Browse the repository at this point in the history
  • Loading branch information
phildobsontw committed Nov 27, 2024
1 parent c3ceb1d commit c5b6200
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.transferwise.tasks.ITaskDataSerializer;
import com.transferwise.tasks.ITasksService;
import com.transferwise.tasks.ITasksService.AddTaskRequest;
import com.transferwise.tasks.TasksProperties;
import com.transferwise.tasks.dao.ITaskDao;
import com.transferwise.tasks.dao.ITaskDao.InsertTaskRequest;
import com.transferwise.tasks.domain.IBaseTask;
Expand Down Expand Up @@ -48,6 +49,8 @@
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.mutable.MutableObject;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
Expand All @@ -70,6 +73,10 @@ public class TaskProcessingIntTest extends BaseIntTest {
protected ITaskDataSerializer taskDataSerializer;
@Autowired
protected GlobalProcessingState globalProcessingState;
@Autowired
protected TasksProperties tasksProperties;
@Autowired
protected KafkaProducer<String, String> kafkaProducer;

private KafkaTasksExecutionTriggerer kafkaTasksExecutionTriggerer;

Expand Down Expand Up @@ -433,6 +440,48 @@ public void freeSpace(IBaseTask task) {
}));
}

@Test
void taskProcessingWillHandlePoisonPillAttack() {
// given:
int tasksToFire = 10;
AtomicInteger counter = new AtomicInteger();

testTaskHandlerAdapter.setProcessor((ISyncTaskProcessor) task -> {
counter.incrementAndGet();
return new ProcessResult().setResultCode(ResultCode.DONE);
});

// when:
for (int i = 0; i < tasksToFire; i++) {
publishPosionPill();
addTask();
publishPosionPill();
}

// then:
await().until(() -> transactionsHelper.withTransaction().asNew().call(() -> {
try {
return testTasksService.getFinishedTasks("test", null).size() == tasksToFire
&& counter.get() == tasksToFire;
} catch (Throwable t) {
log.error(t.getMessage(), t);
}
return false;
}));
}

@SneakyThrows
private void publishPosionPill() {
final var topicName = "twTasks." + tasksProperties.getGroupId() + ".executeTask.default";
kafkaProducer.send(new ProducerRecord<>(topicName, "poison-pill")).get();
}

private void addTask() {
transactionsHelper.withTransaction().asNew().call(() ->
tasksService.addTask(new ITasksService.AddTaskRequest().setType("test").setData(taskDataSerializer.serialize("foo")))
);
}

private int counterSum(String name) {
return meterRegistry.find(name)
.counters()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.transferwise.tasks.testapp.config;

import com.transferwise.common.context.TwContextClockHolder;
import com.transferwise.tasks.TasksProperties;
import com.transferwise.tasks.buckets.BucketProperties;
import com.transferwise.tasks.buckets.IBucketsManager;
import com.transferwise.tasks.domain.ITask;
Expand All @@ -17,14 +18,19 @@
import java.time.ZonedDateTime;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.apache.kafka.clients.CommonClientConfigs;
import org.apache.kafka.clients.admin.AdminClient;
import org.apache.kafka.clients.admin.NewTopic;
import org.apache.kafka.clients.consumer.ConsumerConfig;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.CooperativeStickyAssignor;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.ProducerConfig;
import org.apache.kafka.common.serialization.StringSerializer;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.kafka.KafkaProperties;
Expand Down Expand Up @@ -141,6 +147,31 @@ public IKafkaListenerConsumerPropertiesProvider twTasksKafkaListenerSpringKafkaC
};
}

@Bean
public KafkaProducer<String, String> kafkaTaskTriggererProducer(TasksProperties tasksProperties) {
Map<String, Object> configs = new HashMap<>();
configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, tasksProperties.getTriggering().getKafka().getBootstrapServers());

configs.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class);
configs.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class);

configs.put(ProducerConfig.ACKS_CONFIG, "all");
configs.put(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, 5);
configs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, "5000");
configs.put(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, "true");
configs.put(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG, "5000");
configs.put(ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG, "10000");
configs.put(ProducerConfig.LINGER_MS_CONFIG, "5");
configs.put(ProducerConfig.CLIENT_ID_CONFIG, tasksProperties.getGroupId() + ".tw-tasks-triggerer");
configs.put(ProducerConfig.RECONNECT_BACKOFF_MAX_MS_CONFIG, "5000");
configs.put(ProducerConfig.RECONNECT_BACKOFF_MS_CONFIG, "100");
configs.put(ProducerConfig.METADATA_MAX_AGE_CONFIG, "120000");

configs.putAll(tasksProperties.getTriggering().getKafka().getProperties());

return new KafkaProducer<>(configs);
}

@Bean
ITaskRegistrationDecorator jambiRegistrationInterceptor() {
return new JambiTaskRegistrationDecorator();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import lombok.SneakyThrows;
import org.apache.kafka.clients.consumer.ConsumerConfig;
import org.apache.kafka.clients.consumer.KafkaConsumer;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.common.serialization.StringDeserializer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
Expand All @@ -34,12 +36,15 @@ class KafkaTasksExecutionTriggererIntTest extends BaseIntTest {

public static final String PARTITION_KEY = "7a1a43c9-35af-4bea-9349-a1f344c8185c";
private static final String BUCKET_ID = "manualStart";
private static final String TASK_TYPE = "test";

private KafkaConsumer<String, String> kafkaConsumer;
@Autowired
protected ITaskDataSerializer taskDataSerializer;
@Autowired
private TasksProperties tasksProperties;
@Autowired
protected KafkaProducer<String, String> kafkaProducer;

@BeforeEach
@SneakyThrows
Expand Down Expand Up @@ -68,7 +73,6 @@ void cleanup() {
@Test
void shouldUsePartitionKeyStrategyWhenCustomStrategyDefinedInProcessor() {
final var data = "Hello World!";
final var taskType = "test";
final var taskId = UuidUtils.generatePrefixCombUuid();

testTaskHandlerAdapter.setProcessor(resultRegisteringSyncTaskProcessor)
Expand All @@ -82,12 +86,12 @@ void shouldUsePartitionKeyStrategyWhenCustomStrategyDefinedInProcessor() {
final var taskRequest = new AddTaskRequest()
.setTaskId(taskId)
.setData(taskDataSerializer.serialize(data))
.setType(taskType)
.setType(TASK_TYPE)
.setUniqueKey(uniqueKey.toString())
.setRunAfterTime(ZonedDateTime.now().plusHours(1));

transactionsHelper.withTransaction().asNew().call(() -> testTasksService.addTask(taskRequest));
await().until(() -> testTasksService.getWaitingTasks(taskType, null).size() > 0);
await().until(() -> testTasksService.getWaitingTasks(TASK_TYPE, null).size() > 0);

assertTrue(transactionsHelper.withTransaction().asNew().call(() ->
testTasksService.resumeTask(new ITasksService.ResumeTaskRequest().setTaskId(taskId).setVersion(0))
Expand All @@ -106,6 +110,48 @@ void shouldUsePartitionKeyStrategyWhenCustomStrategyDefinedInProcessor() {
Assertions.assertTrue(keys.contains(PARTITION_KEY));
}

@Test
void handlesPoisonPills() {
// setup:
testTaskHandlerAdapter.setProcessor(resultRegisteringSyncTaskProcessor)
.setProcessingPolicy(new SimpleTaskProcessingPolicy()
.setProcessingBucket(BUCKET_ID)
.setMaxProcessingDuration(Duration.of(1, ChronoUnit.HOURS))
.setPartitionKeyStrategy(new TestPartitionKeyStrategy()));


// when
int tasksToFire = 10;
for (int i = 0; i < tasksToFire; i++) {
publishPosionPill();
addTask();
publishPosionPill();
}
testTasksService.startTasksProcessing(BUCKET_ID);

await().until(
() -> resultRegisteringSyncTaskProcessor.getTaskResults().size() == tasksToFire
);

}

@SneakyThrows
private void addTask() {
UUID taskId = UuidUtils.generatePrefixCombUuid();
final var taskRequest = new AddTaskRequest()
.setTaskId(taskId)
.setData(taskDataSerializer.serialize("Hello World!"))
.setType(TASK_TYPE);

transactionsHelper.withTransaction().asNew().call(() -> testTasksService.addTask(taskRequest));
}

@SneakyThrows
private void publishPosionPill() {
final var topicName = "twTasks." + tasksProperties.getGroupId() + ".executeTask.default";
kafkaProducer.send(new ProducerRecord<>(topicName, "poison-pill")).get();
}

static class TestPartitionKeyStrategy implements IPartitionKeyStrategy {

private static final UUID KEY = UUID.fromString(PARTITION_KEY);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,16 @@ public void poll(String bucketId) {

log.debug("Received Kafka message from topic '{}' partition {} offset {}.", consumerRecord.topic(), consumerRecord.partition(), offset);

BaseTask task = JsonUtils.fromJson(objectMapper, consumerRecord.value(), BaseTask.class);
BaseTask task;
try {
task = JsonUtils.fromJson(objectMapper, consumerRecord.value(), BaseTask.class);
} catch (Exception e) {
log.error("Received malformed task trigger in bucket {} [from topic '{}' partition {} offset {}].",
bucketId, consumerRecord.topic(), consumerRecord.partition(), offset, e);
releaseCompletedOffset(consumerBucket, topicPartition, offset);
break;
}

mdcService.with(() -> {
mdcService.put(task);

Expand Down

0 comments on commit c5b6200

Please sign in to comment.