diff --git a/gateway-spi/src/main/java/org/apache/knox/gateway/SpiGatewayMessages.java b/gateway-spi/src/main/java/org/apache/knox/gateway/SpiGatewayMessages.java index cbd310a63..771e616f4 100644 --- a/gateway-spi/src/main/java/org/apache/knox/gateway/SpiGatewayMessages.java +++ b/gateway-spi/src/main/java/org/apache/knox/gateway/SpiGatewayMessages.java @@ -120,4 +120,7 @@ public interface SpiGatewayMessages { @Message( level = MessageLevel.DEBUG, text = "Malformed dispatch URL: {0}" ) void malformedDispatchUrl(String url); + + @Message( level = MessageLevel.ERROR, text = "No valid principal found" ) + void noPrincipalFound(); } diff --git a/gateway-spi/src/main/java/org/apache/knox/gateway/dispatch/ConfigurableDispatch.java b/gateway-spi/src/main/java/org/apache/knox/gateway/dispatch/ConfigurableDispatch.java index 9bbcde761..6f1d240e5 100644 --- a/gateway-spi/src/main/java/org/apache/knox/gateway/dispatch/ConfigurableDispatch.java +++ b/gateway-spi/src/main/java/org/apache/knox/gateway/dispatch/ConfigurableDispatch.java @@ -22,20 +22,29 @@ import org.apache.knox.gateway.audit.api.ActionOutcome; import org.apache.knox.gateway.config.Configure; import org.apache.knox.gateway.config.Default; +import org.apache.knox.gateway.security.SubjectUtils; import org.apache.knox.gateway.util.StringUtils; +import javax.security.auth.Subject; import javax.servlet.http.HttpServletRequest; import java.io.UnsupportedEncodingException; import java.net.URI; import java.net.URLDecoder; import java.nio.charset.StandardCharsets; + import java.util.Collections; -import java.util.HashSet; import java.util.Map; import java.util.Set; import java.util.Arrays; -import java.util.Optional; +import java.util.HashSet; import java.util.HashMap; +import java.util.Optional; +import java.util.List; +import java.util.Collection; +import java.util.Locale; +import java.util.ArrayList; +import java.util.concurrent.ConcurrentHashMap; +import java.util.regex.Pattern; import java.util.stream.Collectors; /** @@ -50,6 +59,21 @@ public class ConfigurableDispatch extends DefaultDispatch { private Set responseExcludeSetCookieHeaderDirectives = super.getOutboundResponseExcludedSetCookieHeaderDirectives(); private Boolean removeUrlEncoding = false; + private boolean shouldIncludePrincipalAndGroups; + private String actorIdHeaderName = DEFAULT_AUTH_ACTOR_ID_HEADER_NAME; + private String actorGroupsHeaderPrefix = DEFAULT_AUTH_ACTOR_GROUPS_HEADER_PREFIX; + private String groupFilterPattern = DEFAULT_GROUP_FILTER_PATTERN; + + static final String DEFAULT_AUTH_ACTOR_ID_HEADER_NAME = "X-Knox-Actor-ID"; + static final String DEFAULT_AUTH_ACTOR_GROUPS_HEADER_PREFIX = "X-Knox-Actor-Groups"; + static final String DEFAULT_GROUP_FILTER_PATTERN = ".*"; + static final String DEFAULT_ARE_USERS_GROUPS_HEADER_INCLUDED = "false"; + + protected static final int MAX_HEADER_LENGTH = 1000; + protected static final String ACTOR_GROUPS_HEADER_FORMAT = "%s-%d"; + protected Pattern groupPattern = Pattern.compile(DEFAULT_GROUP_FILTER_PATTERN); + + private Set convertCommaDelimitedHeadersToSet(String headers) { return headers == null ? Collections.emptySet(): new HashSet<>(Arrays.asList(headers.split("\\s*,\\s*"))); } @@ -123,6 +147,27 @@ protected void setRemoveUrlEncoding(@Default("false") String removeUrlEncoding) this.removeUrlEncoding = Boolean.parseBoolean(removeUrlEncoding); } + @Configure + public void setShouldIncludePrincipalAndGroups(@Default(DEFAULT_ARE_USERS_GROUPS_HEADER_INCLUDED) boolean shouldIncludePrincipalAndGroups) { + this.shouldIncludePrincipalAndGroups = shouldIncludePrincipalAndGroups; + } + + @Configure + public void setActorIdHeaderName(@Default(DEFAULT_AUTH_ACTOR_ID_HEADER_NAME) String actorIdHeaderName) { + this.actorIdHeaderName = actorIdHeaderName; + } + + @Configure + public void setActorGroupsHeaderPrefix(@Default(DEFAULT_AUTH_ACTOR_GROUPS_HEADER_PREFIX) String actorGroupsHeaderPrefix) { + this.actorGroupsHeaderPrefix = actorGroupsHeaderPrefix; + } + + @Configure + public void setGroupFilterPattern(@Default(DEFAULT_GROUP_FILTER_PATTERN) String groupFilterPattern) { + this.groupFilterPattern = groupFilterPattern; + groupPattern = Pattern.compile(this.groupFilterPattern); + } + @Override public void copyRequestHeaderFields(HttpUriRequest outboundRequest, HttpServletRequest inboundRequest) { @@ -133,6 +178,61 @@ public void copyRequestHeaderFields(HttpUriRequest outboundRequest, if(MapUtils.isNotEmpty(extraHeaders)){ extraHeaders.forEach(outboundRequest::addHeader); } + + /* If we need to add user and groups to outbound request */ + if(shouldIncludePrincipalAndGroups) { + Map groups = addPrincipalAndGroups(); + if(MapUtils.isNotEmpty(groups)){ + groups.forEach(outboundRequest::addHeader); + } + } + } + + private Map addPrincipalAndGroups() { + final Map headers = new ConcurrentHashMap(); + final Subject subject = SubjectUtils.getCurrentSubject(); + + final String primaryPrincipalName = subject == null ? null : SubjectUtils.getPrimaryPrincipalName(subject); + if (primaryPrincipalName == null) { + LOG.noPrincipalFound(); + headers.put(actorIdHeaderName, ""); + } else { + headers.put(actorIdHeaderName, primaryPrincipalName); + } + + // Populate actor groups headers + final Set matchingGroupNames = subject == null ? Collections.emptySet() + : SubjectUtils.getGroupPrincipals(subject).stream().filter(group -> groupPattern.matcher(group.getName()).matches()).map(group -> group.getName()) + .collect(Collectors.toSet()); + if (!matchingGroupNames.isEmpty()) { + final List groupStrings = getGroupStrings(matchingGroupNames); + for (int i = 0; i < groupStrings.size(); i++) { + headers.put(String.format(Locale.ROOT, ACTOR_GROUPS_HEADER_FORMAT, actorGroupsHeaderPrefix, i + 1), groupStrings.get(i)); + } + } + return headers; + } + + private List getGroupStrings(final Collection groupNames) { + if (groupNames.isEmpty()) { + return Collections.emptyList(); + } + List groupStrings = new ArrayList<>(); + StringBuilder sb = new StringBuilder(); + for (String groupName : groupNames) { + if (sb.length() + groupName.length() > MAX_HEADER_LENGTH) { + groupStrings.add(sb.toString()); + sb = new StringBuilder(); + } + if (sb.length() > 0) { + sb.append(','); + } + sb.append(groupName); + } + if (sb.length() > 0) { + groupStrings.add(sb.toString()); + } + return groupStrings; } @Override @@ -180,4 +280,5 @@ public URI getDispatchUrl(HttpServletRequest request) { return super.getDispatchUrl(request); } + } diff --git a/gateway-spi/src/test/java/org/apache/knox/gateway/dispatch/ConfigurableDispatchTest.java b/gateway-spi/src/test/java/org/apache/knox/gateway/dispatch/ConfigurableDispatchTest.java index 8f38fd5a5..0386ac9af 100644 --- a/gateway-spi/src/test/java/org/apache/knox/gateway/dispatch/ConfigurableDispatchTest.java +++ b/gateway-spi/src/test/java/org/apache/knox/gateway/dispatch/ConfigurableDispatchTest.java @@ -18,6 +18,7 @@ package org.apache.knox.gateway.dispatch; import static org.apache.knox.gateway.dispatch.AbstractGatewayDispatch.REQUEST_ID_HEADER_NAME; +import static org.apache.knox.gateway.dispatch.ConfigurableDispatch.DEFAULT_AUTH_ACTOR_ID_HEADER_NAME; import static org.apache.knox.gateway.dispatch.DefaultDispatch.SET_COOKIE; import static org.apache.knox.gateway.dispatch.DefaultDispatch.WWW_AUTHENTICATE; import static org.hamcrest.CoreMatchers.containsString; @@ -26,12 +27,15 @@ import static org.junit.Assert.assertThat; import java.net.URI; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.UUID; +import javax.security.auth.Subject; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -41,6 +45,8 @@ import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpUriRequest; import org.apache.http.message.BasicHeader; +import org.apache.knox.gateway.security.GroupPrincipal; +import org.apache.knox.gateway.security.PrimaryPrincipal; import org.apache.knox.test.TestUtils; import org.apache.knox.test.mock.MockHttpServletResponse; import org.apache.logging.log4j.CloseableThreadContext; @@ -316,7 +322,7 @@ public void testRequestAppendHeadersConfig() { assertThat(outboundRequestHeaders[3].getName(), is("c")); } - @Test( timeout = TestUtils.SHORT_TIMEOUT ) + @Test( timeout = TestUtils.LONG_TIMEOUT ) public void testRequestExcludeAndAppendHeadersConfig() { ConfigurableDispatch dispatch = new ConfigurableDispatch(); dispatch.setRequestAppendHeaders("a : b ; c : d"); @@ -724,4 +730,47 @@ public void testXRequestIDHeaderExcludeListNoReqHeader() { assertThat(outboundResponse.getHeader(REQUEST_ID_HEADER_NAME), nullValue()); } + /** + * Make sure X-Knox-Actor-ID and X-Knox-Actor-Groups-1 headers + * are added for authenticated users. + */ + @Test + public void testGroupHeaders() throws PrivilegedActionException { + Subject subject = new Subject(); + subject.getPrincipals().add(new PrimaryPrincipal("knoxui")); + subject.getPrincipals().add(new GroupPrincipal("knox")); + subject.getPrincipals().add(new GroupPrincipal("admin")); + + ConfigurableDispatch dispatch = new ConfigurableDispatch(); + final String headerReqID = "1234567890ABCD"; + dispatch.setShouldIncludePrincipalAndGroups(true); + + Map headers = new HashMap<>(); + headers.put(REQUEST_ID_HEADER_NAME, headerReqID); + headers.put(HttpHeaders.ACCEPT, "abc"); + headers.put("TEST", "test"); + + HttpServletRequest inboundRequest = EasyMock.createNiceMock(HttpServletRequest.class); + EasyMock.expect(inboundRequest.getHeaderNames()).andReturn(Collections.enumeration(headers.keySet())).anyTimes(); + Capture capturedArgument = Capture.newInstance(); + EasyMock.expect(inboundRequest.getHeader(EasyMock.capture(capturedArgument))) + .andAnswer(() -> headers.get(capturedArgument.getValue())).anyTimes(); + EasyMock.replay(inboundRequest); + + HttpUriRequest outboundRequest = new HttpGet(); + + Subject.doAs(subject, new PrivilegedExceptionAction() { + + @Override + public Object run() throws Exception { + dispatch.copyRequestHeaderFields(outboundRequest, inboundRequest); + return null; + } + }); + + Header[] outboundRequestHeaders = outboundRequest.getAllHeaders(); + assertThat(outboundRequestHeaders.length, is(5)); + assertThat(outboundRequest.getHeaders(DEFAULT_AUTH_ACTOR_ID_HEADER_NAME)[0].getValue(), is("knoxui")); + } + }