Skip to content

Commit

Permalink
Add redis cache
Browse files Browse the repository at this point in the history
Signed-off-by: Paolo Di Tommaso <paolo.ditommaso@gmail.com>
  • Loading branch information
pditommaso committed Dec 16, 2024
1 parent 8a6c2fb commit db27aae
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class CredentialServiceImpl implements CredentialsService {
if (pairing.isExpired())
log.debug("Exchange key registered for service ${PairingService.TOWER_SERVICE} at endpoint: ${identity.towerEndpoint} used after expiration, should be renewed soon")

final all = towerClient.listCredentials(identity.towerEndpoint, JwtAuth.of(identity), identity.workspaceId, identity.workflowId).get().credentials
final all = towerClient.listCredentials(identity.towerEndpoint, JwtAuth.of(identity), identity.workspaceId, identity.workflowId).credentials

if (!all) {
log.debug "No credentials found for userId=$identity.userId; workspaceId=$identity.workspaceId; endpoint=$identity.towerEndpoint"
Expand Down Expand Up @@ -92,7 +92,7 @@ class CredentialServiceImpl implements CredentialsService {
// log for debugging purposes
log.debug "Credentials matching criteria registryName=$registryName; userId=$identity.userId; workspaceId=$identity.workspaceId; endpoint=$identity.towerEndpoint => $creds"
// now fetch the encrypted key
final encryptedCredentials = towerClient.fetchEncryptedCredentials(identity.towerEndpoint, JwtAuth.of(identity), creds.id, pairing.pairingId, identity.workspaceId, identity.workflowId).get()
final encryptedCredentials = towerClient.fetchEncryptedCredentials(identity.towerEndpoint, JwtAuth.of(identity), creds.id, pairing.pairingId, identity.workspaceId, identity.workflowId)
final privateKey = pairing.privateKey
final credentials = decryptCredentials(privateKey, encryptedCredentials.keys)
return parsePayload(credentials)
Expand All @@ -112,7 +112,7 @@ class CredentialServiceImpl implements CredentialsService {
final response = towerClient.describeWorkflowLaunch(identity.towerEndpoint, JwtAuth.of(identity), identity.workflowId)
if( !response )
return null
final computeEnv = response.get()?.launch?.computeEnv
final computeEnv = response?.launch?.computeEnv
if( !computeEnv )
return null
if( computeEnv.platform != 'aws-batch' )
Expand Down
24 changes: 7 additions & 17 deletions src/main/groovy/io/seqera/wave/service/UserServiceImpl.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,14 @@
package io.seqera.wave.service

import java.util.concurrent.CompletableFuture
import java.util.concurrent.ExecutionException
import io.micronaut.core.annotation.Nullable

import groovy.transform.CompileStatic
import groovy.util.logging.Slf4j
import io.micronaut.core.annotation.Nullable
import io.seqera.wave.exception.UnauthorizedException
import io.seqera.wave.tower.User
import io.seqera.wave.tower.auth.JwtAuth
import io.seqera.wave.tower.client.TowerClient
import io.seqera.wave.tower.client.UserInfoResponse
import jakarta.inject.Inject
import jakarta.inject.Singleton
/**
Expand All @@ -50,23 +48,15 @@ class UserServiceImpl implements UserService {
if( !towerClient )
throw new IllegalStateException("Missing Tower client - make sure the 'tower' micronaut environment has been provided")

towerClient.userInfo(endpoint, auth).handle( (UserInfoResponse resp, Throwable error) -> {
if( error )
throw error
if (!resp || !resp.user)
throw new UnauthorizedException("Unauthorized - Make sure you have provided a valid access token")
log.debug("Authorized user=$resp.user")
return resp.user
})
return CompletableFuture.supplyAsync(()-> getUserByAccessToken(endpoint,auth))
}

@Override
User getUserByAccessToken(String endpoint, JwtAuth auth) {
try {
return getUserByAccessTokenAsync(endpoint, auth).get()
}
catch(ExecutionException e){
throw e.cause
}
final resp = towerClient.userInfo(endpoint, auth)
if (!resp || !resp.user)
throw new UnauthorizedException("Unauthorized - Make sure you have provided a valid access token")
log.debug("Authorized user=$resp.user")
return resp.user
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ import io.seqera.wave.encoder.MoshiEncodeStrategy
import io.seqera.wave.store.state.AbstractStateStore
import io.seqera.wave.store.state.impl.StateProvider
import jakarta.inject.Inject
import jakarta.inject.Singleton

/**
* Implement a distributed store for blob cache entry.
*
Expand All @@ -39,7 +37,6 @@ import jakarta.inject.Singleton
*/
@Slf4j
@CompileStatic
@Singleton
class BlobStoreImpl extends AbstractStateStore<BlobEntry> implements BlobStateStore {

@Inject
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package io.seqera.wave.store.cache

import java.time.Duration
import java.util.concurrent.TimeUnit
import java.util.function.Function

import com.github.benmanes.caffeine.cache.AsyncCache
import com.github.benmanes.caffeine.cache.Caffeine
Expand All @@ -33,7 +34,7 @@ import io.seqera.wave.encoder.MoshiEncodeStrategy
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
*/
@CompileStatic
class AbstractTieredCache<V> implements TieredCache<String,V> {
abstract class AbstractTieredCache<V> implements TieredCache<String,V> {

@Canonical
static class Payload {
Expand All @@ -59,6 +60,8 @@ class AbstractTieredCache<V> implements TieredCache<String,V> {
.buildAsync()
}

abstract protected String getPrefix()

@Override
V get(String key) {
// Try local cache first
Expand All @@ -77,6 +80,21 @@ class AbstractTieredCache<V> implements TieredCache<String,V> {
return value
}

V getOrCompute(String key, Function<String,V> loader) {
def result = get(key)
if( result!=null ) {
return result
}

result = loader.apply(key)
if( result!=null ) {
l1.synchronous().put(key,result)
l2Put(key,result)
}

return result
}

@Override
void put(String key, V value) {
// Store in Caffeine
Expand All @@ -85,11 +103,13 @@ class AbstractTieredCache<V> implements TieredCache<String,V> {
l2Put(key, value)
}

protected String key0(String k) { return getPrefix() + ':' + k }

protected V l2Get(String key) {
if( l2 == null )
return null

final raw = l2.get(key)
final raw = l2.get(key0(key))
if( raw == null )
return null

Expand All @@ -102,8 +122,11 @@ class AbstractTieredCache<V> implements TieredCache<String,V> {
protected void l2Put(String key, V value) {
if( l2 != null ) {
final raw = encoder.encode(new Payload(value, ttl.toMillis() + System.currentTimeMillis()))
l2.put(key, raw, ttl)
l2.put(key0(key), raw, ttl)
}
}

void invalidateAll() {
l1.synchronous().invalidateAll()
}
}
44 changes: 27 additions & 17 deletions src/main/groovy/io/seqera/wave/tower/client/TowerClient.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import java.util.concurrent.CompletableFuture

import com.google.common.hash.Hashing
import groovy.transform.CompileStatic
import io.micronaut.cache.annotation.Cacheable
import io.micronaut.core.annotation.Nullable
import io.seqera.wave.tower.auth.JwtAuth
import io.seqera.wave.tower.client.cache.ClientCacheShort
import io.seqera.wave.tower.client.connector.TowerConnector
import io.seqera.wave.tower.compute.DescribeWorkflowLaunchResponse
import jakarta.inject.Inject
Expand All @@ -43,42 +43,46 @@ class TowerClient {
@Inject
private TowerConnector connector

@Inject
private ClientCacheShort cacheShort

@Inject
private ClientCacheShort cacheLong

protected <T> CompletableFuture<T> getAsync(URI uri, String endpoint, @Nullable JwtAuth authorization, Class<T> type) {
assert uri, "Missing uri argument"
assert endpoint, "Missing endpoint argument"
return connector.sendAsync(endpoint, uri, authorization, type)
}

@Cacheable(value = 'cache-tower-client-short', atomic = true, parameters = ['cacheKey'])
protected <T> CompletableFuture<T> getCacheShort(URI uri, String endpoint, @Nullable JwtAuth authorization, Class<T> type, String cacheKey) {
return getAsync(uri, endpoint, authorization, type)
protected Object getCacheShort(URI uri, String endpoint, @Nullable JwtAuth authorization, Class type, String cacheKey) {
return cacheShort.getOrCompute(cacheKey, (k)-> getAsync(uri, endpoint, authorization, type).get())
}

@Cacheable(value = 'cache-tower-client-long', atomic = true, parameters = ['cacheKey'])
protected <T> CompletableFuture<T> getCacheLong(URI uri, String endpoint, @Nullable JwtAuth authorization, Class<T> type, String cacheKey) {
return getAsync(uri, endpoint, authorization, type)
protected Object getCacheLong(URI uri, String endpoint, @Nullable JwtAuth authorization, Class type, String cacheKey) {
return cacheLong.getOrCompute(cacheKey, (k)-> getAsync(uri, endpoint, authorization, type).get())
}

CompletableFuture<UserInfoResponse> userInfo(String towerEndpoint, JwtAuth authorization) {
UserInfoResponse userInfo(String towerEndpoint, JwtAuth authorization) {
final uri = userInfoEndpoint(towerEndpoint)
final k = makeKey(uri, authorization.key, null, null)
return getCacheLong(uri, towerEndpoint, authorization, UserInfoResponse, k)
getCacheLong(uri, towerEndpoint, authorization, UserInfoResponse, k) as UserInfoResponse
}

CompletableFuture<ListCredentialsResponse> listCredentials(String towerEndpoint, JwtAuth authorization, Long workspaceId, String workflowId) {
ListCredentialsResponse listCredentials(String towerEndpoint, JwtAuth authorization, Long workspaceId, String workflowId) {
final uri = listCredentialsEndpoint(towerEndpoint, workspaceId)
final k = makeKey(uri, authorization.key, workspaceId, workflowId)
return workflowId
return (workflowId
? getCacheLong(uri, towerEndpoint, authorization, ListCredentialsResponse, k)
: getCacheShort(uri, towerEndpoint, authorization, ListCredentialsResponse, k)
: getCacheShort(uri, towerEndpoint, authorization, ListCredentialsResponse, k)) as ListCredentialsResponse
}

CompletableFuture<GetCredentialsKeysResponse> fetchEncryptedCredentials(String towerEndpoint, JwtAuth authorization, String credentialsId, String pairingId, Long workspaceId, String workflowId) {
GetCredentialsKeysResponse fetchEncryptedCredentials(String towerEndpoint, JwtAuth authorization, String credentialsId, String pairingId, Long workspaceId, String workflowId) {
final uri = fetchCredentialsEndpoint(towerEndpoint, credentialsId, pairingId, workspaceId)
final k = makeKey(uri, authorization.key, workspaceId, workflowId)
return workflowId
return (workflowId
? getCacheLong(uri, towerEndpoint, authorization, GetCredentialsKeysResponse, k)
: getCacheShort(uri, towerEndpoint, authorization, GetCredentialsKeysResponse, k)
: getCacheShort(uri, towerEndpoint, authorization, GetCredentialsKeysResponse, k)) as GetCredentialsKeysResponse
}

protected static URI fetchCredentialsEndpoint(String towerEndpoint, String credentialsId, String pairingId, Long workspaceId) {
Expand Down Expand Up @@ -116,10 +120,10 @@ class TowerClient {
StringUtils.removeEnd(endpoint, "/")
}

CompletableFuture<DescribeWorkflowLaunchResponse> describeWorkflowLaunch(String towerEndpoint, JwtAuth authorization, String workflowId) {
DescribeWorkflowLaunchResponse describeWorkflowLaunch(String towerEndpoint, JwtAuth authorization, String workflowId) {
final uri = workflowLaunchEndpoint(towerEndpoint,workflowId)
final k = makeKey(uri, authorization.key, null, workflowId)
return getCacheShort(uri, towerEndpoint, authorization, DescribeWorkflowLaunchResponse.class, k)
return getCacheShort(uri, towerEndpoint, authorization, DescribeWorkflowLaunchResponse.class, k) as DescribeWorkflowLaunchResponse
}

protected static URI workflowLaunchEndpoint(String towerEndpoint, String workflowId) {
Expand All @@ -135,4 +139,10 @@ class TowerClient {
}
return h.hash()
}

/** Only for testing - do not use */
protected void invalidateCache() {
cacheLong.invalidateAll()
cacheShort.invalidateAll()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Wave, containers provisioning service
* Copyright (c) 2023-2024, Seqera Labs
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

package io.seqera.wave.tower.client.cache

import java.time.Duration

import groovy.transform.CompileStatic
import io.micronaut.context.annotation.Value
import io.micronaut.core.annotation.Nullable
import io.seqera.wave.store.cache.AbstractTieredCache
import io.seqera.wave.store.cache.L2TieredCache
import jakarta.inject.Singleton

/**
*
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
*/
@Singleton
@CompileStatic
class ClientCacheLong extends AbstractTieredCache {
ClientCacheLong(@Nullable L2TieredCache l2,
@Value('${wave.pairing.client.long.ttl:24h}') Duration ttl,
@Value('${wave.pairing.client.long.max-size:10000}') int maxSize)
{
super(l2, ttl, maxSize)
}

@Override
protected String getPrefix() {
return 'pairing-cache-long/v1'
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Wave, containers provisioning service
* Copyright (c) 2023-2024, Seqera Labs
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

package io.seqera.wave.tower.client.cache

import java.time.Duration

import groovy.transform.CompileStatic
import io.micronaut.context.annotation.Value
import io.micronaut.core.annotation.Nullable
import io.seqera.wave.store.cache.AbstractTieredCache
import io.seqera.wave.store.cache.L2TieredCache
import jakarta.inject.Singleton

/**
*
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
*/
@Singleton
@CompileStatic
class ClientCacheShort extends AbstractTieredCache {
ClientCacheShort(@Nullable L2TieredCache l2,
@Value('${wave.pairing.client.short.ttl:60s}') Duration ttl,
@Value('${wave.pairing.client.short.max-size:10000}') int maxSize)
{
super(l2, ttl, maxSize)
}

@Override
protected String getPrefix() {
return 'pairing-cache-short/v1'
}
}
Loading

0 comments on commit db27aae

Please sign in to comment.