Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve tower client cache v2 #774

Merged
merged 15 commits into from
Dec 17, 2024
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.15.6-B2
1.15.6-B5
17 changes: 12 additions & 5 deletions src/main/groovy/io/seqera/wave/encoder/MoshiEncodeStrategy.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,18 @@ abstract class MoshiEncodeStrategy<V> implements EncodingStrategy<V> {
init()
}

MoshiEncodeStrategy(JsonAdapter.Factory customFactory) {
this.type = TypeHelper.getGenericType(this, 0)
init(customFactory)
}

MoshiEncodeStrategy(Type type) {
this.type = type
init()
}

private void init() {
this.moshi = new Moshi.Builder()
private void init(JsonAdapter.Factory customFactory=null) {
final builder = new Moshi.Builder()
.add(new ByteArrayAdapter())
.add(new DateTimeAdapter())
.add(new PathAdapter())
Expand All @@ -73,9 +78,11 @@ abstract class MoshiEncodeStrategy<V> implements EncodingStrategy<V> {
.withSubtype(ProxyHttpRequest.class, ProxyHttpRequest.simpleName)
.withSubtype(ProxyHttpResponse.class, ProxyHttpResponse.simpleName)
.withSubtype(PairingHeartbeat.class, PairingHeartbeat.simpleName)
.withSubtype(PairingResponse.class, PairingResponse.simpleName)
)
.build()
.withSubtype(PairingResponse.class, PairingResponse.simpleName) )
// add custom factory if provider
if( customFactory )
builder.add(customFactory)
this.moshi = builder.build()
this.jsonAdapter = moshi.adapter(type)

}
Expand Down
27 changes: 27 additions & 0 deletions src/main/groovy/io/seqera/wave/encoder/MoshiExchange.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Wave, containers provisioning service
* Copyright (c) 2023-2024, Seqera Labs
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

package io.seqera.wave.encoder

