Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

/oauth2/keys Specify the service to obtain the public key #2642

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clients/go/zts/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1092,9 +1092,9 @@ func (client ZTSClient) GetOAuthConfig() (*OAuthConfig, error) {
}
}

func (client ZTSClient) GetJWKList(rfc *bool) (*JWKList, error) {
func (client ZTSClient) GetJWKList(rfc *bool, service ServiceName) (*JWKList, error) {
var data *JWKList
url := client.URL + "/oauth2/keys" + encodeParams(encodeOptionalBoolParam("rfc", rfc))
url := client.URL + "/oauth2/keys" + encodeParams(encodeOptionalBoolParam("rfc", rfc), encodeStringParam("service", string(service), "zts"))
resp, err := client.httpGet(url, nil)
if err != nil {
return data, err
Expand Down
1 change: 1 addition & 0 deletions clients/go/zts/zts_schema.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 22 additions & 3 deletions clients/java/zts/src/main/java/com/yahoo/athenz/zts/ZTSClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -990,25 +990,44 @@ public OpenIDConfig getOpenIDConfig() {
/**
* Retrieve list of ZTS Server public keys in Json WEB Key (JWK) format
* @param rfcCurveNames EC curve names - use values defined in RFC only
* @param service service name - Obtain the public key of the specified service (zms or zts)
* @return list of public keys (JWKs) on success. ZTSClientException will be thrown in case of failure
*/
public JWKList getJWKList(boolean rfcCurveNames) {
public JWKList getJWKList(boolean rfcCurveNames, String service) {
updateServicePrincipal();
try {
return ztsClient.getJWKList(rfcCurveNames);
return ztsClient.getJWKList(rfcCurveNames, service);
} catch (ResourceException ex) {
throw new ZTSClientException(ex.getCode(), ex.getData());
} catch (Exception ex) {
throw new ZTSClientException(ResourceException.BAD_REQUEST, ex.getMessage());
}
}

/**
* Retrieve list of ZTS Server public keys in Json WEB Key (JWK) format
* @param service service name - Obtain the public key of the specified service (zms or zts)
* @return list of public keys (JWKs) on success. ZTSClientException will be thrown in case of failure
*/
public JWKList getJWKList(String service) {
return getJWKList(false, service);
}

/**
* Retrieve list of ZTS Server public keys in Json WEB Key (JWK) format
* @param rfcCurveNames EC curve names - use values defined in RFC only
* @return list of public keys (JWKs) on success. ZTSClientException will be thrown in case of failure
*/
public JWKList getJWKList(boolean rfcCurveNames) {
return getJWKList(rfcCurveNames, "zts");
}

/**
* Retrieve list of ZTS Server public keys in Json WEB Key (JWK) format
* @return list of public keys (JWKs) on success. ZTSClientException will be thrown in case of failure
*/
public JWKList getJWKList() {
return getJWKList(false);
return getJWKList(false, "zts");
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -874,12 +874,15 @@ public OAuthConfig getOAuthConfig() throws URISyntaxException, IOException {
}
}

public JWKList getJWKList(Boolean rfc) throws URISyntaxException, IOException {
public JWKList getJWKList(Boolean rfc, String service) throws URISyntaxException, IOException {
UriTemplateBuilder uriTemplateBuilder = new UriTemplateBuilder(baseUrl, "/oauth2/keys");
URIBuilder uriBuilder = new URIBuilder(uriTemplateBuilder.getUri());
if (rfc != null) {
uriBuilder.setParameter("rfc", String.valueOf(rfc));
}
if (service != null) {
uriBuilder.setParameter("service", service);
}
HttpUriRequest httpUriRequest = RequestBuilder.get()
.setUri(uriBuilder.build())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ public void setJwkFailure(int jwkExcCode) {
}

@Override
public JWKList getJWKList(Boolean rfc) {
public JWKList getJWKList(Boolean rfc, String service) {

if (jwkExcCode != 0) {
if (jwkExcCode < 500) {
Expand Down
1 change: 1 addition & 0 deletions core/zts/src/main/java/com/yahoo/athenz/zts/ZTSSchema.java
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,7 @@ private static Schema build() {

sb.resource("JWKList", "GET", "/oauth2/keys")
.queryParam("rfc", "rfc", "Bool", false, "flag to indicate ec curve names are restricted to RFC values")
.queryParam("service", "service", "ServiceName", "zts", "service")
.expected("OK")
.exception("BAD_REQUEST", "ResourceError", "")

Expand Down
3 changes: 2 additions & 1 deletion core/zts/src/main/rdl/OAuth.rdli
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ resource OAuthConfig GET "/.well-known/oauth-authorization-server" {
}
}

resource JWKList GET "/oauth2/keys?rfc={rfc}" {
resource JWKList GET "/oauth2/keys?rfc={rfc}&service={service}" {
Bool rfc (optional, default=false); //flag to indicate ec curve names are restricted to RFC values
ServiceName service (optional, default="zts"); //service
expected OK;
exceptions {
ResourceError BAD_REQUEST;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public interface ZTSHandler {
Response postSSHCertRequest(ResourceContext context, SSHCertRequest certRequest);
OpenIDConfig getOpenIDConfig(ResourceContext context);
OAuthConfig getOAuthConfig(ResourceContext context);
JWKList getJWKList(ResourceContext context, Boolean rfc);
JWKList getJWKList(ResourceContext context, Boolean rfc, String service);
AccessTokenResponse postAccessTokenRequest(ResourceContext context, String request);
Response getOIDCResponse(ResourceContext context, String responseType, String clientId, String redirectUri, String scope, String state, String nonce, String keyType, Boolean fullArn, Integer expiryTime, String output, Boolean roleInAudClaim);
RoleCertificate postRoleCertificateRequestExt(ResourceContext context, RoleCertificateRequest req);
Expand Down
9 changes: 7 additions & 2 deletions servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -5020,13 +5020,18 @@ public ExternalCredentialsResponse postExternalCredentialsRequest(ResourceContex
}

@Override
public JWKList getJWKList(ResourceContext ctx, Boolean rfc) {
public JWKList getJWKList(ResourceContext ctx, Boolean rfc, String service) {

final String caller = ctx.getApiName();
final String principalDomain = logPrincipalAndGetDomain(ctx);

validateOIDCRequest(ctx.request(), principalDomain, caller);
return dataStore.getZtsJWKList(rfc);
switch (service) {
case ServerCommonConsts.ZMS_SERVICE:
return dataStore.getZmsJWKList(rfc);
default:
return dataStore.getZtsJWKList(rfc);
}
}

long getSvcTokenExpiryTime(Integer expiryTime) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -863,12 +863,13 @@ public OAuthConfig getOAuthConfig(
@Produces(MediaType.APPLICATION_JSON)
@Operation(description = "")
public JWKList getJWKList(
@Parameter(description = "flag to indicate ec curve names are restricted to RFC values", required = false) @QueryParam("rfc") @DefaultValue("false") Boolean rfc) {
@Parameter(description = "flag to indicate ec curve names are restricted to RFC values", required = false) @QueryParam("rfc") @DefaultValue("false") Boolean rfc,
@Parameter(description = "service", required = false) @QueryParam("service") @DefaultValue("zts") String service) {
int code = ResourceException.OK;
ResourceContext context = null;
try {
context = this.delegate.newResourceContext(this.servletContext, this.request, this.response, "getJWKList");
return this.delegate.getJWKList(context, rfc);
return this.delegate.getJWKList(context, rfc, service);
} catch (ResourceException e) {
code = e.getCode();
switch (code) {
Expand Down
67 changes: 46 additions & 21 deletions servers/zts/src/main/java/com/yahoo/athenz/zts/store/DataStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ public class DataStore implements DataCacheProvider, RolesProvider, PubKeysProvi
final RequireRoleCertCache requireRoleCertCache;
final Map<String, List<String>> hostCache;
final Map<String, String> publicKeyCache;
final JWKList zmsJWKList;
final JWKList zmsJWKListStrictRFC;
final JWKList ztsJWKList;
final JWKList ztsJWKListStrictRFC;
private final ObjectMapper jsonMapper;
Expand Down Expand Up @@ -130,6 +132,8 @@ public DataStore(ChangeLogStore clogStore, CloudStore cloudStore, Metric metric)

requireRoleCertCache = new RequireRoleCertCache();

zmsJWKList = new JWKList();
zmsJWKListStrictRFC = new JWKList();
ztsJWKList = new JWKList();
ztsJWKListStrictRFC = new JWKList();

Expand Down Expand Up @@ -586,6 +590,10 @@ public JWKList getZtsJWKList(Boolean rfc) {
return rfc == Boolean.TRUE ? ztsJWKListStrictRFC : ztsJWKList;
}

public JWKList getZmsJWKList(Boolean rfc) {
return rfc == Boolean.TRUE ? zmsJWKListStrictRFC : zmsJWKList;
}

boolean loadAthenzPublicKeys() {

final String rootDir = ZTSImpl.getRootDir();
Expand Down Expand Up @@ -613,42 +621,59 @@ boolean loadAthenzPublicKeys() {
LOGGER.error("No valid public ZMS keys in conf file: {}", confFileName);
return false;
}
loadZmsJwk(zmsPublicKeys);

final ArrayList<com.yahoo.athenz.zms.PublicKeyEntry> ztsPublicKeys = conf.getZtsPublicKeys();
if (ztsPublicKeys == null) {
LOGGER.error("Conf file {} has no ZTS Public keys", confFileName);
return false;
}
final List<JWK> jwkList = new ArrayList<>();
final List<JWK> jwkListStrictRFC = new ArrayList<>();
for (com.yahoo.athenz.zms.PublicKeyEntry publicKey : ztsPublicKeys) {
final String id = publicKey.getId();
final String key = publicKey.getKey();
if (key == null || id == null) {
LOGGER.error("Missing required zts public key attributes: {}/{}", id, key);
continue;
}
final JWK jwk = getJWK(key, id, false);
if (jwk != null) {
jwkList.add(jwk);
}
final JWK jwkRfc = getJWK(key, id, true);
if (jwkRfc != null) {
jwkListStrictRFC.add(jwkRfc);
}
}
if (jwkList.isEmpty() || jwkListStrictRFC.isEmpty()) {
if (!loadZtsJwk(ztsPublicKeys)) {
LOGGER.error("No valid public ZTS keys in conf file: {}", confFileName);
return false;
}
ztsJWKList.setKeys(jwkList);
ztsJWKListStrictRFC.setKeys(jwkListStrictRFC);
} catch (IOException ex) {
LOGGER.error("Unable to parse conf file {}, error: {}", confFileName, ex.getMessage());
return false;
}
return true;
}

boolean loadJwk(ArrayList<com.yahoo.athenz.zms.PublicKeyEntry> keys, JWKList jwkList, JWKList jwkListStrictRFC) {
final List<JWK> tmpJwkList = new ArrayList<>();
final List<JWK> tmpJwkListStrictRFC = new ArrayList<>();
for (com.yahoo.athenz.zms.PublicKeyEntry publicKey : keys) {
final String id = publicKey.getId();
final String key = publicKey.getKey();
if (key == null || id == null) {
LOGGER.error("Missing required public key attributes: {}/{}", id, key);
continue;
}
final JWK jwk = getJWK(key, id, false);
if (jwk != null) {
tmpJwkList.add(jwk);
}
final JWK jwkRfc = getJWK(key, id, true);
if (jwkRfc != null) {
tmpJwkListStrictRFC.add(jwkRfc);
}
}
if (tmpJwkList.isEmpty() || tmpJwkListStrictRFC.isEmpty()) {
return false;
}
jwkList.setKeys(tmpJwkList);
jwkListStrictRFC.setKeys(tmpJwkListStrictRFC);
return true;
}

boolean loadZmsJwk(ArrayList<com.yahoo.athenz.zms.PublicKeyEntry> keys) {
return loadJwk(keys, zmsJWKList, zmsJWKListStrictRFC);
}

boolean loadZtsJwk(ArrayList<com.yahoo.athenz.zms.PublicKeyEntry> keys) {
return loadJwk(keys, ztsJWKList, ztsJWKListStrictRFC);
}

@SuppressWarnings("rawtypes")
String getCurveName(org.bouncycastle.jce.spec.ECParameterSpec ecParameterSpec, boolean rfc) {

Expand Down
64 changes: 60 additions & 4 deletions servers/zts/src/test/java/com/yahoo/athenz/zts/ZTSImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1209,13 +1209,13 @@ public void testGetHostServices() {
}

@Test
public void testGetJWKList() {
public void testGetZTSJWKList() {

Principal principal = SimplePrincipal.create("user_domain", "user1",
"v=U1;d=user_domain;n=user;s=signature", 0, null);
ResourceContext context = createResourceContext(principal);

JWKList list = zts.getJWKList(context, false);
JWKList list = zts.getJWKList(context, false, "zts");
assertNotNull(list);
List<JWK> keys = list.getKeys();
assertEquals(keys.size(), 2);
Expand All @@ -1232,7 +1232,7 @@ public void testGetJWKList() {
// execute the same test with argument passed as null
// for the Boolean rfc object so it should be same result

list = zts.getJWKList(context, null);
list = zts.getJWKList(context, null, "zts");
assertNotNull(list);
keys = list.getKeys();
assertEquals(keys.size(), 2);
Expand All @@ -1249,7 +1249,7 @@ public void testGetJWKList() {
// now let's try with rfc option on in which case
// we'll get the curve name as P-256

list = zts.getJWKList(context, true);
list = zts.getJWKList(context, true, "zts");
assertNotNull(list);
keys = list.getKeys();
assertEquals(keys.size(), 2);
Expand All @@ -1264,6 +1264,62 @@ public void testGetJWKList() {
assertEquals(key2.getCrv(), "P-256", key2.getCrv());
}

@Test
public void testGetZMSJWKList() {

Principal principal = SimplePrincipal.create("user_domain", "user1",
"v=U1;d=user_domain;n=user;s=signature", 0, null);
ResourceContext context = createResourceContext(principal);

JWKList list = zts.getJWKList(context, false, "zms");
assertNotNull(list);
List<JWK> keys = list.getKeys();
assertEquals(keys.size(), 2);

JWK key1 = keys.get(0);
assertEquals(key1.getKty(), "RSA", key1.getKty());
assertEquals(key1.getKid(), "0", key1.getKid());

JWK key2 = keys.get(1);
assertEquals(key2.getKty(), "EC", key2.getKty());
assertEquals(key2.getKid(), "zms.dev.0", key2.getKid());
assertEquals(key2.getCrv(), "prime256v1", key2.getCrv());

// execute the same test with argument passed as null
// for the Boolean rfc object so it should be same result

list = zts.getJWKList(context, null, "zms");
assertNotNull(list);
keys = list.getKeys();
assertEquals(keys.size(), 2);

key1 = keys.get(0);
assertEquals(key1.getKty(), "RSA", key1.getKty());
assertEquals(key1.getKid(), "0", key1.getKid());

key2 = keys.get(1);
assertEquals(key2.getKty(), "EC", key2.getKty());
assertEquals(key2.getKid(), "zms.dev.0", key2.getKid());
assertEquals(key2.getCrv(), "prime256v1", key2.getCrv());

// now let's try with rfc option on in which case
// we'll get the curve name as P-256

list = zts.getJWKList(context, true, "zms");
assertNotNull(list);
keys = list.getKeys();
assertEquals(keys.size(), 2);

key1 = keys.get(0);
assertEquals(key1.getKty(), "RSA", key1.getKty());
assertEquals(key1.getKid(), "0", key1.getKid());

key2 = keys.get(1);
assertEquals(key2.getKty(), "EC", key2.getKty());
assertEquals(key2.getKid(), "zms.dev.0", key2.getKid());
assertEquals(key2.getCrv(), "P-256", key2.getCrv());
}

@Test
public void testGetHostServicesInvalidHost() {

Expand Down
2 changes: 1 addition & 1 deletion utils/zpe-updater/zpu_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ func getZtsPublicKey(config *ZpuConfiguration, ztsClient zts.ZTSClient, ztsKeyID
// fetch all zts jwk keys and update config
log.Debugf("key id: [%s] does not exist in also after reloading athenz jwks from disk, about to fetch directly from zts", ztsKeyID)
rfc := true
ztsJwkList, err := ztsClient.GetJWKList(&rfc)
ztsJwkList, err := ztsClient.GetJWKList(&rfc, "zts")
if err != nil {
return "", fmt.Errorf("unable to get the zts jwk keys, err: %v", err)
}
Expand Down