Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

shareIn and cache operators #1716

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions kotlinx-coroutines-core/common/src/flow/internal/SharedFlow.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.flow.internal

import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.internal.*
import kotlinx.coroutines.sync.*

internal fun <T> Flow<T>.asCachedFlow(
cacheHistory: Int
): Flow<T> {

require(cacheHistory > 0) { "cacheHistory parameter must be greater than 0, but was $cacheHistory" }

val cache = CircularArray<T>(cacheHistory)

return onEach { value ->
// While flowing, also record all values in the cache.
cache.add(value)
}.onStart {
// Before emitting any values in sourceFlow,
// emit any cached values starting with the oldest.
cache.forEach { emit(it) }
}
}

internal fun <T> Flow<T>.asSharedFlow(
scope: CoroutineScope, cacheHistory: Int
): Flow<T> = SharedFlow(this, scope, cacheHistory)

/**
* An auto-resetting [broadcast] flow. It tracks the number of active collectors, and automatically resets when
* the number reaches 0.
*
* `SharedFlow` has an optional [cache], where the latest _n_ elements emitted by the source Flow will be replayed to
* late collectors.
*
* ### Upon reset
* 1) The underlying [BroadcastChannel] is closed. A new BroadcastChannel will be created when a new collector starts.
* 2) The cache is reset. New collectors will not receive values from before the reset, but will generate a new cache.
*/
internal class SharedFlow<T>(
private val sourceFlow: Flow<T>,
private val scope: CoroutineScope,
private val cacheHistory: Int
) : Flow<T> {

private var refCount = 0
private var cache = CircularArray<T>(cacheHistory)
private val mutex = Mutex(false)

init {
require(cacheHistory >= 0) { "cacheHistory parameter must be at least 0, but was $cacheHistory" }
}

public override suspend fun collect(
collector: FlowCollector<T>
) = collector.emitAll(createFlow())

// Replay happens per new collector, if cacheHistory > 0.
private suspend fun createFlow(): Flow<T> = getChannel()
.asFlow()
.replayIfNeeded()
.onCompletion { onCollectEnd() }

private suspend fun getChannel(): BroadcastChannel<T> = mutex.withLock {
refCount++
lazyChannelRef.value
}

// lazy holder for the BroadcastChannel, which is reset whenever all collection ends
private var lazyChannelRef = createLazyChannel()

// must be lazy so that the broadcast doesn't begin immediately after a reset
private fun createLazyChannel() = lazy(LazyThreadSafetyMode.NONE) {
sourceFlow.cacheIfNeeded()
.broadcastIn(scope)
}

private fun Flow<T>.replayIfNeeded(): Flow<T> = if (cacheHistory > 0) {
onStart {
cache.forEach {
emit(it)
}
}
} else this

private fun Flow<T>.cacheIfNeeded(): Flow<T> = if (cacheHistory > 0) {
onEach { value ->
// While flowing, also record all values in the cache.
cache.add(value)
}
} else this

private suspend fun onCollectEnd() = mutex.withLock { if (--refCount == 0) reset() }

private fun reset() {
cache = CircularArray(cacheHistory)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Existing channel is not closed, which can leak the shared flow each time refCount reaches 0.
This function should close lazyChannelRef.value before assigning a new value to the lazyChannelRef variable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Believe it or not I removed that just before making the PR, thinking "well no one's collecting it, so what could leak?". Except of course the leak can go two ways.


lazyChannelRef.value.cancel()
lazyChannelRef = createLazyChannel()
}
}
166 changes: 166 additions & 0 deletions kotlinx-coroutines-core/common/src/flow/operators/Share.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

@file:JvmMultifileClass
@file:JvmName("FlowKt")

package kotlinx.coroutines.flow

import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.flow.internal.*
import kotlin.jvm.*

/**
* A "cached" [Flow] which will record the last [history] collected values.
*
* When a collector begins collecting after values have already been recorded,
* those values will be collected *before* values from the receiver [Flow] are collected.
*
* example:
* ```Kotlin
* val ints = flowOf(1, 2, 3, 4).cache(2) // cache the last 2 values
*
* ints.take(4).collect { ... } // 4 values are emitted, but also recorded. The last 2 remain.
*
* ints.collect { ... } // collects [3, 4, 1, 2, 3, 4]
* ```
*
* Throws [IllegalArgumentException] if size parameter is not greater than 0
*
* @param history the number of items to keep in the [Flow]'s history -- must be greater than 0
*/
@FlowPreview
public fun <T> Flow<T>.cache(history: Int): Flow<T> = asCachedFlow(history)

