diff --git a/data-prepper-plugins/kinesis-source/build.gradle b/data-prepper-plugins/kinesis-source/build.gradle new file mode 100644 index 0000000000..400fe22c00 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/build.gradle @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +plugins { + id 'java' +} + +dependencies { + implementation project(':data-prepper-api') + implementation project(':data-prepper-plugins:common') + implementation project(':data-prepper-plugins:buffer-common') + implementation libs.armeria.core + implementation 'com.fasterxml.jackson.core:jackson-core' + implementation 'com.fasterxml.jackson.core:jackson-databind' + implementation project(':data-prepper-plugins:blocking-buffer') + implementation 'software.amazon.awssdk:kinesis' + implementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-ion' + implementation 'software.amazon.kinesis:amazon-kinesis-client:2.6.0' + compileOnly 'org.projectlombok:lombok:1.18.20' + annotationProcessor 'org.projectlombok:lombok:1.18.20' + implementation("software.amazon.awssdk:dynamodb") + implementation("com.amazonaws:aws-java-sdk:1.12.394") + implementation project(path: ':data-prepper-plugins:aws-plugin-api') + + testImplementation 'org.yaml:snakeyaml:2.2' + testImplementation project(':data-prepper-test-common') + testImplementation platform('org.junit:junit-bom:5.9.1') + testImplementation 'org.junit.jupiter:junit-jupiter' + testImplementation project(':data-prepper-test-event') +} + +jacocoTestCoverageVerification { + dependsOn jacocoTestReport + violationRules { + rule { //in addition to core projects rule + limit { + minimum = 1.0 + } + } + } +} \ No newline at end of file diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/ClientFactory.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/ClientFactory.java new file mode 100644 index 0000000000..5abf9185d0 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/ClientFactory.java @@ -0,0 +1,48 @@ +package org.opensearch.dataprepper.plugins.source.kinesis; + +import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.AwsAuthenticationConfig; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.kinesis.common.KinesisClientUtil; + +public class ClientFactory { + private final AwsCredentialsProvider awsCredentialsProvider; + private final AwsAuthenticationConfig awsAuthenticationConfig; + + public ClientFactory(final AwsCredentialsSupplier awsCredentialsSupplier, + final AwsAuthenticationConfig awsAuthenticationConfig) { + awsCredentialsProvider = awsCredentialsSupplier.getProvider(AwsCredentialsOptions.builder() + .withRegion(awsAuthenticationConfig.getAwsRegion()) + .withStsRoleArn(awsAuthenticationConfig.getAwsStsRoleArn()) + .withStsExternalId(awsAuthenticationConfig.getAwsStsExternalId()) + .withStsHeaderOverrides(awsAuthenticationConfig.getAwsStsHeaderOverrides()) + .build()); + this.awsAuthenticationConfig = awsAuthenticationConfig; + } + + public DynamoDbAsyncClient buildDynamoDBClient() { + return DynamoDbAsyncClient.builder() + .credentialsProvider(awsAuthenticationConfig.authenticateAwsConfiguration()) + .region(awsAuthenticationConfig.getAwsRegion()) + .build(); + } + + public KinesisAsyncClient buildKinesisAsyncClient() { + return KinesisClientUtil.createKinesisAsyncClient( + KinesisAsyncClient.builder() + .credentialsProvider(awsAuthenticationConfig.authenticateAwsConfiguration()) + .region(awsAuthenticationConfig.getAwsRegion()) + ); + } + + public CloudWatchAsyncClient buildCloudWatchAsyncClient() { + return CloudWatchAsyncClient.builder() + .credentialsProvider(awsAuthenticationConfig.authenticateAwsConfiguration()) + .region(awsAuthenticationConfig.getAwsRegion()) + .build(); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/KinesisMultiStreamTracker.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/KinesisMultiStreamTracker.java new file mode 100644 index 0000000000..5da7a3d160 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/KinesisMultiStreamTracker.java @@ -0,0 +1,82 @@ +package org.opensearch.dataprepper.plugins.source.kinesis; + +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.KinesisSourceConfig; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.KinesisStreamConfig; +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamRequest; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamResponse; +import software.amazon.awssdk.services.kinesis.model.StreamDescription; +import software.amazon.kinesis.common.InitialPositionInStreamExtended; +import software.amazon.kinesis.common.StreamConfig; +import software.amazon.kinesis.common.StreamIdentifier; +import software.amazon.kinesis.processor.FormerStreamsLeasesDeletionStrategy; +import software.amazon.kinesis.processor.MultiStreamTracker; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; + + +public class KinesisMultiStreamTracker implements MultiStreamTracker { + private static final String COLON = ":"; + + private final KinesisAsyncClient kinesisClient; + private final KinesisSourceConfig sourceConfig; + private final String applicationName; + + public KinesisMultiStreamTracker(KinesisAsyncClient kinesisClient, final KinesisSourceConfig sourceConfig, final String applicationName) { + this.kinesisClient = kinesisClient; + this.sourceConfig = sourceConfig; + this.applicationName = applicationName; + } + + @Override + public List streamConfigList() { + List streamConfigList = new ArrayList<>(); + for (KinesisStreamConfig kinesisStreamConfig : sourceConfig.getStreams()) { + StreamConfig streamConfig; + try { + streamConfig = getStreamConfig(kinesisStreamConfig); + } catch (Exception e) { + throw new RuntimeException(e); + } + streamConfigList.add(streamConfig); + } + return streamConfigList; + } + + private StreamConfig getStreamConfig(KinesisStreamConfig kinesisStreamConfig) throws Exception { + StreamIdentifier sourceStreamIdentifier = getStreamIdentifier(kinesisStreamConfig); + return new StreamConfig(sourceStreamIdentifier, + InitialPositionInStreamExtended.newInitialPosition(kinesisStreamConfig.getInitialPosition())); + } + + private StreamIdentifier getStreamIdentifier(KinesisStreamConfig kinesisStreamConfig) throws Exception { + DescribeStreamRequest describeStreamRequest = DescribeStreamRequest.builder() + .streamName(kinesisStreamConfig.getName()) + .build(); + DescribeStreamResponse describeStreamResponse = kinesisClient.describeStream(describeStreamRequest).get(); + String streamIdentifierString = getStreamIdentifierString(describeStreamResponse.streamDescription()); + return StreamIdentifier.multiStreamInstance(streamIdentifierString); + } + + private String getStreamIdentifierString(StreamDescription streamDescription) { + String accountId = streamDescription.streamARN().split(COLON)[4]; + long creationEpochSecond = streamDescription.streamCreationTimestamp().getEpochSecond(); + return String.join(COLON, accountId, streamDescription.streamName(), String.valueOf(creationEpochSecond)); + } + + /** + * Setting the deletion policy as autodetect and release shard lease with a wait time of 10 sec + */ + @Override + public FormerStreamsLeasesDeletionStrategy formerStreamsLeasesDeletionStrategy() { + return new FormerStreamsLeasesDeletionStrategy.AutoDetectionAndDeferredDeletionStrategy() { + @Override + public Duration waitPeriodToDeleteFormerStreams() { + return Duration.ofSeconds(10); + } + }; + + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/KinesisService.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/KinesisService.java new file mode 100644 index 0000000000..0d48aaf161 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/KinesisService.java @@ -0,0 +1,140 @@ +package org.opensearch.dataprepper.plugins.source.kinesis; + +import io.micrometer.core.instrument.util.StringUtils; +import lombok.Setter; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.configuration.PipelineDescription; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.ConsumerStrategy; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.KinesisSourceConfig; +import org.opensearch.dataprepper.plugins.source.kinesis.processor.KinesisShardRecordProcessorFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.dynamodb.model.BillingMode; +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.kinesis.common.ConfigsBuilder; +import software.amazon.kinesis.coordinator.Scheduler; +import software.amazon.kinesis.processor.ShardRecordProcessorFactory; +import software.amazon.kinesis.retrieval.polling.PollingConfig; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class KinesisService { + private static final Logger LOG = LoggerFactory.getLogger(KinesisService.class); + private static final int GRACEFUL_SHUTDOWN_WAIT_INTERVAL_SECONDS = 20; + + private final PluginMetrics pluginMetrics; + private final PluginFactory pluginFactory; + + private final String applicationName; + private final String tableName; + private final String pipelineName; + private final AcknowledgementSetManager acknowledgementSetManager; + private final KinesisSourceConfig sourceConfig; + private final KinesisAsyncClient kinesisClient; + private final DynamoDbAsyncClient dynamoDbClient; + private final CloudWatchAsyncClient cloudWatchClient; + + @Setter + private Scheduler scheduler; + + private final ExecutorService executorService; + + public KinesisService(final KinesisSourceConfig sourceConfig, + final ClientFactory clientFactory, + final PluginMetrics pluginMetrics, + final PluginFactory pluginFactory, + final PipelineDescription pipelineDescription, + final AcknowledgementSetManager acknowledgementSetManager + ){ + this.sourceConfig = sourceConfig; + this.pluginMetrics = pluginMetrics; + this.pluginFactory = pluginFactory; + this.acknowledgementSetManager = acknowledgementSetManager; + this.dynamoDbClient = clientFactory.buildDynamoDBClient(); + this.kinesisClient = clientFactory.buildKinesisAsyncClient(); + this.cloudWatchClient = clientFactory.buildCloudWatchAsyncClient(); + this.pipelineName = pipelineDescription.getPipelineName(); + this.applicationName = pipelineName; + this.tableName = StringUtils.isNotEmpty(sourceConfig.getLeaseCoordinationTable()) ? + sourceConfig.getLeaseCoordinationTable() : applicationName; + this.executorService = Executors.newFixedThreadPool(1); + } + + public void start(final Buffer> buffer) { + if (buffer == null) { + throw new IllegalStateException("Buffer provided is null"); + } + + if (sourceConfig.getStreams() == null || sourceConfig.getStreams().isEmpty()) { + throw new IllegalStateException("Streams are empty!"); + } + + scheduler = getScheduler(buffer); + executorService.execute(scheduler); + } + + public void shutDown() { + LOG.info("Stop request received for Kinesis Source"); + + Future gracefulShutdownFuture = scheduler.startGracefulShutdown(); + LOG.info("Waiting up to {} seconds for shutdown to complete.", GRACEFUL_SHUTDOWN_WAIT_INTERVAL_SECONDS); + try { + gracefulShutdownFuture.get(GRACEFUL_SHUTDOWN_WAIT_INTERVAL_SECONDS, TimeUnit.SECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException ex) { + LOG.error("Exception while executing kinesis consumer graceful shutdown, doing force shutdown", ex); + scheduler.shutdown(); + } + LOG.info("Completed, shutting down now."); + } + + public Scheduler getScheduler(final Buffer> buffer) { + if (scheduler == null) { + return createScheduler(buffer); + } + return scheduler; + } + + public Scheduler createScheduler(final Buffer> buffer) { + final ShardRecordProcessorFactory processorFactory = new KinesisShardRecordProcessorFactory( + buffer, sourceConfig, acknowledgementSetManager, pluginMetrics, pluginFactory); + + ConfigsBuilder configsBuilder = + new ConfigsBuilder( + new KinesisMultiStreamTracker(kinesisClient, sourceConfig, applicationName), + applicationName, kinesisClient, dynamoDbClient, cloudWatchClient, + new WorkerIdentifierGenerator().generate(), processorFactory + ).tableName(applicationName); + + ConsumerStrategy consumerStrategy = sourceConfig.getConsumerStrategy(); + if (consumerStrategy == ConsumerStrategy.POLLING) { + configsBuilder.retrievalConfig().retrievalSpecificConfig( + new PollingConfig(kinesisClient) + .maxRecords(sourceConfig.getPollingConfig().getMaxPollingRecords()) + .idleTimeBetweenReadsInMillis( + sourceConfig.getPollingConfig().getIdleTimeBetweenReadsInMillis())); + } + + return new Scheduler( + configsBuilder.checkpointConfig(), + configsBuilder.coordinatorConfig(), + configsBuilder.leaseManagementConfig() + .billingMode(BillingMode.PAY_PER_REQUEST), + configsBuilder.lifecycleConfig(), + configsBuilder.metricsConfig(), + configsBuilder.processorConfig(), + configsBuilder.retrievalConfig() + ); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/KinesisSource.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/KinesisSource.java new file mode 100644 index 0000000000..98069b5bb8 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/KinesisSource.java @@ -0,0 +1,56 @@ +package org.opensearch.dataprepper.plugins.source.kinesis; + +import lombok.Setter; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.annotations.DataPrepperPlugin; +import org.opensearch.dataprepper.model.annotations.DataPrepperPluginConstructor; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.source.Source; +import org.opensearch.dataprepper.model.configuration.PipelineDescription; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.KinesisSourceConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@DataPrepperPlugin(name = "kinesis", pluginType = Source.class, pluginConfigurationType = KinesisSourceConfig.class) +public class KinesisSource implements Source> { + private static final Logger LOG = LoggerFactory.getLogger(KinesisSource.class); + private final KinesisSourceConfig kinesisSourceConfig; + + @Setter + private KinesisService kinesisService; + + @DataPrepperPluginConstructor + public KinesisSource(final KinesisSourceConfig kinesisSourceConfig, + final PluginMetrics pluginMetrics, + final PluginFactory pluginFactory, + final PipelineDescription pipelineDescription, + final AwsCredentialsSupplier awsCredentialsSupplier, + final AcknowledgementSetManager acknowledgementSetManager) { + this.kinesisSourceConfig = kinesisSourceConfig; + ClientFactory clientFactory = new ClientFactory(awsCredentialsSupplier, kinesisSourceConfig.getAwsAuthenticationConfig()); + this.kinesisService = new KinesisService(kinesisSourceConfig, clientFactory, pluginMetrics, pluginFactory, pipelineDescription, acknowledgementSetManager); + } + @Override + public void start(final Buffer> buffer) { + if (buffer == null) { + throw new IllegalStateException("Buffer provided is null"); + } + + kinesisService.start(buffer); + } + + @Override + public void stop() { + kinesisService.shutDown(); + } + + @Override + public boolean areAcknowledgementsEnabled() { + return kinesisSourceConfig.isAcknowledgments(); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/WorkerIdentifierGenerator.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/WorkerIdentifierGenerator.java new file mode 100644 index 0000000000..9390e25920 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/WorkerIdentifierGenerator.java @@ -0,0 +1,28 @@ +package org.opensearch.dataprepper.plugins.source.kinesis; + +import java.net.InetAddress; +import java.net.UnknownHostException; + +/** + * Generate a unique ID to represent a consumer application instance. + */ +public class WorkerIdentifierGenerator { + + private static final String hostName; + + static { + try { + hostName = InetAddress.getLocalHost().getHostName(); + } catch (final UnknownHostException e) { + throw new RuntimeException(e); + } + } + + + /** + * @return Default to use host name. + */ + public String generate() { + return hostName; + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/AwsAuthenticationConfig.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/AwsAuthenticationConfig.java new file mode 100644 index 0000000000..d51a7ee6a4 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/AwsAuthenticationConfig.java @@ -0,0 +1,87 @@ +package org.opensearch.dataprepper.plugins.source.kinesis.configuration; + +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.constraints.Size; +import software.amazon.awssdk.arns.Arn; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sts.StsClient; +import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider; +import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; + +import java.util.Map; +import java.util.UUID; + +public class AwsAuthenticationConfig { + private static final String AWS_IAM_ROLE = "role"; + private static final String AWS_IAM = "iam"; + + @JsonProperty("region") + @Size(min = 1, message = "Region cannot be empty string") + private String awsRegion; + + @JsonProperty("sts_role_arn") + @Size(min = 20, max = 2048, message = "awsStsRoleArn length should be between 1 and 2048 characters") + private String awsStsRoleArn; + + @JsonProperty("sts_external_id") + @Size(min = 2, max = 1224, message = "awsStsExternalId length should be between 2 and 1224 characters") + private String awsStsExternalId; + + @JsonProperty("sts_header_overrides") + @Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override") + private Map awsStsHeaderOverrides; + + public Region getAwsRegion() { + return awsRegion != null ? Region.of(awsRegion) : null; + } + + public String getAwsStsRoleArn() { + return awsStsRoleArn; + } + + public String getAwsStsExternalId() { + return awsStsExternalId; + } + + public Map getAwsStsHeaderOverrides() { + return awsStsHeaderOverrides; + } + + public AwsCredentialsProvider authenticateAwsConfiguration() { + + final AwsCredentialsProvider awsCredentialsProvider; + if (awsStsRoleArn != null && !awsStsRoleArn.isEmpty()) { + try { + Arn.fromString(awsStsRoleArn); + } catch (final Exception e) { + throw new IllegalArgumentException("Invalid ARN format for awsStsRoleArn"); + } + + final StsClient stsClient = StsClient.builder().region(getAwsRegion()).build(); + + AssumeRoleRequest.Builder assumeRoleRequestBuilder = AssumeRoleRequest.builder() + .roleSessionName("GeoIP-Processor-" + UUID.randomUUID()).roleArn(awsStsRoleArn); + + if (awsStsHeaderOverrides != null && !awsStsHeaderOverrides.isEmpty()) { + assumeRoleRequestBuilder = assumeRoleRequestBuilder.overrideConfiguration( + configuration -> awsStsHeaderOverrides.forEach(configuration::putHeader)); + } + + awsCredentialsProvider = StsAssumeRoleCredentialsProvider.builder() + .stsClient(stsClient) + .refreshRequest(assumeRoleRequestBuilder.build()) + .build(); + + } else { + awsCredentialsProvider = DefaultCredentialsProvider.create(); + } + return awsCredentialsProvider; + } +} \ No newline at end of file diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/ConsumerStrategy.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/ConsumerStrategy.java new file mode 100644 index 0000000000..7b7a7ddf9d --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/ConsumerStrategy.java @@ -0,0 +1,23 @@ +package org.opensearch.dataprepper.plugins.source.kinesis.configuration; + +import com.fasterxml.jackson.annotation.JsonValue; + +// Reference: https://docs.aws.amazon.com/streams/latest/dev/enhanced-consumers.html + +public enum ConsumerStrategy { + + POLLING("Polling"), + + ENHANCED_FAN_OUT("Fan-Out"); + + private final String value; + + ConsumerStrategy(String value) { + this.value = value; + } + + @JsonValue + public String getValue() { + return value; + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/KinesisSourceConfig.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/KinesisSourceConfig.java new file mode 100644 index 0000000000..3aefe7bece --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/KinesisSourceConfig.java @@ -0,0 +1,68 @@ +package org.opensearch.dataprepper.plugins.source.kinesis.configuration; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.Size; +import lombok.Getter; +import org.opensearch.dataprepper.model.configuration.PluginModel; + +import java.time.Duration; +import java.util.List; + +public class KinesisSourceConfig { + static final Duration DEFAULT_TIME_OUT_IN_MILLIS = Duration.ofMillis(1000); + static final int DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE = 100; + static final Duration DEFAULT_SHARD_ACKNOWLEDGEMENT_TIMEOUT = Duration.ofMinutes(10); + + @Getter + @JsonProperty("streams") + @NotNull + @Valid + @Size(min = 1, max = 4, message = "Only support a maximum of 4 streams") + private List streams; + + @Getter + @JsonProperty("aws") + @NotNull + @Valid + private AwsAuthenticationConfig awsAuthenticationConfig; + + @Getter + @JsonProperty("buffer_timeout") + private Duration bufferTimeout = DEFAULT_TIME_OUT_IN_MILLIS; + + @Getter + @JsonProperty("records_to_accumulate") + private int numberOfRecordsToAccumulate = DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE; + + @JsonProperty("acknowledgments") + @Getter + private boolean acknowledgments = false; + + @Getter + @JsonProperty("consumer_strategy") + private ConsumerStrategy consumerStrategy = ConsumerStrategy.ENHANCED_FAN_OUT; + + @Getter + @JsonProperty("polling") + private KinesisStreamPollingConfig pollingConfig; + + @JsonProperty("lease_coordination_table") + @Getter + private String leaseCoordinationTable; + + @Getter + @JsonProperty("codec") + private PluginModel codec; + + @JsonProperty("shard_acknowledgment_timeout") + private Duration shardAcknowledgmentTimeout = DEFAULT_SHARD_ACKNOWLEDGEMENT_TIMEOUT; + + public Duration getShardAcknowledgmentTimeout() { + return shardAcknowledgmentTimeout; + } +} + + + diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/KinesisStreamConfig.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/KinesisStreamConfig.java new file mode 100644 index 0000000000..49e3a5743f --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/KinesisStreamConfig.java @@ -0,0 +1,32 @@ +package org.opensearch.dataprepper.plugins.source.kinesis.configuration; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotNull; +import lombok.Getter; +import software.amazon.kinesis.common.InitialPositionInStream; + +@Getter +public class KinesisStreamConfig { + // Checkpointing interval + private static final int MINIMAL_CHECKPOINT_INTERVAL_MILLIS = 2 * 60 * 1000; // 2 minute + private static final boolean DEFAULT_ENABLE_CHECKPOINT = false; + + @JsonProperty("stream_name") + @NotNull + @Valid + private String name; + + @JsonProperty("stream_arn") + private String arn; + + @JsonProperty("initial_position") + private InitialPositionInStream initialPosition = InitialPositionInStream.LATEST; + + @JsonProperty("checkpoint_interval") + private int checkPointIntervalInMilliseconds = MINIMAL_CHECKPOINT_INTERVAL_MILLIS; + + @Getter + @JsonProperty("enableCheckpoint") + private boolean enableCheckPoint = DEFAULT_ENABLE_CHECKPOINT; +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/KinesisStreamPollingConfig.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/KinesisStreamPollingConfig.java new file mode 100644 index 0000000000..e48908d7f8 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/KinesisStreamPollingConfig.java @@ -0,0 +1,17 @@ +package org.opensearch.dataprepper.plugins.source.kinesis.configuration; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Getter; + +public class KinesisStreamPollingConfig { + private static final int DEFAULT_MAX_RECORDS = 10000; + private static final int IDLE_TIME_BETWEEN_READS_IN_MILLIS = 250; + @Getter + @JsonProperty("maxPollingRecords") + private int maxPollingRecords = DEFAULT_MAX_RECORDS; + + @Getter + @JsonProperty("idleTimeBetweenReadsInMillis") + private int idleTimeBetweenReadsInMillis = IDLE_TIME_BETWEEN_READS_IN_MILLIS; + +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/processor/KinesisRecordProcessor.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/processor/KinesisRecordProcessor.java new file mode 100644 index 0000000000..0a7c7c0ebe --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/processor/KinesisRecordProcessor.java @@ -0,0 +1,180 @@ +package org.opensearch.dataprepper.plugins.source.kinesis.processor; + +import io.micrometer.core.instrument.Counter; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.codec.InputCodec; +import org.opensearch.dataprepper.model.configuration.PluginModel; +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.source.kinesis.KinesisSource; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.KinesisSourceConfig; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.KinesisStreamConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.kinesis.common.StreamIdentifier; +import software.amazon.kinesis.exceptions.InvalidStateException; +import software.amazon.kinesis.exceptions.ShutdownException; +import software.amazon.kinesis.exceptions.ThrottlingException; +import software.amazon.kinesis.lifecycle.events.InitializationInput; +import software.amazon.kinesis.lifecycle.events.LeaseLostInput; +import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; +import software.amazon.kinesis.lifecycle.events.ShardEndedInput; +import software.amazon.kinesis.lifecycle.events.ShutdownRequestedInput; +import software.amazon.kinesis.processor.RecordProcessorCheckpointer; +import software.amazon.kinesis.processor.ShardRecordProcessor; +import software.amazon.kinesis.retrieval.KinesisClientRecord; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; + +public class KinesisRecordProcessor implements ShardRecordProcessor { + private static final Logger LOG = LoggerFactory.getLogger(KinesisSource.class); + private final StreamIdentifier streamIdentifier; + private final KinesisStreamConfig kinesisStreamConfig; + private final int checkpointIntervalMilliSeconds; + private final KinesisSourceConfig kinesisSourceConfig; + private final Buffer> buffer; + private String kinesisShardId; + private final InputCodec codec; + private long lastCheckpointTimeInMillis; + private final int bufferTimeoutMillis; + private final AcknowledgementSetManager acknowledgementSetManager; + private final Counter acknowledgementSetCallbackCounter; + private final Counter recordProcessingErrors; + private final Counter checkpointFailures; + private static final Duration ACKNOWLEDGEMENT_SET_TIMEOUT = Duration.ofSeconds(20); + private static final String ACKNOWLEDGEMENT_SET_CALLBACK_METRIC_NAME = "acknowledgementSetCallbackCounter"; + public static final String KINESIS_RECORD_PROCESSING_ERRORS = "recordProcessingErrors"; + public static final String KINESIS_CHECKPOINT_FAILURES = "checkpointFailures"; + public static final String KINESIS_STREAM_TAG_KEY = "stream"; + + public KinesisRecordProcessor(Buffer> buffer, + final KinesisSourceConfig kinesisSourceConfig, + final AcknowledgementSetManager acknowledgementSetManager, + final PluginMetrics pluginMetrics, + final PluginFactory pluginFactory, + final StreamIdentifier streamIdentifier) { + this.bufferTimeoutMillis = (int) kinesisSourceConfig.getBufferTimeout().toMillis(); + this.streamIdentifier = streamIdentifier; + this.kinesisSourceConfig = kinesisSourceConfig; + this.kinesisStreamConfig = getStreamConfig(kinesisSourceConfig); + final PluginModel codecConfiguration = kinesisSourceConfig.getCodec(); + final PluginSetting codecPluginSettings = new PluginSetting(codecConfiguration.getPluginName(), codecConfiguration.getPluginSettings()); + this.codec = pluginFactory.loadPlugin(InputCodec.class, codecPluginSettings); + this.acknowledgementSetManager = acknowledgementSetManager; + this.acknowledgementSetCallbackCounter = pluginMetrics.counter(ACKNOWLEDGEMENT_SET_CALLBACK_METRIC_NAME); + this.recordProcessingErrors = pluginMetrics.counterWithTags(KINESIS_RECORD_PROCESSING_ERRORS, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName()); + this.checkpointFailures = pluginMetrics.counterWithTags(KINESIS_CHECKPOINT_FAILURES, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName()); + this.checkpointIntervalMilliSeconds = kinesisStreamConfig.getCheckPointIntervalInMilliseconds(); + this.buffer = buffer; + } + + private KinesisStreamConfig getStreamConfig(final KinesisSourceConfig kinesisSourceConfig) { + return kinesisSourceConfig.getStreams().stream().filter(streamConfig -> streamConfig.getName().equals(streamIdentifier.streamName())).findAny().get(); + } + + @Override + public void initialize(InitializationInput initializationInput) { + // Called once when the processor is initialized. + kinesisShardId = initializationInput.shardId(); + LOG.info("Initialize Processor for shard: " + kinesisShardId); + lastCheckpointTimeInMillis = System.currentTimeMillis(); + } + + private AcknowledgementSet createAcknowledgmentSet(final ProcessRecordsInput processRecordsInput) { + return acknowledgementSetManager.create((result) -> { + acknowledgementSetCallbackCounter.increment(); + if (result) { + LOG.info("acknowledgements received"); + checkpoint(processRecordsInput.checkpointer()); + } else { + LOG.info("acknowledgements received with false"); + } + + }, ACKNOWLEDGEMENT_SET_TIMEOUT); + } + + @Override + public void processRecords(ProcessRecordsInput processRecordsInput) { + List> records = new ArrayList<>(); + + try { + AcknowledgementSet acknowledgementSet; + boolean acknowledgementsEnabled = kinesisSourceConfig.isAcknowledgments(); + if (acknowledgementsEnabled) { + acknowledgementSet = createAcknowledgmentSet(processRecordsInput); + } else { + acknowledgementSet = null; + } + + for (KinesisClientRecord record : processRecordsInput.records()) { + processRecord(record, records::add); + } + + if (acknowledgementSet != null) { + records.forEach(record -> { + acknowledgementSet.add(record.getData()); + }); + } + + buffer.writeAll(records, bufferTimeoutMillis); + + if (acknowledgementSet != null) { + acknowledgementSet.complete(); + } + + // Checkpoint for shard + if (kinesisStreamConfig.isEnableCheckPoint() && System.currentTimeMillis() - lastCheckpointTimeInMillis > checkpointIntervalMilliSeconds) { + LOG.info("Regular checkpointing for shard " + kinesisShardId); + checkpoint(processRecordsInput.checkpointer()); + lastCheckpointTimeInMillis = System.currentTimeMillis(); + } + } catch (Exception ex) { + recordProcessingErrors.increment(); + LOG.error("Failed writing shard data to buffer: ", ex); + } + } + + private void processRecord(KinesisClientRecord record, Consumer> eventConsumer) throws IOException { + // Read bytebuffer + byte[] arr = new byte[record.data().remaining()]; + record.data().get(arr); + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(arr); + codec.parse(byteArrayInputStream, eventConsumer); + } + + @Override + public void leaseLost(LeaseLostInput leaseLostInput) { + LOG.debug("Lease Lost"); + } + + @Override + public void shardEnded(ShardEndedInput shardEndedInput) { + LOG.info("Reached shard end, checkpointing shard: {}", kinesisShardId); + checkpoint(shardEndedInput.checkpointer()); + } + + @Override + public void shutdownRequested(ShutdownRequestedInput shutdownRequestedInput) { + LOG.info("Scheduler is shutting down, checkpointing shard: {}", kinesisShardId); + checkpoint(shutdownRequestedInput.checkpointer()); + } + + private void checkpoint(RecordProcessorCheckpointer checkpointer) { + try { + checkpointer.checkpoint(); + } catch (ShutdownException | ThrottlingException | InvalidStateException ex) { + LOG.info("Caught exception at checkpoint, skipping checkpoint.", ex); + checkpointFailures.increment(); + } + } +} diff --git a/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/processor/KinesisShardRecordProcessorFactory.java b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/processor/KinesisShardRecordProcessorFactory.java new file mode 100644 index 0000000000..91294857b4 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/main/java/org/opensearch/dataprepper/plugins/source/kinesis/processor/KinesisShardRecordProcessorFactory.java @@ -0,0 +1,43 @@ +package org.opensearch.dataprepper.plugins.source.kinesis.processor; + +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.KinesisSourceConfig; +import software.amazon.kinesis.common.StreamIdentifier; +import software.amazon.kinesis.processor.ShardRecordProcessor; +import software.amazon.kinesis.processor.ShardRecordProcessorFactory; + +public class KinesisShardRecordProcessorFactory implements ShardRecordProcessorFactory { + + private final Buffer> buffer; + private final KinesisSourceConfig kinesisSourceConfig; + private final AcknowledgementSetManager acknowledgementSetManager; + private final PluginMetrics pluginMetrics; + private final PluginFactory pluginFactory; + + public KinesisShardRecordProcessorFactory(Buffer> buffer, + KinesisSourceConfig kinesisSourceConfig, + final AcknowledgementSetManager acknowledgementSetManager, + final PluginMetrics pluginMetrics, + final PluginFactory pluginFactory) { + this.kinesisSourceConfig = kinesisSourceConfig; + this.buffer = buffer; + this.acknowledgementSetManager = acknowledgementSetManager; + this.pluginMetrics = pluginMetrics; + this.pluginFactory = pluginFactory; + } + + @Override + public ShardRecordProcessor shardRecordProcessor() { + throw new UnsupportedOperationException("Use the method with stream details!"); + } + + @Override + public ShardRecordProcessor shardRecordProcessor(StreamIdentifier streamIdentifier) { + return new KinesisRecordProcessor(buffer, kinesisSourceConfig, acknowledgementSetManager, pluginMetrics, pluginFactory, streamIdentifier); + } +} \ No newline at end of file diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/ClientFactoryTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/ClientFactoryTest.java new file mode 100644 index 0000000000..3365a1e690 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/ClientFactoryTest.java @@ -0,0 +1,47 @@ +package org.opensearch.dataprepper.plugins.source.kinesis; + +import org.junit.jupiter.api.Test; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.AwsAuthenticationConfig; +import org.opensearch.dataprepper.test.helper.ReflectivelySetField; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; + +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.mockito.Mockito.mock; + +public class ClientFactoryTest { + private Region region = Region.US_EAST_1; + private String roleArn; + private Map stsHeader; + private AwsCredentialsSupplier awsCredentialsSupplier; + + @Test + void testCreateClient() throws NoSuchFieldException, IllegalAccessException { + roleArn = "arn:aws:iam::278936200144:role/test-role"; + stsHeader= new HashMap<>(); + stsHeader.put(UUID.randomUUID().toString(),UUID.randomUUID().toString()); + awsCredentialsSupplier = mock(AwsCredentialsSupplier.class); + + AwsAuthenticationConfig awsAuthenticationOptionsConfig = new AwsAuthenticationConfig(); + ReflectivelySetField.setField(AwsAuthenticationConfig.class, awsAuthenticationOptionsConfig, "awsRegion", "us-east-1"); + ReflectivelySetField.setField(AwsAuthenticationConfig.class, awsAuthenticationOptionsConfig, "awsStsRoleArn", roleArn); + + ClientFactory clientFactory = new ClientFactory(awsCredentialsSupplier, awsAuthenticationOptionsConfig); + + final DynamoDbAsyncClient dynamoDbAsyncClient = clientFactory.buildDynamoDBClient(); + assertNotNull(dynamoDbAsyncClient); + + final KinesisAsyncClient kinesisAsyncClient = clientFactory.buildKinesisAsyncClient(); + assertNotNull(kinesisAsyncClient); + + final CloudWatchAsyncClient cloudWatchAsyncClient = clientFactory.buildCloudWatchAsyncClient(); + assertNotNull(cloudWatchAsyncClient); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/KinesisMultiStreamTrackerTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/KinesisMultiStreamTrackerTest.java new file mode 100644 index 0000000000..637cb16375 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/KinesisMultiStreamTrackerTest.java @@ -0,0 +1,143 @@ +package org.opensearch.dataprepper.plugins.source.kinesis; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.KinesisSourceConfig; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.KinesisStreamConfig; +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamRequest; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamResponse; +import software.amazon.awssdk.services.kinesis.model.StreamDescription; +import software.amazon.kinesis.common.InitialPositionInStream; +import software.amazon.kinesis.common.InitialPositionInStreamExtended; +import software.amazon.kinesis.common.StreamConfig; +import software.amazon.kinesis.common.StreamIdentifier; +import software.amazon.kinesis.processor.FormerStreamsLeasesDeletionStrategy; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class KinesisMultiStreamTrackerTest { + private static final String APPLICATION_NAME = "multi-stream-application"; + private static final String awsAccountId = "1234"; + private static final String streamArnFormat = "arn:aws:kinesis:us-east-1:%s:stream/%s"; + private static final Instant streamCreationTime = Instant.now(); + private static final List STREAMS_LIST = ImmutableList.of("stream-1", "stream-2", "stream-3"); + + private KinesisMultiStreamTracker kinesisMultiStreamTracker; + @Mock + private KinesisAsyncClient kinesisClient; + private List streamConfigList; + + private Map streamConfigMap; + + @Mock + KinesisSourceConfig kinesisSourceConfig; + + @BeforeEach + public void setUp() { + MockitoAnnotations.openMocks(this); + List kinesisStreamConfigs = new ArrayList<>(); + streamConfigMap = new HashMap<>(); + STREAMS_LIST.forEach(stream -> { + KinesisStreamConfig kinesisStreamConfig = mock(KinesisStreamConfig.class); + when(kinesisStreamConfig.getName()).thenReturn(stream); + when(kinesisStreamConfig.getInitialPosition()).thenReturn(InitialPositionInStream.LATEST); + + StreamDescription streamDescription = StreamDescription.builder() + .streamARN(String.format(streamArnFormat, awsAccountId, stream)) + .streamCreationTimestamp(streamCreationTime) + .streamName(stream) + .build(); + + DescribeStreamRequest describeStreamRequest = DescribeStreamRequest.builder() + .streamName(stream) + .build(); + + DescribeStreamResponse describeStreamResponse = DescribeStreamResponse.builder() + .streamDescription(streamDescription) + .build(); + + when(kinesisClient.describeStream(describeStreamRequest)).thenReturn(CompletableFuture.completedFuture(describeStreamResponse)); + kinesisStreamConfigs.add(kinesisStreamConfig); + + streamConfigMap.put(stream, kinesisStreamConfig); + }); + + when(kinesisSourceConfig.getStreams()).thenReturn(kinesisStreamConfigs); + kinesisMultiStreamTracker = new KinesisMultiStreamTracker(kinesisClient, kinesisSourceConfig, APPLICATION_NAME); + } + + @Test + public void testStreamConfigList() { + streamConfigList = kinesisMultiStreamTracker.streamConfigList(); + assertEquals(kinesisSourceConfig.getStreams().size(), streamConfigList.size()); + + int totalStreams = streamConfigList.size(); + for (int i=0; i kinesisStreamConfigs = new ArrayList<>(); + streamConfigMap = new HashMap<>(); + STREAMS_LIST.forEach(stream -> { + KinesisStreamConfig kinesisStreamConfig = mock(KinesisStreamConfig.class); + when(kinesisStreamConfig.getName()).thenReturn(stream); + when(kinesisStreamConfig.getInitialPosition()).thenReturn(InitialPositionInStream.LATEST); + + DescribeStreamRequest describeStreamRequest = DescribeStreamRequest.builder() + .streamName(stream) + .build(); + + when(kinesisClient.describeStream(describeStreamRequest)).thenThrow(new RuntimeException()); + kinesisStreamConfigs.add(kinesisStreamConfig); + + streamConfigMap.put(stream, kinesisStreamConfig); + }); + + when(kinesisSourceConfig.getStreams()).thenReturn(kinesisStreamConfigs); + kinesisMultiStreamTracker = new KinesisMultiStreamTracker(kinesisClient, kinesisSourceConfig, APPLICATION_NAME); + + assertThrows(RuntimeException.class, () -> kinesisMultiStreamTracker.streamConfigList()); + } + + @Test + public void formerStreamsLeasesDeletionStrategy() { + + FormerStreamsLeasesDeletionStrategy formerStreamsLeasesDeletionStrategy = + kinesisMultiStreamTracker.formerStreamsLeasesDeletionStrategy(); + + Duration duration = formerStreamsLeasesDeletionStrategy.waitPeriodToDeleteFormerStreams(); + + Assertions.assertTrue(formerStreamsLeasesDeletionStrategy instanceof + FormerStreamsLeasesDeletionStrategy.AutoDetectionAndDeferredDeletionStrategy); + assertEquals(10, duration.getSeconds()); + } + + private StreamIdentifier getStreamIdentifier(final String streamName) { + return StreamIdentifier.multiStreamInstance(String.join(":", awsAccountId, streamName, String.valueOf(streamCreationTime.getEpochSecond()))); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/KinesisServiceTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/KinesisServiceTest.java new file mode 100644 index 0000000000..3a631dae97 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/KinesisServiceTest.java @@ -0,0 +1,287 @@ +package org.opensearch.dataprepper.plugins.source.kinesis; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.configuration.PipelineDescription; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.AwsAuthenticationConfig; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.ConsumerStrategy; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.KinesisSourceConfig; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.KinesisStreamConfig; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.KinesisStreamPollingConfig; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.awssdk.services.kinesis.KinesisServiceClientConfiguration; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamRequest; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamResponse; +import software.amazon.awssdk.services.kinesis.model.StreamDescription; +import software.amazon.kinesis.common.InitialPositionInStream; +import software.amazon.kinesis.coordinator.Scheduler; +import software.amazon.kinesis.metrics.MetricsLevel; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class KinesisServiceTest { + private final String PIPELINE_NAME = "kinesis-pipeline-test"; + private final String streamId = "stream-1"; + + private static final int CHECKPOINT_INTERVAL_MS = 0; + private static final int NUMBER_OF_RECORDS_TO_ACCUMULATE = 10; + private static final int DEFAULT_MAX_RECORDS = 10000; + private static final int IDLE_TIME_BETWEEN_READS_IN_MILLIS = 250; + private static final String awsAccountId = "123456789012"; + private static final String streamArnFormat = "arn:aws:kinesis:us-east-1:%s:stream/%s"; + private static final Instant streamCreationTime = Instant.now(); + + @Mock + private PluginMetrics pluginMetrics; + + @Mock + private PluginFactory pluginFactory; + + @Mock + private KinesisSourceConfig kinesisSourceConfig; + + @Mock + private KinesisStreamConfig kinesisStreamConfig; + + @Mock + private KinesisStreamPollingConfig kinesisStreamPollingConfig; + + @Mock + private AwsAuthenticationConfig awsAuthenticationConfig; + @Mock + private AcknowledgementSetManager acknowledgementSetManager; + + @Mock + private PipelineDescription pipelineDescription; + + @Mock + private ClientFactory clientFactory; + + @Mock + private KinesisAsyncClient kinesisClient; + + @Mock + private DynamoDbAsyncClient dynamoDbClient; + + @Mock + private CloudWatchAsyncClient cloudWatchClient; + + @Mock + Buffer> buffer; + + @Mock + private Scheduler scheduler; + + @BeforeEach + void setup() { + awsAuthenticationConfig = mock(AwsAuthenticationConfig.class); + kinesisSourceConfig = mock(KinesisSourceConfig.class); + kinesisStreamConfig = mock(KinesisStreamConfig.class); + kinesisStreamPollingConfig = mock(KinesisStreamPollingConfig.class); + kinesisClient = mock(KinesisAsyncClient.class); + dynamoDbClient = mock(DynamoDbAsyncClient.class); + cloudWatchClient = mock(CloudWatchAsyncClient.class); + clientFactory = mock(ClientFactory.class); + scheduler = mock(Scheduler.class); + pipelineDescription = mock(PipelineDescription.class); + buffer = mock(Buffer.class); + + when(awsAuthenticationConfig.getAwsRegion()).thenReturn(Region.of("us-west-2")); + when(awsAuthenticationConfig.getAwsStsRoleArn()).thenReturn(UUID.randomUUID().toString()); + when(awsAuthenticationConfig.getAwsStsExternalId()).thenReturn(UUID.randomUUID().toString()); + final Map stsHeaderOverrides = Map.of(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + when(awsAuthenticationConfig.getAwsStsHeaderOverrides()).thenReturn(stsHeaderOverrides); + StreamDescription streamDescription = StreamDescription.builder() + .streamARN(String.format(streamArnFormat, awsAccountId, streamId)) + .streamCreationTimestamp(streamCreationTime) + .streamName(streamId) + .build(); + + DescribeStreamRequest describeStreamRequest = DescribeStreamRequest.builder() + .streamName(streamId) + .build(); + + DescribeStreamResponse describeStreamResponse = DescribeStreamResponse.builder() + .streamDescription(streamDescription) + .build(); + + when(kinesisClient.describeStream(describeStreamRequest)).thenReturn(CompletableFuture.completedFuture(describeStreamResponse)); + + when(kinesisSourceConfig.getAwsAuthenticationConfig()).thenReturn(awsAuthenticationConfig); + when(kinesisStreamConfig.getName()).thenReturn(streamId); + when(kinesisStreamConfig.getCheckPointIntervalInMilliseconds()).thenReturn(CHECKPOINT_INTERVAL_MS); + when(kinesisStreamConfig.getInitialPosition()).thenReturn(InitialPositionInStream.LATEST); + when(kinesisSourceConfig.getConsumerStrategy()).thenReturn(ConsumerStrategy.POLLING); + when(kinesisSourceConfig.getPollingConfig()).thenReturn(kinesisStreamPollingConfig); + when(kinesisStreamPollingConfig.getMaxPollingRecords()).thenReturn(DEFAULT_MAX_RECORDS); + when(kinesisStreamPollingConfig.getIdleTimeBetweenReadsInMillis()).thenReturn(IDLE_TIME_BETWEEN_READS_IN_MILLIS); + + List streamConfigs = new ArrayList<>(); + streamConfigs.add(kinesisStreamConfig); + when(kinesisSourceConfig.getStreams()).thenReturn(streamConfigs); + when(kinesisSourceConfig.getNumberOfRecordsToAccumulate()).thenReturn(NUMBER_OF_RECORDS_TO_ACCUMULATE); + + when(clientFactory.buildDynamoDBClient()).thenReturn(dynamoDbClient); + when(clientFactory.buildKinesisAsyncClient()).thenReturn(kinesisClient); + when(clientFactory.buildCloudWatchAsyncClient()).thenReturn(cloudWatchClient); + when(kinesisClient.serviceClientConfiguration()).thenReturn(KinesisServiceClientConfiguration.builder().region(Region.US_EAST_1).build()); + when(scheduler.startGracefulShutdown()).thenReturn(CompletableFuture.completedFuture(true)); + when(pipelineDescription.getPipelineName()).thenReturn(PIPELINE_NAME); + } + + public KinesisService createObjectUnderTest() { + return new KinesisService(kinesisSourceConfig, clientFactory, pluginMetrics, pluginFactory, pipelineDescription, acknowledgementSetManager); + } + + @Test + void testServiceStart() { + KinesisService kinesisService = createObjectUnderTest(); + kinesisService.start(buffer); + assertNotNull(kinesisService.getScheduler(buffer)); + } + + @Test + void testCreateScheduler() { + KinesisService kinesisService = new KinesisService(kinesisSourceConfig, clientFactory, pluginMetrics, pluginFactory, pipelineDescription, acknowledgementSetManager); + Scheduler schedulerObjectUnderTest = kinesisService.createScheduler(buffer); + + assertNotNull(schedulerObjectUnderTest); + assertNotNull(schedulerObjectUnderTest.checkpointConfig()); + assertNotNull(schedulerObjectUnderTest.leaseManagementConfig()); + assertSame(schedulerObjectUnderTest.leaseManagementConfig().initialPositionInStream().getInitialPositionInStream(), InitialPositionInStream.TRIM_HORIZON); + assertNotNull(schedulerObjectUnderTest.lifecycleConfig()); + assertNotNull(schedulerObjectUnderTest.metricsConfig()); + assertSame(schedulerObjectUnderTest.metricsConfig().metricsLevel(), MetricsLevel.DETAILED); + assertNotNull(schedulerObjectUnderTest.processorConfig()); + assertNotNull(schedulerObjectUnderTest.retrievalConfig()); + } + + @Test + void testCreateSchedulerWithPollingStrategy() { + when(kinesisSourceConfig.getConsumerStrategy()).thenReturn(ConsumerStrategy.POLLING); + KinesisService kinesisService = new KinesisService(kinesisSourceConfig, clientFactory, pluginMetrics, pluginFactory, pipelineDescription, acknowledgementSetManager); + Scheduler schedulerObjectUnderTest = kinesisService.createScheduler(buffer); + + assertNotNull(schedulerObjectUnderTest); + assertNotNull(schedulerObjectUnderTest.checkpointConfig()); + assertNotNull(schedulerObjectUnderTest.leaseManagementConfig()); + assertSame(schedulerObjectUnderTest.leaseManagementConfig().initialPositionInStream().getInitialPositionInStream(), InitialPositionInStream.TRIM_HORIZON); + assertNotNull(schedulerObjectUnderTest.lifecycleConfig()); + assertNotNull(schedulerObjectUnderTest.metricsConfig()); + assertSame(schedulerObjectUnderTest.metricsConfig().metricsLevel(), MetricsLevel.DETAILED); + assertNotNull(schedulerObjectUnderTest.processorConfig()); + assertNotNull(schedulerObjectUnderTest.retrievalConfig()); + } + + + @Test + void testServiceStartNullBufferThrows() { + KinesisService kinesisService = createObjectUnderTest(); + assertThrows(IllegalStateException.class, () -> kinesisService.start(null)); + + verify(scheduler, times(0)).run(); + } + + @Test + void testServiceStartNullStreams() { + when(kinesisSourceConfig.getStreams()).thenReturn(null); + + KinesisService kinesisService = createObjectUnderTest(); + assertThrows(IllegalStateException.class, () -> kinesisService.start(buffer)); + + verify(scheduler, times(0)).run(); + } + + @Test + void testServiceStartEmptyStreams() { + when(kinesisSourceConfig.getStreams()).thenReturn(new ArrayList<>()); + + KinesisService kinesisService = createObjectUnderTest(); + assertThrows(IllegalStateException.class, () -> kinesisService.start(buffer)); + + verify(scheduler, times(0)).run(); + } + + @Test + public void testShutdownGraceful() { + KinesisService kinesisService = createObjectUnderTest(); + kinesisService.setScheduler(scheduler); + kinesisService.shutDown(); + + verify(scheduler).startGracefulShutdown(); + verify(scheduler, times(0)).shutdown(); + } + + @Test + public void testShutdownGracefulThrowInterruptedException() { + KinesisService kinesisService = createObjectUnderTest(); + + when(scheduler.startGracefulShutdown()).thenReturn(CompletableFuture.failedFuture(new InterruptedException())); + kinesisService.setScheduler(scheduler); + assertDoesNotThrow(kinesisService::shutDown); + + verify(scheduler).startGracefulShutdown(); + verify(scheduler, times(1)).shutdown(); + } + + @Test + public void testShutdownGracefulThrowTimeoutException() { + KinesisService kinesisService = createObjectUnderTest(); + kinesisService.setScheduler(scheduler); + when(scheduler.startGracefulShutdown()).thenReturn(CompletableFuture.failedFuture(new TimeoutException())); + assertDoesNotThrow(kinesisService::shutDown); + + verify(scheduler).startGracefulShutdown(); + verify(scheduler, times(1)).shutdown(); + } + + @Test + public void testShutdownGracefulThrowExecutionException() { + KinesisService kinesisService = createObjectUnderTest(); + kinesisService.setScheduler(scheduler); + when(scheduler.startGracefulShutdown()).thenReturn(CompletableFuture.failedFuture(new ExecutionException(new Throwable()))); + assertDoesNotThrow(kinesisService::shutDown); + + verify(scheduler).startGracefulShutdown(); + verify(scheduler, times(1)).shutdown(); + } + + @Test + public void testShutdownExecutorServiceInterruptedException() { + when(scheduler.startGracefulShutdown()).thenReturn(CompletableFuture.failedFuture(new InterruptedException())); + + KinesisService kinesisService = createObjectUnderTest(); + kinesisService.setScheduler(scheduler); + kinesisService.shutDown(); + + verify(scheduler).startGracefulShutdown(); + verify(scheduler).shutdown(); + } + +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/KinesisSourceTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/KinesisSourceTest.java new file mode 100644 index 0000000000..d367780571 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/KinesisSourceTest.java @@ -0,0 +1,141 @@ +package org.opensearch.dataprepper.plugins.source.kinesis; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.configuration.PipelineDescription; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.AwsAuthenticationConfig; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.KinesisSourceConfig; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.KinesisStreamConfig; +import software.amazon.awssdk.regions.Region; + +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class KinesisSourceTest { + private final String PIPELINE_NAME = "kinesis-pipeline-test"; + private final String streamId = "stream-1"; + + @Mock + private PluginMetrics pluginMetrics; + + @Mock + private PluginFactory pluginFactory; + + @Mock + private KinesisSourceConfig kinesisSourceConfig; + + @Mock + private AwsAuthenticationConfig awsAuthenticationConfig; + + private KinesisSource source; + + @Mock + private AcknowledgementSetManager acknowledgementSetManager; + + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + + @Mock + private PipelineDescription pipelineDescription; + + @Mock KinesisService kinesisService; + + @BeforeEach + void setup() { + pluginMetrics = mock(PluginMetrics.class); + pluginFactory = mock(PluginFactory.class); + kinesisSourceConfig = mock(KinesisSourceConfig.class); + this.pipelineDescription = mock(PipelineDescription.class); + awsCredentialsSupplier = mock(AwsCredentialsSupplier.class); + awsAuthenticationConfig = mock(AwsAuthenticationConfig.class); + acknowledgementSetManager = mock(AcknowledgementSetManager.class); + kinesisService = mock(KinesisService.class); + + when(awsAuthenticationConfig.getAwsRegion()).thenReturn(Region.of("us-west-2")); + when(awsAuthenticationConfig.getAwsStsRoleArn()).thenReturn(UUID.randomUUID().toString()); + when(awsAuthenticationConfig.getAwsStsExternalId()).thenReturn(UUID.randomUUID().toString()); + final Map stsHeaderOverrides = Map.of(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + when(awsAuthenticationConfig.getAwsStsHeaderOverrides()).thenReturn(stsHeaderOverrides); + when(kinesisSourceConfig.getAwsAuthenticationConfig()).thenReturn(awsAuthenticationConfig); + when(pipelineDescription.getPipelineName()).thenReturn(PIPELINE_NAME); + } + + public KinesisSource createObjectUnderTest() { + return new KinesisSource(kinesisSourceConfig, pluginMetrics, pluginFactory, pipelineDescription, awsCredentialsSupplier, acknowledgementSetManager); + } + + @Test + public void testSourceWithoutAcknowledgements() { + when(kinesisSourceConfig.isAcknowledgments()).thenReturn(false); + source = createObjectUnderTest(); + assertThat(source.areAcknowledgementsEnabled(), equalTo(false)); + } + + @Test + public void testSourceWithAcknowledgements() { + when(kinesisSourceConfig.isAcknowledgments()).thenReturn(true); + source = createObjectUnderTest(); + assertThat(source.areAcknowledgementsEnabled(), equalTo(true)); + } + + @Test + public void testSourceStart() { + + source = createObjectUnderTest(); + + Buffer> buffer = mock(Buffer.class); + when(kinesisSourceConfig.getNumberOfRecordsToAccumulate()).thenReturn(100); + KinesisStreamConfig kinesisStreamConfig = mock(KinesisStreamConfig.class); + when(kinesisStreamConfig.getName()).thenReturn(streamId); + when(kinesisSourceConfig.getStreams()).thenReturn(List.of(kinesisStreamConfig)); + source.setKinesisService(kinesisService); + + source.start(buffer); + + verify(kinesisService, times(1)).start(any(Buffer.class)); + + } + + @Test + public void testSourceStartBufferNull() { + + source = createObjectUnderTest(); + + assertThrows(IllegalStateException.class, () -> source.start(null)); + + verify(kinesisService, times(0)).start(any(Buffer.class)); + + } + + @Test + public void testSourceStop() { + + source = createObjectUnderTest(); + + source.setKinesisService(kinesisService); + + source.stop(); + + verify(kinesisService, times(1)).shutDown(); + + } + +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/AwsAuthenticationConfigTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/AwsAuthenticationConfigTest.java new file mode 100644 index 0000000000..bfb6debc6f --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/AwsAuthenticationConfigTest.java @@ -0,0 +1,263 @@ +package org.opensearch.dataprepper.plugins.source.kinesis.configuration; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.hamcrest.CoreMatchers; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; +import org.mockito.MockedStatic; +import org.opensearch.dataprepper.test.helper.ReflectivelySetField; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sts.StsClient; +import software.amazon.awssdk.services.sts.StsClientBuilder; +import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; + +import java.util.Collections; +import java.util.Map; +import java.util.UUID; +import java.util.function.Consumer; + +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.sameInstance; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class AwsAuthenticationConfigTest { + private ObjectMapper objectMapper = new ObjectMapper(); + private final String TEST_ROLE = "arn:aws:iam::123456789012:role/test-role"; + + @ParameterizedTest + @ValueSource(strings = {"us-east-1", "us-west-2", "eu-central-1"}) + void getAwsRegionReturnsRegion(final String regionString) { + final Region expectedRegionObject = Region.of(regionString); + final Map jsonMap = Map.of("region", regionString); + final AwsAuthenticationConfig objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationConfig.class); + assertThat(objectUnderTest.getAwsRegion(), equalTo(expectedRegionObject)); + } + + @Test + void getAwsRegionReturnsNullWhenRegionIsNull() { + final Map jsonMap = Collections.emptyMap(); + final AwsAuthenticationConfig objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationConfig.class); + assertThat(objectUnderTest.getAwsRegion(), nullValue()); + } + + @Test + void getAwsStsRoleArnReturnsValueFromDeserializedJSON() { + final String stsRoleArn = UUID.randomUUID().toString(); + final Map jsonMap = Map.of("sts_role_arn", stsRoleArn); + final AwsAuthenticationConfig objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationConfig.class); + assertThat(objectUnderTest.getAwsStsRoleArn(), equalTo(stsRoleArn)); + } + + @Test + void getAwsStsRoleArnReturnsNullIfNotInJSON() { + final Map jsonMap = Collections.emptyMap(); + final AwsAuthenticationConfig objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationConfig.class); + assertThat(objectUnderTest.getAwsStsRoleArn(), nullValue()); + } + + @Test + void getAwsStsExternalIdReturnsValueFromDeserializedJSON() { + final String stsExternalId = UUID.randomUUID().toString(); + final Map jsonMap = Map.of("sts_external_id", stsExternalId); + final AwsAuthenticationConfig objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationConfig.class); + assertThat(objectUnderTest.getAwsStsExternalId(), equalTo(stsExternalId)); + } + + @Test + void getAwsStsExternalIdReturnsNullIfNotInJSON() { + final Map jsonMap = Collections.emptyMap(); + final AwsAuthenticationConfig objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationConfig.class); + assertThat(objectUnderTest.getAwsStsExternalId(), nullValue()); + } + + @Test + void getAwsStsHeaderOverridesReturnsValueFromDeserializedJSON() { + final Map stsHeaderOverrides = Map.of(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + final Map jsonMap = Map.of("sts_header_overrides", stsHeaderOverrides); + final AwsAuthenticationConfig objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationConfig.class); + assertThat(objectUnderTest.getAwsStsHeaderOverrides(), equalTo(stsHeaderOverrides)); + } + + @Test + void getAwsStsHeaderOverridesReturnsNullIfNotInJSON() { + final Map jsonMap = Collections.emptyMap(); + final AwsAuthenticationConfig objectUnderTest = objectMapper.convertValue(jsonMap, AwsAuthenticationConfig.class); + assertThat(objectUnderTest.getAwsStsHeaderOverrides(), nullValue()); + } + + @Test + void authenticateAWSConfigurationShouldReturnWithoutStsRoleArn() throws NoSuchFieldException, IllegalAccessException { + AwsAuthenticationConfig awsAuthenticationOptionsConfig = new AwsAuthenticationConfig(); + ReflectivelySetField.setField(AwsAuthenticationConfig.class, awsAuthenticationOptionsConfig, "awsRegion", "us-east-1"); + ReflectivelySetField.setField(AwsAuthenticationConfig.class, awsAuthenticationOptionsConfig, "awsStsRoleArn", null); + + final DefaultCredentialsProvider mockedCredentialsProvider = mock(DefaultCredentialsProvider.class); + final AwsCredentialsProvider actualCredentialsProvider; + try (final MockedStatic defaultCredentialsProviderMockedStatic = mockStatic(DefaultCredentialsProvider.class)) { + defaultCredentialsProviderMockedStatic.when(DefaultCredentialsProvider::create) + .thenReturn(mockedCredentialsProvider); + actualCredentialsProvider = awsAuthenticationOptionsConfig.authenticateAwsConfiguration(); + } + + assertThat(actualCredentialsProvider, sameInstance(mockedCredentialsProvider)); + } + + + @Nested + class WithSts { + private StsClient stsClient; + private StsClientBuilder stsClientBuilder; + + @BeforeEach + void setUp() { + stsClient = mock(StsClient.class); + stsClientBuilder = mock(StsClientBuilder.class); + + when(stsClientBuilder.build()).thenReturn(stsClient); + } + + @Test + void authenticateAWSConfigurationShouldReturnWithStsRoleArn() throws NoSuchFieldException, IllegalAccessException { + AwsAuthenticationConfig awsAuthenticationOptionsConfig = new AwsAuthenticationConfig(); + ReflectivelySetField.setField(AwsAuthenticationConfig.class, awsAuthenticationOptionsConfig, "awsRegion", "us-east-1"); + ReflectivelySetField.setField(AwsAuthenticationConfig.class, awsAuthenticationOptionsConfig, "awsStsRoleArn", TEST_ROLE); + + when(stsClientBuilder.region(Region.US_EAST_1)).thenReturn(stsClientBuilder); + final AssumeRoleRequest.Builder assumeRoleRequestBuilder = mock(AssumeRoleRequest.Builder.class); + when(assumeRoleRequestBuilder.roleSessionName(anyString())) + .thenReturn(assumeRoleRequestBuilder); + when(assumeRoleRequestBuilder.roleArn(anyString())) + .thenReturn(assumeRoleRequestBuilder); + + final AwsCredentialsProvider actualCredentialsProvider; + try (final MockedStatic stsClientMockedStatic = mockStatic(StsClient.class); + final MockedStatic assumeRoleRequestMockedStatic = mockStatic(AssumeRoleRequest.class)) { + stsClientMockedStatic.when(StsClient::builder).thenReturn(stsClientBuilder); + assumeRoleRequestMockedStatic.when(AssumeRoleRequest::builder).thenReturn(assumeRoleRequestBuilder); + actualCredentialsProvider = awsAuthenticationOptionsConfig.authenticateAwsConfiguration(); + } + + assertThat(actualCredentialsProvider, instanceOf(AwsCredentialsProvider.class)); + + verify(assumeRoleRequestBuilder).roleArn(TEST_ROLE); + verify(assumeRoleRequestBuilder).roleSessionName(anyString()); + verify(assumeRoleRequestBuilder).build(); + verifyNoMoreInteractions(assumeRoleRequestBuilder); + } + + @Test + void authenticateAWSConfigurationShouldReturnWithStsRoleArnWhenNoRegion() throws NoSuchFieldException, IllegalAccessException { + AwsAuthenticationConfig awsAuthenticationOptionsConfig = new AwsAuthenticationConfig(); + ReflectivelySetField.setField(AwsAuthenticationConfig.class, awsAuthenticationOptionsConfig, "awsRegion", null); + ReflectivelySetField.setField(AwsAuthenticationConfig.class, awsAuthenticationOptionsConfig, "awsStsRoleArn", TEST_ROLE); + assertThat(awsAuthenticationOptionsConfig.getAwsRegion(), CoreMatchers.equalTo(null)); + + when(stsClientBuilder.region(null)).thenReturn(stsClientBuilder); + + final AwsCredentialsProvider actualCredentialsProvider; + try (final MockedStatic stsClientMockedStatic = mockStatic(StsClient.class)) { + stsClientMockedStatic.when(StsClient::builder).thenReturn(stsClientBuilder); + actualCredentialsProvider = awsAuthenticationOptionsConfig.authenticateAwsConfiguration(); + } + + assertThat(actualCredentialsProvider, instanceOf(AwsCredentialsProvider.class)); + } + + @Test + void authenticateAWSConfigurationShouldOverrideSTSHeadersWhenHeaderOverridesSet() throws NoSuchFieldException, IllegalAccessException { + final String headerName1 = UUID.randomUUID().toString(); + final String headerValue1 = UUID.randomUUID().toString(); + final String headerName2 = UUID.randomUUID().toString(); + final String headerValue2 = UUID.randomUUID().toString(); + final Map overrideHeaders = Map.of(headerName1, headerValue1, headerName2, headerValue2); + + AwsAuthenticationConfig awsAuthenticationOptionsConfig = new AwsAuthenticationConfig(); + ReflectivelySetField.setField(AwsAuthenticationConfig.class, awsAuthenticationOptionsConfig, "awsRegion", "us-east-1"); + ReflectivelySetField.setField(AwsAuthenticationConfig.class, awsAuthenticationOptionsConfig, "awsStsRoleArn", TEST_ROLE); + ReflectivelySetField.setField(AwsAuthenticationConfig.class, awsAuthenticationOptionsConfig, "awsStsHeaderOverrides", overrideHeaders); + + when(stsClientBuilder.region(Region.US_EAST_1)).thenReturn(stsClientBuilder); + + final AssumeRoleRequest.Builder assumeRoleRequestBuilder = mock(AssumeRoleRequest.Builder.class); + when(assumeRoleRequestBuilder.roleSessionName(anyString())) + .thenReturn(assumeRoleRequestBuilder); + when(assumeRoleRequestBuilder.roleArn(anyString())) + .thenReturn(assumeRoleRequestBuilder); + when(assumeRoleRequestBuilder.overrideConfiguration(any(Consumer.class))) + .thenReturn(assumeRoleRequestBuilder); + + final AwsCredentialsProvider actualCredentialsProvider; + try (final MockedStatic stsClientMockedStatic = mockStatic(StsClient.class); + final MockedStatic assumeRoleRequestMockedStatic = mockStatic(AssumeRoleRequest.class)) { + stsClientMockedStatic.when(StsClient::builder).thenReturn(stsClientBuilder); + assumeRoleRequestMockedStatic.when(AssumeRoleRequest::builder).thenReturn(assumeRoleRequestBuilder); + actualCredentialsProvider = awsAuthenticationOptionsConfig.authenticateAwsConfiguration(); + } + + assertThat(actualCredentialsProvider, instanceOf(AwsCredentialsProvider.class)); + + final ArgumentCaptor> configurationCaptor = ArgumentCaptor.forClass(Consumer.class); + + verify(assumeRoleRequestBuilder).roleArn(TEST_ROLE); + verify(assumeRoleRequestBuilder).roleSessionName(anyString()); + verify(assumeRoleRequestBuilder).overrideConfiguration(configurationCaptor.capture()); + verify(assumeRoleRequestBuilder).build(); + verifyNoMoreInteractions(assumeRoleRequestBuilder); + + final Consumer actualOverride = configurationCaptor.getValue(); + + final AwsRequestOverrideConfiguration.Builder configurationBuilder = mock(AwsRequestOverrideConfiguration.Builder.class); + actualOverride.accept(configurationBuilder); + verify(configurationBuilder).putHeader(headerName1, headerValue1); + verify(configurationBuilder).putHeader(headerName2, headerValue2); + verifyNoMoreInteractions(configurationBuilder); + } + + @Test + void authenticateAWSConfigurationShouldNotOverrideSTSHeadersWhenHeaderOverridesAreEmpty() throws NoSuchFieldException, IllegalAccessException { + + AwsAuthenticationConfig awsAuthenticationOptionsConfig = new AwsAuthenticationConfig(); + ReflectivelySetField.setField(AwsAuthenticationConfig.class, awsAuthenticationOptionsConfig, "awsRegion", "us-east-1"); + ReflectivelySetField.setField(AwsAuthenticationConfig.class, awsAuthenticationOptionsConfig, "awsStsRoleArn", TEST_ROLE); + ReflectivelySetField.setField(AwsAuthenticationConfig.class, awsAuthenticationOptionsConfig, "awsStsHeaderOverrides", Collections.emptyMap()); + + when(stsClientBuilder.region(Region.US_EAST_1)).thenReturn(stsClientBuilder); + final AssumeRoleRequest.Builder assumeRoleRequestBuilder = mock(AssumeRoleRequest.Builder.class); + when(assumeRoleRequestBuilder.roleSessionName(anyString())) + .thenReturn(assumeRoleRequestBuilder); + when(assumeRoleRequestBuilder.roleArn(anyString())) + .thenReturn(assumeRoleRequestBuilder); + + final AwsCredentialsProvider actualCredentialsProvider; + try (final MockedStatic stsClientMockedStatic = mockStatic(StsClient.class); + final MockedStatic assumeRoleRequestMockedStatic = mockStatic(AssumeRoleRequest.class)) { + stsClientMockedStatic.when(StsClient::builder).thenReturn(stsClientBuilder); + assumeRoleRequestMockedStatic.when(AssumeRoleRequest::builder).thenReturn(assumeRoleRequestBuilder); + actualCredentialsProvider = awsAuthenticationOptionsConfig.authenticateAwsConfiguration(); + } + + assertThat(actualCredentialsProvider, instanceOf(AwsCredentialsProvider.class)); + + verify(assumeRoleRequestBuilder).roleArn(TEST_ROLE); + verify(assumeRoleRequestBuilder).roleSessionName(anyString()); + verify(assumeRoleRequestBuilder).build(); + verifyNoMoreInteractions(assumeRoleRequestBuilder); + } + } +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/KinesisSourceConfigTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/KinesisSourceConfigTest.java new file mode 100644 index 0000000000..1f2b723919 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/KinesisSourceConfigTest.java @@ -0,0 +1,114 @@ +package org.opensearch.dataprepper.plugins.source.kinesis.configuration; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; +import org.yaml.snakeyaml.Yaml; +import software.amazon.awssdk.regions.Region; +import software.amazon.kinesis.common.InitialPositionInStream; + +import java.io.FileReader; +import java.io.IOException; +import java.io.Reader; +import java.io.StringReader; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.notNullValue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class KinesisSourceConfigTest { + private static final String SIMPLE_PIPELINE_CONFIG = "simple-pipeline.yaml"; + private static final String SIMPLE_PIPELINE_CONFIG_2 = "simple-pipeline-2.yaml"; + private static final int MINIMAL_CHECKPOINT_INTERVAL_MILLIS = 2 * 60 * 1000; // 2 minute + private static final int DEFAULT_MAX_RECORDS = 10000; + private static final int IDLE_TIME_BETWEEN_READS_IN_MILLIS = 250; + + KinesisSourceConfig kinesisSourceConfig; + + @BeforeEach + void setUp(TestInfo testInfo) throws IOException { + String fileName = testInfo.getTags().stream().findFirst().orElse(""); + Yaml yaml = new Yaml(); + FileReader fileReader = new FileReader(getClass().getClassLoader().getResource(fileName).getFile()); + Object data = yaml.load(fileReader); + ObjectMapper mapper = new ObjectMapper(); + if (data instanceof Map) { + Map propertyMap = (Map) data; + Map logPipelineMap = (Map) propertyMap.get("kinesis-pipeline"); + Map sourceMap = (Map) logPipelineMap.get("source"); + Map kinesisConfigMap = (Map) sourceMap.get("kinesis"); + mapper.registerModule(new JavaTimeModule()); + String json = mapper.writeValueAsString(kinesisConfigMap); + Reader reader = new StringReader(json); + kinesisSourceConfig = mapper.readValue(reader, KinesisSourceConfig.class); + } + } + + @Test + @Tag(SIMPLE_PIPELINE_CONFIG) + void testSourceConfig() { + + assertThat(kinesisSourceConfig, notNullValue()); + assertEquals(KinesisSourceConfig.DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE, kinesisSourceConfig.getNumberOfRecordsToAccumulate()); + assertEquals(KinesisSourceConfig.DEFAULT_TIME_OUT_IN_MILLIS, kinesisSourceConfig.getBufferTimeout()); + assertTrue(kinesisSourceConfig.isAcknowledgments()); + assertEquals(KinesisSourceConfig.DEFAULT_SHARD_ACKNOWLEDGEMENT_TIMEOUT, kinesisSourceConfig.getShardAcknowledgmentTimeout()); + assertThat(kinesisSourceConfig.getAwsAuthenticationConfig(), notNullValue()); + assertEquals(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsRegion(), Region.US_EAST_1); + assertEquals(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsStsRoleArn(), "arn:aws:iam::123456789012:role/OSI-PipelineRole"); + assertNull(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsStsExternalId()); + assertNull(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsStsHeaderOverrides()); + + List streamConfigs = kinesisSourceConfig.getStreams(); + assertNull(kinesisSourceConfig.getCodec()); + assertEquals(kinesisSourceConfig.getConsumerStrategy(), ConsumerStrategy.ENHANCED_FAN_OUT); + assertNull(kinesisSourceConfig.getPollingConfig()); + + for (KinesisStreamConfig kinesisStreamConfig: streamConfigs) { + assertTrue(kinesisStreamConfig.getName().contains("stream")); + assertTrue(kinesisStreamConfig.getArn().contains("123456789012:stream/stream")); + assertFalse(kinesisStreamConfig.isEnableCheckPoint()); + assertEquals(kinesisStreamConfig.getInitialPosition(), InitialPositionInStream.LATEST); + assertEquals(kinesisStreamConfig.getCheckPointIntervalInMilliseconds(), MINIMAL_CHECKPOINT_INTERVAL_MILLIS); + } + } + + @Test + @Tag(SIMPLE_PIPELINE_CONFIG_2) + void testSourceConfigWithStreamCodec() { + + assertThat(kinesisSourceConfig, notNullValue()); + assertEquals(KinesisSourceConfig.DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE, kinesisSourceConfig.getNumberOfRecordsToAccumulate()); + assertEquals(KinesisSourceConfig.DEFAULT_TIME_OUT_IN_MILLIS, kinesisSourceConfig.getBufferTimeout()); + assertFalse(kinesisSourceConfig.isAcknowledgments()); + assertEquals(KinesisSourceConfig.DEFAULT_SHARD_ACKNOWLEDGEMENT_TIMEOUT, kinesisSourceConfig.getShardAcknowledgmentTimeout()); + assertThat(kinesisSourceConfig.getAwsAuthenticationConfig(), notNullValue()); + assertEquals(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsRegion(), Region.US_EAST_1); + assertEquals(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsStsRoleArn(), "arn:aws:iam::123456789012:role/OSI-PipelineRole"); + assertNull(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsStsExternalId()); + assertNull(kinesisSourceConfig.getAwsAuthenticationConfig().getAwsStsHeaderOverrides()); + assertNotNull(kinesisSourceConfig.getCodec()); + List streamConfigs = kinesisSourceConfig.getStreams(); + assertEquals(kinesisSourceConfig.getConsumerStrategy(), ConsumerStrategy.POLLING); + assertNotNull(kinesisSourceConfig.getPollingConfig()); + assertEquals(kinesisSourceConfig.getPollingConfig().getMaxPollingRecords(), 10); + assertEquals(kinesisSourceConfig.getPollingConfig().getIdleTimeBetweenReadsInMillis(), 10); + + for (KinesisStreamConfig kinesisStreamConfig: streamConfigs) { + assertTrue(kinesisStreamConfig.getName().contains("stream")); + assertTrue(kinesisStreamConfig.getArn().contains("123456789012:stream/stream")); + assertFalse(kinesisStreamConfig.isEnableCheckPoint()); + assertEquals(kinesisStreamConfig.getInitialPosition(), InitialPositionInStream.LATEST); + assertEquals(kinesisStreamConfig.getCheckPointIntervalInMilliseconds(), MINIMAL_CHECKPOINT_INTERVAL_MILLIS); + } + } +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/KinesisStreamPollingConfigTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/KinesisStreamPollingConfigTest.java new file mode 100644 index 0000000000..342e9d70f6 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/configuration/KinesisStreamPollingConfigTest.java @@ -0,0 +1,18 @@ +package org.opensearch.dataprepper.plugins.source.kinesis.configuration; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class KinesisStreamPollingConfigTest { + private static final int DEFAULT_MAX_RECORDS = 10000; + private static final int IDLE_TIME_BETWEEN_READS_IN_MILLIS = 250; + + @Test + void testConfig() { + KinesisStreamPollingConfig kinesisStreamPollingConfig = new KinesisStreamPollingConfig(); + assertEquals(kinesisStreamPollingConfig.getMaxPollingRecords(), DEFAULT_MAX_RECORDS); + assertEquals(kinesisStreamPollingConfig.getIdleTimeBetweenReadsInMillis(), IDLE_TIME_BETWEEN_READS_IN_MILLIS); + } + +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/processor/KinesisRecordProcessorTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/processor/KinesisRecordProcessorTest.java new file mode 100644 index 0000000000..8ccfbd235a --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/processor/KinesisRecordProcessorTest.java @@ -0,0 +1,301 @@ +package org.opensearch.dataprepper.plugins.source.kinesis.processor; + +import io.micrometer.core.instrument.Counter; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.codec.InputCodec; +import org.opensearch.dataprepper.model.configuration.PluginModel; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.JacksonEvent; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.KinesisSourceConfig; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.KinesisStreamConfig; +import software.amazon.kinesis.common.StreamIdentifier; +import software.amazon.kinesis.exceptions.InvalidStateException; +import software.amazon.kinesis.exceptions.ShutdownException; +import software.amazon.kinesis.exceptions.ThrottlingException; +import software.amazon.kinesis.lifecycle.events.InitializationInput; +import software.amazon.kinesis.lifecycle.events.ProcessRecordsInput; +import software.amazon.kinesis.lifecycle.events.ShardEndedInput; +import software.amazon.kinesis.lifecycle.events.ShutdownRequestedInput; +import software.amazon.kinesis.processor.RecordProcessorCheckpointer; +import software.amazon.kinesis.retrieval.KinesisClientRecord; + +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyCollection; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; +import static org.opensearch.dataprepper.plugins.source.kinesis.processor.KinesisRecordProcessor.KINESIS_CHECKPOINT_FAILURES; +import static org.opensearch.dataprepper.plugins.source.kinesis.processor.KinesisRecordProcessor.KINESIS_RECORD_PROCESSING_ERRORS; +import static org.opensearch.dataprepper.plugins.source.kinesis.processor.KinesisRecordProcessor.KINESIS_STREAM_TAG_KEY; + +public class KinesisRecordProcessorTest { + private KinesisRecordProcessor kinesisRecordProcessor; + private static final String shardId = "123"; + private static final String streamId = "stream-1"; + private static final String codec_plugin_name = "json"; + + private static final int CHECKPOINT_INTERVAL_MS = 1000; + private static final int NUMBER_OF_RECORDS_TO_ACCUMULATE = 10; + + @Mock + private PluginMetrics pluginMetrics; + + @Mock + private PluginFactory pluginFactory; + + @Mock + private KinesisSourceConfig kinesisSourceConfig; + + @Mock + private KinesisStreamConfig kinesisStreamConfig; + + @Mock + private InitializationInput initializationInput; + + @Mock + private ProcessRecordsInput processRecordsInput; + + @Mock + private RecordProcessorCheckpointer checkpointer; + + @Mock + private Buffer> buffer; + + @Mock + StreamIdentifier streamIdentifier; + + @Mock + private AcknowledgementSetManager acknowledgementSetManager; + + @Mock + private AcknowledgementSet acknowledgementSet; + + @Mock + private Counter recordProcessingErrors; + + @Mock + private Counter checkpointFailures; + + @BeforeEach + public void setup() { + MockitoAnnotations.initMocks(this); + pluginMetrics = mock(PluginMetrics.class); + pluginFactory = mock(PluginFactory.class); + acknowledgementSet = mock(AcknowledgementSet.class); + + when(initializationInput.shardId()).thenReturn(shardId); + when(streamIdentifier.streamName()).thenReturn(streamId); + when(kinesisStreamConfig.getName()).thenReturn(streamId); + PluginModel pluginModel = mock(PluginModel.class); + when(pluginModel.getPluginName()).thenReturn(codec_plugin_name); + when(pluginModel.getPluginSettings()).thenReturn(Collections.emptyMap()); + when(kinesisSourceConfig.getCodec()).thenReturn(pluginModel); + + InputCodec codec = mock(InputCodec.class); + when(pluginFactory.loadPlugin(eq(InputCodec.class), any())).thenReturn(codec); + when(kinesisStreamConfig.getCheckPointIntervalInMilliseconds()).thenReturn(CHECKPOINT_INTERVAL_MS); + when(kinesisSourceConfig.getNumberOfRecordsToAccumulate()).thenReturn(NUMBER_OF_RECORDS_TO_ACCUMULATE); + when(kinesisSourceConfig.getStreams()).thenReturn(List.of(kinesisStreamConfig)); + when(processRecordsInput.checkpointer()).thenReturn(checkpointer); + } + + @Test + void testProcessRecordsWithoutAcknowledgementsCheckpointsEnabled() + throws Exception { + when(processRecordsInput.records()).thenReturn(createInputKinesisClientRecords()); + when(kinesisSourceConfig.isAcknowledgments()).thenReturn(false); + when(kinesisStreamConfig.isEnableCheckPoint()).thenReturn(true); + when(acknowledgementSetManager.create(any(), any(Duration.class))).thenReturn(acknowledgementSet); + + kinesisRecordProcessor = new KinesisRecordProcessor(buffer, kinesisSourceConfig, acknowledgementSetManager, pluginMetrics, pluginFactory, streamIdentifier); + kinesisRecordProcessor.initialize(initializationInput); + + Thread.sleep(2000); + + kinesisRecordProcessor.processRecords(processRecordsInput); + + verify(checkpointer).checkpoint(); + verify(buffer).writeAll(anyCollection(), anyInt()); + verify(acknowledgementSetManager, times(0)).create(any(), any(Duration.class)); + } + + @Test + void testProcessRecordsWithNDJsonInputCodec() + throws Exception { + when(processRecordsInput.records()).thenReturn(createInputKinesisClientRecords()); + when(kinesisSourceConfig.isAcknowledgments()).thenReturn(false); + when(kinesisStreamConfig.isEnableCheckPoint()).thenReturn(true); + + PluginModel pluginModel = mock(PluginModel.class); + when(pluginModel.getPluginName()).thenReturn("ndjson"); + when(pluginModel.getPluginSettings()).thenReturn(Collections.emptyMap()); + when(kinesisSourceConfig.getCodec()).thenReturn(pluginModel); + + InputCodec codec = mock(InputCodec.class); + when(pluginFactory.loadPlugin(eq(InputCodec.class), any())).thenReturn(codec); + + when(acknowledgementSetManager.create(any(), any(Duration.class))).thenReturn(acknowledgementSet); + kinesisRecordProcessor = new KinesisRecordProcessor(buffer, kinesisSourceConfig, acknowledgementSetManager, pluginMetrics, pluginFactory, streamIdentifier); + kinesisRecordProcessor.initialize(initializationInput); + + Thread.sleep(2000); + kinesisRecordProcessor.processRecords(processRecordsInput); + + verify(checkpointer).checkpoint(); + verify(buffer).writeAll(anyCollection(), anyInt()); + verify(acknowledgementSetManager, times(0)).create(any(), any(Duration.class)); + } + + + @Test + void testProcessRecordsNoThrowException() + throws Exception { + when(processRecordsInput.records()).thenReturn(createInputKinesisClientRecords()); + when(kinesisSourceConfig.isAcknowledgments()).thenReturn(false); + when(kinesisStreamConfig.isEnableCheckPoint()).thenReturn(false); + + final Throwable exception = mock(RuntimeException.class); + doThrow(exception).when(buffer).writeAll(any(), anyInt()); + + recordProcessingErrors = mock(Counter.class); + when(pluginMetrics.counterWithTags(KINESIS_RECORD_PROCESSING_ERRORS, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName())).thenReturn(recordProcessingErrors); + + kinesisRecordProcessor = new KinesisRecordProcessor(buffer, kinesisSourceConfig, acknowledgementSetManager, pluginMetrics, pluginFactory, streamIdentifier); + kinesisRecordProcessor.initialize(initializationInput); + + assertDoesNotThrow(() -> kinesisRecordProcessor.processRecords(processRecordsInput)); + verify(recordProcessingErrors, times(1)).increment(); + } + + @Test + void testProcessRecordsWithoutAcknowledgementsAndCheckpoints() + throws Exception { + when(processRecordsInput.records()).thenReturn(createInputKinesisClientRecords()); + when(kinesisSourceConfig.isAcknowledgments()).thenReturn(false); + when(kinesisStreamConfig.isEnableCheckPoint()).thenReturn(false); + when(acknowledgementSetManager.create(any(), any(Duration.class))).thenReturn(acknowledgementSet); + + kinesisRecordProcessor = new KinesisRecordProcessor(buffer, kinesisSourceConfig, acknowledgementSetManager, pluginMetrics, pluginFactory, streamIdentifier); + kinesisRecordProcessor.initialize(initializationInput); + + kinesisRecordProcessor.processRecords(processRecordsInput); + + verifyNoInteractions(checkpointer); + verify(buffer).writeAll(anyCollection(), anyInt()); + verify(acknowledgementSetManager, times(0)).create(any(), any(Duration.class)); + } + + @Test + void testProcessRecordsWithAcknowledgements() + throws Exception { + when(processRecordsInput.records()).thenReturn(createInputKinesisClientRecords()); + when(kinesisSourceConfig.isAcknowledgments()).thenReturn(true); + when(kinesisStreamConfig.isEnableCheckPoint()).thenReturn(false); + when(acknowledgementSetManager.create(any(), any(Duration.class))).thenReturn(acknowledgementSet); + + kinesisRecordProcessor = new KinesisRecordProcessor(buffer, kinesisSourceConfig, acknowledgementSetManager, pluginMetrics, pluginFactory, streamIdentifier); + kinesisRecordProcessor.initialize(initializationInput); + + kinesisRecordProcessor.processRecords(processRecordsInput); + + verifyNoInteractions(checkpointer); + verify(buffer).writeAll(anyCollection(), anyInt()); + verify(acknowledgementSetManager, times(1)).create(any(), any(Duration.class)); + } + + @Test + void testShardEndedFlushCalled() throws Exception { + KinesisRecordProcessor mockKinesisRecordProcessor = new KinesisRecordProcessor(buffer, kinesisSourceConfig, acknowledgementSetManager, pluginMetrics, pluginFactory, streamIdentifier); + ShardEndedInput shardEndedInput = mock(ShardEndedInput.class); + when(shardEndedInput.checkpointer()).thenReturn(checkpointer); + + mockKinesisRecordProcessor.shardEnded(shardEndedInput); + + verify(shardEndedInput).checkpointer(); + } + + @ParameterizedTest + @ValueSource(classes = {ShutdownException.class, ThrottlingException.class, InvalidStateException.class}) + void testShardEndedCheckpointerThrowsNoThrowException(final Class exceptionType) throws Exception { + checkpointFailures = mock(Counter.class); + when(pluginMetrics.counterWithTags(KINESIS_CHECKPOINT_FAILURES, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName())).thenReturn(checkpointFailures); + + KinesisRecordProcessor mockKinesisRecordProcessor = new KinesisRecordProcessor(buffer, kinesisSourceConfig, acknowledgementSetManager, pluginMetrics, pluginFactory, streamIdentifier); + ShardEndedInput shardEndedInput = mock(ShardEndedInput.class); + when(shardEndedInput.checkpointer()).thenReturn(checkpointer); + doThrow(exceptionType).when(checkpointer).checkpoint(); + + assertDoesNotThrow(() -> mockKinesisRecordProcessor.shardEnded(shardEndedInput)); + + verify(checkpointer).checkpoint(); + verify(shardEndedInput, times(1)).checkpointer(); + verify(checkpointFailures, times(1)).increment(); + } + + @Test + void testShutdownRequested() { + checkpointFailures = mock(Counter.class); + when(pluginMetrics.counterWithTags(KINESIS_CHECKPOINT_FAILURES, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName())).thenReturn(checkpointFailures); + + KinesisRecordProcessor mockKinesisRecordProcessor = new KinesisRecordProcessor(buffer, kinesisSourceConfig, acknowledgementSetManager, pluginMetrics, pluginFactory, streamIdentifier); + ShutdownRequestedInput shutdownRequestedInput = mock(ShutdownRequestedInput.class); + when(shutdownRequestedInput.checkpointer()).thenReturn(checkpointer); + + mockKinesisRecordProcessor.shutdownRequested(shutdownRequestedInput); + + verify(shutdownRequestedInput).checkpointer(); + verify(checkpointFailures, times(0)).increment(); + } + + @ParameterizedTest + @ValueSource(classes = {ShutdownException.class, ThrottlingException.class, InvalidStateException.class}) + void testShutdownRequestedCheckpointerThrowsNoThrowException(final Class exceptionType) throws Exception { + checkpointFailures = mock(Counter.class); + when(pluginMetrics.counterWithTags(KINESIS_CHECKPOINT_FAILURES, KINESIS_STREAM_TAG_KEY, streamIdentifier.streamName())).thenReturn(checkpointFailures); + + KinesisRecordProcessor mockKinesisRecordProcessor = new KinesisRecordProcessor(buffer, kinesisSourceConfig, acknowledgementSetManager, pluginMetrics, pluginFactory, streamIdentifier); + ShutdownRequestedInput shutdownRequestedInput = mock(ShutdownRequestedInput.class); + when(shutdownRequestedInput.checkpointer()).thenReturn(checkpointer); + doThrow(exceptionType).when(checkpointer).checkpoint(); + + assertDoesNotThrow(() -> mockKinesisRecordProcessor.shutdownRequested(shutdownRequestedInput)); + + verify(checkpointer).checkpoint(); + verify(shutdownRequestedInput, times(1)).checkpointer(); + verify(checkpointFailures, times(1)).increment(); + } + + private List createInputKinesisClientRecords() { + List kinesisClientRecords = new ArrayList<>(); + for (int i = 0; i< KinesisRecordProcessorTest.NUMBER_OF_RECORDS_TO_ACCUMULATE; i++) { + Event event = JacksonEvent.fromMessage(UUID.randomUUID().toString()); + KinesisClientRecord kinesisClientRecord = KinesisClientRecord.builder().data(ByteBuffer.wrap(event.toJsonString().getBytes())).build(); + kinesisClientRecords.add(kinesisClientRecord); + } + return kinesisClientRecords; + } +} diff --git a/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/processor/KinesisShardRecordProcessorFactoryTest.java b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/processor/KinesisShardRecordProcessorFactoryTest.java new file mode 100644 index 0000000000..1bfe08ad6a --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/java/org/opensearch/dataprepper/plugins/source/kinesis/processor/KinesisShardRecordProcessorFactoryTest.java @@ -0,0 +1,84 @@ +package org.opensearch.dataprepper.plugins.source.kinesis.processor; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.model.codec.InputCodec; +import org.opensearch.dataprepper.model.configuration.PluginModel; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.KinesisSourceConfig; +import org.opensearch.dataprepper.plugins.source.kinesis.configuration.KinesisStreamConfig; +import software.amazon.kinesis.common.StreamIdentifier; + +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class KinesisShardRecordProcessorFactoryTest { + private KinesisShardRecordProcessorFactory kinesisShardRecordProcessorFactory; + + private static final String streamId = "stream-1"; + private static final String codec_plugin_name = "json"; + + @Mock + private Buffer> buffer; + + @Mock + StreamIdentifier streamIdentifier; + + @Mock + private PluginMetrics pluginMetrics; + + @Mock + private PluginFactory pluginFactory; + + @Mock + private KinesisSourceConfig kinesisSourceConfig; + + @Mock + private KinesisStreamConfig kinesisStreamConfig; + + @Mock + private AcknowledgementSetManager acknowledgementSetManager; + + @BeforeEach + void setup() { + MockitoAnnotations.initMocks(this); + + PluginModel pluginModel = mock(PluginModel.class); + when(pluginModel.getPluginName()).thenReturn(codec_plugin_name); + when(pluginModel.getPluginSettings()).thenReturn(Collections.emptyMap()); + when(kinesisSourceConfig.getCodec()).thenReturn(pluginModel); + + InputCodec codec = mock(InputCodec.class); + when(pluginFactory.loadPlugin(eq(InputCodec.class), any())).thenReturn(codec); + + when(streamIdentifier.streamName()).thenReturn(streamId); + when(kinesisStreamConfig.getName()).thenReturn(streamId); + when(kinesisSourceConfig.getStreams()).thenReturn(List.of(kinesisStreamConfig)); + } + + @Test + void testKinesisRecordProcessFactoryReturnsKinesisRecordProcessor() { + kinesisShardRecordProcessorFactory = new KinesisShardRecordProcessorFactory(buffer, kinesisSourceConfig, acknowledgementSetManager, pluginMetrics, pluginFactory); + assertInstanceOf(KinesisRecordProcessor.class, kinesisShardRecordProcessorFactory.shardRecordProcessor(streamIdentifier)); + } + + @Test + void testKinesisRecordProcessFactoryDefaultUnsupported() { + kinesisShardRecordProcessorFactory = new KinesisShardRecordProcessorFactory(buffer, kinesisSourceConfig, acknowledgementSetManager, pluginMetrics, pluginFactory); + assertThrows(UnsupportedOperationException.class, () -> kinesisShardRecordProcessorFactory.shardRecordProcessor()); + } +} diff --git a/data-prepper-plugins/kinesis-source/src/test/resources/simple-pipeline-2.yaml b/data-prepper-plugins/kinesis-source/src/test/resources/simple-pipeline-2.yaml new file mode 100644 index 0000000000..46234d50fb --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/resources/simple-pipeline-2.yaml @@ -0,0 +1,21 @@ +kinesis-pipeline: + source: + kinesis: + streams: + - stream_name: "stream-1" + stream_arn: "arn:aws:kinesis:us-east-1:123456789012:stream/stream-1" + - stream_name: "stream-2" + stream_arn: "arn:aws:kinesis:us-east-1:123456789012:stream/stream-2" + - stream_name: "stream-3" + stream_arn: "arn:aws:kinesis:us-east-1:123456789012:stream/stream-3" + codec: + ndjson: + aws: + sts_role_arn: "arn:aws:iam::123456789012:role/OSI-PipelineRole" + region: "us-east-1" + consumer_strategy: "Polling" + polling: + maxPollingRecords: 10 + idleTimeBetweenReadsInMillis: 10 + sink: + - stdout: \ No newline at end of file diff --git a/data-prepper-plugins/kinesis-source/src/test/resources/simple-pipeline.yaml b/data-prepper-plugins/kinesis-source/src/test/resources/simple-pipeline.yaml new file mode 100644 index 0000000000..07caf50d02 --- /dev/null +++ b/data-prepper-plugins/kinesis-source/src/test/resources/simple-pipeline.yaml @@ -0,0 +1,16 @@ +kinesis-pipeline: + source: + kinesis: + streams: + - stream_name: "stream-1" + stream_arn: "arn:aws:kinesis:us-east-1:123456789012:stream/stream-1" + - stream_name: "stream-2" + stream_arn: "arn:aws:kinesis:us-east-1:123456789012:stream/stream-2" + - stream_name: "stream-3" + stream_arn: "arn:aws:kinesis:us-east-1:123456789012:stream/stream-3" + aws: + sts_role_arn: "arn:aws:iam::123456789012:role/OSI-PipelineRole" + region: "us-east-1" + acknowledgments: true + sink: + - stdout: \ No newline at end of file diff --git a/settings.gradle b/settings.gradle index 18ccd4dc7b..e9346e7b56 100644 --- a/settings.gradle +++ b/settings.gradle @@ -179,4 +179,5 @@ include 'data-prepper-plugins:mongodb' include 'data-prepper-plugins:rds-source' include 'data-prepper-plugins:http-source-common' include 'data-prepper-plugins:http-common' -include 'data-prepper-plugins:aws-lambda' \ No newline at end of file +include 'data-prepper-plugins:aws-lambda' +include 'data-prepper-plugins:kinesis-source' \ No newline at end of file