Skip to content

Commit

Permalink
pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zhukaihan committed Sep 24, 2024
1 parent 61dd7a8 commit c9b8130
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 153 deletions.
11 changes: 6 additions & 5 deletions src/main/kotlin/deployment/DeploymentRunner.kt
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
@file:OptIn(ExperimentalApi::class)

package com.amplitude.experiment.deployment

import com.amplitude.experiment.*
import com.amplitude.experiment.LocalEvaluationConfig
import com.amplitude.experiment.LocalEvaluationMetrics
import com.amplitude.experiment.cohort.CohortApi
import com.amplitude.experiment.cohort.CohortLoader
import com.amplitude.experiment.cohort.CohortStorage
import com.amplitude.experiment.flag.*
import com.amplitude.experiment.flag.FlagConfigApi
import com.amplitude.experiment.flag.FlagConfigFallbackRetryWrapper
import com.amplitude.experiment.flag.FlagConfigPoller
import com.amplitude.experiment.flag.FlagConfigStorage
import com.amplitude.experiment.flag.FlagConfigStreamApi
import com.amplitude.experiment.flag.FlagConfigStreamer
import com.amplitude.experiment.util.LocalEvaluationMetricsWrapper
import com.amplitude.experiment.util.Logger
import com.amplitude.experiment.util.Once
Expand Down Expand Up @@ -46,7 +47,7 @@ internal class DeploymentRunner(
private val amplitudeFlagConfigUpdater =
if (flagConfigStreamApi != null)
FlagConfigFallbackRetryWrapper(
FlagConfigStreamer(flagConfigStreamApi, flagConfigStorage, cohortLoader, cohortStorage, config, metrics),
FlagConfigStreamer(flagConfigStreamApi, flagConfigStorage, cohortLoader, cohortStorage, metrics),
amplitudeFlagConfigPoller,
)
else amplitudeFlagConfigPoller
Expand Down
46 changes: 25 additions & 21 deletions src/main/kotlin/flag/FlagConfigStreamApi.kt
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
package com.amplitude.experiment.flag

import com.amplitude.experiment.evaluation.EvaluationFlag
import com.amplitude.experiment.util.*
import com.amplitude.experiment.util.SseStream
import com.amplitude.experiment.util.StreamException
import com.amplitude.experiment.util.json
import kotlinx.serialization.decodeFromString
import okhttp3.HttpUrl
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.sse.EventSource
import okhttp3.sse.EventSourceListener
import java.util.concurrent.CompletableFuture
import java.util.concurrent.ExecutionException
import java.util.concurrent.TimeUnit
Expand Down Expand Up @@ -37,9 +35,6 @@ internal class FlagConfigStreamApi (
reconnIntervalMillis: Long = RECONN_INTERVAL_MILLIS_DEFAULT,
) {
private val lock: ReentrantLock = ReentrantLock()
var onInitUpdate: ((List<EvaluationFlag>) -> Unit)? = null
var onUpdate: ((List<EvaluationFlag>) -> Unit)? = null
var onError: ((Exception?) -> Unit)? = null
val url = serverUrl.newBuilder().addPathSegments("sdk/stream/v1/flags").build()
private val stream: SseStream = SseStream(
"Api-Key $deploymentKey",
Expand All @@ -49,14 +44,23 @@ internal class FlagConfigStreamApi (
keepaliveTimeoutMillis,
reconnIntervalMillis)

internal fun connect() {
/**
* Connects to flag configs stream.
* This will ensure stream connects, first set of flags is received and processed successfully, then returns.
* If stream fails to connect, first set of flags is not received, or first set of flags did not process successfully, it throws.
*/
internal fun connect(
onInitUpdate: ((List<EvaluationFlag>) -> Unit)? = null,
onUpdate: ((List<EvaluationFlag>) -> Unit)? = null,
onError: ((Exception?) -> Unit)? = null
) {
// Guarded by lock. Update to callbacks and waits can lead to race conditions.
lock.withLock {
val isInit = AtomicBoolean(true)
val isDuringInit = AtomicBoolean(true)
val connectTimeoutFuture = CompletableFuture<Unit>()
val updateTimeoutFuture = CompletableFuture<Unit>()
stream.onUpdate = { data ->
if (isInit.getAndSet(false)) {
val onSseUpdate: ((String) -> Unit) = { data ->
if (isDuringInit.getAndSet(false)) {
// Stream is establishing. First data received.
// Resolve timeout.
connectTimeoutFuture.complete(Unit)
Expand All @@ -67,9 +71,9 @@ internal class FlagConfigStreamApi (

try {
if (onInitUpdate != null) {
onInitUpdate?.let { it(flags) }
onInitUpdate.invoke(flags)
} else {
onUpdate?.let { it(flags) }
onUpdate?.invoke(flags)
}
updateTimeoutFuture.complete(Unit)
} catch (e: Throwable) {
Expand All @@ -86,26 +90,26 @@ internal class FlagConfigStreamApi (
val flags = getFlagsFromData(data)

try {
onUpdate?.let { it(flags) }
onUpdate?.invoke(flags)
} catch (_: Throwable) {
// Don't care about application error.
}
} catch (_: Throwable) {
// Stream corrupted. Reconnect.
handleError(FlagConfigStreamApiDataCorruptError())
handleError(onError, FlagConfigStreamApiDataCorruptError())
}

}
}
stream.onError = { t ->
if (isInit.getAndSet(false)) {
val onSseError: ((Throwable?) -> Unit) = { t ->
if (isDuringInit.getAndSet(false)) {
connectTimeoutFuture.completeExceptionally(t)
updateTimeoutFuture.completeExceptionally(t)
} else {
handleError(FlagConfigStreamApiStreamError(t))
handleError(onError, FlagConfigStreamApiStreamError(t))
}
}
stream.connect()
stream.connect(onSseUpdate, onSseError)

val t: Throwable
try {
Expand Down Expand Up @@ -139,8 +143,8 @@ internal class FlagConfigStreamApi (
return json.decodeFromString<List<EvaluationFlag>>(data)
}

private fun handleError(e: Exception?) {
private fun handleError(onError: ((Exception?) -> Unit)?, e: Exception?) {
close()
onError?.let { it(e) }
onError?.invoke(e)
}
}
81 changes: 59 additions & 22 deletions src/main/kotlin/flag/FlagConfigUpdater.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ import com.amplitude.experiment.LocalEvaluationMetrics
import com.amplitude.experiment.cohort.CohortLoader
import com.amplitude.experiment.cohort.CohortStorage
import com.amplitude.experiment.evaluation.EvaluationFlag
import com.amplitude.experiment.util.*
import com.amplitude.experiment.util.LocalEvaluationMetricsWrapper
import com.amplitude.experiment.util.Logger
import com.amplitude.experiment.util.daemonFactory
import com.amplitude.experiment.util.getAllCohortIds
import com.amplitude.experiment.util.wrapMetrics
import java.util.concurrent.CompletableFuture
import java.util.concurrent.ConcurrentHashMap
Expand All @@ -19,23 +20,43 @@ import kotlin.concurrent.withLock
import kotlin.math.max
import kotlin.math.min

/**
* Flag config updaters should receive flags through their own means (ex. http GET, SSE stream),
* or as wrapper of others.
* They all should have these methods to control their lifecycle.
*/
internal interface FlagConfigUpdater {
// Start the updater. There can be multiple calls.
// If start fails, it should throw exception. The caller should handle error.
// If some other error happened while updating (already started successfully), it should call onError.
/**
* Start the updater. There can be multiple calls.
* If start fails, it should throw exception. The caller should handle error.
* If some other error happened while updating (already started successfully), it should call onError.
*/
fun start(onError: (() -> Unit)? = null)
// Stop should stop updater temporarily. There may be another start in the future.
// To stop completely, with intention to never start again, use shutdown() instead.

/**
* Stop should stop updater temporarily. There may be another start in the future.
* To stop completely, with intention to never start again, use shutdown() instead.
*/
fun stop()
// Destroy should stop the updater forever in preparation for server shutdown.

/**
* Destroy should stop the updater forever in preparation for server shutdown.
*/
fun shutdown()
}

/**
* All flag config updaters should share this class, which contains a function to properly process flag updates.
*/
internal abstract class FlagConfigUpdaterBase(
private val flagConfigStorage: FlagConfigStorage,
private val cohortLoader: CohortLoader?,
private val cohortStorage: CohortStorage?,
): FlagConfigUpdater {
) {
/**
* Call this method after receiving and parsing flag configs from network.
* This method updates flag configs into storage and download all cohorts if needed.
*/
protected fun update(flagConfigs: List<EvaluationFlag>) {
// Remove flags that no longer exist.
val flagKeys = flagConfigs.map { it.key }.toSet()
Expand Down Expand Up @@ -85,19 +106,27 @@ internal abstract class FlagConfigUpdaterBase(
}
}

/**
* This is the poller for flag configs.
* It keeps polling flag configs with specified interval until error occurs.
*/
internal class FlagConfigPoller(
private val flagConfigApi: FlagConfigApi,
private val storage: FlagConfigStorage,
private val cohortLoader: CohortLoader?,
private val cohortStorage: CohortStorage?,
storage: FlagConfigStorage,
cohortLoader: CohortLoader?,
cohortStorage: CohortStorage?,
private val config: LocalEvaluationConfig,
private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper(),
): FlagConfigUpdaterBase(
): FlagConfigUpdater, FlagConfigUpdaterBase(
storage, cohortLoader, cohortStorage
) {
private val lock: ReentrantLock = ReentrantLock()
private val pool = Executors.newScheduledThreadPool(1, daemonFactory)
private var scheduledFuture: ScheduledFuture<*>? = null // @GuardedBy(lock)

/**
* Start will fetch once, then start poller to poll flag configs.
*/
override fun start(onError: (() -> Unit)?) {
refresh()
lock.withLock {
Expand Down Expand Up @@ -153,29 +182,36 @@ internal class FlagConfigPoller(
}
}

/**
* Streamer for flag configs. This receives flag updates with an SSE connection.
*/
internal class FlagConfigStreamer(
private val flagConfigStreamApi: FlagConfigStreamApi,
private val storage: FlagConfigStorage,
private val cohortLoader: CohortLoader?,
private val cohortStorage: CohortStorage?,
private val config: LocalEvaluationConfig,
storage: FlagConfigStorage,
cohortLoader: CohortLoader?,
cohortStorage: CohortStorage?,
private val metrics: LocalEvaluationMetrics = LocalEvaluationMetricsWrapper()
): FlagConfigUpdaterBase(
): FlagConfigUpdater, FlagConfigUpdaterBase(
storage, cohortLoader, cohortStorage
) {
private val lock: ReentrantLock = ReentrantLock()

/**
* Start makes sure it connects to stream and the first set of flag configs is loaded.
* Then, it will update the flags whenever there's a stream.
*/
override fun start(onError: (() -> Unit)?) {
lock.withLock {
flagConfigStreamApi.onUpdate = { flags ->
val onStreamUpdate: ((List<EvaluationFlag>) -> Unit) = { flags ->
update(flags)
}
flagConfigStreamApi.onError = { e ->
val onStreamError: ((Exception?) -> Unit) = { e ->
Logger.e("Stream flag configs streaming failed.", e)
metrics.onFlagConfigStreamFailure(e)
onError?.invoke()
}
wrapMetrics(metric = metrics::onFlagConfigStream, failure = metrics::onFlagConfigStreamFailure) {
flagConfigStreamApi.connect()
flagConfigStreamApi.connect(onStreamUpdate, onStreamUpdate, onStreamError)
}
}
}
Expand All @@ -190,11 +226,12 @@ internal class FlagConfigStreamer(

private const val RETRY_DELAY_MILLIS_DEFAULT = 15 * 1000L
private const val MAX_JITTER_MILLIS_DEFAULT = 2000L

internal class FlagConfigFallbackRetryWrapper(
private val mainUpdater: FlagConfigUpdater,
private val fallbackUpdater: FlagConfigUpdater?,
private val retryDelayMillis: Long = RETRY_DELAY_MILLIS_DEFAULT,
private val maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT,
retryDelayMillis: Long = RETRY_DELAY_MILLIS_DEFAULT,
maxJitterMillis: Long = MAX_JITTER_MILLIS_DEFAULT,
): FlagConfigUpdater {
private val lock: ReentrantLock = ReentrantLock()
private val reconnIntervalRange = max(0, retryDelayMillis - maxJitterMillis)..(min(retryDelayMillis, retryDelayMillis - maxJitterMillis) + maxJitterMillis)
Expand Down
Loading

0 comments on commit c9b8130

Please sign in to comment.