Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement caching for proxy responses #778

Merged
merged 7 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ import io.micronaut.scheduling.annotation.ExecuteOn
import io.seqera.wave.ErrorHandler
import io.seqera.wave.configuration.HttpClientConfig
import io.seqera.wave.core.RegistryProxyService
import io.seqera.wave.core.RegistryProxyService.DelegateResponse
import io.seqera.wave.core.RouteHandler
import io.seqera.wave.core.RoutePath
import io.seqera.wave.exception.DockerRegistryException
import io.seqera.wave.exchange.RegistryErrorResponse
import io.seqera.wave.proxy.DelegateResponse
import io.seqera.wave.ratelimit.AcquireRequest
import io.seqera.wave.ratelimit.RateLimiterService
import io.seqera.wave.service.blob.BlobCacheService
Expand All @@ -54,7 +54,6 @@ import io.seqera.wave.storage.DigestStore
import io.seqera.wave.storage.DockerDigestStore
import io.seqera.wave.storage.HttpDigestStore
import io.seqera.wave.storage.Storage
import io.seqera.wave.util.Retryable
import jakarta.inject.Inject
import org.reactivestreams.Publisher
import reactor.core.publisher.Mono
Expand Down Expand Up @@ -274,7 +273,7 @@ class RegistryProxyController {
final resp = proxyService.handleRequest(route, headers)
HttpResponse
.status(HttpStatus.valueOf(resp.statusCode))
.body(resp.body.bytes)
.body(resp.body)
.headers(toMutableHeaders(resp.headers))
}

Expand Down Expand Up @@ -348,14 +347,9 @@ class RegistryProxyController {
}

MutableHttpResponse<?> fromContentResponse(DelegateResponse resp, RoutePath route) {
// create the retry logic on error §
final retryable = Retryable
.<byte[]>of(httpConfig)
.onRetry((event) -> log.warn("Unable to read manifest body - request: $route; event: $event"))

HttpResponse
.status(HttpStatus.valueOf(resp.statusCode))
.body(retryable.apply(()-> resp.body.bytes))
.body(resp.body)
.headers(toMutableHeaders(resp.headers))
}

Expand Down
50 changes: 38 additions & 12 deletions src/main/groovy/io/seqera/wave/core/RegistryProxyService.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ package io.seqera.wave.core

import java.util.concurrent.CompletableFuture

