Skip to content

Commit

Permalink
/oauth2/keys Specify the service to obtain the public key (#2642)
Browse files Browse the repository at this point in the history
* Specify the service to obtain the public key

---------

Signed-off-by: takumats <takumats@lycorp.co.jp>
  • Loading branch information
TakuyaMatsu authored Jul 1, 2024
1 parent 48efa78 commit 11ed67e
Show file tree
Hide file tree
Showing 13 changed files with 151 additions and 39 deletions.
4 changes: 2 additions & 2 deletions clients/go/zts/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1092,9 +1092,9 @@ func (client ZTSClient) GetOAuthConfig() (*OAuthConfig, error) {
}
}

func (client ZTSClient) GetJWKList(rfc *bool) (*JWKList, error) {
func (client ZTSClient) GetJWKList(rfc *bool, service ServiceName) (*JWKList, error) {
var data *JWKList
url := client.URL + "/oauth2/keys" + encodeParams(encodeOptionalBoolParam("rfc", rfc))
url := client.URL + "/oauth2/keys" + encodeParams(encodeOptionalBoolParam("rfc", rfc), encodeStringParam("service", string(service), "zts"))
resp, err := client.httpGet(url, nil)
if err != nil {
return data, err
Expand Down
1 change: 1 addition & 0 deletions clients/go/zts/zts_schema.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 22 additions & 3 deletions clients/java/zts/src/main/java/com/yahoo/athenz/zts/ZTSClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -990,25 +990,44 @@ public OpenIDConfig getOpenIDConfig() {
/**
* Retrieve list of ZTS Server public keys in Json WEB Key (JWK) format
* @param rfcCurveNames EC curve names - use values defined in RFC only
* @param service service name - Obtain the public key of the specified service (zms or zts)
* @return list of public keys (JWKs) on success. ZTSClientException will be thrown in case of failure
*/
public JWKList getJWKList(boolean rfcCurveNames) {
public JWKList getJWKList(boolean rfcCurveNames, String service) {
updateServicePrincipal();
try {
return ztsClient.getJWKList(rfcCurveNames);
return ztsClient.getJWKList(rfcCurveNames, service);
} catch (ResourceException ex) {
throw new ZTSClientException(ex.getCode(), ex.getData());
} catch (Exception ex) {
throw new ZTSClientException(ResourceException.BAD_REQUEST, ex.getMessage());
}
}

/**
* Retrieve list of ZTS Server public keys in Json WEB Key (JWK) format
* @param service service name - Obtain the public key of the specified service (zms or zts)
* @return list of public keys (JWKs) on success. ZTSClientException will be thrown in case of failure
*/
public JWKList getJWKList(String service) {
return getJWKList(false, service);
}

/**
* Retrieve list of ZTS Server public keys in Json WEB Key (JWK) format
* @param rfcCurveNames EC curve names - use values defined in RFC only
* @return list of public keys (JWKs) on success. ZTSClientException will be thrown in case of failure
*/
public JWKList getJWKList(boolean rfcCurveNames) {
return getJWKList(rfcCurveNames, "zts");
}

/**
* Retrieve list of ZTS Server public keys in Json WEB Key (JWK) format
* @return list of public keys (JWKs) on success. ZTSClientException will be thrown in case of failure
*/
public JWKList getJWKList() {
return getJWKList(false);
return getJWKList(false, "zts");
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -874,12 +874,15 @@ public OAuthConfig getOAuthConfig() throws URISyntaxException, IOException {
}
}

public JWKList getJWKList(Boolean rfc) throws URISyntaxException, IOException {
public JWKList getJWKList(Boolean rfc, String service) throws URISyntaxException, IOException {
UriTemplateBuilder uriTemplateBuilder = new UriTemplateBuilder(baseUrl, "/oauth2/keys");
URIBuilder uriBuilder = new URIBuilder(uriTemplateBuilder.getUri());
if (rfc != null) {
uriBuilder.setParameter("rfc", String.valueOf(rfc));
}
if (service != null) {
uriBuilder.setParameter("service", service);
}
HttpUriRequest httpUriRequest = RequestBuilder.get()
.setUri(uriBuilder.build())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ public void setJwkFailure(int jwkExcCode) {
}

@Override
public JWKList getJWKList(Boolean rfc) {
public JWKList getJWKList(Boolean rfc, String service) {

if (jwkExcCode != 0) {
if (jwkExcCode < 500) {
Expand Down
1 change: 1 addition & 0 deletions core/zts/src/main/java/com/yahoo/athenz/zts/ZTSSchema.java
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,7 @@ private static Schema build() {

sb.resource("JWKList", "GET", "/oauth2/keys")
.queryParam("rfc", "rfc", "Bool", false, "flag to indicate ec curve names are restricted to RFC values")
.queryParam("service", "service", "ServiceName", "zts", "service")
.expected("OK")
.exception("BAD_REQUEST", "ResourceError", "")

Expand Down
3 changes: 2 additions & 1 deletion core/zts/src/main/rdl/OAuth.rdli
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ resource OAuthConfig GET "/.well-known/oauth-authorization-server" {
}
}

resource JWKList GET "/oauth2/keys?rfc={rfc}" {
resource JWKList GET "/oauth2/keys?rfc={rfc}&service={service}" {
Bool rfc (optional, default=false); //flag to indicate ec curve names are restricted to RFC values
ServiceName service (optional, default="zts"); //service
expected OK;
exceptions {
ResourceError BAD_REQUEST;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public interface ZTSHandler {
Response postSSHCertRequest(ResourceContext context, SSHCertRequest certRequest);
OpenIDConfig getOpenIDConfig(ResourceContext context);
OAuthConfig getOAuthConfig(ResourceContext context);
JWKList getJWKList(ResourceContext context, Boolean rfc);
JWKList getJWKList(ResourceContext context, Boolean rfc, String service);
AccessTokenResponse postAccessTokenRequest(ResourceContext context, String request);
Response getOIDCResponse(ResourceContext context, String responseType, String clientId, String redirectUri, String scope, String state, String nonce, String keyType, Boolean fullArn, Integer expiryTime, String output, Boolean roleInAudClaim);
RoleCertificate postRoleCertificateRequestExt(ResourceContext context, RoleCertificateRequest req);
Expand Down
9 changes: 7 additions & 2 deletions servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -5020,13 +5020,18 @@ public ExternalCredentialsResponse postExternalCredentialsRequest(ResourceContex
}

@Override
public JWKList getJWKList(ResourceContext ctx, Boolean rfc) {
public JWKList getJWKList(ResourceContext ctx, Boolean rfc, String service) {

final String caller = ctx.getApiName();
final String principalDomain = logPrincipalAndGetDomain(ctx);

validateOIDCRequest(ctx.request(), principalDomain, caller);
return dataStore.getZtsJWKList(rfc);
switch (service) {
case ServerCommonConsts.ZMS_SERVICE:
return dataStore.getZmsJWKList(rfc);
default:
return dataStore.getZtsJWKList(rfc);
}
}

long getSvcTokenExpiryTime(Integer expiryTime) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -863,12 +863,13 @@ public OAuthConfig getOAuthConfig(
@Produces(MediaType.APPLICATION_JSON)
@Operation(description = "")
public JWKList getJWKList(
@Parameter(description = "flag to indicate ec curve names are restricted to RFC values", required = false) @QueryParam("rfc") @DefaultValue("false") Boolean rfc) {
@Parameter(description = "flag to indicate ec curve names are restricted to RFC values", required = false) @QueryParam("rfc") @DefaultValue("false") Boolean rfc,
@Parameter(description = "service", required = false) @QueryParam("service") @DefaultValue("zts") String service) {
int code = ResourceException.OK;
ResourceContext context = null;
try {
context = this.delegate.newResourceContext(this.servletContext, this.request, this.response, "getJWKList");
return this.delegate.getJWKList(context, rfc);
return this.delegate.getJWKList(context, rfc, service);
} catch (ResourceException e) {
code = e.getCode();
switch (code) {
Expand Down
67 changes: 46 additions & 21 deletions servers/zts/src/main/java/com/yahoo/athenz/zts/store/DataStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ public class DataStore implements DataCacheProvider, RolesProvider, PubKeysProvi
final RequireRoleCertCache requireRoleCertCache;
final Map<String, List<String>> hostCache;
final Map<String, String> publicKeyCache;
final JWKList zmsJWKList;
final JWKList zmsJWKListStrictRFC;
final JWKList ztsJWKList;
final JWKList ztsJWKListStrictRFC;
private final ObjectMapper jsonMapper;
Expand Down Expand Up @@ -130,6 +132,8 @@ public DataStore(ChangeLogStore clogStore, CloudStore cloudStore, Metric metric)

requireRoleCertCache = new RequireRoleCertCache();

zmsJWKList = new JWKList();
zmsJWKListStrictRFC = new JWKList();
ztsJWKList = new JWKList();
ztsJWKListStrictRFC = new JWKList();

Expand Down Expand Up @@ -586,6 +590,10 @@ public JWKList getZtsJWKList(Boolean rfc) {
return rfc == Boolean.TRUE ? ztsJWKListStrictRFC : ztsJWKList;
}

public JWKList getZmsJWKList(Boolean rfc) {
return rfc == Boolean.TRUE ? zmsJWKListStrictRFC : zmsJWKList;
}

boolean loadAthenzPublicKeys() {

final String rootDir = ZTSImpl.getRootDir();
Expand Down Expand Up @@ -613,42 +621,59 @@ boolean loadAthenzPublicKeys() {
LOGGER.error("No valid public ZMS keys in conf file: {}", confFileName);
return false;
}
loadZmsJwk(zmsPublicKeys);

final ArrayList<com.yahoo.athenz.zms.PublicKeyEntry> ztsPublicKeys = conf.getZtsPublicKeys();
if (ztsPublicKeys == null) {
LOGGER.error("Conf file {} has no ZTS Public keys", confFileName);
return false;
}
final List<JWK> jwkList = new ArrayList<>();
final List<JWK> jwkListStrictRFC = new ArrayList<>();
for (com.yahoo.athenz.zms.PublicKeyEntry publicKey : ztsPublicKeys) {
final String id = publicKey.getId();
final String key = publicKey.getKey();
if (key == null || id == null) {
LOGGER.error("Missing required zts public key attributes: {}/{}", id, key);
continue;
}
final JWK jwk = getJWK(key, id, false);
if (jwk != null) {
jwkList.add(jwk);
}
final JWK jwkRfc = getJWK(key, id, true);
if (jwkRfc != null) {
jwkListStrictRFC.add(jwkRfc);
}
}
if (jwkList.isEmpty() || jwkListStrictRFC.isEmpty()) {
if (!loadZtsJwk(ztsPublicKeys)) {
LOGGER.error("No valid public ZTS keys in conf file: {}", confFileName);
return false;
}
ztsJWKList.setKeys(jwkList);
ztsJWKListStrictRFC.setKeys(jwkListStrictRFC);
} catch (IOException ex) {
LOGGER.error("Unable to parse conf file {}, error: {}", confFileName, ex.getMessage());
return false;
}
return true;
}

boolean loadJwk(ArrayList<com.yahoo.athenz.zms.PublicKeyEntry> keys, JWKList jwkList, JWKList jwkListStrictRFC) {
final List<JWK> tmpJwkList = new ArrayList<>();
final List<JWK> tmpJwkListStrictRFC = new ArrayList<>();
for (com.yahoo.athenz.zms.PublicKeyEntry publicKey : keys) {
final String id = publicKey.getId();
final String key = publicKey.getKey();
if (key == null || id == null) {
LOGGER.error("Missing required public key attributes: {}/{}", id, key);
continue;
}
final JWK jwk = getJWK(key, id, false);
if (jwk != null) {
tmpJwkList.add(jwk);
}
final JWK jwkRfc = getJWK(key, id, true);
if (jwkRfc != null) {
tmpJwkListStrictRFC.add(jwkRfc);
}
}
if (tmpJwkList.isEmpty() || tmpJwkListStrictRFC.isEmpty()) {
return false;
}
jwkList.setKeys(tmpJwkList);
jwkListStrictRFC.setKeys(tmpJwkListStrictRFC);
return true;
}

boolean loadZmsJwk(ArrayList<com.yahoo.athenz.zms.PublicKeyEntry> keys) {
return loadJwk(keys, zmsJWKList, zmsJWKListStrictRFC);
}

boolean loadZtsJwk(ArrayList<com.yahoo.athenz.zms.PublicKeyEntry> keys) {
return loadJwk(keys, ztsJWKList, ztsJWKListStrictRFC);
}

@SuppressWarnings("rawtypes")
String getCurveName(org.bouncycastle.jce.spec.ECParameterSpec ecParameterSpec, boolean rfc) {

Expand Down
64 changes: 60 additions & 4 deletions servers/zts/src/test/java/com/yahoo/athenz/zts/ZTSImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1209,13 +1209,13 @@ public void testGetHostServices() {
}

@Test
public void testGetJWKList() {
public void testGetZTSJWKList() {

Principal principal = SimplePrincipal.create("user_domain", "user1",
"v=U1;d=user_domain;n=user;s=signature", 0, null);
ResourceContext context = createResourceContext(principal);

JWKList list = zts.getJWKList(context, false);
JWKList list = zts.getJWKList(context, false, "zts");
assertNotNull(list);
List<JWK> keys = list.getKeys();
assertEquals(keys.size(), 2);
Expand All @@ -1232,7 +1232,7 @@ public void testGetJWKList() {
// execute the same test with argument passed as null
// for the Boolean rfc object so it should be same result

list = zts.getJWKList(context, null);
list = zts.getJWKList(context, null, "zts");
assertNotNull(list);
keys = list.getKeys();
assertEquals(keys.size(), 2);
Expand All @@ -1249,7 +1249,7 @@ public void testGetJWKList() {
// now let's try with rfc option on in which case
// we'll get the curve name as P-256

list = zts.getJWKList(context, true);
list = zts.getJWKList(context, true, "zts");
assertNotNull(list);
keys = list.getKeys();
assertEquals(keys.size(), 2);
Expand All @@ -1264,6 +1264,62 @@ public void testGetJWKList() {
assertEquals(key2.getCrv(), "P-256", key2.getCrv());
}

@Test
public void testGetZMSJWKList() {

Principal principal = SimplePrincipal.create("user_domain", "user1",
"v=U1;d=user_domain;n=user;s=signature", 0, null);
ResourceContext context = createResourceContext(principal);

JWKList list = zts.getJWKList(context, false, "zms");
assertNotNull(list);
List<JWK> keys = list.getKeys();
assertEquals(keys.size(), 2);

JWK key1 = keys.get(0);
assertEquals(key1.getKty(), "RSA", key1.getKty());
assertEquals(key1.getKid(), "0", key1.getKid());

JWK key2 = keys.get(1);
assertEquals(key2.getKty(), "EC", key2.getKty());
assertEquals(key2.getKid(), "zms.dev.0", key2.getKid());
assertEquals(key2.getCrv(), "prime256v1", key2.getCrv());

// execute the same test with argument passed as null
// for the Boolean rfc object so it should be same result

list = zts.getJWKList(context, null, "zms");
assertNotNull(list);
keys = list.getKeys();
assertEquals(keys.size(), 2);

key1 = keys.get(0);
assertEquals(key1.getKty(), "RSA", key1.getKty());
assertEquals(key1.getKid(), "0", key1.getKid());

key2 = keys.get(1);
assertEquals(key2.getKty(), "EC", key2.getKty());
assertEquals(key2.getKid(), "zms.dev.0", key2.getKid());
assertEquals(key2.getCrv(), "prime256v1", key2.getCrv());

// now let's try with rfc option on in which case
// we'll get the curve name as P-256

list = zts.getJWKList(context, true, "zms");
assertNotNull(list);
keys = list.getKeys();
assertEquals(keys.size(), 2);

key1 = keys.get(0);
assertEquals(key1.getKty(), "RSA", key1.getKty());
assertEquals(key1.getKid(), "0", key1.getKid());

key2 = keys.get(1);
assertEquals(key2.getKty(), "EC", key2.getKty());
assertEquals(key2.getKid(), "zms.dev.0", key2.getKid());
assertEquals(key2.getCrv(), "P-256", key2.getCrv());
}

@Test
public void testGetHostServicesInvalidHost() {

Expand Down
2 changes: 1 addition & 1 deletion utils/zpe-updater/zpu_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ func getZtsPublicKey(config *ZpuConfiguration, ztsClient zts.ZTSClient, ztsKeyID
// fetch all zts jwk keys and update config
log.Debugf("key id: [%s] does not exist in also after reloading athenz jwks from disk, about to fetch directly from zts", ztsKeyID)
rfc := true
ztsJwkList, err := ztsClient.GetJWKList(&rfc)
ztsJwkList, err := ztsClient.GetJWKList(&rfc, "zts")
if err != nil {
return "", fmt.Errorf("unable to get the zts jwk keys, err: %v", err)
}
Expand Down

0 comments on commit 11ed67e

Please sign in to comment.