From 6e500bdfbd0be10ed6c1830fb4b79d55e31f9e6b Mon Sep 17 00:00:00 2001 From: "Piotr P. Karwasz" Date: Mon, 11 Sep 2023 13:11:37 +0200 Subject: [PATCH] Allow multiple Log4jServletContextListener registrations This closes #1782. --- .../web/Log4jServletContextListener.java | 70 +++++++++++---- .../web/Log4jServletContextListenerTest.java | 89 ++++++++++++------- .../web/Log4jServletContextListener.java | 70 +++++++++++---- .../web/Log4jServletContextListenerTest.java | 89 ++++++++++++------- ...tiple_servletcontextlistener_instances.xml | 28 ++++++ 5 files changed, 246 insertions(+), 100 deletions(-) create mode 100644 src/changelog/.2.x.x/1782_allow_multiple_servletcontextlistener_instances.xml diff --git a/log4j-jakarta-web/src/main/java/org/apache/logging/log4j/web/Log4jServletContextListener.java b/log4j-jakarta-web/src/main/java/org/apache/logging/log4j/web/Log4jServletContextListener.java index 599eddb8b82..3a518c62e74 100644 --- a/log4j-jakarta-web/src/main/java/org/apache/logging/log4j/web/Log4jServletContextListener.java +++ b/log4j-jakarta-web/src/main/java/org/apache/logging/log4j/web/Log4jServletContextListener.java @@ -36,6 +36,8 @@ */ public class Log4jServletContextListener implements ServletContextListener { + static final String START_COUNT_ATTR = Log4jServletContextListener.class.getName() + ".START_COUNT"; + private static final int DEFAULT_STOP_TIMEOUT = 30; private static final TimeUnit DEFAULT_STOP_TIMEOUT_TIMEUNIT = TimeUnit.SECONDS; @@ -47,11 +49,30 @@ public class Log4jServletContextListener implements ServletContextListener { private ServletContext servletContext; private Log4jWebLifeCycle initializer; + private int getAndIncrementCount() { + Integer count = (Integer) servletContext.getAttribute(START_COUNT_ATTR); + if (count == null) { + count = 0; + } + servletContext.setAttribute(START_COUNT_ATTR, count + 1); + return count; + } + + private int decrementAndGetCount() { + Integer count = (Integer) servletContext.getAttribute(START_COUNT_ATTR); + if (count == null) { + LOGGER.warn( + "{} received a 'contextDestroyed' message without a corresponding 'contextInitialized' message.", + getClass().getName()); + count = 1; + } + servletContext.setAttribute(START_COUNT_ATTR, --count); + return count; + } + @Override public void contextInitialized(final ServletContextEvent event) { this.servletContext = event.getServletContext(); - LOGGER.debug("Log4jServletContextListener ensuring that Log4j starts up properly."); - if ("true".equalsIgnoreCase(servletContext.getInitParameter( Log4jWebSupport.IS_LOG4J_AUTO_SHUTDOWN_DISABLED))) { throw new IllegalStateException("Do not use " + getClass().getSimpleName() + " when " @@ -61,6 +82,12 @@ public void contextInitialized(final ServletContextEvent event) { } this.initializer = WebLoggerContextUtils.getWebLifeCycle(this.servletContext); + if (getAndIncrementCount() != 0) { + LOGGER.debug("Skipping Log4j context initialization, since {} is registered multiple times.", + getClass().getSimpleName()); + return; + } + LOGGER.info("{} triggered a Log4j context initialization.", getClass().getSimpleName()); try { this.initializer.start(); this.initializer.setLoggerContext(); // the application is just now starting to start up @@ -72,23 +99,32 @@ public void contextInitialized(final ServletContextEvent event) { @Override public void contextDestroyed(final ServletContextEvent event) { if (this.servletContext == null || this.initializer == null) { - LOGGER.warn("Context destroyed before it was initialized."); + LOGGER.warn("Servlet context destroyed before it was initialized."); return; } - LOGGER.debug("Log4jServletContextListener ensuring that Log4j shuts down properly."); - - this.initializer.clearLoggerContext(); // the application is finished - // shutting down now - if (initializer instanceof LifeCycle2) { - final String stopTimeoutStr = servletContext.getInitParameter(KEY_STOP_TIMEOUT); - final long stopTimeout = Strings.isEmpty(stopTimeoutStr) ? DEFAULT_STOP_TIMEOUT - : Long.parseLong(stopTimeoutStr); - final String timeoutTimeUnitStr = servletContext.getInitParameter(KEY_STOP_TIMEOUT_TIMEUNIT); - final TimeUnit timeoutTimeUnit = Strings.isEmpty(timeoutTimeUnitStr) ? DEFAULT_STOP_TIMEOUT_TIMEUNIT - : TimeUnit.valueOf(toRootUpperCase(timeoutTimeUnitStr)); - ((LifeCycle2) this.initializer).stop(stopTimeout, timeoutTimeUnit); - } else { - this.initializer.stop(); + + if (decrementAndGetCount() != 0) { + LOGGER.debug("Skipping Log4j context shutdown, since {} is registered multiple times.", + getClass().getSimpleName()); + return; + } + LOGGER.info("{} triggered a Log4j context shutdown.", getClass().getSimpleName()); + try { + this.initializer.clearLoggerContext(); // the application is finished + // shutting down now + if (initializer instanceof LifeCycle2) { + final String stopTimeoutStr = servletContext.getInitParameter(KEY_STOP_TIMEOUT); + final long stopTimeout = Strings.isEmpty(stopTimeoutStr) ? DEFAULT_STOP_TIMEOUT + : Long.parseLong(stopTimeoutStr); + final String timeoutTimeUnitStr = servletContext.getInitParameter(KEY_STOP_TIMEOUT_TIMEUNIT); + final TimeUnit timeoutTimeUnit = Strings.isEmpty(timeoutTimeUnitStr) ? DEFAULT_STOP_TIMEOUT_TIMEUNIT + : TimeUnit.valueOf(toRootUpperCase(timeoutTimeUnitStr)); + ((LifeCycle2) this.initializer).stop(stopTimeout, timeoutTimeUnit); + } else { + this.initializer.stop(); + } + } catch (final IllegalStateException e) { + throw new IllegalStateException("Failed to shutdown Log4j properly.", e); } } } diff --git a/log4j-jakarta-web/src/test/java/org/apache/logging/log4j/web/Log4jServletContextListenerTest.java b/log4j-jakarta-web/src/test/java/org/apache/logging/log4j/web/Log4jServletContextListenerTest.java index 2022d137683..909e67296c1 100644 --- a/log4j-jakarta-web/src/test/java/org/apache/logging/log4j/web/Log4jServletContextListenerTest.java +++ b/log4j-jakarta-web/src/test/java/org/apache/logging/log4j/web/Log4jServletContextListenerTest.java @@ -16,6 +16,8 @@ */ package org.apache.logging.log4j.web; +import java.util.concurrent.atomic.AtomicReference; + import jakarta.servlet.ServletContext; import jakarta.servlet.ServletContextEvent; @@ -23,59 +25,85 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; +import org.mockito.Mock.Strictness; import org.mockito.junit.jupiter.MockitoExtension; -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.BDDMockito.eq; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.AdditionalAnswers.answerVoid; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.then; import static org.mockito.BDDMockito.willThrow; +import static org.mockito.Mockito.doAnswer; @ExtendWith(MockitoExtension.class) public class Log4jServletContextListenerTest { /* event and servletContext are marked lenient because they aren't used in the * testDestroyWithNoInit but are only accessed during initialization */ - @Mock(lenient = true) + @Mock(strictness = Strictness.LENIENT) private ServletContextEvent event; - @Mock(lenient = true) + @Mock(strictness = Strictness.LENIENT) private ServletContext servletContext; @Mock private Log4jWebLifeCycle initializer; - private Log4jServletContextListener listener; + private final AtomicReference count = new AtomicReference<>(); @BeforeEach public void setUp() { - this.listener = new Log4jServletContextListener(); given(event.getServletContext()).willReturn(servletContext); given(servletContext.getAttribute(Log4jWebSupport.SUPPORT_ATTRIBUTE)).willReturn(initializer); - } - @Test - public void testInitAndDestroy() throws Exception { - this.listener.contextInitialized(this.event); + doAnswer(answerVoid((k, v) -> count.set(v))) + .when(servletContext) + .setAttribute(eq(Log4jServletContextListener.START_COUNT_ATTR), any()); + doAnswer(__ -> count.get()) + .when(servletContext) + .getAttribute(Log4jServletContextListener.START_COUNT_ATTR); + } - then(initializer).should().start(); - then(initializer).should().setLoggerContext(); + @ParameterizedTest + @ValueSource(ints = { 1, 2, 3 }) + public void testInitAndDestroy(final int listenerCount) throws Exception { + final Log4jServletContextListener[] listeners = new Log4jServletContextListener[listenerCount]; + for (int idx = 0; idx < listenerCount; idx++) { + final Log4jServletContextListener listener = new Log4jServletContextListener(); + listeners[idx] = listener; + + listener.contextInitialized(event); + if (idx == 0) { + then(initializer).should().start(); + then(initializer).should().setLoggerContext(); + } else { + then(initializer).shouldHaveNoMoreInteractions(); + } + } - this.listener.contextDestroyed(this.event); + for (int idx = listenerCount - 1; idx >= 0; idx--) { + final Log4jServletContextListener listener = listeners[idx]; - then(initializer).should().clearLoggerContext(); - then(initializer).should().stop(); + listener.contextDestroyed(event); + if (idx == 0) { + then(initializer).should().clearLoggerContext(); + then(initializer).should().stop(); + } else { + then(initializer).shouldHaveNoMoreInteractions(); + } + } } @Test public void testInitFailure() throws Exception { willThrow(new IllegalStateException(Strings.EMPTY)).given(initializer).start(); + final Log4jServletContextListener listener = new Log4jServletContextListener(); - try { - this.listener.contextInitialized(this.event); - fail("Expected a RuntimeException."); - } catch (final RuntimeException e) { - assertEquals("Failed to initialize Log4j properly.", e.getMessage(), "The message is not correct."); - } + assertThrows(RuntimeException.class, () -> listener.contextInitialized(this.event), + "Failed to initialize Log4j properly."); } @Test @@ -93,17 +121,12 @@ public void initializingLog4jServletContextListenerShouldFaileWhenAutoShutdownIs } private void ensureInitializingFailsWhenAuthShutdownIsEnabled() { - try { - this.listener.contextInitialized(this.event); - fail("Expected a RuntimeException."); - } catch (final RuntimeException e) { - final String expectedMessage = - "Do not use " + Log4jServletContextListener.class.getSimpleName() + " when " - + Log4jWebSupport.IS_LOG4J_AUTO_SHUTDOWN_DISABLED + " is true. Please use " - + Log4jShutdownOnContextDestroyedListener.class.getSimpleName() + " instead of " - + Log4jServletContextListener.class.getSimpleName() + "."; - - assertEquals(expectedMessage, e.getMessage(), "The message is not correct"); - } + final Log4jServletContextListener listener = new Log4jServletContextListener(); + final String message = "Do not use " + Log4jServletContextListener.class.getSimpleName() + " when " + + Log4jWebSupport.IS_LOG4J_AUTO_SHUTDOWN_DISABLED + " is true. Please use " + + Log4jShutdownOnContextDestroyedListener.class.getSimpleName() + " instead of " + + Log4jServletContextListener.class.getSimpleName() + "."; + + assertThrows(RuntimeException.class, () -> listener.contextInitialized(event), message); } } diff --git a/log4j-web/src/main/java/org/apache/logging/log4j/web/Log4jServletContextListener.java b/log4j-web/src/main/java/org/apache/logging/log4j/web/Log4jServletContextListener.java index 9ca83202084..93312d14e19 100644 --- a/log4j-web/src/main/java/org/apache/logging/log4j/web/Log4jServletContextListener.java +++ b/log4j-web/src/main/java/org/apache/logging/log4j/web/Log4jServletContextListener.java @@ -36,6 +36,8 @@ */ public class Log4jServletContextListener implements ServletContextListener { + static final String START_COUNT_ATTR = Log4jServletContextListener.class.getName() + ".START_COUNT"; + private static final int DEFAULT_STOP_TIMEOUT = 30; private static final TimeUnit DEFAULT_STOP_TIMEOUT_TIMEUNIT = TimeUnit.SECONDS; @@ -47,11 +49,30 @@ public class Log4jServletContextListener implements ServletContextListener { private ServletContext servletContext; private Log4jWebLifeCycle initializer; + private int getAndIncrementCount() { + Integer count = (Integer) servletContext.getAttribute(START_COUNT_ATTR); + if (count == null) { + count = 0; + } + servletContext.setAttribute(START_COUNT_ATTR, count + 1); + return count; + } + + private int decrementAndGetCount() { + Integer count = (Integer) servletContext.getAttribute(START_COUNT_ATTR); + if (count == null) { + LOGGER.warn( + "{} received a 'contextDestroyed' message without a corresponding 'contextInitialized' message.", + getClass().getName()); + count = 1; + } + servletContext.setAttribute(START_COUNT_ATTR, --count); + return count; + } + @Override public void contextInitialized(final ServletContextEvent event) { this.servletContext = event.getServletContext(); - LOGGER.debug("Log4jServletContextListener ensuring that Log4j starts up properly."); - if ("true".equalsIgnoreCase(servletContext.getInitParameter( Log4jWebSupport.IS_LOG4J_AUTO_SHUTDOWN_DISABLED))) { throw new IllegalStateException("Do not use " + getClass().getSimpleName() + " when " @@ -61,6 +82,12 @@ public void contextInitialized(final ServletContextEvent event) { } this.initializer = WebLoggerContextUtils.getWebLifeCycle(this.servletContext); + if (getAndIncrementCount() != 0) { + LOGGER.debug("Skipping Log4j context initialization, since {} is registered multiple times.", + getClass().getSimpleName()); + return; + } + LOGGER.info("{} triggered a Log4j context initialization.", getClass().getSimpleName()); try { this.initializer.start(); this.initializer.setLoggerContext(); // the application is just now starting to start up @@ -72,23 +99,32 @@ public void contextInitialized(final ServletContextEvent event) { @Override public void contextDestroyed(final ServletContextEvent event) { if (this.servletContext == null || this.initializer == null) { - LOGGER.warn("Context destroyed before it was initialized."); + LOGGER.warn("Servlet context destroyed before it was initialized."); return; } - LOGGER.debug("Log4jServletContextListener ensuring that Log4j shuts down properly."); - - this.initializer.clearLoggerContext(); // the application is finished - // shutting down now - if (initializer instanceof LifeCycle2) { - final String stopTimeoutStr = servletContext.getInitParameter(KEY_STOP_TIMEOUT); - final long stopTimeout = Strings.isEmpty(stopTimeoutStr) ? DEFAULT_STOP_TIMEOUT - : Long.parseLong(stopTimeoutStr); - final String timeoutTimeUnitStr = servletContext.getInitParameter(KEY_STOP_TIMEOUT_TIMEUNIT); - final TimeUnit timeoutTimeUnit = Strings.isEmpty(timeoutTimeUnitStr) ? DEFAULT_STOP_TIMEOUT_TIMEUNIT - : TimeUnit.valueOf(toRootUpperCase(timeoutTimeUnitStr)); - ((LifeCycle2) this.initializer).stop(stopTimeout, timeoutTimeUnit); - } else { - this.initializer.stop(); + + if (decrementAndGetCount() != 0) { + LOGGER.debug("Skipping Log4j context shutdown, since {} is registered multiple times.", + getClass().getSimpleName()); + return; + } + LOGGER.info("{} triggered a Log4j context shutdown.", getClass().getSimpleName()); + try { + this.initializer.clearLoggerContext(); // the application is finished + // shutting down now + if (initializer instanceof LifeCycle2) { + final String stopTimeoutStr = servletContext.getInitParameter(KEY_STOP_TIMEOUT); + final long stopTimeout = Strings.isEmpty(stopTimeoutStr) ? DEFAULT_STOP_TIMEOUT + : Long.parseLong(stopTimeoutStr); + final String timeoutTimeUnitStr = servletContext.getInitParameter(KEY_STOP_TIMEOUT_TIMEUNIT); + final TimeUnit timeoutTimeUnit = Strings.isEmpty(timeoutTimeUnitStr) ? DEFAULT_STOP_TIMEOUT_TIMEUNIT + : TimeUnit.valueOf(toRootUpperCase(timeoutTimeUnitStr)); + ((LifeCycle2) this.initializer).stop(stopTimeout, timeoutTimeUnit); + } else { + this.initializer.stop(); + } + } catch (final IllegalStateException e) { + throw new IllegalStateException("Failed to shutdown Log4j properly.", e); } } } diff --git a/log4j-web/src/test/java/org/apache/logging/log4j/web/Log4jServletContextListenerTest.java b/log4j-web/src/test/java/org/apache/logging/log4j/web/Log4jServletContextListenerTest.java index 698e25871a4..187b44c6ede 100644 --- a/log4j-web/src/test/java/org/apache/logging/log4j/web/Log4jServletContextListenerTest.java +++ b/log4j-web/src/test/java/org/apache/logging/log4j/web/Log4jServletContextListenerTest.java @@ -16,6 +16,8 @@ */ package org.apache.logging.log4j.web; +import java.util.concurrent.atomic.AtomicReference; + import javax.servlet.ServletContext; import javax.servlet.ServletContextEvent; @@ -23,59 +25,85 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; +import org.mockito.Mock.Strictness; import org.mockito.junit.jupiter.MockitoExtension; -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.BDDMockito.eq; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.AdditionalAnswers.answerVoid; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.then; import static org.mockito.BDDMockito.willThrow; +import static org.mockito.Mockito.doAnswer; @ExtendWith(MockitoExtension.class) public class Log4jServletContextListenerTest { /* event and servletContext are marked lenient because they aren't used in the * testDestroyWithNoInit but are only accessed during initialization */ - @Mock(lenient = true) + @Mock(strictness = Strictness.LENIENT) private ServletContextEvent event; - @Mock(lenient = true) + @Mock(strictness = Strictness.LENIENT) private ServletContext servletContext; @Mock private Log4jWebLifeCycle initializer; - private Log4jServletContextListener listener; + private final AtomicReference count = new AtomicReference<>(); @BeforeEach public void setUp() { - this.listener = new Log4jServletContextListener(); given(event.getServletContext()).willReturn(servletContext); given(servletContext.getAttribute(Log4jWebSupport.SUPPORT_ATTRIBUTE)).willReturn(initializer); - } - @Test - public void testInitAndDestroy() throws Exception { - this.listener.contextInitialized(this.event); + doAnswer(answerVoid((k, v) -> count.set(v))) + .when(servletContext) + .setAttribute(eq(Log4jServletContextListener.START_COUNT_ATTR), any()); + doAnswer(__ -> count.get()) + .when(servletContext) + .getAttribute(Log4jServletContextListener.START_COUNT_ATTR); + } - then(initializer).should().start(); - then(initializer).should().setLoggerContext(); + @ParameterizedTest + @ValueSource(ints = { 1, 2, 3 }) + public void testInitAndDestroy(final int listenerCount) throws Exception { + final Log4jServletContextListener[] listeners = new Log4jServletContextListener[listenerCount]; + for (int idx = 0; idx < listenerCount; idx++) { + final Log4jServletContextListener listener = new Log4jServletContextListener(); + listeners[idx] = listener; + + listener.contextInitialized(event); + if (idx == 0) { + then(initializer).should().start(); + then(initializer).should().setLoggerContext(); + } else { + then(initializer).shouldHaveNoMoreInteractions(); + } + } - this.listener.contextDestroyed(this.event); + for (int idx = listenerCount - 1; idx >= 0; idx--) { + final Log4jServletContextListener listener = listeners[idx]; - then(initializer).should().clearLoggerContext(); - then(initializer).should().stop(); + listener.contextDestroyed(event); + if (idx == 0) { + then(initializer).should().clearLoggerContext(); + then(initializer).should().stop(); + } else { + then(initializer).shouldHaveNoMoreInteractions(); + } + } } @Test public void testInitFailure() throws Exception { willThrow(new IllegalStateException(Strings.EMPTY)).given(initializer).start(); + final Log4jServletContextListener listener = new Log4jServletContextListener(); - try { - this.listener.contextInitialized(this.event); - fail("Expected a RuntimeException."); - } catch (final RuntimeException e) { - assertEquals("Failed to initialize Log4j properly.", e.getMessage(), "The message is not correct."); - } + assertThrows(RuntimeException.class, () -> listener.contextInitialized(this.event), + "Failed to initialize Log4j properly."); } @Test @@ -93,17 +121,12 @@ public void initializingLog4jServletContextListenerShouldFaileWhenAutoShutdownIs } private void ensureInitializingFailsWhenAuthShutdownIsEnabled() { - try { - this.listener.contextInitialized(this.event); - fail("Expected a RuntimeException."); - } catch (final RuntimeException e) { - final String expectedMessage = - "Do not use " + Log4jServletContextListener.class.getSimpleName() + " when " - + Log4jWebSupport.IS_LOG4J_AUTO_SHUTDOWN_DISABLED + " is true. Please use " - + Log4jShutdownOnContextDestroyedListener.class.getSimpleName() + " instead of " - + Log4jServletContextListener.class.getSimpleName() + "."; - - assertEquals(expectedMessage, e.getMessage(), "The message is not correct"); - } + final Log4jServletContextListener listener = new Log4jServletContextListener(); + final String message = "Do not use " + Log4jServletContextListener.class.getSimpleName() + " when " + + Log4jWebSupport.IS_LOG4J_AUTO_SHUTDOWN_DISABLED + " is true. Please use " + + Log4jShutdownOnContextDestroyedListener.class.getSimpleName() + " instead of " + + Log4jServletContextListener.class.getSimpleName() + "."; + + assertThrows(RuntimeException.class, () -> listener.contextInitialized(event), message); } } diff --git a/src/changelog/.2.x.x/1782_allow_multiple_servletcontextlistener_instances.xml b/src/changelog/.2.x.x/1782_allow_multiple_servletcontextlistener_instances.xml new file mode 100644 index 00000000000..50f3f61988c --- /dev/null +++ b/src/changelog/.2.x.x/1782_allow_multiple_servletcontextlistener_instances.xml @@ -0,0 +1,28 @@ + + + + + + + + Only shutdown Log4j after last `Log4jServletContextListener` is executed. + +