diff --git a/misk-aws/src/main/kotlin/misk/jobqueue/sqs/SqsJobConsumer.kt b/misk-aws/src/main/kotlin/misk/jobqueue/sqs/SqsJobConsumer.kt index 58223e15c56..379242ae8d8 100644 --- a/misk-aws/src/main/kotlin/misk/jobqueue/sqs/SqsJobConsumer.kt +++ b/misk-aws/src/main/kotlin/misk/jobqueue/sqs/SqsJobConsumer.kt @@ -21,7 +21,7 @@ import misk.tasks.RepeatedTaskQueue import misk.tasks.Status import misk.time.timed import org.slf4j.MDC -import wisp.logging.TaggedLogger +import wisp.logging.SmartTagsThreadLocalHandler import wisp.logging.error import wisp.logging.getLogger import wisp.tracing.traceWithNewRootSpan @@ -212,7 +212,7 @@ internal class SqsJobConsumer @Inject internal constructor( ) Status.OK } catch (th: Throwable) { - val mdcTags = TaggedLogger.popThreadLocalMdcContext() + val mdcTags = SmartTagsThreadLocalHandler.popThreadLocalSmartTags() log.error(th, *mdcTags.toTypedArray()) { "error handling job from ${queue.queueName}" } diff --git a/misk-aws/src/test/kotlin/misk/jobqueue/sqs/SmartTagsJobQueueTest.kt b/misk-aws/src/test/kotlin/misk/jobqueue/sqs/SmartTagsJobQueueTest.kt new file mode 100644 index 00000000000..78c04f7c3d3 --- /dev/null +++ b/misk-aws/src/test/kotlin/misk/jobqueue/sqs/SmartTagsJobQueueTest.kt @@ -0,0 +1,295 @@ +package misk.jobqueue.sqs + +import ch.qos.logback.classic.Level +import ch.qos.logback.classic.spi.ILoggingEvent +import com.amazonaws.services.sqs.AmazonSQS +import com.amazonaws.services.sqs.model.CreateQueueRequest +import jakarta.inject.Inject +import misk.inject.KAbstractModule +import misk.jobqueue.JobQueue +import misk.jobqueue.QueueName +import misk.jobqueue.sqs.SqsJobConsumer.Companion.CONSUMERS_BATCH_SIZE +import misk.jobqueue.subscribe +import misk.logging.LogCollectorModule +import misk.testing.MiskExternalDependency +import misk.testing.MiskTest +import misk.testing.MiskTestModule +import mu.KLogger +import mu.KotlinLogging +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Nested +import org.junit.jupiter.api.Test +import wisp.feature.testing.FakeFeatureFlags +import wisp.logging.LogCollector +import wisp.logging.Tag +import wisp.logging.getLogger +import wisp.logging.withSmartTags +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicInteger +import kotlin.reflect.KClass + +@MiskTest(startService = true) +internal class SmartTagsJobQueueTest { + @MiskExternalDependency private val dockerSqs = DockerSqs + @MiskTestModule private val module = object : KAbstractModule() { + override fun configure() { + install(SqsJobQueueTestModule(dockerSqs.credentials, dockerSqs.client)) + install(LogCollectorModule()) + } + } + + @Inject private lateinit var sqs: AmazonSQS + @Inject private lateinit var queue: JobQueue + @Inject private lateinit var consumer: SqsJobConsumer + @Inject private lateinit var logCollector: LogCollector + @Inject private lateinit var fakeFeatureFlags: FakeFeatureFlags + + private lateinit var queueName: QueueName + + @BeforeEach + fun setUp() { + queueName = QueueName("sqs_job_queue_test") + sqs.createQueue( + CreateQueueRequest() + .withQueueName(queueName.value) + .withAttributes( + mapOf( + // 1 second visibility timeout + "VisibilityTimeout" to 1.toString() + ) + ) + ) + fakeFeatureFlags.override(CONSUMERS_BATCH_SIZE, 10) + } + + @Deprecated("This nesting should be removed once TaggedLogger is removed") + @Nested + inner class OriginalTaggedLoggerTestsConverted { + @Test + fun shouldLogMdcTagsWhenSqsJobConsumerLogsExceptionFromService() { + val allJobsComplete = CountDownLatch(1) + var messageIdToVerify: String? = null + val jobsReceived = AtomicInteger() + + consumer.subscribe(queueName) { + // If already received and processed job and thrown exception, now trigger the test verification + if (jobsReceived.getAndIncrement() == 1) { + it.acknowledge() + allJobsComplete.countDown() + return@subscribe + } + + taggedLogger + .testTag("test123") + .asContext { + messageIdToVerify = it.id + taggedLogger.info("Test log with mdc") + throw SqsJobQueueTestException("Test exception") + } + } + + queue.enqueue(queueName, "job body") + + assertThat(allJobsComplete.await(10, TimeUnit.SECONDS)).isTrue() + + val serviceLogEvents = + logCollector.takeEvents(SmartTagsJobQueueTest::class, consumeUnmatchedLogs = false) + val sqsLogErrorEvents = logCollector.takeEvents(SqsJobConsumer::class) + .filter { it.level == Level.ERROR } + + assertThat(serviceLogEvents).hasSize(1) + assertThat(serviceLogEvents.single().message).isEqualTo("Test log with mdc") + assertThat(serviceLogEvents.single().mdcPropertyMap).containsEntry("testTag", "test123") + + assertThat(sqsLogErrorEvents).hasSize(1) + assertThat(sqsLogErrorEvents.single().message).isEqualTo("error handling job from ${queueName.value}") + assertThat(sqsLogErrorEvents.single().mdcPropertyMap).containsEntry("testTag", "test123") + assertExistingMdcPropertiesArePresent(sqsLogErrorEvents.single(), messageIdToVerify) + } + + @Test + fun shouldLogNormallyWhenNotUsingTaggedLogger() { + val allJobsComplete = CountDownLatch(1) + var messageIdToVerify: String? = null + val jobsReceived = AtomicInteger() + + consumer.subscribe(queueName) { + // If already received and processed job and thrown exception, now trigger the test verification + if (jobsReceived.getAndIncrement() == 1) { + it.acknowledge() + allJobsComplete.countDown() + return@subscribe + } + + messageIdToVerify = it.id + logger.info("Test log without mdc") + throw SqsJobQueueTestException("Test exception") + } + + queue.enqueue(queueName, "job body") + + assertThat(allJobsComplete.await(10, TimeUnit.SECONDS)).isTrue() + + val serviceLogEvents = + logCollector.takeEvents(SmartTagsJobQueueTest::class, consumeUnmatchedLogs = false) + val sqsLogErrorEvents = logCollector.takeEvents(SqsJobConsumer::class) + .filter { it.level == Level.ERROR } + + assertThat(serviceLogEvents).hasSize(1) + assertThat(serviceLogEvents.single().message).isEqualTo("Test log without mdc") + + assertThat(sqsLogErrorEvents).hasSize(1) + assertThat(sqsLogErrorEvents.single().message).isEqualTo("error handling job from ${queueName.value}") + assertExistingMdcPropertiesArePresent(sqsLogErrorEvents.single(), messageIdToVerify) + } + + private fun assertExistingMdcPropertiesArePresent( + logEvent: ILoggingEvent, + messageIdToVerify: String? + ) { + assertThat(logEvent.mdcPropertyMap).containsEntry("sqs_job_id", messageIdToVerify) + assertThat(logEvent.mdcPropertyMap).containsEntry("misk.job_queue.job_id", messageIdToVerify) + assertThat(logEvent.mdcPropertyMap).containsEntry( + "misk.job_queue.queue_name", + queueName.value + ) + assertThat(logEvent.mdcPropertyMap).containsEntry("misk.job_queue.queue_type", "aws-sqs") + } + } + + @Nested + inner class WithSmartTagsTests { + @Test + fun shouldLogMdcTagsWhenSqsJobConsumerLogsExceptionFromService() { + val allJobsComplete = CountDownLatch(1) + var messageIdToVerify: String? = null + val jobsReceived = AtomicInteger() + + consumer.subscribe(queueName) { + // If already received and processed job and thrown exception, now trigger the test verification + if (jobsReceived.getAndIncrement() == 1) { + it.acknowledge() + allJobsComplete.countDown() + return@subscribe + } + + withSmartTags("testTag" to "test123") { + messageIdToVerify = it.id + logger.info("Test log with mdc") + throw SqsJobQueueTestException("Test exception") + } + } + + queue.enqueue(queueName, "job body") + + assertThat(allJobsComplete.await(10, TimeUnit.SECONDS)).isTrue() + + val serviceLogEvents = + logCollector.takeEvents(SmartTagsJobQueueTest::class, consumeUnmatchedLogs = false) + val sqsLogErrorEvents = logCollector.takeEvents(SqsJobConsumer::class) + .filter { it.level == Level.ERROR } + + assertThat(serviceLogEvents).hasSize(1) + assertThat(serviceLogEvents.single().message).isEqualTo("Test log with mdc") + assertThat(serviceLogEvents.single().mdcPropertyMap).containsEntry("testTag", "test123") + + assertThat(sqsLogErrorEvents).hasSize(1) + assertThat(sqsLogErrorEvents.single().message).isEqualTo("error handling job from ${queueName.value}") + assertThat(sqsLogErrorEvents.single().mdcPropertyMap).containsEntry("testTag", "test123") + assertExistingMdcPropertiesArePresent(sqsLogErrorEvents.single(), messageIdToVerify) + } + + @Test + fun shouldLogNormallyWhenNotUsingTaggedLogger() { + val allJobsComplete = CountDownLatch(1) + var messageIdToVerify: String? = null + val jobsReceived = AtomicInteger() + + consumer.subscribe(queueName) { + // If already received and processed job and thrown exception, now trigger the test verification + if (jobsReceived.getAndIncrement() == 1) { + it.acknowledge() + allJobsComplete.countDown() + return@subscribe + } + + messageIdToVerify = it.id + logger.info("Test log without mdc") + throw SqsJobQueueTestException("Test exception") + } + + queue.enqueue(queueName, "job body") + + assertThat(allJobsComplete.await(10, TimeUnit.SECONDS)).isTrue() + + val serviceLogEvents = + logCollector.takeEvents(SmartTagsJobQueueTest::class, consumeUnmatchedLogs = false) + val sqsLogErrorEvents = logCollector.takeEvents(SqsJobConsumer::class) + .filter { it.level == Level.ERROR } + + assertThat(serviceLogEvents).hasSize(1) + assertThat(serviceLogEvents.single().message).isEqualTo("Test log without mdc") + + assertThat(sqsLogErrorEvents).hasSize(1) + assertThat(sqsLogErrorEvents.single().message).isEqualTo("error handling job from ${queueName.value}") + assertExistingMdcPropertiesArePresent(sqsLogErrorEvents.single(), messageIdToVerify) + } + + private fun assertExistingMdcPropertiesArePresent( + logEvent: ILoggingEvent, + messageIdToVerify: String? + ) { + assertThat(logEvent.mdcPropertyMap).containsEntry("sqs_job_id", messageIdToVerify) + assertThat(logEvent.mdcPropertyMap).containsEntry("misk.job_queue.job_id", messageIdToVerify) + assertThat(logEvent.mdcPropertyMap).containsEntry( + "misk.job_queue.queue_name", + queueName.value + ) + assertThat(logEvent.mdcPropertyMap).containsEntry("misk.job_queue.queue_type", "aws-sqs") + } + } + + class SqsJobQueueTestException(override val message: String) : Exception() + + companion object { + val taggedLogger = this::class.getTaggedLogger() + val logger = getLogger() + } +} + +/** + * This is an example wrapper to demonstrate how a service using an existing `TaggedLogger` implementation + * could go about migrating to this new tagged logger with minimal changes in their service initially. + * + * In particular, this enables the tests above to have minimal migration from how they were originally + * written to be able to thoroughly test the new `withSmartTags` style of logging. + */ +data class TestTaggedLogger( + val kLogger: KLogger, + private val tags: Set = emptySet() +) : KLogger by kLogger { + fun testTag(value: String) = tag("testTag" to value) + fun testTagNested(value: String) = tag("testTagNested" to value) + + fun tag(vararg newTags: Tag) = TestTaggedLogger(kLogger, tags.plus(newTags)) + + // Adds the tags to the Mapped Diagnostic Context for the current thread for the duration of the + // block. + fun asContext(f: () -> T): T { + return withSmartTags(*tags.toTypedArray()) { + f() + } + } +} + +fun KClass.getTaggedLogger(): TestTaggedLogger = when { + this.isCompanion -> { + TestTaggedLogger(KotlinLogging.logger(this.java.declaringClass.canonicalName)) + } + + else -> { + TestTaggedLogger(KotlinLogging.logger(this.java.canonicalName)) + } +} diff --git a/misk-aws/src/test/kotlin/misk/jobqueue/sqs/TaggedLoggerJobQueueTest.kt b/misk-aws/src/test/kotlin/misk/jobqueue/sqs/TaggedLoggerJobQueueTest.kt deleted file mode 100644 index 0d0407e4811..00000000000 --- a/misk-aws/src/test/kotlin/misk/jobqueue/sqs/TaggedLoggerJobQueueTest.kt +++ /dev/null @@ -1,177 +0,0 @@ -package misk.jobqueue.sqs - -import ch.qos.logback.classic.Level -import ch.qos.logback.classic.spi.ILoggingEvent -import com.amazonaws.services.sqs.AmazonSQS -import com.amazonaws.services.sqs.model.CreateQueueRequest -import jakarta.inject.Inject -import misk.annotation.ExperimentalMiskApi -import misk.clustering.fake.lease.FakeLeaseManager -import misk.inject.KAbstractModule -import misk.jobqueue.JobQueue -import misk.jobqueue.QueueName -import misk.jobqueue.sqs.SqsJobConsumer.Companion.CONSUMERS_BATCH_SIZE -import misk.jobqueue.sqs.TaggedLoggerJobQueueTest.SqsJobQueueTestTaggedLogger.Companion.getTaggedLogger -import misk.jobqueue.subscribe -import misk.logging.LogCollectorModule -import misk.tasks.RepeatedTaskQueue -import misk.testing.MiskExternalDependency -import misk.testing.MiskTest -import misk.testing.MiskTestModule -import org.assertj.core.api.Assertions.assertThat -import org.junit.jupiter.api.BeforeEach -import org.junit.jupiter.api.Test -import wisp.feature.testing.FakeFeatureFlags -import wisp.logging.Copyable -import wisp.logging.LogCollector -import wisp.logging.Tag -import wisp.logging.TaggedLogger -import wisp.logging.getLogger -import java.util.concurrent.CountDownLatch -import java.util.concurrent.TimeUnit -import java.util.concurrent.atomic.AtomicInteger -import kotlin.reflect.KClass - -@MiskTest(startService = true) -internal class TaggedLoggerJobQueueTest { - @MiskExternalDependency private val dockerSqs = DockerSqs - @MiskTestModule private val module = object: KAbstractModule() { - override fun configure() { - install(SqsJobQueueTestModule(dockerSqs.credentials, dockerSqs.client)) - install(LogCollectorModule()) - } - } - - @Inject private lateinit var sqs: AmazonSQS - @Inject private lateinit var queue: JobQueue - @Inject private lateinit var consumer: SqsJobConsumer - @Inject private lateinit var logCollector: LogCollector - @Inject private lateinit var sqsMetrics: SqsMetrics - @Inject @ForSqsHandling lateinit var taskQueue: RepeatedTaskQueue - @Inject private lateinit var fakeFeatureFlags: FakeFeatureFlags - @Inject private lateinit var fakeLeaseManager: FakeLeaseManager - @Inject private lateinit var queueResolver: QueueResolver - - private lateinit var queueName: QueueName - private lateinit var deadLetterQueueName: QueueName - - @BeforeEach - fun setUp() { - queueName = QueueName("sqs_job_queue_test") - sqs.createQueue( - CreateQueueRequest() - .withQueueName(queueName.value) - .withAttributes( - mapOf( - // 1 second visibility timeout - "VisibilityTimeout" to 1.toString() - ) - ) - ) - fakeFeatureFlags.override(CONSUMERS_BATCH_SIZE, 10) - } - - @Test - fun shouldLogMdcTagsWhenSqsJobConsumerLogsExceptionFromService() { - val allJobsComplete = CountDownLatch(1) - var messageIdToVerify: String? = null - val jobsReceived = AtomicInteger() - - consumer.subscribe(queueName) { - // If already received and processed job and thrown exception, now trigger the test verification - if (jobsReceived.getAndIncrement() == 1) { - it.acknowledge() - allJobsComplete.countDown() - return@subscribe - } - - taggedLogger - .testTag("test123") - .asContext { - messageIdToVerify = it.id - taggedLogger.info("Test log with mdc") - throw SqsJobQueueTestException("Test exception") - } - } - - queue.enqueue(queueName, "job body") - - assertThat(allJobsComplete.await(10, TimeUnit.SECONDS)).isTrue() - - val serviceLogEvents = logCollector.takeEvents(TaggedLoggerJobQueueTest::class, consumeUnmatchedLogs = false) - val sqsLogErrorEvents = logCollector.takeEvents(SqsJobConsumer::class) - .filter { it.level == Level.ERROR } - - assertThat(serviceLogEvents).hasSize(1) - assertThat(serviceLogEvents.single().message).isEqualTo("Test log with mdc") - assertThat(serviceLogEvents.single().mdcPropertyMap).containsEntry("testTag", "test123") - - assertThat(sqsLogErrorEvents).hasSize(1) - assertThat(sqsLogErrorEvents.single().message).isEqualTo("error handling job from ${queueName.value}") - assertThat(sqsLogErrorEvents.single().mdcPropertyMap).containsEntry("testTag", "test123") - assertExistingMdcPropertiesArePresent(sqsLogErrorEvents.single(), messageIdToVerify) - } - - @Test - fun shouldLogNormallyWhenNotUsingTaggedLogger() { - val allJobsComplete = CountDownLatch(1) - var messageIdToVerify: String? = null - val jobsReceived = AtomicInteger() - - consumer.subscribe(queueName) { - // If already received and processed job and thrown exception, now trigger the test verification - if (jobsReceived.getAndIncrement() == 1) { - it.acknowledge() - allJobsComplete.countDown() - return@subscribe - } - - messageIdToVerify = it.id - normalLogger.info("Test log without mdc") - throw SqsJobQueueTestException("Test exception") - } - - queue.enqueue(queueName, "job body") - - assertThat(allJobsComplete.await(10, TimeUnit.SECONDS)).isTrue() - - val serviceLogEvents = logCollector.takeEvents(TaggedLoggerJobQueueTest::class, consumeUnmatchedLogs = false) - val sqsLogErrorEvents = logCollector.takeEvents(SqsJobConsumer::class) - .filter { it.level == Level.ERROR } - - assertThat(serviceLogEvents).hasSize(1) - assertThat(serviceLogEvents.single().message).isEqualTo("Test log without mdc") - - assertThat(sqsLogErrorEvents).hasSize(1) - assertThat(sqsLogErrorEvents.single().message).isEqualTo("error handling job from ${queueName.value}") - assertExistingMdcPropertiesArePresent(sqsLogErrorEvents.single(), messageIdToVerify) - } - - private fun assertExistingMdcPropertiesArePresent(logEvent: ILoggingEvent, messageIdToVerify: String?) { - assertThat(logEvent.mdcPropertyMap).containsEntry("sqs_job_id", messageIdToVerify) - assertThat(logEvent.mdcPropertyMap).containsEntry("misk.job_queue.job_id", messageIdToVerify) - assertThat(logEvent.mdcPropertyMap).containsEntry("misk.job_queue.queue_name", queueName.value) - assertThat(logEvent.mdcPropertyMap).containsEntry("misk.job_queue.queue_type", "aws-sqs") - } - - class SqsJobQueueTestException(override val message: String): Exception() - - companion object { - val taggedLogger = this::class.getTaggedLogger() - val normalLogger = getLogger() - } - - @OptIn(ExperimentalMiskApi::class) - data class SqsJobQueueTestTaggedLogger(val logClass: KClass, val tags: Set = emptySet()): TaggedLogger>(logClass, tags), - Copyable> { - fun testTag(value: String) = tag(Tag("testTag", value)) - - companion object { - fun KClass.getTaggedLogger() = SqsJobQueueTestTaggedLogger(this) - } - - override fun copyWithNewTags(newTags: Set): SqsJobQueueTestTaggedLogger { - return copy(tags = newTags) - } - } -} diff --git a/misk/src/main/kotlin/misk/web/exceptions/ExceptionHandlingInterceptor.kt b/misk/src/main/kotlin/misk/web/exceptions/ExceptionHandlingInterceptor.kt index 86e51105c3b..55506ae64d6 100644 --- a/misk/src/main/kotlin/misk/web/exceptions/ExceptionHandlingInterceptor.kt +++ b/misk/src/main/kotlin/misk/web/exceptions/ExceptionHandlingInterceptor.kt @@ -22,8 +22,8 @@ import okhttp3.Headers.Companion.toHeaders import okio.Buffer import okio.BufferedSink import okio.ByteString +import wisp.logging.SmartTagsThreadLocalHandler import wisp.logging.Tag -import wisp.logging.TaggedLogger import wisp.logging.error import wisp.logging.getLogger import wisp.logging.log @@ -52,7 +52,7 @@ class ExceptionHandlingInterceptor( chain.proceed(chain.httpCall) } catch (th: Throwable) { try { - val mdcTags = TaggedLogger.popThreadLocalMdcContext() + val mdcTags = SmartTagsThreadLocalHandler.popThreadLocalSmartTags() if (chain.httpCall.dispatchMechanism == DispatchMechanism.GRPC) { // This response object is only used for determining the status code. toGrpcResponse diff --git a/misk/src/test/kotlin/misk/web/exceptions/TaggedLoggerExceptionHandlingInterceptorTest.kt b/misk/src/test/kotlin/misk/web/exceptions/SmartTagsExceptionHandlingInterceptorTest.kt similarity index 80% rename from misk/src/test/kotlin/misk/web/exceptions/TaggedLoggerExceptionHandlingInterceptorTest.kt rename to misk/src/test/kotlin/misk/web/exceptions/SmartTagsExceptionHandlingInterceptorTest.kt index 6b16813623a..8b9da3c7bd1 100644 --- a/misk/src/test/kotlin/misk/web/exceptions/TaggedLoggerExceptionHandlingInterceptorTest.kt +++ b/misk/src/test/kotlin/misk/web/exceptions/SmartTagsExceptionHandlingInterceptorTest.kt @@ -9,10 +9,6 @@ import misk.MiskTestingServiceModule import misk.annotation.ExperimentalMiskApi import misk.inject.KAbstractModule import misk.logging.LogCollectorModule -import misk.web.exceptions.TaggedLoggerExceptionHandlingInterceptorTest.LogMDCContextTestAction.LogMDCContextTestActionLogger.Companion.getTaggedLogger -import misk.web.exceptions.TaggedLoggerExceptionHandlingInterceptorTest.NestedLoggersOuterExceptionHandled.ServiceExtendedTaggedLogger.Companion.getTaggedLoggerNestedOuterExceptionThrown -import misk.web.exceptions.TaggedLoggerExceptionHandlingInterceptorTest.NestedLoggersOuterExceptionHandledNoneThrown.ServiceExtendedTaggedLogger.Companion.getTaggedLoggerNestedOuterExceptionThrownThenNone -import misk.web.exceptions.TaggedLoggerExceptionHandlingInterceptorTest.NestedTaggedLoggersThrowsException.ServiceExtendedTaggedLogger.Companion.getTaggedLoggerNested import misk.security.authz.AccessControlModule import misk.security.authz.FakeCallerAuthenticator import misk.security.authz.MiskCallerAuthenticator @@ -25,29 +21,29 @@ import misk.web.ResponseContentType import misk.web.WebActionModule import misk.web.WebServerTestingModule import misk.web.actions.WebAction -import misk.web.exceptions.TaggedLoggerExceptionHandlingInterceptorTest.NestedTaggedLoggersBothSucceed.ServiceExtendedTaggedLogger.Companion.getTaggedLoggerNestedThreads import misk.web.jetty.JettyService import misk.web.mediatype.MediaTypes import mu.KLogger +import mu.KotlinLogging import okhttp3.Headers import okhttp3.OkHttpClient import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.slf4j.MDC -import wisp.logging.Copyable import wisp.logging.LogCollector +import wisp.logging.SmartTagsThreadLocalHandler import wisp.logging.Tag -import wisp.logging.TaggedLogger import wisp.logging.getLogger import wisp.logging.info +import wisp.logging.withSmartTags import java.util.concurrent.Callable import java.util.concurrent.Executors import java.util.concurrent.TimeUnit import kotlin.reflect.KClass @MiskTest(startService = true) -internal class TaggedLoggerExceptionHandlingInterceptorTest { +internal class SmartTagsExceptionHandlingInterceptorTest { @MiskTestModule val module = object :KAbstractModule() { override fun configure() { @@ -222,21 +218,6 @@ internal class TaggedLoggerExceptionHandlingInterceptorTest { val logger = this::class.getTaggedLogger() const val URL = "/log/LogMDCContextTestAction/test" } - - data class LogMDCContextTestActionLogger(val logClass: KClass, val tags: Set = emptySet()): TaggedLogger>(logClass, tags), - Copyable> { - fun testTag(value: String) = tag(Tag("testTag", value)) - - companion object { - fun KClass.getTaggedLogger(): LogMDCContextTestActionLogger { - return LogMDCContextTestActionLogger(this) - } - } - - override fun copyWithNewTags(newTags: Set): LogMDCContextTestActionLogger { - return this.copy(tags = newTags) - } - } } @Test @@ -337,36 +318,18 @@ internal class TaggedLoggerExceptionHandlingInterceptorTest { } // Manually add this tag to identify the execution for verification - logger.info(EXECUTION_IDENTIFIER to headers[IDENTIFIER_HEADER]) { "Log message after TaggedLogger" } + logger.tag(EXECUTION_IDENTIFIER to headers[IDENTIFIER_HEADER]) + .asContext { logger.info { "Log message after TaggedLogger" } } return result } companion object { - val logger = this::class.getTaggedLoggerNestedThreads() + val logger = this::class.getTaggedLogger() const val URL = "/log/NestedTaggedLoggersBothSucceed/test" const val IDENTIFIER_HEADER = "IDENTIFIER_HEADER" const val EXECUTION_IDENTIFIER = "executionIdentifier" } - data class ServiceExtendedTaggedLogger( - val logClass: KClass, - val tags: Set = emptySet() - ): TaggedLogger>(logClass, tags), Copyable> { - - fun testTag(value: String) = tag(Tag("testTag", value)) - fun testTagNested(value: String) = tag(Tag("testTagNested", value)) - - companion object { - fun KClass.getTaggedLoggerNestedThreads(): ServiceExtendedTaggedLogger { - return ServiceExtendedTaggedLogger(this) - } - } - - override fun copyWithNewTags(newTags: Set): ServiceExtendedTaggedLogger { - return this.copy(tags = newTags) - } - } - class AnotherClass() { fun functionWithNestedTaggedLogger(parentTag: String?): String { return logger @@ -379,7 +342,7 @@ internal class TaggedLoggerExceptionHandlingInterceptorTest { } companion object { - val logger = this::class.getTaggedLoggerNestedThreads() + val logger = this::class.getTaggedLogger() } } } @@ -436,25 +399,9 @@ internal class TaggedLoggerExceptionHandlingInterceptorTest { class NestedTaggedLoggersException(message: String) : Throwable(message) companion object { - val logger = this::class.getTaggedLoggerNested() + val logger = this::class.getTaggedLogger() const val URL = "/log/NestedTaggedLoggersLogger/test" } - - data class ServiceExtendedTaggedLogger(val logClass: KClass, val tags: Set = emptySet()): TaggedLogger>(logClass, tags), - Copyable> { - fun testTag(value: String) = tag(Tag("testTag", value)) - fun testTagNested(value: String) = tag(Tag("testTagNested", value)) - - companion object { - fun KClass.getTaggedLoggerNested(): ServiceExtendedTaggedLogger { - return ServiceExtendedTaggedLogger(this) - } - } - - override fun copyWithNewTags(newTags: Set): ServiceExtendedTaggedLogger { - return this.copy(tags = newTags) - } - } } @@ -512,25 +459,9 @@ internal class TaggedLoggerExceptionHandlingInterceptorTest { class OuterTaggedLoggerException(message: String) : Throwable(message) companion object { - val logger = this::class.getTaggedLoggerNestedOuterExceptionThrown() + val logger = this::class.getTaggedLogger() const val URL = "/log/NestedLoggersOuterExceptionHandled/test" } - - data class ServiceExtendedTaggedLogger(val logClass: KClass, val tags: Set = emptySet()): TaggedLogger>(logClass, tags), - Copyable> { - fun testTag(value: String)= tag("testTag" to value) - fun testTagNested(value: String) = tag("testTagNested" to value) - - companion object { - fun KClass.getTaggedLoggerNestedOuterExceptionThrown(): ServiceExtendedTaggedLogger { - return ServiceExtendedTaggedLogger(this) - } - } - - override fun copyWithNewTags(newTags: Set): ServiceExtendedTaggedLogger { - return this.copy(tags = newTags) - } - } } @@ -568,7 +499,7 @@ internal class TaggedLoggerExceptionHandlingInterceptorTest { // This is testing the ThreadLocal cleanup function within TaggedLogger when asContext() exits // without throwing an exception - val shouldBeEmptySet = TaggedLogger.popThreadLocalMdcContext() + val shouldBeEmptySet = SmartTagsThreadLocalHandler.popThreadLocalSmartTags() logger.info { "Should be zero size and log with no MDC context: ${shouldBeEmptySet.size}" } return "" } @@ -585,25 +516,10 @@ internal class TaggedLoggerExceptionHandlingInterceptorTest { class OuterTaggedLoggerException(message: String) : Throwable(message) companion object { - val logger = this::class.getTaggedLoggerNestedOuterExceptionThrownThenNone() + val logger = this::class.getTaggedLogger() const val URL = "/log/NestedLoggersOuterExceptionHandledNoneThrown/test" } - data class ServiceExtendedTaggedLogger(val logClass: KClass, val tags: Set = emptySet()): TaggedLogger>(logClass, tags), - Copyable> { - fun testTag(value: String)= tag("testTag" to value) - fun testTagNested(value: String) = tag("testTagNested" to value) - - companion object { - fun KClass.getTaggedLoggerNestedOuterExceptionThrownThenNone(): ServiceExtendedTaggedLogger { - return ServiceExtendedTaggedLogger(this) - } - } - - override fun copyWithNewTags(newTags: Set): ServiceExtendedTaggedLogger { - return this.copy(tags = newTags) - } - } } @@ -619,3 +535,38 @@ internal class TaggedLoggerExceptionHandlingInterceptorTest { return httpClient.newCall(request.build()).execute() } } + +/** + * This is an example wrapper to demonstrate how a service using an existing `TaggedLogger` implementation + * could go about migrating to this new tagged logger with minimal changes in their service initially. + * + * In particular, this enables the tests above to have minimal migration from how they were originally + * written to be able to thoroughly test the new `withSmartTags` style of logging. + */ +data class TestTaggedLogger( + val kLogger: KLogger, + private val tags: Set = emptySet() +): KLogger by kLogger { + fun testTag(value: String)= tag("testTag" to value) + fun testTagNested(value: String) = tag("testTagNested" to value) + + fun tag(vararg newTags: Tag) = TestTaggedLogger(kLogger, tags.plus(newTags)) + + // Adds the tags to the Mapped Diagnostic Context for the current thread for the duration of the + // block. + fun asContext(f: () -> T): T { + return withSmartTags(*tags.toTypedArray()) { + f() + } + } +} + +fun KClass.getTaggedLogger(): TestTaggedLogger = when { + this.isCompanion -> { + TestTaggedLogger(KotlinLogging.logger(this.java.declaringClass.canonicalName)) + } + + else -> { + TestTaggedLogger(KotlinLogging.logger(this.java.canonicalName)) + } +} diff --git a/wisp/wisp-logging/api/wisp-logging.api b/wisp/wisp-logging/api/wisp-logging.api index 32d712d2f1d..edd0e9a9217 100644 --- a/wisp/wisp-logging/api/wisp-logging.api +++ b/wisp/wisp-logging/api/wisp-logging.api @@ -17,6 +17,7 @@ public final class wisp/logging/LoggingKt { public static final fun trace (Lmu/KLogger;[Lkotlin/Pair;Lkotlin/jvm/functions/Function0;)V public static final fun warn (Lmu/KLogger;Ljava/lang/Throwable;[Lkotlin/Pair;Lkotlin/jvm/functions/Function0;)V public static final fun warn (Lmu/KLogger;[Lkotlin/Pair;Lkotlin/jvm/functions/Function0;)V + public static final fun withSmartTags ([Lkotlin/Pair;Lkotlin/jvm/functions/Function0;)Ljava/lang/Object; public static final fun withTags ([Lkotlin/Pair;Lkotlin/jvm/functions/Function0;)V } @@ -112,6 +113,11 @@ public class wisp/logging/SampledLogger : mu/KLogger { public fun warn (Lorg/slf4j/Marker;Lkotlin/jvm/functions/Function0;)V } +public final class wisp/logging/SmartTagsThreadLocalHandler { + public static final field INSTANCE Lwisp/logging/SmartTagsThreadLocalHandler; + public final fun popThreadLocalSmartTags ()Ljava/util/Set; +} + public abstract class wisp/logging/TaggedLogger : mu/KLogger, wisp/logging/Copyable { public static final field Companion Lwisp/logging/TaggedLogger$Companion; public fun (Lkotlin/reflect/KClass;Ljava/util/Set;)V @@ -209,6 +215,5 @@ public abstract class wisp/logging/TaggedLogger : mu/KLogger, wisp/logging/Copya } public final class wisp/logging/TaggedLogger$Companion { - public final fun popThreadLocalMdcContext ()Ljava/util/Set; } diff --git a/wisp/wisp-logging/src/main/kotlin/wisp/logging/Logging.kt b/wisp/wisp-logging/src/main/kotlin/wisp/logging/Logging.kt index 41ac78d8635..7c47bfb2f8e 100644 --- a/wisp/wisp-logging/src/main/kotlin/wisp/logging/Logging.kt +++ b/wisp/wisp-logging/src/main/kotlin/wisp/logging/Logging.kt @@ -113,3 +113,82 @@ fun withTags(vararg tags: Tag, f: () -> Unit) { priorMDC.forEach { (k, v) -> if (v == null) MDC.remove(k) else MDC.put(k, v) } } } + +/** + * Use this function to add tags to the MDC context for the duration of the block. + * + * This is particularly useful (the smart aspect) when an exception is thrown within the block, + * the tags can be retrieved outside that block using `SmartTagsThreadLocalHandler.popThreadLocalSmartTags()` + * and added to the MDC context again when logging the exception. + * + * Within Misk this is already built into both WebAction (`misk.web.exceptions.ExceptionHandlingInterceptor`) + * and `misk.jobqueue.sqs.SqsJobConsumer`. These can be used as an example to extend for any + * other incoming "event" consumers within a service such as Kafka, scheduled tasks, temporal workflows, etc. + * + * Usage: + * ``` + * class ServiceAction (private val webClient: WebClient): WebAction { + * + * @Post("/api/resource") + * fun executeWebAction(@RequestBody request: ServiceActionRequest) { + * logger.info() { "Received request" } + * + * val loadedContext = aClient.load(request.id) + * + * withSmartTags( + * "processValue" to request.process_value, + * "contextToken" to loadedContext.token + * ) { + * logger.info() { "Processing request" } + * doSomething() + * } + * } + * + * private fun doSomething() { + * logger.info() { "Start Process" } + * + * client.someWebRequest() // Client throws exception which is caught and logged by misk framework + * + * logger.info() { "Done" } + * } + * + * companion object { + * val logger = KotlinLogging.logger(ServiceAction::class.java.canonicalName) + * } + * } + * ``` + * + * Logging result: + * ``` + * Log MDC context: [] Log message: "Received request" + * Log MDC context: [processValue: "PV_123", contextToken: "contextTokenValue"] Log message: "Processing request" + * Log MDC context: [processValue: "PV_123", contextToken: "contextTokenValue"] Log message: "Start Process" + * Log MDC context: [processValue: "PV_123", contextToken: "contextTokenValue"] Log message: "unexpected error dispatching to ServiceAction" // This log would not normally include the MDC context + * ``` + */ +fun withSmartTags(vararg tags: Tag, f: () -> T): T { + // Establish MDC, saving prior MDC + val priorMDC = tags.map { (k, v) -> + val priorValue = MDC.get(k) + MDC.put(k, v.toString()) + k to priorValue + } + + try { + return f().also { + // Exiting this block gracefully: Lets do some cleanup to keep the ThreadLocal clear. + // The scenario here is that when nested `withSmartTags` threw an exception and it was + // caught and handled by this `withSmartTags`, it should clean up the unused and unneeded context. + SmartTagsThreadLocalHandler.clear() + } + } catch (th: Throwable) { + // Calls to `withSmartTags` can be nested - only set if there is not already a context set + // This will be cleared upon logging of the exception within misk or if the thrown exception + // is handled by a higher level `withSmartTags` + SmartTagsThreadLocalHandler.addOrClearTags(th, tags.toSet()) + throw th + } finally { + // Restore or clear prior MDC + priorMDC.forEach { (k, v) -> if (v == null) MDC.remove(k) else MDC.put(k, v) } + } +} diff --git a/wisp/wisp-logging/src/main/kotlin/wisp/logging/SmartTagsThreadLocalHandler.kt b/wisp/wisp-logging/src/main/kotlin/wisp/logging/SmartTagsThreadLocalHandler.kt new file mode 100644 index 00000000000..3b2301dd020 --- /dev/null +++ b/wisp/wisp-logging/src/main/kotlin/wisp/logging/SmartTagsThreadLocalHandler.kt @@ -0,0 +1,41 @@ +package wisp.logging + +object SmartTagsThreadLocalHandler { + private val threadLocalMdcContext = ThreadLocal() + + /** + * Retrieves all the logging MDC tags that were added to the logger via `withSmartTags()` and + * clears the thread local storage. + * + * Note: the thread local storage is only populated when an exception is thrown within a + * `withSmartTags()` block. + */ + fun popThreadLocalSmartTags() = threadLocalMdcContext + .get() + ?.tags + ?.also { threadLocalMdcContext.remove() } + ?: emptySet() + + internal fun clear() = threadLocalMdcContext.remove() + + internal fun addOrClearTags(th: Throwable, tags: Set) { + val existingContext = threadLocalMdcContext.get() + + if (existingContext == null || !existingContext.wasTriggeredBy(th)) { + threadLocalMdcContext.set(ThreadLocalTaggedLoggerMdcContext(th, tags.toSet())) + } else if (existingContext.wasTriggeredBy(th)) { + threadLocalMdcContext.set(existingContext.copy(tags = existingContext.tags + tags.toSet())) + } else { + threadLocalMdcContext.remove() + } + } + + private data class ThreadLocalTaggedLoggerMdcContext( + val triggeringThrowable: Throwable, + val tags: Set + ) { + fun wasTriggeredBy(throwable: Throwable): Boolean { + return triggeringThrowable == throwable + } + } +} diff --git a/wisp/wisp-logging/src/main/kotlin/wisp/logging/TaggedLogger.kt b/wisp/wisp-logging/src/main/kotlin/wisp/logging/TaggedLogger.kt index 225ac9aa6b4..1c986cdbe32 100644 --- a/wisp/wisp-logging/src/main/kotlin/wisp/logging/TaggedLogger.kt +++ b/wisp/wisp-logging/src/main/kotlin/wisp/logging/TaggedLogger.kt @@ -3,7 +3,6 @@ package wisp.logging import misk.annotation.ExperimentalMiskApi import mu.KLogger import mu.KotlinLogging -import org.slf4j.MDC import kotlin.reflect.KClass /** @@ -97,68 +96,12 @@ abstract class TaggedLogger ( // Adds the tags to the Mapped Diagnostic Context for the current thread for the duration of the block. fun asContext(f: () -> T): T { - val priorMDC = MDC.getCopyOfContextMap() ?: emptyMap() - - tags.forEach { (k, v) -> - if (v != null) { - MDC.put(k, v.toString()) - } - } - - try { - return f().also { - // Exiting this TaggedLogger gracefully: Lets do some cleanup to keep the ThreadLocal clear. - // The scenario here is that when nested TaggedLogger threw an exception and it was - // caught and handled by this TaggedLogger, it should clean up the unused and unneeded context. - threadLocalMdcContext.remove() - } - } catch (th: Throwable) { - // TaggedLoggers can be nested - only set if there is not already a context set - // This will be cleared upon logging of the exception within misk or if the thrown exception - // is handled by a higher level TaggedLogger - if (shouldSetThreadLocalContext(th)) { - // Set thread local MDC context for the ExceptionHandlingInterceptor to read - threadLocalMdcContext.set(ThreadLocalTaggedLoggerMdcContext.createWithMdcSnapshot(th)) - } - throw th - } finally { - MDC.setContextMap(priorMDC) - } - } - - private fun shouldSetThreadLocalContext(th: Throwable): Boolean { - // This is the first of any nested TaggedLoggers to catch this exception - if (threadLocalMdcContext.get() == null) { - return true - } - - // A nested TaggedLogger may have caught and handled the exception, and has now thrown something else - return !(threadLocalMdcContext.get()?.wasTriggeredBy(th) ?: false) - } - - private data class ThreadLocalTaggedLoggerMdcContext( - val triggeringThrowable: Throwable, - val tags: Set - ) { - fun wasTriggeredBy(throwable: Throwable): Boolean { - return triggeringThrowable == throwable - } - - companion object { - fun createWithMdcSnapshot(triggeringThrowable: Throwable) = - ThreadLocalTaggedLoggerMdcContext(triggeringThrowable, MDC.getCopyOfContextMap().map { Tag(it.key, it.value) }.toSet()) + return withSmartTags(*tags.toTypedArray()) { + f() } } companion object { - private val threadLocalMdcContext = ThreadLocal() - - fun popThreadLocalMdcContext() = threadLocalMdcContext - .get() - ?.tags - ?.also { threadLocalMdcContext.remove() } - ?: emptySet() - private fun getLogger(loggerClass: KClass): KLogger { return when { loggerClass.isCompanion -> {