Skip to content

Commit

Permalink
Allow multiple Log4jServletContextListener registrations
Browse files Browse the repository at this point in the history
This closes apache#1782.
  • Loading branch information
ppkarwasz committed Sep 11, 2023
1 parent 3b990f4 commit 1b79941
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
*/
public class Log4jServletContextListener implements ServletContextListener {

static final String START_COUNT_ATTR = Log4jServletContextListener.class.getName() + ".START_COUNT";

private static final int DEFAULT_STOP_TIMEOUT = 30;
private static final TimeUnit DEFAULT_STOP_TIMEOUT_TIMEUNIT = TimeUnit.SECONDS;

Expand All @@ -47,11 +49,30 @@ public class Log4jServletContextListener implements ServletContextListener {
private ServletContext servletContext;
private Log4jWebLifeCycle initializer;

private int getAndIncrementCount() {
Integer count = (Integer) servletContext.getAttribute(START_COUNT_ATTR);
if (count == null) {
count = 0;
}
servletContext.setAttribute(START_COUNT_ATTR, count + 1);
return count;
}

private int decrementAndGetCount() {
Integer count = (Integer) servletContext.getAttribute(START_COUNT_ATTR);
if (count == null) {
LOGGER.warn(
"{} received a 'contextDestroyed' message without a corresponding 'contextInitialized' message.",
getClass().getName());
count = 1;
}
servletContext.setAttribute(START_COUNT_ATTR, --count);
return count;
}

@Override
public void contextInitialized(final ServletContextEvent event) {
this.servletContext = event.getServletContext();
LOGGER.debug("Log4jServletContextListener ensuring that Log4j starts up properly.");

if ("true".equalsIgnoreCase(servletContext.getInitParameter(
Log4jWebSupport.IS_LOG4J_AUTO_SHUTDOWN_DISABLED))) {
throw new IllegalStateException("Do not use " + getClass().getSimpleName() + " when "
Expand All @@ -61,6 +82,12 @@ public void contextInitialized(final ServletContextEvent event) {
}

this.initializer = WebLoggerContextUtils.getWebLifeCycle(this.servletContext);
if (getAndIncrementCount() != 0) {
LOGGER.debug("Skipping Log4j context initialization, since {} is registered multiple times.",
getClass().getSimpleName());
return;
}
LOGGER.info("{} triggered a Log4j context initialization.", getClass().getSimpleName());
try {
this.initializer.start();
this.initializer.setLoggerContext(); // the application is just now starting to start up
Expand All @@ -72,23 +99,32 @@ public void contextInitialized(final ServletContextEvent event) {
@Override
public void contextDestroyed(final ServletContextEvent event) {
if (this.servletContext == null || this.initializer == null) {
LOGGER.warn("Context destroyed before it was initialized.");
LOGGER.warn("Servlet context destroyed before it was initialized.");
return;
}
LOGGER.debug("Log4jServletContextListener ensuring that Log4j shuts down properly.");

this.initializer.clearLoggerContext(); // the application is finished
// shutting down now
if (initializer instanceof LifeCycle2) {
final String stopTimeoutStr = servletContext.getInitParameter(KEY_STOP_TIMEOUT);
final long stopTimeout = Strings.isEmpty(stopTimeoutStr) ? DEFAULT_STOP_TIMEOUT
: Long.parseLong(stopTimeoutStr);
final String timeoutTimeUnitStr = servletContext.getInitParameter(KEY_STOP_TIMEOUT_TIMEUNIT);
final TimeUnit timeoutTimeUnit = Strings.isEmpty(timeoutTimeUnitStr) ? DEFAULT_STOP_TIMEOUT_TIMEUNIT
: TimeUnit.valueOf(toRootUpperCase(timeoutTimeUnitStr));
((LifeCycle2) this.initializer).stop(stopTimeout, timeoutTimeUnit);
} else {
this.initializer.stop();

if (decrementAndGetCount() != 0) {
LOGGER.debug("Skipping Log4j context shutdown, since {} is registered multiple times.",
getClass().getSimpleName());
return;
}
LOGGER.info("{} triggered a Log4j context shutdown.", getClass().getSimpleName());
try {
this.initializer.clearLoggerContext(); // the application is finished
// shutting down now
if (initializer instanceof LifeCycle2) {
final String stopTimeoutStr = servletContext.getInitParameter(KEY_STOP_TIMEOUT);
final long stopTimeout = Strings.isEmpty(stopTimeoutStr) ? DEFAULT_STOP_TIMEOUT
: Long.parseLong(stopTimeoutStr);
final String timeoutTimeUnitStr = servletContext.getInitParameter(KEY_STOP_TIMEOUT_TIMEUNIT);
final TimeUnit timeoutTimeUnit = Strings.isEmpty(timeoutTimeUnitStr) ? DEFAULT_STOP_TIMEOUT_TIMEUNIT
: TimeUnit.valueOf(toRootUpperCase(timeoutTimeUnitStr));
((LifeCycle2) this.initializer).stop(stopTimeout, timeoutTimeUnit);
} else {
this.initializer.stop();
}
} catch (final IllegalStateException e) {
throw new IllegalStateException("Failed to shutdown Log4j properly.", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,66 +16,94 @@
*/
package org.apache.logging.log4j.web;

import java.util.concurrent.atomic.AtomicReference;

import jakarta.servlet.ServletContext;
import jakarta.servlet.ServletContextEvent;

import org.apache.logging.log4j.util.Strings;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.mockito.Mock.Strictness;
import org.mockito.junit.jupiter.MockitoExtension;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.BDDMockito.eq;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.AdditionalAnswers.answerVoid;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.then;
import static org.mockito.BDDMockito.willThrow;
import static org.mockito.Mockito.doAnswer;

@ExtendWith(MockitoExtension.class)
public class Log4jServletContextListenerTest {
/* event and servletContext are marked lenient because they aren't used in the
* testDestroyWithNoInit but are only accessed during initialization
*/
@Mock(lenient = true)
@Mock(strictness = Strictness.LENIENT)
private ServletContextEvent event;
@Mock(lenient = true)
@Mock(strictness = Strictness.LENIENT)
private ServletContext servletContext;
@Mock
private Log4jWebLifeCycle initializer;

private Log4jServletContextListener listener;
private final AtomicReference<Object> count = new AtomicReference<>();

@BeforeEach
public void setUp() {
this.listener = new Log4jServletContextListener();
given(event.getServletContext()).willReturn(servletContext);
given(servletContext.getAttribute(Log4jWebSupport.SUPPORT_ATTRIBUTE)).willReturn(initializer);
}

@Test
public void testInitAndDestroy() throws Exception {
this.listener.contextInitialized(this.event);
doAnswer(answerVoid((k, v) -> count.set(v)))
.when(servletContext)
.setAttribute(eq(Log4jServletContextListener.START_COUNT_ATTR), any());
doAnswer(__ -> count.get())
.when(servletContext)
.getAttribute(Log4jServletContextListener.START_COUNT_ATTR);
}

then(initializer).should().start();
then(initializer).should().setLoggerContext();
@ParameterizedTest
@ValueSource(ints = { 1, 2, 3 })
public void testInitAndDestroy(final int listenerCount) throws Exception {
final Log4jServletContextListener[] listeners = new Log4jServletContextListener[listenerCount];
for (int idx = 0; idx < listenerCount; idx++) {
final Log4jServletContextListener listener = new Log4jServletContextListener();
listeners[idx] = listener;

listener.contextInitialized(event);
if (idx == 0) {
then(initializer).should().start();
then(initializer).should().setLoggerContext();
} else {
then(initializer).shouldHaveNoMoreInteractions();
}
}

this.listener.contextDestroyed(this.event);
for (int idx = listenerCount - 1; idx >= 0; idx--) {
final Log4jServletContextListener listener = listeners[idx];

then(initializer).should().clearLoggerContext();
then(initializer).should().stop();
listener.contextDestroyed(event);
if (idx == 0) {
then(initializer).should().clearLoggerContext();
then(initializer).should().stop();
} else {
then(initializer).shouldHaveNoMoreInteractions();
}
}
}

@Test
public void testInitFailure() throws Exception {
willThrow(new IllegalStateException(Strings.EMPTY)).given(initializer).start();
final Log4jServletContextListener listener = new Log4jServletContextListener();

try {
this.listener.contextInitialized(this.event);
fail("Expected a RuntimeException.");
} catch (final RuntimeException e) {
assertEquals("Failed to initialize Log4j properly.", e.getMessage(), "The message is not correct.");
}
assertThrows(RuntimeException.class, () -> listener.contextInitialized(this.event),
"Failed to initialize Log4j properly.");
}

@Test
Expand All @@ -93,17 +121,12 @@ public void initializingLog4jServletContextListenerShouldFaileWhenAutoShutdownIs
}

private void ensureInitializingFailsWhenAuthShutdownIsEnabled() {
try {
this.listener.contextInitialized(this.event);
fail("Expected a RuntimeException.");
} catch (final RuntimeException e) {
final String expectedMessage =
"Do not use " + Log4jServletContextListener.class.getSimpleName() + " when "
+ Log4jWebSupport.IS_LOG4J_AUTO_SHUTDOWN_DISABLED + " is true. Please use "
+ Log4jShutdownOnContextDestroyedListener.class.getSimpleName() + " instead of "
+ Log4jServletContextListener.class.getSimpleName() + ".";

assertEquals(expectedMessage, e.getMessage(), "The message is not correct");
}
final Log4jServletContextListener listener = new Log4jServletContextListener();
final String message = "Do not use " + Log4jServletContextListener.class.getSimpleName() + " when "
+ Log4jWebSupport.IS_LOG4J_AUTO_SHUTDOWN_DISABLED + " is true. Please use "
+ Log4jShutdownOnContextDestroyedListener.class.getSimpleName() + " instead of "
+ Log4jServletContextListener.class.getSimpleName() + ".";

assertThrows(RuntimeException.class, () -> listener.contextInitialized(event), message);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
*/
public class Log4jServletContextListener implements ServletContextListener {

static final String START_COUNT_ATTR = Log4jServletContextListener.class.getName() + ".START_COUNT";

private static final int DEFAULT_STOP_TIMEOUT = 30;
private static final TimeUnit DEFAULT_STOP_TIMEOUT_TIMEUNIT = TimeUnit.SECONDS;

Expand All @@ -47,11 +49,30 @@ public class Log4jServletContextListener implements ServletContextListener {
private ServletContext servletContext;
private Log4jWebLifeCycle initializer;

private int getAndIncrementCount() {
Integer count = (Integer) servletContext.getAttribute(START_COUNT_ATTR);
if (count == null) {
count = 0;
}
servletContext.setAttribute(START_COUNT_ATTR, count + 1);
return count;
}

private int decrementAndGetCount() {
Integer count = (Integer) servletContext.getAttribute(START_COUNT_ATTR);
if (count == null) {
LOGGER.warn(
"{} received a 'contextDestroyed' message without a corresponding 'contextInitialized' message.",
getClass().getName());
count = 1;
}
servletContext.setAttribute(START_COUNT_ATTR, --count);
return count;
}

@Override
public void contextInitialized(final ServletContextEvent event) {
this.servletContext = event.getServletContext();
LOGGER.debug("Log4jServletContextListener ensuring that Log4j starts up properly.");

if ("true".equalsIgnoreCase(servletContext.getInitParameter(
Log4jWebSupport.IS_LOG4J_AUTO_SHUTDOWN_DISABLED))) {
throw new IllegalStateException("Do not use " + getClass().getSimpleName() + " when "
Expand All @@ -61,6 +82,12 @@ public void contextInitialized(final ServletContextEvent event) {
}

this.initializer = WebLoggerContextUtils.getWebLifeCycle(this.servletContext);
if (getAndIncrementCount() != 0) {
LOGGER.debug("Skipping Log4j context initialization, since {} is registered multiple times.",
getClass().getSimpleName());
return;
}
LOGGER.info("{} triggered a Log4j context initialization.", getClass().getSimpleName());
try {
this.initializer.start();
this.initializer.setLoggerContext(); // the application is just now starting to start up
Expand All @@ -72,23 +99,32 @@ public void contextInitialized(final ServletContextEvent event) {
@Override
public void contextDestroyed(final ServletContextEvent event) {
if (this.servletContext == null || this.initializer == null) {
LOGGER.warn("Context destroyed before it was initialized.");
LOGGER.warn("Servlet context destroyed before it was initialized.");
return;
}
LOGGER.debug("Log4jServletContextListener ensuring that Log4j shuts down properly.");

this.initializer.clearLoggerContext(); // the application is finished
// shutting down now
if (initializer instanceof LifeCycle2) {
final String stopTimeoutStr = servletContext.getInitParameter(KEY_STOP_TIMEOUT);
final long stopTimeout = Strings.isEmpty(stopTimeoutStr) ? DEFAULT_STOP_TIMEOUT
: Long.parseLong(stopTimeoutStr);
final String timeoutTimeUnitStr = servletContext.getInitParameter(KEY_STOP_TIMEOUT_TIMEUNIT);
final TimeUnit timeoutTimeUnit = Strings.isEmpty(timeoutTimeUnitStr) ? DEFAULT_STOP_TIMEOUT_TIMEUNIT
: TimeUnit.valueOf(toRootUpperCase(timeoutTimeUnitStr));
((LifeCycle2) this.initializer).stop(stopTimeout, timeoutTimeUnit);
} else {
this.initializer.stop();

if (decrementAndGetCount() != 0) {
LOGGER.debug("Skipping Log4j context shutdown, since {} is registered multiple times.",
getClass().getSimpleName());
return;
}
LOGGER.info("{} triggered a Log4j context shutdown.", getClass().getSimpleName());
try {
this.initializer.clearLoggerContext(); // the application is finished
// shutting down now
if (initializer instanceof LifeCycle2) {
final String stopTimeoutStr = servletContext.getInitParameter(KEY_STOP_TIMEOUT);
final long stopTimeout = Strings.isEmpty(stopTimeoutStr) ? DEFAULT_STOP_TIMEOUT
: Long.parseLong(stopTimeoutStr);
final String timeoutTimeUnitStr = servletContext.getInitParameter(KEY_STOP_TIMEOUT_TIMEUNIT);
final TimeUnit timeoutTimeUnit = Strings.isEmpty(timeoutTimeUnitStr) ? DEFAULT_STOP_TIMEOUT_TIMEUNIT
: TimeUnit.valueOf(toRootUpperCase(timeoutTimeUnitStr));
((LifeCycle2) this.initializer).stop(stopTimeout, timeoutTimeUnit);
} else {
this.initializer.stop();
}
} catch (final IllegalStateException e) {
throw new IllegalStateException("Failed to shutdown Log4j properly.", e);
}
}
}
Loading

0 comments on commit 1b79941

Please sign in to comment.