Skip to content

Commit

Permalink
Improve Tower client caching via Tiered cache (#772)
Browse files Browse the repository at this point in the history

Signed-off-by: Paolo Di Tommaso <paolo.ditommaso@gmail.com>
Signed-off-by: munishchouhan <hrma017@gmail.com>
Co-authored-by: munishchouhan <hrma017@gmail.com>
  • Loading branch information
pditommaso and munishchouhan authored Dec 17, 2024
1 parent 13a5937 commit f0ca0f6
Show file tree
Hide file tree
Showing 32 changed files with 1,101 additions and 118 deletions.
17 changes: 12 additions & 5 deletions src/main/groovy/io/seqera/wave/encoder/MoshiEncodeStrategy.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,18 @@ abstract class MoshiEncodeStrategy<V> implements EncodingStrategy<V> {
init()
}

MoshiEncodeStrategy(JsonAdapter.Factory customFactory) {
this.type = TypeHelper.getGenericType(this, 0)
init(customFactory)
}

MoshiEncodeStrategy(Type type) {
this.type = type
init()
}

private void init() {
this.moshi = new Moshi.Builder()
private void init(JsonAdapter.Factory customFactory=null) {
final builder = new Moshi.Builder()
.add(new ByteArrayAdapter())
.add(new DateTimeAdapter())
.add(new PathAdapter())
Expand All @@ -73,9 +78,11 @@ abstract class MoshiEncodeStrategy<V> implements EncodingStrategy<V> {
.withSubtype(ProxyHttpRequest.class, ProxyHttpRequest.simpleName)
.withSubtype(ProxyHttpResponse.class, ProxyHttpResponse.simpleName)
.withSubtype(PairingHeartbeat.class, PairingHeartbeat.simpleName)
.withSubtype(PairingResponse.class, PairingResponse.simpleName)
)
.build()
.withSubtype(PairingResponse.class, PairingResponse.simpleName) )
// add custom factory if provider
if( customFactory )
builder.add(customFactory)
this.moshi = builder.build()
this.jsonAdapter = moshi.adapter(type)

}
Expand Down
27 changes: 27 additions & 0 deletions src/main/groovy/io/seqera/wave/encoder/MoshiExchange.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Wave, containers provisioning service
* Copyright (c) 2023-2024, Seqera Labs
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

package io.seqera.wave.encoder

/**
* Marker interface for Moshi encoded exchange objects
*
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
*/
interface MoshiExchange {
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class CredentialServiceImpl implements CredentialsService {
if (pairing.isExpired())
log.debug("Exchange key registered for service ${PairingService.TOWER_SERVICE} at endpoint: ${identity.towerEndpoint} used after expiration, should be renewed soon")

final all = towerClient.listCredentials(identity.towerEndpoint, JwtAuth.of(identity), identity.workspaceId).get().credentials
final all = towerClient.listCredentials(identity.towerEndpoint, JwtAuth.of(identity), identity.workspaceId, identity.workflowId).credentials

if (!all) {
log.debug "No credentials found for userId=$identity.userId; workspaceId=$identity.workspaceId; endpoint=$identity.towerEndpoint"
Expand Down Expand Up @@ -92,7 +92,7 @@ class CredentialServiceImpl implements CredentialsService {
// log for debugging purposes
log.debug "Credentials matching criteria registryName=$registryName; userId=$identity.userId; workspaceId=$identity.workspaceId; endpoint=$identity.towerEndpoint => $creds"
// now fetch the encrypted key
final encryptedCredentials = towerClient.fetchEncryptedCredentials(identity.towerEndpoint, JwtAuth.of(identity), creds.id, pairing.pairingId, identity.workspaceId).get()
final encryptedCredentials = towerClient.fetchEncryptedCredentials(identity.towerEndpoint, JwtAuth.of(identity), creds.id, pairing.pairingId, identity.workspaceId, identity.workflowId)
final privateKey = pairing.privateKey
final credentials = decryptCredentials(privateKey, encryptedCredentials.keys)
return parsePayload(credentials)
Expand All @@ -112,7 +112,7 @@ class CredentialServiceImpl implements CredentialsService {
final response = towerClient.describeWorkflowLaunch(identity.towerEndpoint, JwtAuth.of(identity), identity.workflowId)
if( !response )
return null
final computeEnv = response.get()?.launch?.computeEnv
final computeEnv = response?.launch?.computeEnv
if( !computeEnv )
return null
if( computeEnv.platform != 'aws-batch' )
Expand Down
31 changes: 14 additions & 17 deletions src/main/groovy/io/seqera/wave/service/UserServiceImpl.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@
package io.seqera.wave.service

import java.util.concurrent.CompletableFuture
import java.util.concurrent.ExecutionException
import io.micronaut.core.annotation.Nullable
import java.util.concurrent.ExecutorService

import groovy.transform.CompileStatic
import groovy.util.logging.Slf4j
import io.micronaut.core.annotation.Nullable
import io.micronaut.scheduling.TaskExecutors
import io.seqera.wave.exception.UnauthorizedException
import io.seqera.wave.tower.User
import io.seqera.wave.tower.auth.JwtAuth
import io.seqera.wave.tower.client.TowerClient
import io.seqera.wave.tower.client.UserInfoResponse
import jakarta.inject.Inject
import jakarta.inject.Named
import jakarta.inject.Singleton
/**
* Define a service to access a Tower user
Expand All @@ -45,28 +46,24 @@ class UserServiceImpl implements UserService {
@Nullable
private TowerClient towerClient

@Inject
@Named(TaskExecutors.BLOCKING)
private ExecutorService ioExecutor

@Override
CompletableFuture<User> getUserByAccessTokenAsync(String endpoint, JwtAuth auth) {
if( !towerClient )
throw new IllegalStateException("Missing Tower client - make sure the 'tower' micronaut environment has been provided")

towerClient.userInfo(endpoint, auth).handle( (UserInfoResponse resp, Throwable error) -> {
if( error )
throw error
if (!resp || !resp.user)
throw new UnauthorizedException("Unauthorized - Make sure you have provided a valid access token")
log.debug("Authorized user=$resp.user")
return resp.user
})
return CompletableFuture.supplyAsync(()-> getUserByAccessToken(endpoint,auth), ioExecutor)
}

@Override
User getUserByAccessToken(String endpoint, JwtAuth auth) {
try {
return getUserByAccessTokenAsync(endpoint, auth).get()
}
catch(ExecutionException e){
throw e.cause
}
final resp = towerClient.userInfo(endpoint, auth)
if (!resp || !resp.user)
throw new UnauthorizedException("Unauthorized - Make sure you have provided a valid access token")
log.debug("Authorized user=$resp.user")
return resp.user
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class PairingWebSocket {
// Register the client and the sender callback that it's needed to deliver
// the message to the remote client
channel.registerClient(service, endpoint, session.id,(pairingMessage) -> {
log.trace "Sendind message=${pairingMessage} - endpoint: ${endpoint} [sessionId: $session.id]"
log.trace "Sending message=${pairingMessage} - endpoint: ${endpoint} [sessionId: $session.id]"
session .sendAsync(pairingMessage)
})

Expand Down
180 changes: 180 additions & 0 deletions src/main/groovy/io/seqera/wave/store/cache/AbstractTieredCache.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Wave, containers provisioning service
* Copyright (c) 2023-2024, Seqera Labs
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

package io.seqera.wave.store.cache

import java.time.Duration
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.Lock
import java.util.concurrent.locks.ReentrantLock
import java.util.function.Function

import com.github.benmanes.caffeine.cache.AsyncCache
import com.github.benmanes.caffeine.cache.Caffeine
import com.github.benmanes.caffeine.cache.RemovalCause
import com.github.benmanes.caffeine.cache.RemovalListener
import groovy.transform.Canonical
import groovy.transform.CompileStatic
import groovy.util.logging.Slf4j
import io.seqera.wave.encoder.EncodingStrategy
import io.seqera.wave.encoder.MoshiEncodeStrategy
import io.seqera.wave.encoder.MoshiExchange
import org.jetbrains.annotations.Nullable
/**
* Implement a tiered-cache mechanism using a local caffeine cache as 1st level access
* and a 2nd-level cache backed on Redis.
*
* This allow the use in distributed deployment. Note however strong consistently is not guaranteed.
*
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
*/
@Slf4j
@CompileStatic
abstract class AbstractTieredCache<V extends MoshiExchange> implements TieredCache<String,V> {

@Canonical
static class Entry implements MoshiExchange {
MoshiExchange value
long expiresAt
}

private EncodingStrategy<Entry> encoder

// FIXME https://github.com/seqeralabs/wave/issues/747
private AsyncCache<String,V> l1

private final Duration ttl

private L2TieredCache<String,String> l2

private final Lock sync = new ReentrantLock()

AbstractTieredCache(L2TieredCache<String,String> l2, MoshiEncodeStrategy encoder, Duration duration, long maxSize) {
log.info "Cache '${getName()}' config - prefix=${getPrefix()}; ttl=${duration}; max-size: ${maxSize}; l2=${l2}"
this.l2 = l2
this.ttl = duration
this.encoder = encoder
this.l1 = Caffeine.newBuilder()
.expireAfterWrite(duration.toMillis(), TimeUnit.MILLISECONDS)
.maximumSize(maxSize)
.removalListener(removalListener0())
.buildAsync()
}

abstract protected getName()

abstract protected String getPrefix()

private RemovalListener removalListener0() {
new RemovalListener() {
@Override
void onRemoval(@Nullable key, @Nullable value, RemovalCause cause) {
log.trace "Cache '${name}' removing key=$key; value=$value; cause=$cause"
}
}
}

@Override
V get(String key) {
getOrCompute(key, null)
}

V getOrCompute(String key, Function<String,V> loader) {
log.trace "Cache '${name}' checking key=$key"
// Try L1 cache first
V value = l1.synchronous().getIfPresent(key)
if (value != null) {
log.trace "Cache '${name}' L1 hit (a) - key=$key => value=$value"
return value
}

sync.lock()
try {
value = l1.synchronous().getIfPresent(key)
if (value != null) {
log.trace "Cache '${name}' L1 hit (b) - key=$key => value=$value"
return value
}

// Fallback to L2 cache
value = l2Get(key)
if (value != null) {
log.trace "Cache '${name}' L2 hit - key=$key => value=$value"
// Rehydrate L1 cache
l1.synchronous().put(key, value)
return value
}

// still not value found, use loader function to fetch the value
if( value==null && loader!=null ) {
log.trace "Cache '${name}' invoking loader - key=$key"
value = loader.apply(key)
if( value!=null ) {
l1.synchronous().put(key,value)
l2Put(key,value)
}
}

log.trace "Cache '${name}' missing value - key=$key => value=${value}"
// finally return the value
return value
}
finally {
sync.unlock()
}
}

@Override
void put(String key, V value) {
assert key!=null, "Cache key argument cannot be null"
assert value!=null, "Cache value argument cannot be null"
log.trace "Cache '${name}' putting - key=$key; value=${value}"
l1.synchronous().put(key, value)
l2Put(key, value)
}

protected String key0(String k) { return getPrefix() + ':' + k }

protected V l2Get(String key) {
if( l2 == null )
return null

final raw = l2.get(key0(key))
if( raw == null )
return null

final Entry payload = encoder.decode(raw)
if( System.currentTimeMillis() > payload.expiresAt ) {
log.trace "Cache '${name}' L2 exipired - key=$key => value=${payload.value}"
return null
}
return (V) payload.value
}

protected void l2Put(String key, V value) {
if( l2 != null ) {
final raw = encoder.encode(new Entry(value, ttl.toMillis() + System.currentTimeMillis()))
l2.put(key0(key), raw, ttl)
}
}

void invalidateAll() {
l1.synchronous().invalidateAll()
}

}
41 changes: 41 additions & 0 deletions src/main/groovy/io/seqera/wave/store/cache/L2TieredCache.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Wave, containers provisioning service
* Copyright (c) 2023-2024, Seqera Labs
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

package io.seqera.wave.store.cache

import java.time.Duration

/**
* Define the interface for 2nd level tired cache
*
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
*/
interface L2TieredCache<K,V> extends TieredCache<K,V> {


/**
* Add a value in the cache with the specified key. If a value already exists is overridden
* with the new value.
*
* @param key The key of the value to be added. {@code null} is not allowed.
* @param value The value to be added in the cache for the specified key. {@code null} is not allowed.
* @param ttl The value time-to-live, after which the value is automatically evicted.
*/
void put(K key, V value, Duration ttl)

}
Loading

0 comments on commit f0ca0f6

Please sign in to comment.