import com.google.common.hash.Hashing
import groovy.transform.CompileStatic
import groovy.transform.ToString
import groovy.util.logging.Slf4j
import io.micronaut.cache.annotation.Cacheable
import io.micronaut.context.annotation.Context
Expand All @@ -36,6 +36,8 @@ import io.seqera.wave.auth.RegistryLookupService
import io.seqera.wave.configuration.HttpClientConfig
import io.seqera.wave.http.HttpClientFactory
import io.seqera.wave.model.ContainerCoordinates
import io.seqera.wave.proxy.DelegateResponse
import io.seqera.wave.proxy.ProxyCache
import io.seqera.wave.proxy.ProxyClient
import io.seqera.wave.service.CredentialsService
import io.seqera.wave.service.builder.BuildRequest
Expand All @@ -44,6 +46,7 @@ import io.seqera.wave.storage.DigestStore
import io.seqera.wave.storage.Storage
import io.seqera.wave.tower.PlatformId
import io.seqera.wave.util.RegHelper
import io.seqera.wave.util.Retryable
import jakarta.inject.Inject
import jakarta.inject.Singleton
import reactor.core.publisher.Flux
Expand Down Expand Up @@ -91,6 +94,9 @@ class RegistryProxyService {
@Client("stream-client")
private ReactorStreamingHttpClient streamClient

@Inject
private ProxyCache cache

private ContainerAugmenter scanner(ProxyClient proxyClient) {
return new ContainerAugmenter()
.withStorage(storage)
Expand Down Expand Up @@ -141,7 +147,31 @@ class RegistryProxyService {
}
}

DelegateResponse handleRequest(RoutePath route, Map<String,List<String>> headers){
static protected String requestKey(RoutePath route, Map<String,List<String>> headers) {
final hasher = Hashing.sipHash24().newHasher()
hasher.putUnencodedChars(route.stableHash())
hasher.putUnencodedChars('/')
for( Map.Entry<String,List<String>> entry : (headers ?: Map.of()) ) {
hasher.putUnencodedChars(entry.key)
for( String it : entry.value ) {
if( it )
hasher.putUnencodedChars(it)
hasher.putUnencodedChars('/')
}
hasher.putUnencodedChars('/')
}
return hasher.hash().toString()
}

DelegateResponse handleRequest(RoutePath route, Map<String,List<String>> headers) {
final resp = cache.getOrCompute(
requestKey(route, headers),
(it)-> handleRequest0(route, headers),
(resp)-> route.isDigest() && resp.isCacheable() )
return resp
}

private DelegateResponse handleRequest0(RoutePath route, Map<String,List<String>> headers) {
ProxyClient proxyClient = client(route)
final resp1 = proxyClient.getStream(route.path, headers, false)
final redirect = resp1.headers().firstValue('Location').orElse(null)
Expand Down Expand Up @@ -182,10 +212,15 @@ class RegistryProxyService {
// otherwise read it and include the body input stream in the response
// the caller must consume and close the body to prevent memory leaks
else {
// create the retry logic on error §
final retryable = Retryable
.<byte[]>of(httpConfig)
.onRetry((event) -> log.warn("Unable to read blob body - request: $route; event: $event"))
// read the body and compose the response
return new DelegateResponse(
statusCode: resp1.statusCode(),
headers: resp1.headers().map(),
body: resp1.body() )
body: retryable.apply(()-> resp1.body().bytes) )
}
}

Expand Down Expand Up @@ -226,15 +261,6 @@ class RegistryProxyService {
return result
}

@ToString(includeNames = true, includePackage = false)
static class DelegateResponse {
int statusCode
Map<String,List<String>> headers
InputStream body
String location
boolean isRedirect() { location }
}

Flux<ByteBuffer<?>> streamBlob(RoutePath route, Map<String,List<String>> headers) {
ProxyClient proxyClient = client(route)
return proxyClient.stream(streamClient, route.path, headers)
Expand Down
6 changes: 6 additions & 0 deletions src/main/groovy/io/seqera/wave/core/RoutePath.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import io.micronaut.core.annotation.Nullable
import io.seqera.wave.model.ContainerCoordinates
import io.seqera.wave.service.request.ContainerRequest
import io.seqera.wave.tower.PlatformId
import io.seqera.wave.util.RegHelper
import static io.seqera.wave.WaveDefault.DOCKER_IO
/**
* Model a container registry route path
Expand Down Expand Up @@ -150,4 +151,9 @@ class RoutePath implements ContainerPath {
else
throw new IllegalArgumentException("Not a valid container path - offending value: '$location'")
}

String stableHash() {
RegHelper.sipHash(type, registry, path, image, reference, identity.stableHash())
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,22 @@
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

package io.seqera.wave.tower.client
package io.seqera.wave.proxy

import spock.lang.Specification
import groovy.transform.ToString
import io.seqera.wave.encoder.MoshiExchange

/**
*
* Model a response object to be forwarded to the client
*
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
*/
class TowerClientTest extends Specification {

def 'should create consistent hash' () {
given:
def client = new TowerClient()

expect:
client.makeKey('a') == '92cf27ac76c18d8e'
and:
client.makeKey('a') == client.makeKey('a')
and:
client.makeKey('a','b','c') == client.makeKey('a','b','c')
and:
client.makeKey('a','b',null) == client.makeKey('a','b',null)
and:
client.makeKey(new URI('http://foo.com')) == client.makeKey('http://foo.com')
and:
client.makeKey(100l) == client.makeKey('100')
}

@ToString(includeNames = true, includePackage = false)
class DelegateResponse implements MoshiExchange {
int statusCode
Map<String,List<String>> headers
byte[] body
String location
boolean isRedirect() { location }
boolean isCacheable() { location!=null || (body!=null && statusCode>=200 && statusCode<400) }
}
62 changes: 62 additions & 0 deletions src/main/groovy/io/seqera/wave/proxy/ProxyCache.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Wave, containers provisioning service
* Copyright (c) 2023-2024, Seqera Labs
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

package io.seqera.wave.proxy

import java.time.Duration

import com.squareup.moshi.adapters.PolymorphicJsonAdapterFactory
import groovy.transform.CompileStatic
import io.micronaut.context.annotation.Value
import io.micronaut.core.annotation.Nullable
import io.seqera.wave.encoder.MoshiEncodeStrategy
import io.seqera.wave.encoder.MoshiExchange
import io.seqera.wave.store.cache.AbstractTieredCache
import io.seqera.wave.store.cache.L2TieredCache
import jakarta.inject.Singleton
/**
* Implements a tiered cache for proxied http responses
*
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
*/
@Singleton
@CompileStatic
class ProxyCache extends AbstractTieredCache<DelegateResponse> {
ProxyCache(@Nullable L2TieredCache l2,
@Value('${wave.proxy-cache.duration:1h}') Duration duration,
@Value('${wave.proxy-cache.max-size:10000}') long maxSize) {
super(l2, encoder(), duration, maxSize)
}

static MoshiEncodeStrategy encoder() {
// json adapter factory
final factory = PolymorphicJsonAdapterFactory.of(MoshiExchange.class, "@type")
.withSubtype(Entry.class, Entry.name)
.withSubtype(DelegateResponse.class, DelegateResponse.simpleName)
// the encoding strategy
return new MoshiEncodeStrategy<AbstractTieredCache.Entry>(factory) {}
}

String getName() {
'proxy-cache'
}

String getPrefix() {
'proxy-cache/v1'
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class StreamServiceImpl implements StreamService {
// when it's a response with a binary body, just return it
if( resp.body!=null ) {
log.debug "Streaming response body for route: $route"
return resp.body
return new ByteArrayInputStream(resp.body)
}
// otherwise cache the blob and stream the resulting uri
if( blobCacheService ) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ abstract class AbstractTieredCache<V extends MoshiExchange> implements TieredCac

private L2TieredCache<String,String> l2

private final Lock sync = new ReentrantLock()
private final WeakHashMap<String,Lock> locks = new WeakHashMap<>()

AbstractTieredCache(L2TieredCache<String,String> l2, MoshiEncodeStrategy encoder, Duration duration, long maxSize) {
log.info "Cache '${getName()}' config - prefix=${getPrefix()}; ttl=${duration}; max-size: ${maxSize}; l2=${l2}"
log.info "Cache '${getName()}' config - prefix=${getPrefix()}; ttl=${duration}; max-size: ${maxSize}"
if( l2==null )
log.warn "Missing L2 cache for tiered cache '${getName()}'"
this.l2 = l2
this.ttl = duration
this.encoder = encoder
Expand All @@ -89,12 +91,49 @@ abstract class AbstractTieredCache<V extends MoshiExchange> implements TieredCac
}
}

/**
* Retrieve the value associated with the specified key
*
* @param key
* The key of the value to be retrieved
* @return
* The value associated with the specified key, or {@code null} otherwise
*/
@Override
V get(String key) {
getOrCompute(key, null)
getOrCompute(key, null, (v)->true)
}

/**
* Retrieve the value associated with the specified key
*
* @param key
* The key of the value to be retrieved
* @param loader
* A function invoked to load the value the entry with the specified key is not available
* @return
* The value associated with the specified key, or {@code null} otherwise
*/
V getOrCompute(String key, Function<String,V> loader) {
getOrCompute(key, loader, (v)->true)
}

/**
* Retrieve the value associated with the specified key
*
* @param key
* The key of the value to be retrieved
* @param loader
* The function invoked to load the value the entry with the specified key is not available
* @param cacheCondition
* The function to determine if the loaded value should be cached
* @return
* The value associated with the specified key, or #function result otherwise
*/
V getOrCompute(String key, Function<String,V> loader, Function<V,Boolean> cacheCondition) {
assert key!=null, "Argument key cannot be null"
assert cacheCondition!=null, "Argument condition cannot be null"

log.trace "Cache '${name}' checking key=$key"
// Try L1 cache first
V value = l1.synchronous().getIfPresent(key)
Expand All @@ -103,6 +142,7 @@ abstract class AbstractTieredCache<V extends MoshiExchange> implements TieredCac
return value
}

final sync = locks.computeIfAbsent(key, (k)-> new ReentrantLock())
sync.lock()
try {
value = l1.synchronous().getIfPresent(key)
Expand All @@ -124,7 +164,7 @@ abstract class AbstractTieredCache<V extends MoshiExchange> implements TieredCac
if( value==null && loader!=null ) {
log.trace "Cache '${name}' invoking loader - key=$key"
value = loader.apply(key)
if( value!=null ) {
if( value!=null && cacheCondition.apply(value) ) {
l1.synchronous().put(key,value)
l2Put(key,value)
}
Expand Down
11 changes: 11 additions & 0 deletions src/main/groovy/io/seqera/wave/tower/PlatformId.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import groovy.transform.Canonical
import groovy.transform.CompileStatic
import io.seqera.wave.api.ContainerInspectRequest
import io.seqera.wave.api.SubmitContainerTokenRequest
import io.seqera.wave.util.RegHelper
import io.seqera.wave.util.StringUtils

/**
Expand Down Expand Up @@ -80,4 +81,14 @@ class PlatformId {
", workflowId=" + workflowId +
')';
}

String stableHash() {
RegHelper.sipHash(
getUserId(),
getUserEmail(),
workspaceId,
accessToken,
towerEndpoint,
workflowId )
}
}
Loading