/**
* Creates a [broadcast] coroutine which collects the [Flow] receiver and shares with multiple collectors.
*
* A [BroadcastChannel] with [default][Channel.Factory.BUFFERED] buffer size is created.
* Use [buffer] operator on the flow before calling `shareIn` to specify a value other than
* default and to control what happens when data is produced faster than it is consumed,
* that is to control back-pressure behavior.
*
* Concurrent collectors will all collect from a single [broadcast] flow. This flow will be cancelled automatically
* when it is no longer being collected, and the underlying channel will be closed.
*
* If a new collector is added after the channel has been closed, a new channel will be created.
*
* By default, this flow is effectively **stateless** in that collectors will only receive values emitted after collection begins.
*
* example:
*
* ```
* val sourceFlow = flowOf(1, 2, 3, 4, 5)
* .onStart { println("start source") }
* .onEach { println("emit $it") }
* .onCompletion { println("complete source") }
* .shareIn(this)
*
* val a = async { sourceFlow.toList() }
* val b = async { sourceFlow.toList() } // collect concurrently
*
* println(a.await())
* println(b.await())
*
* println("** break **")
*
* println(sourceFlow.toList())
*
* prints:
*
* start source
* emit 1
* emit 2
* emit 3
* emit 4
* emit 5
* complete source
* [1, 2, 3, 4, 5]
* [1, 2, 3, 4, 5]
* ** break **
* start source
* emit 1
* emit 2
* emit 3
* emit 4
* emit 5
* complete source
* [1, 2, 3, 4, 5]
*
* ```
* ### Caching
*
* When a shared flow is cached, the values are recorded as they are emitted from the source Flow.
* They are then replayed for each new subscriber.
*
* When a shared flow is reset, the cached values are cleared.
*
* example:
*
* ```
* val sourceFlow = flowOf(1, 2, 3, 4, 5)
* .onEach {
* delay(50)
* println("emit $it")
* }.shareIn(this, 1)
*
* val a = async { sourceFlow.toList() }
* delay(125)
* val b = async { sourceFlow.toList() } // begin collecting after "emit 3"
*
* println(a.await())
* println(b.await())
*
* println("** break **")
*
* println(sourceFlow.toList()) // the shared flow has been reset, so the cached values are cleared
*
* prints:
*
* emit 1
* emit 2
* emit 3
* emit 4
* emit 5
* [1, 2, 3, 4, 5]
* [2, 3, 4, 5]
* ** break **
* emit 1
* emit 2
* emit 3
* emit 4
* emit 5
* [1, 2, 3, 4, 5]
*
* ```
*
* In order to have cached values persist across resets, use `cache(n)` before `shareIn(...)`.
*
* example:
*
* ```
* // resets cache whenever the Flow is reset
* flowOf(1, 2, 3).shareIn(myScope, 3)
*
* // persists cache across resets
* flowOf(1, 2, 3).cached(3).shareIn(myScope)
* ```
*
* ### Cancellation semantics
* 1) Flow consumer is cancelled when the original channel is cancelled.
* 2) Flow consumer completes normally when the original channel completes (~is closed) normally.
* 3) Collection is cancelled when the (scope)[CoroutineScope] parameter is cancelled,
* thereby ending the consumer when it has run out of elements.
* 4) If the flow consumer fails with an exception, subscription is cancelled.
*
* @param scope The [CoroutineScope] used to create the [broadcast] coroutine. Cancellation of this scope
* will close the underlying [BroadcastChannel].
* @param cacheHistory (default = 0). Any value greater than zero will add a [cache] to the shared Flow.
*
*/
@FlowPreview
fun <T> Flow<T>.shareIn(
scope: CoroutineScope, cacheHistory: Int = 0
): Flow<T> = asSharedFlow(scope, cacheHistory)
89 changes: 89 additions & 0 deletions kotlinx-coroutines-core/common/src/internal/CircularArray.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package kotlinx.coroutines.internal

import kotlinx.coroutines.*


/**
* CircularArray implementation which will hold the latest of up to `size` elements.
*
* After the cache has been filled, all further additions will overwrite the least recent value.
*
* @param size the maximum number of elements to store in the array
*/
internal class CircularArray<T>(size: Int) : Iterable<T> {

private val array: Array<Any?> = arrayOfNulls(size)
private var count: Int = 0
private var tail: Int = -1

/**
* Adds [item] to the [CircularArray].
*
* If the `CircularArray` has not yet been filled,
* `item` will simply be added to the next available slot.
*
* If the `CircularArray` has already been filled,
* `item` will replace the oldest item already in the array.
*
* example:
* ```
* val ca = CircularArray<T>(3)
*
* ca.add(0) // ca contents : [0, null, null]
* ca.add(1) // ca contents : [0, 1, null]
* ca.add(2) // ca contents : [0, 1, 2]
* // overwrite the oldest value
* ca.add(3) // ca contents : [3, 1, 2]
* ```
*/
public fun add(item: T) {
tail = (tail + 1) % array.size
array[tail] = item
if (count < array.size) count++
}

/**
* Iterates over the [CircularArray].
*
* Order is always first-in-first-out, with the oldest item being used first.
*
* example:
* ```
* val ca = CircularArray<Int>(3)
*
* ca.add(0) // ca contents : [0, null, null]
* ca.add(1) // ca contents : [0, 1, null]
* ca.add(2) // ca contents : [0, 1, 2]
* // overwrite the oldest value
* ca.add(3) // ca contents : [3, 1, 2]
*
* ca.forEach { ... } // order : [1, 2, 3]
* ```
*/
public override fun iterator(): Iterator<T> = object : Iterator<T> {
private val arraySnapshot = array.copyOf()
private val tailSnapshot = tail

private var _index = 0

private val head: Int
get() = when (count) {
arraySnapshot.size -> (tailSnapshot + 1) % count
else -> 0
}

@Suppress("UNCHECKED_CAST")
private fun get(index: Int): T = when (count) {
arraySnapshot.size -> arraySnapshot[(head + index) % arraySnapshot.size] as T
else -> arraySnapshot[index] as T
}

override fun hasNext(): Boolean = _index < count
override fun next(): T = get(_index++)

}

public override fun toString(): String = "$classSimpleName[array=${contentToString()}]"

private fun contentToString(): String = joinToString { "$it" }
}
Loading