Skip to content

Commit

Permalink
Avoid the overhead of a sorted stream when wrapping data loaders
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielThomas committed Sep 1, 2023
1 parent aaa04f8 commit fc69659
Showing 1 changed file with 53 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import org.springframework.beans.factory.NoSuchBeanDefinitionException
import org.springframework.context.ApplicationContext
import org.springframework.util.ReflectionUtils
import java.util.function.Supplier
import kotlin.system.measureTimeMillis

/**
* Framework implementation class responsible for finding and configuring data loaders.
Expand All @@ -63,67 +64,20 @@ class DgsDataLoaderProvider(
}

fun <T> buildRegistryWithContextSupplier(contextSupplier: Supplier<T>): DataLoaderRegistry {
val startTime = System.currentTimeMillis()
val dgsDataLoaderRegistry = DgsDataLoaderRegistry()
val registry = DgsDataLoaderRegistry()
val totalTime = measureTimeMillis {
val extensionProviders = applicationContext
.getBeanProvider(DataLoaderInstrumentationExtensionProvider::class.java)
.orderedStream()
.toList()

batchLoaders.forEach {
if (it.dispatchPredicate == null) {
dgsDataLoaderRegistry.register(
it.name,
createDataLoader(it.theLoader, it.annotation, it.name, dgsDataLoaderRegistry)
)
} else {
dgsDataLoaderRegistry.registerWithDispatchPredicate(it.name, createDataLoader(it.theLoader, it.annotation, it.name, dgsDataLoaderRegistry), it.dispatchPredicate)
}
}
mappedBatchLoaders.forEach {
if (it.dispatchPredicate == null) {
dgsDataLoaderRegistry.register(
it.name,
createDataLoader(it.theLoader, it.annotation, it.name, dgsDataLoaderRegistry)
)
} else {
dgsDataLoaderRegistry.registerWithDispatchPredicate(
it.name,
createDataLoader(it.theLoader, it.annotation, it.name, dgsDataLoaderRegistry),
it.dispatchPredicate
)
}
}
batchLoadersWithContext.forEach {
if (it.dispatchPredicate == null) {
dgsDataLoaderRegistry.register(
it.name,
createDataLoader(it.theLoader, it.annotation, it.name, contextSupplier, dgsDataLoaderRegistry)
)
} else {
dgsDataLoaderRegistry.registerWithDispatchPredicate(
it.name,
createDataLoader(it.theLoader, it.annotation, it.name, contextSupplier, dgsDataLoaderRegistry),
it.dispatchPredicate
)
}
}
mappedBatchLoadersWithContext.forEach {
if (it.dispatchPredicate == null) {
dgsDataLoaderRegistry.register(
it.name,
createDataLoader(it.theLoader, it.annotation, it.name, contextSupplier, dgsDataLoaderRegistry)
)
} else {
dgsDataLoaderRegistry.registerWithDispatchPredicate(
it.name,
createDataLoader(it.theLoader, it.annotation, it.name, contextSupplier, dgsDataLoaderRegistry),
it.dispatchPredicate
)
}
batchLoaders.forEach { registerDataLoader(it, registry, contextSupplier, extensionProviders) }
batchLoadersWithContext.forEach { registerDataLoader(it, registry, contextSupplier, extensionProviders) }
mappedBatchLoaders.forEach { registerDataLoader(it, registry, contextSupplier, extensionProviders) }
mappedBatchLoadersWithContext.forEach { registerDataLoader(it, registry, contextSupplier, extensionProviders) }
}

val endTime = System.currentTimeMillis()
val totalTime = endTime - startTime
logger.debug("Created DGS dataloader registry in {}ms", totalTime)

return dgsDataLoaderRegistry
return registry
}

@PostConstruct
Expand Down Expand Up @@ -195,30 +149,32 @@ class DgsDataLoaderProvider(
batchLoader: BatchLoader<*, *>,
dgsDataLoader: DgsDataLoader,
dataLoaderName: String,
dataLoaderRegistry: DataLoaderRegistry
dataLoaderRegistry: DataLoaderRegistry,
extensionProviders: Iterable<DataLoaderInstrumentationExtensionProvider>
): DataLoader<*, *> {
val options = dataLoaderOptionsProvider.getOptions(dataLoaderName, dgsDataLoader)

if (batchLoader is DgsDataLoaderRegistryConsumer) {
batchLoader.setDataLoaderRegistry(dataLoaderRegistry)
}

val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName)
val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName, extensionProviders)
return DataLoaderFactory.newDataLoader(extendedBatchLoader, options)
}

private fun createDataLoader(
batchLoader: MappedBatchLoader<*, *>,
dgsDataLoader: DgsDataLoader,
dataLoaderName: String,
dataLoaderRegistry: DataLoaderRegistry
dataLoaderRegistry: DataLoaderRegistry,
extensionProviders: Iterable<DataLoaderInstrumentationExtensionProvider>
): DataLoader<*, *> {
val options = dataLoaderOptionsProvider.getOptions(dataLoaderName, dgsDataLoader)

if (batchLoader is DgsDataLoaderRegistryConsumer) {
batchLoader.setDataLoaderRegistry(dataLoaderRegistry)
}
val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName)
val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName, extensionProviders)

return DataLoaderFactory.newMappedDataLoader(extendedBatchLoader, options)
}
Expand All @@ -228,7 +184,8 @@ class DgsDataLoaderProvider(
dgsDataLoader: DgsDataLoader,
dataLoaderName: String,
supplier: Supplier<T>,
dataLoaderRegistry: DataLoaderRegistry
dataLoaderRegistry: DataLoaderRegistry,
extensionProviders: Iterable<DataLoaderInstrumentationExtensionProvider>
): DataLoader<*, *> {
val options = dataLoaderOptionsProvider.getOptions(dataLoaderName, dgsDataLoader)
.setBatchLoaderContextProvider(supplier::get)
Expand All @@ -237,7 +194,7 @@ class DgsDataLoaderProvider(
batchLoader.setDataLoaderRegistry(dataLoaderRegistry)
}

val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName)
val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName, extensionProviders)
return DataLoaderFactory.newDataLoader(extendedBatchLoader, options)
}

Expand All @@ -246,7 +203,8 @@ class DgsDataLoaderProvider(
dgsDataLoader: DgsDataLoader,
dataLoaderName: String,
supplier: Supplier<T>,
dataLoaderRegistry: DataLoaderRegistry
dataLoaderRegistry: DataLoaderRegistry,
extensionProviders: Iterable<DataLoaderInstrumentationExtensionProvider>
): DataLoader<*, *> {
val options = dataLoaderOptionsProvider.getOptions(dataLoaderName, dgsDataLoader)
.setBatchLoaderContextProvider(supplier::get)
Expand All @@ -255,35 +213,55 @@ class DgsDataLoaderProvider(
batchLoader.setDataLoaderRegistry(dataLoaderRegistry)
}

val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName)
val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName, extensionProviders)
return DataLoaderFactory.newMappedDataLoader(extendedBatchLoader, options)
}

private inline fun <reified T> wrappedDataLoader(loader: T, name: String): T {
try {
val stream = applicationContext
.getBeanProvider(DataLoaderInstrumentationExtensionProvider::class.java)
.orderedStream()
private fun registerDataLoader(
holder: LoaderHolder<*>,
registry: DgsDataLoaderRegistry,
contextSupplier: Supplier<*>,
extensionProviders: Iterable<DataLoaderInstrumentationExtensionProvider>
) {
val loader = when (holder.theLoader) {
is BatchLoader<*, *> -> createDataLoader(holder.theLoader, holder.annotation, holder.name, registry, extensionProviders)
is BatchLoaderWithContext<*, *> -> createDataLoader(holder.theLoader, holder.annotation, holder.name, contextSupplier, registry, extensionProviders)
is MappedBatchLoader<*, *> -> createDataLoader(holder.theLoader, holder.annotation, holder.name, registry, extensionProviders)
is MappedBatchLoaderWithContext<*, *> -> createDataLoader(holder.theLoader, holder.annotation, holder.name, contextSupplier, registry, extensionProviders)
else -> throw IllegalArgumentException("Data loader ${holder.name} has unknown type")
}
if (holder.dispatchPredicate == null) {
registry.register(holder.name, loader)
} else {
registry.registerWithDispatchPredicate(holder.name, loader, holder.dispatchPredicate)
}
}

private inline fun <reified T> wrappedDataLoader(
loader: T,
name: String,
extensionProviders: Iterable<DataLoaderInstrumentationExtensionProvider>
): T {
try {
when (loader) {
is BatchLoader<*, *> -> {
var wrappedBatchLoader: BatchLoader<*, *> = loader
stream.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) }
extensionProviders.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) }
return wrappedBatchLoader as T
}
is BatchLoaderWithContext<*, *> -> {
var wrappedBatchLoader: BatchLoaderWithContext<*, *> = loader
stream.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) }
extensionProviders.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) }
return wrappedBatchLoader as T
}
is MappedBatchLoader<*, *> -> {
var wrappedBatchLoader: MappedBatchLoader<*, *> = loader
stream.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) }
extensionProviders.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) }
return wrappedBatchLoader as T
}
is MappedBatchLoaderWithContext<*, *> -> {
var wrappedBatchLoader: MappedBatchLoaderWithContext<*, *> = loader
stream.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) }
extensionProviders.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) }
return wrappedBatchLoader as T
}
}
Expand Down

0 comments on commit fc69659

Please sign in to comment.