Skip to content

Commit

Permalink
feat: support organizations (#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyijun authored Nov 29, 2023
1 parent c5942ff commit bf533c0
Show file tree
Hide file tree
Showing 16 changed files with 214 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ import io.logto.sdk.android.type.LogtoConfig
import io.logto.sdk.android.util.LogtoUtils.expiresAtFrom
import io.logto.sdk.android.util.LogtoUtils.nowRoundToSec
import io.logto.sdk.core.Core
import io.logto.sdk.core.constant.UserScope
import io.logto.sdk.core.type.IdTokenClaims
import io.logto.sdk.core.type.OidcConfigResponse
import io.logto.sdk.core.type.UserInfoResponse
import io.logto.sdk.core.util.TokenUtils
import org.jetbrains.annotations.TestOnly
import org.jose4j.jwk.JsonWebKeySet
import org.jose4j.jwt.JwtClaims
import org.jose4j.jwt.consumer.InvalidJwtException
import org.jose4j.lang.JoseException

Expand Down Expand Up @@ -118,6 +120,10 @@ open class LogtoClient(
issuer = oidcConfig.issuer,
responseIdToken = codeToken.idToken,
responseRefreshToken = codeToken.refreshToken,
/**
* Treat `scopes` as `null` to construct the default access token key
*/
accessTokenKey = buildAccessTokenKey(),
accessToken = accessToken,
completion = completion,
)
Expand Down Expand Up @@ -176,7 +182,28 @@ open class LogtoClient(
* @param[completion] the completion which handles the result
*/
fun getAccessToken(completion: Completion<LogtoException, AccessToken>) =
getAccessToken(null, completion)
getAccessToken(null, null, completion)

/**
* Get the access token for the specified organization with refresh strategy.
*
* Scope `UserScope.Organizations` is required in the config to use organization-related methods.
*
*/
fun getOrganizationToken(
organizationId: String,
completion: Completion<LogtoException, AccessToken>,
) {
if (!logtoConfig.scopes.contains(UserScope.ORGANIZATIONS)) {
completion.onComplete(
LogtoException(LogtoException.Type.MISSING_SCOPE_ORGANIZATIONS),
null,
)
return
}

return getAccessToken(null, organizationId, completion)
}

/**
* Get access token
Expand All @@ -185,6 +212,7 @@ open class LogtoClient(
*/
fun getAccessToken(
resource: String?,
organizationId: String?,
completion: Completion<LogtoException, AccessToken>,
) {
if (!isAuthenticated) {
Expand All @@ -203,7 +231,7 @@ open class LogtoClient(
}

// MARK: Retrieve access token from accessTokenMap
val accessTokenKey = buildAccessTokenKey(null, resource)
val accessTokenKey = buildAccessTokenKey(null, resource, organizationId)
val accessToken = accessTokenMap[accessTokenKey]
accessToken?.let {
if (it.expiresAt > nowRoundToSec()) {
Expand All @@ -230,6 +258,7 @@ open class LogtoClient(
clientId = logtoConfig.appId,
refreshToken = requireNotNull(refreshToken),
resource = resource,
organizationId = organizationId,
scopes = null,
) { fetchRefreshedTokenException, fetchedTokenResponse ->
fetchRefreshedTokenException?.let {
Expand Down Expand Up @@ -257,6 +286,7 @@ open class LogtoClient(
issuer = oidcConfig.issuer,
responseIdToken = refreshedToken.idToken,
responseRefreshToken = refreshedToken.refreshToken,
accessTokenKey = buildAccessTokenKey(null, resource, organizationId),
accessToken = refreshedAccessToken,
) { verifyException ->
verifyException?.let { completion.onComplete(it, null) }
Expand Down Expand Up @@ -286,6 +316,34 @@ open class LogtoClient(
}
}

/**
* Get the organization token claims for the specified organization.
*
* @param[organizationId] The ID of the organization that the access token is granted for.
* @param[completion] the completion which handles the retrieved result
*/
fun getOrganizationTokenClaims(
organizationId: String,
completion: Completion<LogtoException, JwtClaims>,
) {
getOrganizationToken(organizationId) { getOrgTokenException, token ->
getOrgTokenException?.let {
completion.onComplete(it, null)
return@getOrganizationToken
}

try {
val tokenClaims = TokenUtils.decodeToken(requireNotNull(token).token)
completion.onComplete(null, tokenClaims)
} catch (exception: InvalidJwtException) {
completion.onComplete(
LogtoException(LogtoException.Type.UNABLE_TO_PARSE_TOKEN_CLAIMS, exception),
null,
)
}
}
}

/**
* Fetch user info
* @param[completion] the completion which handles the retrieved result
Expand Down Expand Up @@ -322,6 +380,7 @@ open class LogtoClient(
issuer: String,
responseIdToken: String?,
responseRefreshToken: String?,
accessTokenKey: String,
accessToken: AccessToken,
completion: EmptyCompletion<LogtoException>,
) {
Expand All @@ -340,10 +399,6 @@ open class LogtoClient(
idToken = it
}

// Note
// - Treat `scopes` as `null` to construct the default access token key
// for we do not support custom scopes in V1
val accessTokenKey = buildAccessTokenKey(null, getResourceFromAccessToken(accessToken.token))
accessTokenMap[accessTokenKey] = accessToken
refreshToken = responseRefreshToken
completion.onComplete(null)
Expand Down Expand Up @@ -405,16 +460,15 @@ open class LogtoClient(
idToken = storage?.getItem(StorageKey.ID_TOKEN)
}

private fun getResourceFromAccessToken(accessToken: String) = try {
TokenUtils.decodeToken(accessToken).audience[0]
} catch (_: InvalidJwtException) {
null
}

internal fun buildAccessTokenKey(scopes: List<String>?, resource: String?): String {
internal fun buildAccessTokenKey(
scopes: List<String>? = null,
resource: String? = null,
organizationId: String? = null,
): String {
val scopesPart = scopes?.sorted()?.joinToString(" ") ?: ""
val resourcePart = resource ?: ""
return "$scopesPart@$resourcePart"
val organizationPart = organizationId?.let { "#$it" } ?: ""
return "$scopesPart@$resourcePart$organizationPart"
}

@TestOnly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ class LogtoException(
UNABLE_TO_FETCH_TOKEN_BY_AUTHORIZATION_CODE,
UNABLE_TO_FETCH_TOKEN_BY_REFRESH_TOKEN,
UNABLE_TO_REVOKE_TOKEN,
UNABLE_TO_PARSE_TOKEN_CLAIMS,
UNABLE_TO_PARSE_ID_TOKEN_CLAIMS,
UNABLE_TO_FETCH_USER_INFO,
UNABLE_TO_FETCH_JWKS_JSON,
UNABLE_TO_PARSE_JWKS,
INVALID_ID_TOKEN,
MISSING_SCOPE_ORGANIZATIONS,
ALIPAY_APP_ID_NO_FOUND,
ALIPAY_AUTH_FAILED,
WECHAT_APP_ID_NO_FOUND,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
package io.logto.sdk.android.type

import io.logto.sdk.core.constant.PromptValue
import io.logto.sdk.core.constant.ReservedResource
import io.logto.sdk.core.constant.UserScope
import io.logto.sdk.core.util.ScopeUtils

class LogtoConfig(
val endpoint: String,
val appId: String,
scopes: List<String>? = null,
val resources: List<String>? = null,
resources: List<String>? = null,
val usingPersistStorage: Boolean = true,
val prompt: String = PromptValue.CONSENT,
) {
/**
* Normalize the Logto client configuration per the following rules:
*
* - Add default scopes (`openid`, `offline_access` and `profile`) if not provided.
* - Add `ReservedResource.Organization` to resources if `UserScope.Organizations` is included in scopes.
*/
val scopes = ScopeUtils.withDefaultScopes(scopes)
val resources = if (this.scopes.contains(UserScope.ORGANIZATIONS)) {
(resources.orEmpty() + ReservedResource.ORGANIZATION)
} else {
resources
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ class LogtoClientTest {

logtoClient.getAccessToken(
resource = TEST_RESOURCE_3,
organizationId = null,
) { logtoException, result ->
assertThat(logtoException)
.hasMessageThat()
Expand All @@ -284,7 +285,7 @@ class LogtoClientTest {
logtoClient = LogtoClient(logtoConfigMock, mockk())
logtoClient.setupRefreshToken(TEST_REFRESH_TOKEN)

val testTokenKey = logtoClient.buildAccessTokenKey(null, null)
val testTokenKey = logtoClient.buildAccessTokenKey()
val testAccessToken: AccessToken = mockk()
every { testAccessToken.expiresAt } returns LogtoUtils.nowRoundToSec() + timeBias

Expand All @@ -295,6 +296,7 @@ class LogtoClientTest {

logtoClient.getAccessToken(
null,
null,
) { logtoException, result ->
assertThat(logtoException).isNull()
assertThat(result).isEqualTo(testAccessToken)
Expand All @@ -305,13 +307,14 @@ class LogtoClientTest {
fun `getAccessToken should refresh token when existing accessToken is expired`() {
setupRefreshTokenTestEnv()

val expiredAccessTokenKey = logtoClient.buildAccessTokenKey(null, null)
val expiredAccessTokenKey = logtoClient.buildAccessTokenKey()
val expiredAccessToken: AccessToken = mockk()
every { expiredAccessToken.expiresAt } returns LogtoUtils.nowRoundToSec() - timeBias
logtoClient.setupAccessTokenMap(mapOf(expiredAccessTokenKey to expiredAccessToken))

logtoClient.getAccessToken(
null,
null,
) { logtoException, result ->
assertThat(logtoException).isNull()
assertThat(result).isNotNull()
Expand All @@ -321,7 +324,7 @@ class LogtoClientTest {
}

verify(exactly = 1) {
Core.fetchTokenByRefreshToken(any(), any(), any(), any(), any(), any())
Core.fetchTokenByRefreshToken(any(), any(), any(), any(), any(), any(), any())
}
}

Expand All @@ -331,6 +334,7 @@ class LogtoClientTest {

logtoClient.getAccessToken(
null,
null,
) { logtoException, result ->
assertThat(logtoException).isNull()
assertThat(result).isNotNull()
Expand All @@ -340,7 +344,7 @@ class LogtoClientTest {
}

verify(exactly = 1) {
Core.fetchTokenByRefreshToken(any(), any(), any(), any(), any(), any())
Core.fetchTokenByRefreshToken(any(), any(), any(), any(), any(), any(), any())
}
}

Expand Down Expand Up @@ -453,7 +457,7 @@ class LogtoClientTest {
}
val accessTokenMock: AccessToken = mockk()
every { accessTokenMock.token } returns TEST_ACCESS_TOKEN
every { logtoClient.getAccessToken(any(), any()) } answers {
every { logtoClient.getAccessToken(any(), any(), any()) } answers {
lastArg<Completion<LogtoException, AccessToken>>().onComplete(null, accessTokenMock)
}

Expand Down Expand Up @@ -502,7 +506,7 @@ class LogtoClientTest {
}

val mockGetAccessTokenException: LogtoException = mockk()
every { logtoClient.getAccessToken(any(), any()) } answers {
every { logtoClient.getAccessToken(any(), any(), any()) } answers {
lastArg<Completion<LogtoException, AccessToken>>().onComplete(mockGetAccessTokenException, null)
}

Expand All @@ -524,7 +528,7 @@ class LogtoClientTest {
}
val accessTokenMock: AccessToken = mockk()
every { accessTokenMock.token } returns TEST_ACCESS_TOKEN
every { logtoClient.getAccessToken(any(), any()) } answers {
every { logtoClient.getAccessToken(any(), any(), any()) } answers {
lastArg<Completion<LogtoException, AccessToken>>().onComplete(null, accessTokenMock)
}

Expand Down Expand Up @@ -701,7 +705,7 @@ class LogtoClientTest {
every { refreshTokenTokenResponseMock.idToken } returns TEST_ID_TOKEN

mockkObject(Core)
every { Core.fetchTokenByRefreshToken(any(), any(), any(), any(), any(), any()) } answers {
every { Core.fetchTokenByRefreshToken(any(), any(), any(), any(), any(), any(), any()) } answers {
lastArg<HttpCompletion<RefreshTokenTokenResponse>>()
.onComplete(null, refreshTokenTokenResponseMock)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.logto.sdk.android.type

import com.google.common.truth.Truth.assertThat
import io.logto.sdk.core.constant.ReservedResource
import io.logto.sdk.core.constant.ReservedScope
import io.logto.sdk.core.constant.UserScope
import org.junit.Test
Expand Down Expand Up @@ -32,4 +33,15 @@ class LogtoConfigTest {
contains("other_scope")
}
}

@Test
fun `LogtoConfig's resource should contain 'organization' if organization scope is provided`() {
val logtoConfig = LogtoConfig(
endpoint = "endpoint",
appId = "appId",
scopes = listOf(UserScope.ORGANIZATIONS)
)

assertThat(logtoConfig.resources).contains(ReservedResource.ORGANIZATION)
}
}
11 changes: 11 additions & 0 deletions kotlin-sdk/kotlin/src/main/kotlin/io/logto/sdk/core/Core.kt
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,20 @@ object Core {
}

fun fetchTokenByRefreshToken(
/** The token endpoint of the authorization server. */
tokenEndpoint: String,
/** The client ID of the application. */
clientId: String,
/** The refresh token to be used to fetch the access token. */
refreshToken: String,
/** The API resource to be fetch the access token for. */
resource: String?,
/** The ID of the organization to be fetch the access token for. */
organizationId: String?,
/**
* The scopes to request for the access token. If not provided, the authorization server
* will use all the scopes that the client is authorized for.
*/
scopes: List<String>?,
completion: HttpCompletion<RefreshTokenTokenResponse>,
) {
Expand All @@ -100,6 +110,7 @@ object Core {
add(QueryKey.REFRESH_TOKEN, refreshToken)
add(QueryKey.GRANT_TYPE, GrantType.REFRESH_TOKEN)
resource?.let { add(QueryKey.RESOURCE, it) }
organizationId?.let { add(QueryKey.ORGANIZATION_ID, it) }
scopes?.let { add(QueryKey.SCOPE, it.joinToString(" ")) }
}.build()
httpPost(tokenEndpoint, formBody, completion)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ object ClaimName {
const val EMAIL_VERIFIED = "email_verified"
const val PHONE_NUMBER = "phone_number"
const val PHONE_NUMBER_VERIFIED = "phone_number_verified"
const val ROLES = "roles"
const val ORGANIZATIONS = "organizations"
const val ORGANIZATION_ROLES = "organization_roles"
const val CUSTOM_DATA = "custom_data"
const val IDENTITIES = "identities"
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ object QueryKey {
const val SCOPE = "scope"
const val STATE = "state"
const val TOKEN = "token"
const val ORGANIZATION_ID = "organization_id"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package io.logto.sdk.core.constant

object ReservedResource {
const val ORGANIZATION = "urn:logto:resource:organizations"
}
Loading

0 comments on commit bf533c0

Please sign in to comment.