diff --git a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DgsDataLoaderProvider.kt b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DgsDataLoaderProvider.kt index 90a470bae..26d0ea65e 100644 --- a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DgsDataLoaderProvider.kt +++ b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DgsDataLoaderProvider.kt @@ -42,6 +42,7 @@ import org.springframework.beans.factory.NoSuchBeanDefinitionException import org.springframework.context.ApplicationContext import org.springframework.util.ReflectionUtils import java.util.function.Supplier +import kotlin.system.measureTimeMillis /** * Framework implementation class responsible for finding and configuring data loaders. @@ -63,67 +64,20 @@ class DgsDataLoaderProvider( } fun buildRegistryWithContextSupplier(contextSupplier: Supplier): DataLoaderRegistry { - val startTime = System.currentTimeMillis() - val dgsDataLoaderRegistry = DgsDataLoaderRegistry() + val registry = DgsDataLoaderRegistry() + val totalTime = measureTimeMillis { + val extensionProviders = applicationContext + .getBeanProvider(DataLoaderInstrumentationExtensionProvider::class.java) + .orderedStream() + .toList() - batchLoaders.forEach { - if (it.dispatchPredicate == null) { - dgsDataLoaderRegistry.register( - it.name, - createDataLoader(it.theLoader, it.annotation, it.name, dgsDataLoaderRegistry) - ) - } else { - dgsDataLoaderRegistry.registerWithDispatchPredicate(it.name, createDataLoader(it.theLoader, it.annotation, it.name, dgsDataLoaderRegistry), it.dispatchPredicate) - } - } - mappedBatchLoaders.forEach { - if (it.dispatchPredicate == null) { - dgsDataLoaderRegistry.register( - it.name, - createDataLoader(it.theLoader, it.annotation, it.name, dgsDataLoaderRegistry) - ) - } else { - dgsDataLoaderRegistry.registerWithDispatchPredicate( - it.name, - createDataLoader(it.theLoader, it.annotation, it.name, dgsDataLoaderRegistry), - it.dispatchPredicate - ) - } - } - batchLoadersWithContext.forEach { - if (it.dispatchPredicate == null) { - dgsDataLoaderRegistry.register( - it.name, - createDataLoader(it.theLoader, it.annotation, it.name, contextSupplier, dgsDataLoaderRegistry) - ) - } else { - dgsDataLoaderRegistry.registerWithDispatchPredicate( - it.name, - createDataLoader(it.theLoader, it.annotation, it.name, contextSupplier, dgsDataLoaderRegistry), - it.dispatchPredicate - ) - } - } - mappedBatchLoadersWithContext.forEach { - if (it.dispatchPredicate == null) { - dgsDataLoaderRegistry.register( - it.name, - createDataLoader(it.theLoader, it.annotation, it.name, contextSupplier, dgsDataLoaderRegistry) - ) - } else { - dgsDataLoaderRegistry.registerWithDispatchPredicate( - it.name, - createDataLoader(it.theLoader, it.annotation, it.name, contextSupplier, dgsDataLoaderRegistry), - it.dispatchPredicate - ) - } + batchLoaders.forEach { registerDataLoader(it, registry, contextSupplier, extensionProviders) } + batchLoadersWithContext.forEach { registerDataLoader(it, registry, contextSupplier, extensionProviders) } + mappedBatchLoaders.forEach { registerDataLoader(it, registry, contextSupplier, extensionProviders) } + mappedBatchLoadersWithContext.forEach { registerDataLoader(it, registry, contextSupplier, extensionProviders) } } - - val endTime = System.currentTimeMillis() - val totalTime = endTime - startTime logger.debug("Created DGS dataloader registry in {}ms", totalTime) - - return dgsDataLoaderRegistry + return registry } @PostConstruct @@ -195,7 +149,8 @@ class DgsDataLoaderProvider( batchLoader: BatchLoader<*, *>, dgsDataLoader: DgsDataLoader, dataLoaderName: String, - dataLoaderRegistry: DataLoaderRegistry + dataLoaderRegistry: DataLoaderRegistry, + extensionProviders: Iterable ): DataLoader<*, *> { val options = dataLoaderOptionsProvider.getOptions(dataLoaderName, dgsDataLoader) @@ -203,7 +158,7 @@ class DgsDataLoaderProvider( batchLoader.setDataLoaderRegistry(dataLoaderRegistry) } - val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName) + val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName, extensionProviders) return DataLoaderFactory.newDataLoader(extendedBatchLoader, options) } @@ -211,14 +166,15 @@ class DgsDataLoaderProvider( batchLoader: MappedBatchLoader<*, *>, dgsDataLoader: DgsDataLoader, dataLoaderName: String, - dataLoaderRegistry: DataLoaderRegistry + dataLoaderRegistry: DataLoaderRegistry, + extensionProviders: Iterable ): DataLoader<*, *> { val options = dataLoaderOptionsProvider.getOptions(dataLoaderName, dgsDataLoader) if (batchLoader is DgsDataLoaderRegistryConsumer) { batchLoader.setDataLoaderRegistry(dataLoaderRegistry) } - val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName) + val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName, extensionProviders) return DataLoaderFactory.newMappedDataLoader(extendedBatchLoader, options) } @@ -228,7 +184,8 @@ class DgsDataLoaderProvider( dgsDataLoader: DgsDataLoader, dataLoaderName: String, supplier: Supplier, - dataLoaderRegistry: DataLoaderRegistry + dataLoaderRegistry: DataLoaderRegistry, + extensionProviders: Iterable ): DataLoader<*, *> { val options = dataLoaderOptionsProvider.getOptions(dataLoaderName, dgsDataLoader) .setBatchLoaderContextProvider(supplier::get) @@ -237,7 +194,7 @@ class DgsDataLoaderProvider( batchLoader.setDataLoaderRegistry(dataLoaderRegistry) } - val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName) + val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName, extensionProviders) return DataLoaderFactory.newDataLoader(extendedBatchLoader, options) } @@ -246,7 +203,8 @@ class DgsDataLoaderProvider( dgsDataLoader: DgsDataLoader, dataLoaderName: String, supplier: Supplier, - dataLoaderRegistry: DataLoaderRegistry + dataLoaderRegistry: DataLoaderRegistry, + extensionProviders: Iterable ): DataLoader<*, *> { val options = dataLoaderOptionsProvider.getOptions(dataLoaderName, dgsDataLoader) .setBatchLoaderContextProvider(supplier::get) @@ -255,35 +213,55 @@ class DgsDataLoaderProvider( batchLoader.setDataLoaderRegistry(dataLoaderRegistry) } - val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName) + val extendedBatchLoader = wrappedDataLoader(batchLoader, dataLoaderName, extensionProviders) return DataLoaderFactory.newMappedDataLoader(extendedBatchLoader, options) } - private inline fun wrappedDataLoader(loader: T, name: String): T { - try { - val stream = applicationContext - .getBeanProvider(DataLoaderInstrumentationExtensionProvider::class.java) - .orderedStream() + private fun registerDataLoader( + holder: LoaderHolder<*>, + registry: DgsDataLoaderRegistry, + contextSupplier: Supplier<*>, + extensionProviders: Iterable + ) { + val loader = when (holder.theLoader) { + is BatchLoader<*, *> -> createDataLoader(holder.theLoader, holder.annotation, holder.name, registry, extensionProviders) + is BatchLoaderWithContext<*, *> -> createDataLoader(holder.theLoader, holder.annotation, holder.name, contextSupplier, registry, extensionProviders) + is MappedBatchLoader<*, *> -> createDataLoader(holder.theLoader, holder.annotation, holder.name, registry, extensionProviders) + is MappedBatchLoaderWithContext<*, *> -> createDataLoader(holder.theLoader, holder.annotation, holder.name, contextSupplier, registry, extensionProviders) + else -> throw IllegalArgumentException("Data loader ${holder.name} has unknown type") + } + if (holder.dispatchPredicate == null) { + registry.register(holder.name, loader) + } else { + registry.registerWithDispatchPredicate(holder.name, loader, holder.dispatchPredicate) + } + } + private inline fun wrappedDataLoader( + loader: T, + name: String, + extensionProviders: Iterable + ): T { + try { when (loader) { is BatchLoader<*, *> -> { var wrappedBatchLoader: BatchLoader<*, *> = loader - stream.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) } + extensionProviders.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) } return wrappedBatchLoader as T } is BatchLoaderWithContext<*, *> -> { var wrappedBatchLoader: BatchLoaderWithContext<*, *> = loader - stream.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) } + extensionProviders.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) } return wrappedBatchLoader as T } is MappedBatchLoader<*, *> -> { var wrappedBatchLoader: MappedBatchLoader<*, *> = loader - stream.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) } + extensionProviders.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) } return wrappedBatchLoader as T } is MappedBatchLoaderWithContext<*, *> -> { var wrappedBatchLoader: MappedBatchLoaderWithContext<*, *> = loader - stream.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) } + extensionProviders.forEach { wrappedBatchLoader = it.provide(wrappedBatchLoader, name) } return wrappedBatchLoader as T } }