Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fetch verification key from server via proxy during accesstoken and roletoken verification #2527

Merged
merged 2 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clients/java/zpe/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
<mock.server.version>5.15.0</mock.server.version>
<commons.io.version>2.15.1</commons.io.version>
<uberjar.name>benchmarks</uberjar.name>
<code.coverage.min>0.8702</code.coverage.min>
<code.coverage.min>0.8682</code.coverage.min>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason why we're reducing our code coverage? please review your additions to make sure all the new code is included in the tests.
we only allow reducing the code coverage if you're removing functions from the module which doesn't seem to be the case here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We modified the implementation of setAccessTokenSignKeyResolver in AuthZpeClient.java to restore coverage.

</properties>

<dependencies>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,21 @@ public static void setAccessTokenSignKeyResolver(final String serverUrl, SSLCont
accessSignKeyResolver = new JwtsSigningKeyResolver(serverUrl, sslContext);
}

/**
* Set the server connection details for the sign key resolver for access
* tokens. By default, the resolver is looking for the "athenz.athenz_conf"
* system property, parses the athenz.conf file and loads any public keys
* defined. The caller can also specify the server URL, the sslcontext and the proxy URL
* (if required) for the resolver to call and fetch the public keys that
* will be required to verify the token signatures
* @param serverUrl server url to fetch json web keys
* @param sslContext ssl context to be used when establishing connection
* @param proxyUrl if a proxy is required, specify the proxy URL
*/
public static void setAccessTokenSignKeyResolver(final String serverUrl, SSLContext sslContext, final String proxyUrl) {
accessSignKeyResolver = new JwtsSigningKeyResolver(serverUrl, sslContext, proxyUrl);
}