/**
* Marker interface for Moshi encoder objects
*
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
*/
interface MoshiExchange {
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class CredentialServiceImpl implements CredentialsService {
if (pairing.isExpired())
log.debug("Exchange key registered for service ${PairingService.TOWER_SERVICE} at endpoint: ${identity.towerEndpoint} used after expiration, should be renewed soon")

final all = towerClient.listCredentials(identity.towerEndpoint, JwtAuth.of(identity), identity.workspaceId, identity.workflowId).get().credentials
final all = towerClient.listCredentials(identity.towerEndpoint, JwtAuth.of(identity), identity.workspaceId, identity.workflowId).credentials

if (!all) {
log.debug "No credentials found for userId=$identity.userId; workspaceId=$identity.workspaceId; endpoint=$identity.towerEndpoint"
Expand Down Expand Up @@ -92,7 +92,7 @@ class CredentialServiceImpl implements CredentialsService {
// log for debugging purposes
log.debug "Credentials matching criteria registryName=$registryName; userId=$identity.userId; workspaceId=$identity.workspaceId; endpoint=$identity.towerEndpoint => $creds"
// now fetch the encrypted key
final encryptedCredentials = towerClient.fetchEncryptedCredentials(identity.towerEndpoint, JwtAuth.of(identity), creds.id, pairing.pairingId, identity.workspaceId, identity.workflowId).get()
final encryptedCredentials = towerClient.fetchEncryptedCredentials(identity.towerEndpoint, JwtAuth.of(identity), creds.id, pairing.pairingId, identity.workspaceId, identity.workflowId)
final privateKey = pairing.privateKey
final credentials = decryptCredentials(privateKey, encryptedCredentials.keys)
return parsePayload(credentials)
Expand All @@ -112,7 +112,7 @@ class CredentialServiceImpl implements CredentialsService {
final response = towerClient.describeWorkflowLaunch(identity.towerEndpoint, JwtAuth.of(identity), identity.workflowId)
if( !response )
return null
final computeEnv = response.get()?.launch?.computeEnv
final computeEnv = response?.launch?.computeEnv
if( !computeEnv )
return null
if( computeEnv.platform != 'aws-batch' )
Expand Down
31 changes: 14 additions & 17 deletions src/main/groovy/io/seqera/wave/service/UserServiceImpl.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@
package io.seqera.wave.service

import java.util.concurrent.CompletableFuture
import java.util.concurrent.ExecutionException
import io.micronaut.core.annotation.Nullable
import java.util.concurrent.ExecutorService

import groovy.transform.CompileStatic
import groovy.util.logging.Slf4j
import io.micronaut.core.annotation.Nullable
import io.micronaut.scheduling.TaskExecutors
import io.seqera.wave.exception.UnauthorizedException
import io.seqera.wave.tower.User
import io.seqera.wave.tower.auth.JwtAuth
import io.seqera.wave.tower.client.TowerClient
import io.seqera.wave.tower.client.UserInfoResponse
import jakarta.inject.Inject
import jakarta.inject.Named
import jakarta.inject.Singleton
/**
* Define a service to access a Tower user
Expand All @@ -45,28 +46,24 @@ class UserServiceImpl implements UserService {
@Nullable
private TowerClient towerClient

@Inject
@Named(TaskExecutors.BLOCKING)
private ExecutorService ioExecutor

@Override
CompletableFuture<User> getUserByAccessTokenAsync(String endpoint, JwtAuth auth) {
if( !towerClient )
throw new IllegalStateException("Missing Tower client - make sure the 'tower' micronaut environment has been provided")

towerClient.userInfo(endpoint, auth).handle( (UserInfoResponse resp, Throwable error) -> {
if( error )
throw error
if (!resp || !resp.user)
throw new UnauthorizedException("Unauthorized - Make sure you have provided a valid access token")
log.debug("Authorized user=$resp.user")
return resp.user
})
return CompletableFuture.supplyAsync(()-> getUserByAccessToken(endpoint,auth), ioExecutor)
}

@Override
User getUserByAccessToken(String endpoint, JwtAuth auth) {
try {
return getUserByAccessTokenAsync(endpoint, auth).get()
}
catch(ExecutionException e){
throw e.cause
}
final resp = towerClient.userInfo(endpoint, auth)
if (!resp || !resp.user)
throw new UnauthorizedException("Unauthorized - Make sure you have provided a valid access token")
log.debug("Authorized user=$resp.user")
return resp.user
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class PairingWebSocket {
// Register the client and the sender callback that it's needed to deliver
// the message to the remote client
channel.registerClient(service, endpoint, session.id,(pairingMessage) -> {
log.trace "Sendind message=${pairingMessage} - endpoint: ${endpoint} [sessionId: $session.id]"
log.trace "Sending message=${pairingMessage} - endpoint: ${endpoint} [sessionId: $session.id]"
session .sendAsync(pairingMessage)
})

Expand Down
177 changes: 177 additions & 0 deletions src/main/groovy/io/seqera/wave/store/cache/AbstractTieredCache.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/*
* Wave, containers provisioning service
* Copyright (c) 2023-2024, Seqera Labs
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

package io.seqera.wave.store.cache

import java.time.Duration
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.Lock
import java.util.concurrent.locks.ReentrantLock
import java.util.function.Function

import com.github.benmanes.caffeine.cache.AsyncCache
import com.github.benmanes.caffeine.cache.Caffeine
import com.github.benmanes.caffeine.cache.RemovalCause
import com.github.benmanes.caffeine.cache.RemovalListener
import groovy.transform.Canonical
import groovy.transform.CompileStatic
import groovy.util.logging.Slf4j
import io.seqera.wave.encoder.EncodingStrategy
import io.seqera.wave.encoder.MoshiEncodeStrategy
import io.seqera.wave.encoder.MoshiExchange
import org.jetbrains.annotations.Nullable
/**
* Abstract implementation for tiered-cache
*
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
*/
@Slf4j
@CompileStatic
abstract class AbstractTieredCache<V extends MoshiExchange> implements TieredCache<String,V> {

@Canonical
static class Payload implements MoshiExchange {
MoshiExchange value
long expiresAt
}

private EncodingStrategy<Payload> encoder

// FIXME https://github.com/seqeralabs/wave/issues/747
private AsyncCache<String,V> l1

private final Duration ttl

private L2TieredCache<String,String> l2

private final Lock sync = new ReentrantLock()

AbstractTieredCache(L2TieredCache<String,String> l2, MoshiEncodeStrategy encoder, Duration duration, long maxSize) {
log.info "Cache configuring '${getName()}' - prefix=${getPrefix()}; ttl=${duration}; max-size: ${maxSize}; l2=${l2}"
this.l2 = l2
this.ttl = duration
this.encoder = encoder
this.l1 = Caffeine.newBuilder()
.expireAfterWrite(duration.toMillis(), TimeUnit.MILLISECONDS)
.maximumSize(maxSize)
.removalListener(removalListener0())
.buildAsync()
}

abstract protected getName()

abstract protected String getPrefix()

private RemovalListener removalListener0() {
new RemovalListener() {
@Override
void onRemoval(@Nullable key, @Nullable value, RemovalCause cause) {
log.trace "Cache '${name}' removing key=$key; value=$value; cause=$cause"
}
}
}

@Override
V get(String key) {
getOrCompute(key, null)
}

V getOrCompute(String key, Function<String,V> loader) {
log.trace "Cache '${name}' checking key=$key"
// Try L1 cache first
V value = l1.synchronous().getIfPresent(key)
if (value != null) {
log.trace "Cache '${name}' L1 hit (a) - key=$key => value=$value"
return value
}

sync.lock()
try {
value = l1.synchronous().getIfPresent(key)
if (value != null) {
log.trace "Cache '${name}' L1 hit (b) - key=$key => value=$value"
return value
}

// Fallback to L2 cache
value = l2Get(key)
if (value != null) {
log.trace "Cache '${name}' L2 hit - key=$key => value=$value"
// Rehydrate L1 cache
l1.synchronous().put(key, value)
return value
}

// still not value found, use loader function to fetch the value
if( value==null && loader!=null ) {
log.trace "Cache '${name}' invoking loader - key=$key"
value = loader.apply(key)
if( value!=null ) {
l1.synchronous().put(key,value)
l2Put(key,value)
}
}

log.trace "Cache '${name}' missing value - key=$key => value=${value}"
// finally return the value
return value
}
finally {
sync.unlock()
}
}

@Override
void put(String key, V value) {
assert key!=null, "Cache key argument cannot be null"
assert value!=null, "Cache value argument cannot be null"
log.trace "Cache '${name}' putting - key=$key; value=${value}"
l1.synchronous().put(key, value)
l2Put(key, value)
}

protected String key0(String k) { return getPrefix() + ':' + k }

protected V l2Get(String key) {
if( l2 == null )
return null

final raw = l2.get(key0(key))
if( raw == null )
return null

final Payload payload = encoder.decode(raw)
if( System.currentTimeMillis() > payload.expiresAt ) {
log.trace "Cache '${name}' L2 exipired - key=$key => value=${payload.value}"
return null
}
return (V) payload.value
}

protected void l2Put(String key, V value) {
if( l2 != null ) {
final raw = encoder.encode(new Payload(value, ttl.toMillis() + System.currentTimeMillis()))
l2.put(key0(key), raw, ttl)
}
}

void invalidateAll() {
l1.synchronous().invalidateAll()
}

}
31 changes: 31 additions & 0 deletions src/main/groovy/io/seqera/wave/store/cache/L2TieredCache.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Wave, containers provisioning service
* Copyright (c) 2023-2024, Seqera Labs
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

package io.seqera.wave.store.cache

import java.time.Duration

/**
*
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
*/
interface L2TieredCache<K,V> extends TieredCache<K,V> {

void put(K key, V value, Duration ttl)

}
Loading
Loading