Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate IP address in certs #568

Merged
merged 2 commits into from
Oct 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -558,12 +558,8 @@ AuditLogMsgBuilder getAuditLogMsgBuilder(ResourceContext ctx, String domainName,

Principal princ = ((RsrcCtxWrapper) ctx).principal();
if (princ != null) {
String unsignedCreds = princ.getUnsignedCredentials();
if (unsignedCreds == null) {
msgBldr.who(princ.getFullName());
} else {
msgBldr.who(unsignedCreds);
}
final String unsignedCreds = princ.getUnsignedCredentials();
msgBldr.who(unsignedCreds == null ? princ.getFullName() : unsignedCreds);
}

// get the client IP
Expand Down Expand Up @@ -1392,8 +1388,8 @@ public RoleAccess getRoleAccess(ResourceContext ctx, String domainName, String p
}

@Override
public RoleToken postRoleCertificateRequest(ResourceContext ctx, String domainName, String roleName,
RoleCertificateRequest req) {
public RoleToken postRoleCertificateRequest(ResourceContext ctx, String domainName,
String roleName, RoleCertificateRequest req) {

final String caller = "postrolecertificaterequest";
final String callerTiming = "postrolecertificaterequest_timing";
Expand Down Expand Up @@ -1472,8 +1468,11 @@ public RoleToken postRoleCertificateRequest(ResourceContext ctx, String domainNa

// validate request/csr details

X509Certificate cert = principal.getX509Certificate();
final String ipAddress = ServletRequestUtil.getRemoteAddress(ctx.request());

if (!validateRoleCertificateRequest(req.getCsr(), domainName, roles, principalName,
validCertSubjectOrgValues)) {
cert, ipAddress, validCertSubjectOrgValues)) {
throw requestError("postRoleCertificateRequest: Unable to validate cert request",
caller, domainName);
}
Expand All @@ -1491,7 +1490,8 @@ public RoleToken postRoleCertificateRequest(ResourceContext ctx, String domainNa
}

boolean validateRoleCertificateRequest(final String csr, final String domainName,
Set<String> roles, final String principal, Set<String> validOrgValues) {
Set<String> roles, final String principal, X509Certificate cert,
final String ip, Set<String> validOrgValues) {

X509RoleCertRequest certReq;
try {
Expand All @@ -1501,7 +1501,13 @@ boolean validateRoleCertificateRequest(final String csr, final String domainName
return false;
}

return certReq.validate(roles, domainName, principal, validOrgValues);
if (!certReq.validate(roles, domainName, principal, validOrgValues)) {
return false;
}

// validate the ip address if any provided

return certReq.validateIPAddress(cert, ip);
}

boolean isAuthorizedServicePrincipal(final Principal principal) {
Expand Down Expand Up @@ -2406,14 +2412,21 @@ public Identity postInstanceRefreshRequest(ResourceContext ctx, String domain,
if (!x509CertReq.validatePublicKeys(publicKey)) {
throw requestError("Invalid CSR - public key mismatch", caller, domain);
}


// verify the IP address in the request matches where the connection
// is coming from

final String ipAddress = ServletRequestUtil.getRemoteAddress(ctx.request());
if (!x509CertReq.validateIPAddress(ipAddress)) {
throw requestError("Invalid CSR - IP address mismatch", caller, domain);
}

// if this is not a user request and the principal authority is the
// certificate authority then we're refreshing our certificate as
// opposed to requesting a new one for the service so we're going
// to do further validation based on the certificate we authenticated

if (refreshOperation) {
final String ipAddress = ServletRequestUtil.getRemoteAddress(ctx.request());
ServiceX509RefreshRequestStatus status = validateServiceX509RefreshRequest(principal,
x509CertReq, ipAddress);
if (status == ServiceX509RefreshRequestStatus.IP_NOT_ALLOWED) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,28 @@ public boolean validatePublicKeys(X509Certificate cert) {
return true;
}

public boolean validateIPAddress(final String ip) {

// if we have no IP addresses in the request, then we're good

if (ipAddresses.isEmpty()) {
return true;
}

// if we have more than 1 IP address in the request then
// we're going to reject it as we can't validate if those
// multiple addresses are from the same host. In this
// scenario a provider model must be used which supports
// multiple IPs in a request

if (ipAddresses.size() != 1) {
LOGGER.error("Cert request contains multiple IP: {} addresses", ipAddresses.size());
return false;
}

return ipAddresses.get(0).equals(ip);
}

boolean validateSpiffeURI(final String domain, final String name, final String value) {

// first extract the URI list from the request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.security.cert.X509Certificate;
import java.util.List;
import java.util.Set;

Expand Down Expand Up @@ -107,4 +108,31 @@ public boolean validate(Set<String> roles, final String domainName,

return validateSpiffeURI(domainName, "role", roleName);
}

public boolean validateIPAddress(X509Certificate cert, final String ip) {

// if we have no IP addresses in the request, then we're good

if (ipAddresses.isEmpty()) {
return true;
}

// if we have a certificate then we need to make sure
// that all the ip addresses in the request match
// the ip addresses in the certificate

if (cert != null) {

List<String> certIPs = Crypto.extractX509CertIPAddresses(cert);

// if the certificate has no ip then we'll do
// validation based on the connection ip

if (!certIPs.isEmpty()) {
return certIPs.containsAll(ipAddresses);
}
}

return validateIPAddress(ip);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.yahoo.athenz.auth.impl.PrincipalAuthority;
import com.yahoo.athenz.auth.impl.SimplePrincipal;
import org.mockito.Mockito;
import org.testng.annotations.Test;
import com.yahoo.athenz.common.metrics.Metric;
Expand All @@ -33,7 +35,7 @@
public class RsrcCtxWrapperTest {

@Test
public void TestRsrcCtxWrapperSimpleAssertion() {
public void testRsrcCtxWrapperSimpleAssertion() {
HttpServletRequest reqMock = Mockito.mock(HttpServletRequest.class);
HttpServletResponse resMock = Mockito.mock(HttpServletResponse.class);

Expand All @@ -52,7 +54,8 @@ public void TestRsrcCtxWrapperSimpleAssertion() {
Mockito.when(reqMock.getMethod()).thenReturn("POST");
authListMock.add(authMock);

RsrcCtxWrapper wrapper = new RsrcCtxWrapper(reqMock, resMock, authListMock, false, authorizerMock, metricMock);
RsrcCtxWrapper wrapper = new RsrcCtxWrapper(reqMock, resMock, authListMock, false,
authorizerMock, metricMock);

assertNotNull(wrapper.context());

Expand All @@ -77,7 +80,7 @@ public void TestRsrcCtxWrapperSimpleAssertion() {
}

@Test
public void TestAuthorize() {
public void testAuthorize() {
HttpServletRequest reqMock = Mockito.mock(HttpServletRequest.class);
HttpServletResponse resMock = Mockito.mock(HttpServletResponse.class);

Expand Down Expand Up @@ -109,7 +112,7 @@ public void TestAuthorize() {
}

@Test(expectedExceptions = { ResourceException.class })
public void TestAuthorizeInvalid() {
public void testAuthorizeInvalid() {
HttpServletRequest reqMock = Mockito.mock(HttpServletRequest.class);
HttpServletResponse resMock = Mockito.mock(HttpServletResponse.class);

Expand All @@ -130,4 +133,54 @@ public void TestAuthorizeInvalid() {
// when not set authority
wrapper.authorize("add-domain", "test", "test");
}

@Test
public void testLogPrincipal() {

HttpServletRequest servletRequest = new MockHttpServletRequest();
HttpServletResponse servletResponse = Mockito.mock(HttpServletResponse.class);

AuthorityList authListMock = new AuthorityList();
Authorizer authorizerMock = Mockito.mock(Authorizer.class);
Metric metricMock = Mockito.mock(Metric.class);

RsrcCtxWrapper wrapper = new RsrcCtxWrapper(servletRequest, servletResponse,
authListMock, false, authorizerMock, metricMock);

wrapper.logPrincipal((Principal) null);
assertNull(servletRequest.getAttribute("com.yahoo.athenz.auth.principal"));

wrapper.logPrincipal((String) null);
assertNull(servletRequest.getAttribute("com.yahoo.athenz.auth.principal"));

SimplePrincipal principal = (SimplePrincipal) SimplePrincipal.create("hockey", "kings",
"v=S1,d=hockey;n=kings;s=sig", 0, new PrincipalAuthority());

wrapper.logPrincipal(principal);
assertEquals(servletRequest.getAttribute("com.yahoo.athenz.auth.principal"), "hockey.kings");
}

@Test
public void testThrowZtsException() {

HttpServletRequest servletRequest = new MockHttpServletRequest();
HttpServletResponse servletResponse = Mockito.mock(HttpServletResponse.class);

AuthorityList authListMock = new AuthorityList();
Authorizer authorizerMock = Mockito.mock(Authorizer.class);
Metric metricMock = Mockito.mock(Metric.class);

RsrcCtxWrapper wrapper = new RsrcCtxWrapper(servletRequest, servletResponse,
authListMock, false, authorizerMock, metricMock);

com.yahoo.athenz.common.server.rest.ResourceException restExc =
new com.yahoo.athenz.common.server.rest.ResourceException(503, null);

try {
wrapper.throwZtsException(restExc);
fail();
} catch (ResourceException ex) {
assertEquals(503, ex.getCode());
}
}
}
Loading