/**
* Include the specified public key and id in the access token
* signing resolver
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
import java.net.URLConnection;
import java.net.*;


public class JwtsHelper {

Expand All @@ -42,8 +41,12 @@ static ObjectMapper initJsonMapper() {
}

public String extractJwksUri(final String openIdConfigUri, final SSLContext sslContext) {
return this.extractJwksUri(openIdConfigUri, sslContext, null);
}

public String extractJwksUri(final String openIdConfigUri, final SSLContext sslContext, final String proxyUrl) {

final String opendIdConfigData = getHttpData(openIdConfigUri, sslContext);
final String opendIdConfigData = getHttpData(openIdConfigUri, sslContext, proxyUrl);
if (opendIdConfigData == null) {
return null;
}
Expand All @@ -59,13 +62,24 @@ public String extractJwksUri(final String openIdConfigUri, final SSLContext sslC
}

public String getHttpData(final String serverUri, final SSLContext sslContext) {
return getHttpData(serverUri, sslContext, null);
}

public String getHttpData(final String serverUri, final SSLContext sslContext, final String proxyUrl) {

if (serverUri == null || serverUri.isEmpty()) {
return null;
}

try {
URLConnection con = getUrlConnection(serverUri);
URLConnection con;
if (proxyUrl == null || proxyUrl.isEmpty()) {
con = getUrlConnection(serverUri);
} else {
URL url = new URL(proxyUrl);
con = getUrlConnection(serverUri, url.getHost(), url.getPort());
}

con.setRequestProperty("Accept", "application/json");
con.setConnectTimeout(10000);
con.setReadTimeout(15000);
Expand Down Expand Up @@ -119,4 +133,10 @@ SSLSocketFactory getSocketFactory(SSLContext sslContext) {
URLConnection getUrlConnection(final String serverUrl) throws IOException {
return new URL(serverUrl).openConnection();
}

URLConnection getUrlConnection(final String serverUrl, final String proxyHost, final Integer proxyPort) throws IOException {
SocketAddress addr = new InetSocketAddress(proxyHost, proxyPort);
Proxy proxy = new Proxy(Proxy.Type.HTTP, addr);
return new URL(serverUrl).openConnection(proxy);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public class JwtsSigningKeyResolver implements SigningKeyResolver {
private static final ObjectMapper JSON_MAPPER = initJsonMapper();
private final SSLContext sslContext;
private final String jwksUri;
private final String proxyUrl;
private static long lastZtsJwkFetchTime;
private static long millisBetweenZtsCalls;

Expand All @@ -66,10 +67,19 @@ public JwtsSigningKeyResolver(final String jwksUri, final SSLContext sslContext)
this(jwksUri, sslContext, false);
}

public JwtsSigningKeyResolver(final String jwksUri, final SSLContext sslContext, final String proxyUrl) {
this(jwksUri, sslContext, proxyUrl, false);
}

public JwtsSigningKeyResolver(final String jwksUri, final SSLContext sslContext, boolean skipConfig) {
this(jwksUri, sslContext, null, skipConfig);
}

public JwtsSigningKeyResolver(final String jwksUri, final SSLContext sslContext, final String proxyUrl, boolean skipConfig) {
this.jwksUri = jwksUri;
this.sslContext = sslContext;
this.publicKeys = new ConcurrentHashMap<>();
this.proxyUrl = proxyUrl;
if (!skipConfig) {
loadPublicKeysFromConfig();
loadJwksFromConfig();
Expand Down Expand Up @@ -124,7 +134,7 @@ public int publicKeyCount() {

public void loadPublicKeysFromServer() {

final String jwksData = getHttpData(jwksUri, sslContext);
final String jwksData = getHttpData(jwksUri, sslContext, proxyUrl);
if (jwksData == null) {
return;
}
Expand All @@ -143,9 +153,9 @@ public void loadPublicKeysFromServer() {
}
}

String getHttpData(final String jwksUri, final SSLContext sslContext) {
String getHttpData(final String jwksUri, final SSLContext sslContext, final String proxyUrl) {
JwtsHelper jwtsHelper = new JwtsHelper();
return jwtsHelper.getHttpData(jwksUri, sslContext);
return jwtsHelper.getHttpData(jwksUri, sslContext, proxyUrl);
}

void loadPublicKeysFromConfig() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,19 @@
import org.mockito.Mockito;
import org.testng.annotations.Test;

import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;

import static org.mockito.Mockito.verify;
import static org.testng.Assert.*;

import java.io.ByteArrayInputStream;
import java.net.HttpURLConnection;
import java.nio.charset.StandardCharsets;

public class JwtsHelperTest {



@Test
public void testExtractJwksUri() {
Expand Down Expand Up @@ -81,4 +90,33 @@ public void testGetSocketFactory() {
JwtsHelper helper = new JwtsHelper();
assertNull(helper.getSocketFactory(sslContext));
}

@Test
public void testGetHttpData() throws Exception {
String url = "https://localhost/";
JwtsHelper helper = Mockito.spy(JwtsHelper.class);
HttpsURLConnection mockHttpConn = Mockito.mock(HttpsURLConnection.class);
Mockito.when(mockHttpConn.getResponseCode()).thenReturn(HttpURLConnection.HTTP_OK);
Mockito.when(mockHttpConn.getInputStream()).thenReturn(new ByteArrayInputStream("".getBytes(StandardCharsets.UTF_8)));
Mockito.doReturn(mockHttpConn).when(helper).getUrlConnection(url);

helper.getHttpData(url, null);

verify(helper).getUrlConnection(url);
}

@Test
public void testGetHttpDataProxy() throws Exception {
String url = "https://localhost/";
String proxyUrl = "http://localhost:8128";
JwtsHelper helper = Mockito.spy(JwtsHelper.class);
HttpsURLConnection mockHttpConn = Mockito.mock(HttpsURLConnection.class);
Mockito.when(mockHttpConn.getResponseCode()).thenReturn(HttpURLConnection.HTTP_OK);
Mockito.when(mockHttpConn.getInputStream()).thenReturn(new ByteArrayInputStream("".getBytes(StandardCharsets.UTF_8)));
Mockito.doReturn(mockHttpConn).when(helper).getUrlConnection(url, "localhost", 8128);

helper.getHttpData(url, null, proxyUrl);

verify(helper).getUrlConnection(url, "localhost", 8128);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public static void setResponseBody(final String body) {
}

@Override
String getHttpData(String jwksUri, SSLContext sslContext) {
String getHttpData(String jwksUri, SSLContext sslContext, String proxyUrl) {
return responseBody;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public void testLoadPublicKeysFromServerInvalidUri() {
@Test
public void testLoadJWKPublicKeysFromServer() {
System.setProperty(ZTS_PROP_JWK_ATHENZ_CONF, TestJwtsSigningKeyResolver.class.getClassLoader().getResource("jwk/athenz.conf").getPath());
JwtsSigningKeyResolver resolver = spy(new JwtsSigningKeyResolver("https://localhost:10099", mock(SSLContext.class)));
JwtsSigningKeyResolver resolver = spy(new JwtsSigningKeyResolver("https://localhost:10099", mock(SSLContext.class), "http://localhost:8128"));
assertNotNull(resolver);
String ecKeys = "{\n" +
" \"keys\": [\n" +
Expand All @@ -102,10 +102,11 @@ public void testLoadJWKPublicKeysFromServer() {
" }\n" +
" ]\n" +
" }";
when(resolver.getHttpData(any(), any())).thenReturn(ecKeys);
when(resolver.getHttpData(any(), any(), any())).thenReturn(ecKeys);
resolver.loadPublicKeysFromServer();
assertNotNull(resolver.getPublicKey("FdFYFzERwC2uCBB46pZQi4GG85LujR8obt-KWRBICVQ"));
assertNotNull(resolver.getPublicKey("c6e34b18-fb1c-43bb-9de7-7edc8981b14d"));
verify(resolver).getHttpData(any(), any(), eq("http://localhost:8128"));
System.clearProperty(ZTS_PROP_JWK_ATHENZ_CONF);
}
}