Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SafeCollector rework #1196

Merged
merged 5 commits into from
May 29, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion kotlinx-coroutines-core/common/src/JobSupport.kt
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,11 @@ public open class JobSupport constructor(active: Boolean) : Job, ChildJob, Paren
if (!tryFinalizeSimpleState(state, proposedUpdate, mode)) return COMPLETING_RETRY
return COMPLETING_COMPLETED
}
// The separate slow-path function to simplify profiling
return tryMakeCompletingSlowPath(state, proposedUpdate, mode)
}

private fun tryMakeCompletingSlowPath(state: Incomplete, proposedUpdate: Any?, mode: Int): Int {
// get state's list or else promote to list to correctly operate on child lists
val list = getOrPromoteCancellingList(state) ?: return COMPLETING_RETRY
// promote to Finishing state if we are not in it yet
Expand Down Expand Up @@ -1202,7 +1207,7 @@ public open class JobSupport constructor(active: Boolean) : Job, ChildJob, Paren
* Class to represent object as the final state of the Job
*/
private class IncompleteStateBox(@JvmField val state: Incomplete)
private fun Any?.boxIncomplete(): Any? = if (this is Incomplete) IncompleteStateBox(this) else this
internal fun Any?.boxIncomplete(): Any? = if (this is Incomplete) IncompleteStateBox(this) else this
internal fun Any?.unboxState(): Any? = (this as? IncompleteStateBox)?.state ?: this

// --------------- helper classes & constants for job implementation
Expand Down
52 changes: 26 additions & 26 deletions kotlinx-coroutines-core/common/src/flow/Flow.kt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import kotlinx.coroutines.*
* trigger their evaluation every time [collect] is executed) or hot ones, but, conventionally, they represent cold streams.
* Transitions between hot and cold streams are supported via channels and the corresponding API: [flowViaChannel], [broadcastIn], [produceIn].
*
* The flow has a context preserving property: it encapsulates its own execution context and never propagates or leaks it downstream, thus making
* The flow has a context preservation property: it encapsulates its own execution context and never propagates or leaks it downstream, thus making
* reasoning about the execution context of particular transformations or terminal operations trivial.
*
* There are two ways to change the context of a flow: [flowOn][Flow.flowOn] and [flowWith][Flow.flowWith].
Expand All @@ -52,24 +52,37 @@ import kotlinx.coroutines.*
* }
* ```
*
* From the implementation point of view it means that all intermediate operators on [Flow] should abide by the following constraint:
* If collection or emission of a flow is to be separated into multiple coroutines, it should use [coroutineScope] or [supervisorScope] and
* is not allowed to modify the coroutines' context:
* From the implementation point of view it means that all intermediate operators on [Flow] should abide by the following constraints:
* 1) If an operator is trivial and does not start any coroutines, regular [flow] builder should be used. Its implementation
* efficiently enforces all the invariants and prevents most of the development mistakes.
*
* 2) If the collection and emission of the flow are to be separated into multiple coroutines, [channelFlow] should be used.
* [channelFlow] encapsulates all the context preservation work and allows you to focus on your domain-specific problem,
* rather than invariant implementation details. It is possible to use any combination of coroutine builders from within [channelFlow].
*
* 3) If you are looking for the performance and are sure that no concurrent emits and context jumps will happen, [flow] builder
* alongside with [coroutineScope] or [supervisorScope] can be used instead:
*
* - scoped primitive should be used to provide a [CoroutineScope]
* - changing the context of the coroutines is prohibited, no matter whether it is `withContext(ctx)` or builder argument (e.g. `launch(ctx)`)
* - Emissions are allowed only from the enclosing scope, emissions from launched coroutines are prohibited.
*
* These constraints are enforced by the default [flow] builder.
* Example of the proper `buffer` implementation:
* ```
* fun <T> Flow<T>.buffer(bufferSize: Int): Flow<T> = flow {
* coroutineScope { // coroutine scope is necessary, withContext is prohibited
* val channel = Channel<T>(bufferSize)
* // GlobalScope.launch { is prohibited
* // launch(Dispatchers.IO) { is prohibited
* launch { // is OK
* collect { value ->
* // GlobalScope.produce { is prohibited
* // produce(Dispatchers.IO) { is prohibited
* val channel = produce(bufferSize) {
elizarov marked this conversation as resolved.
Show resolved Hide resolved
* collect { value -> // Collect from started coroutine -- OK
* channel.send(value)
* }
* channel.close()
* }
*
* for (i in channel) {
* emit(i)
* emit(i) // Emission from the enclosing scope -- OK
* // launch { emit (i) } -- prohibited
* }
* }
* }
Expand All @@ -87,23 +100,10 @@ public interface Flow<out T> {
* A valid implementation of this method has the following constraints:
* 1) It should not change the coroutine context (e.g. with `withContext(Dispatchers.IO)`) when emitting values.
* The emission should happen in the context of the [collect] call.
*
* Only coroutine builders that inherit the context are allowed, for example:
* ```
* class MyFlow : Flow<Int> {
* override suspend fun collect(collector: FlowCollector<Int>) {
* coroutineScope {
* // Context is inherited
* launch { // Dispatcher is not overridden, fine as well
* collector.emit(42) // Emit from the launched coroutine
* }
* }
* }
* }
* ```
* is a proper [Flow] implementation, but using `launch(Dispatchers.IO)` is not.
* Please refer to the top-level [Flow] documentation for more details.
*
* 2) It should serialize calls to [emit][FlowCollector.emit] as [FlowCollector] implementations are not thread safe by default.
* To automatically serialize emissions [channelFlow] builder can be used instead of [flow]
*/
public suspend fun collect(collector: FlowCollector<T>)
}
1 change: 1 addition & 0 deletions kotlinx-coroutines-core/common/src/flow/FlowCollector.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public interface FlowCollector<in T> {

/**
* Collects the value emitted by the upstream.
* This method is not thread-safe and should not be invoked concurrently.
*/
public suspend fun emit(value: T)
}
88 changes: 80 additions & 8 deletions kotlinx-coroutines-core/common/src/flow/internal/SafeCollector.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,90 @@ import kotlin.coroutines.*
@PublishedApi
internal class SafeCollector<T>(
private val collector: FlowCollector<T>,
collectContext: CoroutineContext
) : FlowCollector<T>, SynchronizedObject() {
private val collectContext: CoroutineContext
) : FlowCollector<T> {

private val collectContext = collectContext.minusKey(Job).minusId()
// Note, it is non-capturing lambda, so no extra allocation during init of SafeCollector
private val collectContextSize = collectContext.fold(0) { count, _ -> count + 1 }
private var lastEmissionContext: CoroutineContext? = null

override suspend fun emit(value: T) {
val emitContext = coroutineContext.minusKey(Job).minusId()
if (emitContext != collectContext) {
/*
* Benign data-race here:
* We read potentially racy published coroutineContext, but we only use it for
* referential comparison (=> thus safe) and are not using it for structural comparisons.
*/
val currentContext = coroutineContext
// This check is triggered once per flow on happy path.
if (lastEmissionContext !== currentContext) {
checkContext(currentContext)
lastEmissionContext = currentContext
}
collector.emit(value) // TCE
}

private fun checkContext(currentContext: CoroutineContext) {
val result = currentContext.fold(0) fold@{ count, element ->
val key = element.key
val collectElement = collectContext[key]
if (key !== Job) {
return@fold if (element !== collectElement) Int.MIN_VALUE
else count + 1
}

val collectJob = collectElement as Job?
val emissionParentJob = (element as Job).transitiveCoroutineParent(collectJob)
/*
* Things like
* ```
* coroutineScope {
* launch {
* emit(1)
* }
*
* launch {
* emit(2)
* }
* }
* ```
* are prohibited because 'emit' is not thread-safe by default. Use channelFlow instead if you need concurrent emission
* or want to switch context dynamically (e.g. with `withContext`).
*
* Note that collecting from another coroutine is allowed, e.g.:
* ```
* coroutineScope {
* val channel = produce {
* collect { value ->
* send(value)
* }
* }
* channel.consumeEach { value ->
* emit(value)
* }
* }
* ```
* is a completely valid.
*/
if (emissionParentJob !== collectJob) {
error(
"Flow invariant is violated: emission from another coroutine is detected (child of $emissionParentJob, expected child of $collectJob). " +
"FlowCollector is not thread-safe and concurrent emissions are prohibited. To mitigate this restriction please use 'flowChannel' builder instead of 'flow'"
)
}
count + 1
}
if (result != collectContextSize) {
error(
"Flow invariant is violated: flow was collected in $collectContext, but emission happened in $emitContext. " +
"Please refer to 'flow' documentation or use 'flowOn' instead")
"Flow invariant is violated: flow was collected in $collectContext, but emission happened in $currentContext. " +
"Please refer to 'flow' documentation or use 'flowOn' instead"
)
}
collector.emit(value)
}

private tailrec fun Job?.transitiveCoroutineParent(collectJob: Job?): Job? {
if (this === null) return null
if (this === collectJob) return this
if (this !is ScopeCoroutine<*>) return this
return parent.transitiveCoroutineParent(collectJob)
}
}
2 changes: 2 additions & 0 deletions kotlinx-coroutines-core/common/src/internal/Scopes.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ internal open class ScopeCoroutine<in T>(
final override fun getStackTraceElement(): StackTraceElement? = null
override val defaultResumeMode: Int get() = MODE_DIRECT

internal val parent: Job? get() = parentContext[Job]

override val cancelsParent: Boolean
get() = false // it throws exception to parent instead of cancelling it

Expand Down
107 changes: 104 additions & 3 deletions kotlinx-coroutines-core/common/test/flow/FlowInvariantsTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ import kotlin.test.*
class FlowInvariantsTest : TestBase() {

@Test
fun testWithContextContract() = runTest {
fun testWithContextContract() = runTest({ it is IllegalStateException }) {
flow {
kotlinx.coroutines.withContext(NonCancellable) {
// This one cannot be prevented :(
emit(1)
}
}.collect {
Expand All @@ -34,6 +33,27 @@ class FlowInvariantsTest : TestBase() {
}
}

@Test
fun testCachedInvariantCheckResult() = runTest {
flow {
emit(1)

try {
kotlinx.coroutines.withContext(NamedDispatchers("foo")) {
emit(1)
}
fail()
} catch (e: IllegalStateException) {
expect(2)
}

emit(3)
}.collect {
expect(it)
}
finish(4)
}

@Test
fun testWithNameContractViolated() = runTest({ it is IllegalStateException }) {
flow {
Expand Down Expand Up @@ -66,7 +86,7 @@ class FlowInvariantsTest : TestBase() {
}

@Test
fun testScopedJob() = runTest {
fun testScopedJob() = runTest({ it is IllegalStateException }) {
flow { emit(1) }.buffer(EmptyCoroutineContext).collect {
expect(1)
}
Expand All @@ -83,6 +103,87 @@ class FlowInvariantsTest : TestBase() {
finish(2)
}

@Test
fun testMergeViolation() = runTest {
fun Flow<Int>.merge(other: Flow<Int>): Flow<Int> = flow {
coroutineScope {
launch {
collect { value -> emit(value) }
}
other.collect { value -> emit(value) }
}
}

fun Flow<Int>.trickyMerge(other: Flow<Int>): Flow<Int> = flow {
coroutineScope {
launch {
collect { value ->
coroutineScope { emit(value) }
}
}
other.collect { value -> emit(value) }
}
}

val flow = flowOf(1)
assertFailsWith<IllegalStateException> { flow.merge(flow).toList() }
assertFailsWith<IllegalStateException> { flow.trickyMerge(flow).toList() }
}

// TODO merge artifact
private fun <T> channelFlow(bufferSize: Int = 16, @BuilderInference block: suspend ProducerScope<T>.() -> Unit): Flow<T> =
flow {
coroutineScope {
val channel = produce(capacity = bufferSize, block = block)
channel.consumeEach { value ->
emit(value)
}
}
}

@Test
fun testNoMergeViolation() = runTest {
fun Flow<Int>.merge(other: Flow<Int>): Flow<Int> = channelFlow {
launch {
collect { value -> send(value) }
}
other.collect { value -> send(value) }
}

fun Flow<Int>.trickyMerge(other: Flow<Int>): Flow<Int> = channelFlow {
coroutineScope {
launch {
collect { value ->
coroutineScope { send(value) }
}
}
other.collect { value -> send(value) }
}
}

val flow = flowOf(1)
assertEquals(listOf(1, 1), flow.merge(flow).toList())
assertEquals(listOf(1, 1), flow.trickyMerge(flow).toList())
}

@Test
fun testScopedCoroutineNoViolation() = runTest {
fun Flow<Int>.buffer(): Flow<Int> = flow {
coroutineScope {
val channel = produce {
collect {
send(it)
}
}
channel.consumeEach {
emit(it)
}
}
}

assertEquals(listOf(1, 1), flowOf(1, 1).buffer().toList())
}

private fun Flow<Int>.buffer(coroutineContext: CoroutineContext): Flow<Int> = flow {
coroutineScope {
val channel = Channel<Int>()
Expand Down
2 changes: 1 addition & 1 deletion kotlinx-coroutines-core/jvm/src/CoroutineContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ internal actual val CoroutineContext.coroutineName: String? get() {
}

@Suppress("NOTHING_TO_INLINE")
internal actual inline fun CoroutineContext.minusId(): CoroutineContext = minusKey(CoroutineId)
internal actual inline fun CoroutineContext.minusId(): CoroutineContext = if (DEBUG) minusKey(CoroutineId) else this

private const val DEBUG_THREAD_NAME_SEPARATOR = " @"

Expand Down