Skip to content

Commit

Permalink
feat(batching): Exclude operations that failed pre-executions (#1942) (
Browse files Browse the repository at this point in the history
…#1946)

### 📝 Description

Remove operations that failed pre-executions from batching state, an
operation can become invalid if parsing, validation or any other logic
from preparsed document provider failed (such as persisted queries).
  • Loading branch information
samuelAndalon authored Apr 2, 2024
1 parent a937b8d commit d5c312f
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 71 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2022 Expedia, Inc
* Copyright 2024 Expedia, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,9 +37,8 @@ import org.dataloader.DataLoader
class DataLoaderLevelDispatchedInstrumentation : AbstractExecutionLevelDispatchedInstrumentation() {
override fun getOnLevelDispatchedCallback(
parameters: ExecutionLevelDispatchedInstrumentationParameters
): OnLevelDispatchedCallback = { _, executions: List<ExecutionInput> ->
executions
.getOrNull(0)
): OnLevelDispatchedCallback = { _, _ ->
parameters.executionContext.executionInput
?.dataLoaderRegistry
?.dispatchAll()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2022 Expedia, Inc
* Copyright 2024 Expedia, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,24 +22,26 @@ import com.expediagroup.graphql.dataloader.instrumentation.level.state.Level
import graphql.ExecutionInput
import graphql.ExecutionResult
import graphql.execution.ExecutionContext
import graphql.execution.ExecutionId
import graphql.execution.instrumentation.ExecutionStrategyInstrumentationContext
import graphql.execution.instrumentation.Instrumentation
import graphql.execution.instrumentation.InstrumentationContext
import graphql.execution.instrumentation.InstrumentationState
import graphql.execution.instrumentation.parameters.InstrumentationExecuteOperationParameters
import graphql.execution.instrumentation.SimplePerformantInstrumentation
import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters
import graphql.execution.instrumentation.parameters.InstrumentationExecutionStrategyParameters
import graphql.execution.instrumentation.parameters.InstrumentationFieldFetchParameters
import graphql.schema.DataFetcher

/**
* Represents the signature of a callback that will be executed when a [Level] is dispatched
*/
internal typealias OnLevelDispatchedCallback = (Level, List<ExecutionInput>) -> Unit
internal typealias OnLevelDispatchedCallback = (Level, List<ExecutionId>) -> Unit
/**
* Custom GraphQL [graphql.execution.instrumentation.Instrumentation] that calculate the state of executions
* of all queries sharing the same GraphQLContext map
*/
abstract class AbstractExecutionLevelDispatchedInstrumentation : Instrumentation {
abstract class AbstractExecutionLevelDispatchedInstrumentation : SimplePerformantInstrumentation() {
/**
* This is invoked each time instrumentation attempts to calculate a level dispatched state, this can be called from either
* `beginFieldField` or `beginExecutionStrategy`.
Expand All @@ -52,13 +54,13 @@ abstract class AbstractExecutionLevelDispatchedInstrumentation : Instrumentation
parameters: ExecutionLevelDispatchedInstrumentationParameters
): OnLevelDispatchedCallback

override fun beginExecuteOperation(
parameters: InstrumentationExecuteOperationParameters,
override fun beginExecution(
parameters: InstrumentationExecutionParameters,
state: InstrumentationState?
): InstrumentationContext<ExecutionResult>? =
parameters.executionContext.takeUnless(ExecutionContext::isMutation)
parameters.executionInput
?.graphQLContext?.get<ExecutionLevelDispatchedState>(ExecutionLevelDispatchedState::class)
?.beginExecuteOperation(parameters)
?.beginExecution(parameters)

override fun beginExecutionStrategy(
parameters: InstrumentationExecutionStrategyParameters,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 Expedia, Inc
* Copyright 2024 Expedia, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,38 +20,63 @@ import com.expediagroup.graphql.dataloader.instrumentation.extensions.getExpecte
import com.expediagroup.graphql.dataloader.instrumentation.level.execution.OnLevelDispatchedCallback
import graphql.ExecutionInput
import graphql.ExecutionResult
import graphql.execution.ExecutionId
import graphql.execution.FieldValueInfo
import graphql.execution.instrumentation.ExecutionStrategyInstrumentationContext
import graphql.execution.instrumentation.InstrumentationContext
import graphql.execution.instrumentation.parameters.InstrumentationExecuteOperationParameters
import graphql.execution.instrumentation.SimpleInstrumentationContext
import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters
import graphql.execution.instrumentation.parameters.InstrumentationExecutionStrategyParameters
import graphql.execution.instrumentation.parameters.InstrumentationFieldFetchParameters
import graphql.schema.DataFetcher
import java.util.concurrent.CompletableFuture
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicReference

/**
* Orchestrate the [ExecutionBatchState] of all [ExecutionInput] sharing the same graphQLContext map,
* when a certain state is reached will invoke [OnLevelDispatchedCallback]
*/
class ExecutionLevelDispatchedState(
private val totalExecutions: Int
totalOperations: Int
) {
val executions = ConcurrentHashMap<ExecutionInput, ExecutionBatchState>()
private val totalExecutions: AtomicReference<Int> = AtomicReference(totalOperations)
val executions = ConcurrentHashMap<ExecutionId, ExecutionBatchState>()

/**
* Remove an [ExecutionBatchState] from the state in case operation does not qualify for execution,
* for example:
* parsing, validation, execution errors
* persisted query errors
*/
private fun removeExecution(executionId: ExecutionId) {
if (executions.containsKey(executionId)) {
executions.remove(executionId)
totalExecutions.set(totalExecutions.get() - 1)
}
}

/**
* Initialize the [ExecutionBatchState] of this [ExecutionInput]
*
* @param parameters contains information of which [ExecutionInput] will start his execution
* @return a nullable [InstrumentationContext]
*/
fun beginExecuteOperation(
parameters: InstrumentationExecuteOperationParameters
): InstrumentationContext<ExecutionResult>? {
executions.computeIfAbsent(parameters.executionContext.executionInput) {
fun beginExecution(
parameters: InstrumentationExecutionParameters
): InstrumentationContext<ExecutionResult> {
executions.computeIfAbsent(parameters.executionInput.executionId) {
ExecutionBatchState()
}
return null
return object : SimpleInstrumentationContext<ExecutionResult>() {
override fun onCompleted(result: ExecutionResult?, t: Throwable?) {
result?.let {
if (result.errors.size > 0) {
removeExecution(parameters.executionInput.executionId)
}
}
}
}
}

/**
Expand All @@ -64,11 +89,11 @@ class ExecutionLevelDispatchedState(
parameters: InstrumentationExecutionStrategyParameters,
onLevelDispatched: OnLevelDispatchedCallback
): ExecutionStrategyInstrumentationContext {
val executionInput = parameters.executionContext.executionInput
val executionId = parameters.executionContext.executionInput.executionId
val level = Level(parameters.executionStrategyParameters.path.level + 1)
val fieldCount = parameters.executionStrategyParameters.fields.size()

executions.computeIfPresent(executionInput) { _, executionState ->
executions.computeIfPresent(executionId) { _, executionState ->
executionState.also {
it.initializeLevelStateIfNeeded(level)
it.increaseExpectedFetches(level, fieldCount)
Expand All @@ -86,7 +111,7 @@ class ExecutionLevelDispatchedState(
override fun onFieldValuesInfo(fieldValueInfoList: List<FieldValueInfo>) {
val nextLevel = level.next()

executions.computeIfPresent(executionInput) { _, executionState ->
executions.computeIfPresent(executionId) { _, executionState ->
executionState.also {
it.increaseOnFieldValueInfos(level)
it.increaseExpectedExecutionStrategies(
Expand All @@ -104,7 +129,7 @@ class ExecutionLevelDispatchedState(
}

override fun onFieldValuesException() {
executions.computeIfPresent(executionInput) { _, executionState ->
executions.computeIfPresent(executionId) { _, executionState ->
executionState.also {
it.increaseOnFieldValueInfos(level)
}
Expand All @@ -123,14 +148,13 @@ class ExecutionLevelDispatchedState(
parameters: InstrumentationFieldFetchParameters,
onLevelDispatched: OnLevelDispatchedCallback
): InstrumentationContext<Any> {
val executionInput = parameters.executionContext.executionInput
val executionId = parameters.executionContext.executionInput.executionId
val path = parameters.executionStepInfo.path
val level = Level(path.level)

return object : InstrumentationContext<Any> {
return object : SimpleInstrumentationContext<Any>() {
override fun onDispatched(result: CompletableFuture<Any?>) {

executions.computeIfPresent(executionInput) { _, executionState ->
executions.computeIfPresent(executionId) { _, executionState ->
executionState.also { it.increaseDispatchedFetches(level) }
}

Expand All @@ -140,9 +164,6 @@ class ExecutionLevelDispatchedState(
executions.forEach { (_, executionState) -> executionState.completeDataFetchers(level) }
}
}

override fun onCompleted(result: Any?, t: Throwable?) {
}
}
}

Expand All @@ -161,7 +182,7 @@ class ExecutionLevelDispatchedState(
parameters: InstrumentationFieldFetchParameters
): DataFetcher<*> {
var manuallyCompletableDataFetcher: DataFetcher<*> = dataFetcher
executions.computeIfPresent(parameters.executionContext.executionInput) { _, executionState ->
executions.computeIfPresent(parameters.executionContext.executionInput.executionId) { _, executionState ->
executionState.also {
manuallyCompletableDataFetcher = it.toManuallyCompletableDataFetcher(
Level(parameters.executionStepInfo.path.level),
Expand All @@ -180,9 +201,11 @@ class ExecutionLevelDispatchedState(
* @param level that execution state will be calculated
* @return Boolean for allExecutionsDispatched statement
*/
fun allExecutionsDispatched(level: Level): Boolean =
executions
.takeIf { executions -> executions.size == totalExecutions }
?.all { (_, executionState) -> executionState.isLevelDispatched(level) }
?: false
fun allExecutionsDispatched(level: Level): Boolean = synchronized(executions) {
val operationsToExecute = totalExecutions.get()
when {
executions.size < operationsToExecute -> false
else -> executions.all { (_, executionState) -> executionState.isLevelDispatched(level) }
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2022 Expedia, Inc
* Copyright 2024 Expedia, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,6 +22,7 @@ import com.expediagroup.graphql.dataloader.instrumentation.syncexhaustion.execut
import com.expediagroup.graphql.dataloader.instrumentation.syncexhaustion.execution.SyncExecutionExhaustedInstrumentationParameters
import graphql.ExecutionInput
import graphql.GraphQLContext
import graphql.execution.ExecutionId
import graphql.execution.instrumentation.Instrumentation
import graphql.schema.DataFetcher
import org.dataloader.DataLoader
Expand All @@ -37,10 +38,10 @@ import java.util.concurrent.CompletableFuture
class DataLoaderSyncExecutionExhaustedInstrumentation : AbstractSyncExecutionExhaustedInstrumentation() {
override fun getOnSyncExecutionExhaustedCallback(
parameters: SyncExecutionExhaustedInstrumentationParameters
): OnSyncExecutionExhaustedCallback = { executions: List<ExecutionInput> ->
executions
.getOrNull(0)
?.dataLoaderRegistry
?.dispatchAll()
): OnSyncExecutionExhaustedCallback = { _: List<ExecutionId> ->
parameters
.executionContext.executionInput
.dataLoaderRegistry
.dispatchAll()
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2022 Expedia, Inc
* Copyright 2024 Expedia, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,24 +22,26 @@ import graphql.ExecutionInput
import graphql.ExecutionResult
import graphql.GraphQLContext
import graphql.execution.ExecutionContext
import graphql.execution.ExecutionId
import graphql.execution.instrumentation.ExecutionStrategyInstrumentationContext
import graphql.execution.instrumentation.Instrumentation
import graphql.execution.instrumentation.InstrumentationContext
import graphql.execution.instrumentation.InstrumentationState
import graphql.execution.instrumentation.parameters.InstrumentationExecuteOperationParameters
import graphql.execution.instrumentation.SimplePerformantInstrumentation
import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters
import graphql.execution.instrumentation.parameters.InstrumentationExecutionStrategyParameters
import graphql.execution.instrumentation.parameters.InstrumentationFieldFetchParameters

/**
* typealias that represents the signature of a callback that will be executed when sync execution is exhausted
*/
internal typealias OnSyncExecutionExhaustedCallback = (List<ExecutionInput>) -> Unit
internal typealias OnSyncExecutionExhaustedCallback = (List<ExecutionId>) -> Unit

/**
* Custom GraphQL [Instrumentation] that calculate the synchronous execution exhaustion
* of all GraphQL operations sharing the same [GraphQLContext]
*/
abstract class AbstractSyncExecutionExhaustedInstrumentation : Instrumentation {
abstract class AbstractSyncExecutionExhaustedInstrumentation : SimplePerformantInstrumentation() {
/**
* This is invoked each time instrumentation attempts to calculate exhaustion state, this can be called from either
* `beginFieldField.dispatch` or `beginFieldFetch.complete`.
Expand All @@ -51,13 +53,13 @@ abstract class AbstractSyncExecutionExhaustedInstrumentation : Instrumentation {
parameters: SyncExecutionExhaustedInstrumentationParameters
): OnSyncExecutionExhaustedCallback

override fun beginExecuteOperation(
parameters: InstrumentationExecuteOperationParameters,
override fun beginExecution(
parameters: InstrumentationExecutionParameters,
state: InstrumentationState?
): InstrumentationContext<ExecutionResult>? =
parameters.executionContext.takeUnless(ExecutionContext::isMutation)
?.graphQLContext?.get<SyncExecutionExhaustedState>(SyncExecutionExhaustedState::class)
?.beginExecuteOperation(parameters)
parameters.graphQLContext
?.get<SyncExecutionExhaustedState>(SyncExecutionExhaustedState::class)
?.beginExecution(parameters)

override fun beginExecutionStrategy(
parameters: InstrumentationExecutionStrategyParameters,
Expand Down
Loading

0 comments on commit d5c312f

Please sign in to comment.