Skip to content

Commit

Permalink
shareIn and cache operators
Browse files Browse the repository at this point in the history
  • Loading branch information
Rick Busarow committed Dec 16, 2019
1 parent a930b0c commit c02c44c
Show file tree
Hide file tree
Showing 5 changed files with 646 additions and 0 deletions.
105 changes: 105 additions & 0 deletions kotlinx-coroutines-core/common/src/flow/internal/SharedFlow.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.flow.internal

import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.internal.*
import kotlinx.coroutines.sync.*

internal fun <T> Flow<T>.asCachedFlow(
cacheHistory: Int
): Flow<T> {

require(cacheHistory > 0) { "cacheHistory parameter must be greater than 0, but was $cacheHistory" }

val cache = CircularArray<T>(cacheHistory)

return onEach { value ->
// While flowing, also record all values in the cache.
cache.add(value)
}.onStart {
// Before emitting any values in sourceFlow,
// emit any cached values starting with the oldest.
cache.forEach { emit(it) }
}
}

internal fun <T> Flow<T>.asSharedFlow(
scope: CoroutineScope, cacheHistory: Int
): Flow<T> = SharedFlow(this, scope, cacheHistory)

/**
* An auto-resetting [broadcast] flow. It tracks the number of active collectors, and automatically resets when
* the number reaches 0.
*
* `SharedFlow` has an optional [cache], where the latest _n_ elements emitted by the source Flow will be replayed to
* late collectors.
*
* ### Upon reset
* 1) The underlying [BroadcastChannel] is closed. A new BroadcastChannel will be created when a new collector starts.
* 2) The cache is reset. New collectors will not receive values from before the reset, but will generate a new cache.
*/
internal class SharedFlow<T>(
private val sourceFlow: Flow<T>,
private val scope: CoroutineScope,
private val cacheHistory: Int
) : Flow<T> {

private var refCount = 0
private var cache = CircularArray<T>(cacheHistory)
private val mutex = Mutex(false)

init {
require(cacheHistory >= 0) { "cacheHistory parameter must be at least 0, but was $cacheHistory" }
}

public override suspend fun collect(
collector: FlowCollector<T>
) = collector.emitAll(createFlow())

// Replay happens per new collector, if cacheHistory > 0.
private suspend fun createFlow(): Flow<T> = getChannel()
.asFlow()
.replayIfNeeded()
.onCompletion { onCollectEnd() }

// lazy holder for the BroadcastChannel, which is reset whenever all collection ends
private var lazyChannelRef = createLazyChannel()

// must be lazy so that the broadcast doesn't begin immediately after a reset
private fun createLazyChannel() = lazy(LazyThreadSafetyMode.NONE) {
sourceFlow.cacheIfNeeded()
.broadcastIn(scope)
}

private fun Flow<T>.replayIfNeeded(): Flow<T> = if (cacheHistory > 0) {
onStart {
cache.forEach {
emit(it)
}
}
} else this

private fun Flow<T>.cacheIfNeeded(): Flow<T> = if (cacheHistory > 0) {
onEach { value ->
// While flowing, also record all values in the cache.
cache.add(value)
}
} else this

private fun reset() {
cache = CircularArray(cacheHistory)
lazyChannelRef = createLazyChannel()
}

private suspend fun onCollectEnd() = mutex.withLock { if (--refCount == 0) reset() }
private suspend fun getChannel(): BroadcastChannel<T> = mutex.withLock {
refCount++
lazyChannelRef.value
}

}
166 changes: 166 additions & 0 deletions kotlinx-coroutines-core/common/src/flow/operators/Share.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

@file:JvmMultifileClass
@file:JvmName("FlowKt")

package kotlinx.coroutines.flow

import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.flow.internal.*
import kotlin.jvm.*

/**
* A "cached" [Flow] which will record the last [history] collected values.
*
* When a collector begins collecting after values have already been recorded,
* those values will be collected *before* values from the receiver [Flow] are collected.
*
* example:
* ```Kotlin
* val ints = flowOf(1, 2, 3, 4).cache(2) // cache the last 2 values
*
* ints.take(4).collect { ... } // 4 values are emitted, but also recorded. The last 2 remain.
*
* ints.collect { ... } // collects [3, 4, 1, 2, 3, 4]
* ```
*
* Throws [IllegalArgumentException] if size parameter is not greater than 0
*
* @param history the number of items to keep in the [Flow]'s history -- must be greater than 0
*/
@FlowPreview
public fun <T> Flow<T>.cache(history: Int): Flow<T> = asCachedFlow(history)

/**
* Creates a [broadcast] coroutine which collects the [Flow] receiver and shares with multiple collectors.
*
* A [BroadcastChannel] with [default][Channel.Factory.BUFFERED] buffer size is created.
* Use [buffer] operator on the flow before calling `shareIn` to specify a value other than
* default and to control what happens when data is produced faster than it is consumed,
* that is to control back-pressure behavior.
*
* Concurrent collectors will all collect from a single [broadcast] flow. This flow will be cancelled automatically
* when it is no longer being collected, and the underlying channel will be closed.
*
* If a new collector is added after the channel has been closed, a new channel will be created.
*
* By default, this flow is effectively **stateless** in that collectors will only receive values emitted after collection begins.
*
* example:
*
* ```
* val sourceFlow = flowOf(1, 2, 3, 4, 5)
* .onStart { println("start source") }
* .onEach { println("emit $it") }
* .onCompletion { println("complete source") }
* .shareIn(this)
*
* val a = async { sourceFlow.toList() }
* val b = async { sourceFlow.toList() } // collect concurrently
*
* println(a.await())
* println(b.await())
*
* println("** break **")
*
* println(sourceFlow.toList())
*
* prints:
*
* start source
* emit 1
* emit 2
* emit 3
* emit 4
* emit 5
* complete source
* [1, 2, 3, 4, 5]
* [1, 2, 3, 4, 5]
* ** break **
* start source
* emit 1
* emit 2
* emit 3
* emit 4
* emit 5
* complete source
* [1, 2, 3, 4, 5]
*
* ```
* ### Caching
*
* When a shared flow is cached, the values are recorded as they are emitted from the source Flow.
* They are then replayed for each new subscriber.
*
* When a shared flow is reset, the cached values are cleared.
*
* example:
*
* ```
* val sourceFlow = flowOf(1, 2, 3, 4, 5)
* .onEach {
* delay(50)
* println("emit $it")
* }.shareIn(this, 1)
*
* val a = async { sourceFlow.toList() }
* delay(125)
* val b = async { sourceFlow.toList() } // begin collecting after "emit 3"
*
* println(a.await())
* println(b.await())
*
* println("** break **")
*
* println(sourceFlow.toList()) // the shared flow has been reset, so the cached values are cleared
*
* prints:
*
* emit 1
* emit 2
* emit 3
* emit 4
* emit 5
* [1, 2, 3, 4, 5]
* [2, 3, 4, 5]
* ** break **
* emit 1
* emit 2
* emit 3
* emit 4
* emit 5
* [1, 2, 3, 4, 5]
*
* ```
*
* In order to have cached values persist across resets, use `cache(n)` before `shareIn(...)`.
*
* example:
*
* ```
* // resets cache whenever the Flow is reset
* flowOf(1, 2, 3).shareIn(myScope, 3)
*
* // persists cache across resets
* flowOf(1, 2, 3).cached(3).shareIn(myScope)
* ```
*
* ### Cancellation semantics
* 1) Flow consumer is cancelled when the original channel is cancelled.
* 2) Flow consumer completes normally when the original channel completes (~is closed) normally.
* 3) Collection is cancelled when the (scope)[CoroutineScope] parameter is cancelled,
* thereby ending the consumer when it has run out of elements.
* 4) If the flow consumer fails with an exception, subscription is cancelled.
*
* @param scope The [CoroutineScope] used to create the [broadcast] coroutine. Cancellation of this scope
* will close the underlying [BroadcastChannel].
* @param cacheHistory (default = 0). Any value greater than zero will add a [cache] to the shared Flow.
*
*/
@FlowPreview
fun <T> Flow<T>.shareIn(
scope: CoroutineScope, cacheHistory: Int = 0
): Flow<T> = asSharedFlow(scope, cacheHistory)
89 changes: 89 additions & 0 deletions kotlinx-coroutines-core/common/src/internal/CircularArray.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package kotlinx.coroutines.internal

import kotlinx.coroutines.*


/**
* CircularArray implementation which will hold the latest of up to `size` elements.
*
* After the cache has been filled, all further additions will overwrite the least recent value.
*
* @param size the maximum number of elements to store in the array
*/
internal class CircularArray<T>(size: Int) : Iterable<T> {

private val array: Array<Any?> = arrayOfNulls(size)
private var count: Int = 0
private var tail: Int = -1

/**
* Adds [item] to the [CircularArray].
*
* If the `CircularArray` has not yet been filled,
* `item` will simply be added to the next available slot.
*
* If the `CircularArray` has already been filled,
* `item` will replace the oldest item already in the array.
*
* example:
* ```
* val ca = CircularArray<T>(3)
*
* ca.add(0) // ca contents : [0, null, null]
* ca.add(1) // ca contents : [0, 1, null]
* ca.add(2) // ca contents : [0, 1, 2]
* // overwrite the oldest value
* ca.add(3) // ca contents : [3, 1, 2]
* ```
*/
public fun add(item: T) {
tail = (tail + 1) % array.size
array[tail] = item
if (count < array.size) count++
}

/**
* Iterates over the [CircularArray].
*
* Order is always first-in-first-out, with the oldest item being used first.
*
* example:
* ```
* val ca = CircularArray<Int>(3)
*
* ca.add(0) // ca contents : [0, null, null]
* ca.add(1) // ca contents : [0, 1, null]
* ca.add(2) // ca contents : [0, 1, 2]
* // overwrite the oldest value
* ca.add(3) // ca contents : [3, 1, 2]
*
* ca.forEach { ... } // order : [1, 2, 3]
* ```
*/
public override fun iterator(): Iterator<T> = object : Iterator<T> {
private val arraySnapshot = array.copyOf()
private val tailSnapshot = tail

private var _index = 0

private val head: Int
get() = when (count) {
arraySnapshot.size -> (tailSnapshot + 1) % count
else -> 0
}

@Suppress("UNCHECKED_CAST")
private fun get(index: Int): T = when (count) {
arraySnapshot.size -> arraySnapshot[(head + index) % arraySnapshot.size] as T
else -> arraySnapshot[index] as T
}

override fun hasNext(): Boolean = _index < count
override fun next(): T = get(_index++)

}

public override fun toString(): String = "$classSimpleName[array=${contentToString()}]"

private fun contentToString(): String = joinToString { "$it" }
}
Loading

0 comments on commit c02c44c

Please sign in to comment.