From 631778a290dd93cbafb07e031f83844c53f6a584 Mon Sep 17 00:00:00 2001 From: Martin Costello Date: Wed, 21 Aug 2024 21:17:29 +0100 Subject: [PATCH] Add GetKeyedService overload with Type (#105860) --- ...nsions.DependencyInjection.Abstractions.cs | 1 + .../ServiceProviderKeyedServiceExtensions.cs | 20 ++++++++++++ ...edDependencyInjectionSpecificationTests.cs | 32 +++++++++++++++++-- 3 files changed, 51 insertions(+), 2 deletions(-) diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/ref/Microsoft.Extensions.DependencyInjection.Abstractions.cs b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/ref/Microsoft.Extensions.DependencyInjection.Abstractions.cs index c9b3d241f0f23..dfb07e7ea9b7a 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/ref/Microsoft.Extensions.DependencyInjection.Abstractions.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/ref/Microsoft.Extensions.DependencyInjection.Abstractions.cs @@ -218,6 +218,7 @@ public static partial class ServiceProviderKeyedServiceExtensions public static System.Collections.Generic.IEnumerable GetKeyedServices(this System.IServiceProvider provider, System.Type serviceType, object? serviceKey) { throw null; } public static System.Collections.Generic.IEnumerable GetKeyedServices(this System.IServiceProvider provider, object? serviceKey) { throw null; } public static T? GetKeyedService(this System.IServiceProvider provider, object? serviceKey) { throw null; } + public static object? GetKeyedService(this System.IServiceProvider provider, System.Type serviceType, object? serviceKey) { throw null; } public static object GetRequiredKeyedService(this System.IServiceProvider provider, System.Type serviceType, object? serviceKey) { throw null; } public static T GetRequiredKeyedService(this System.IServiceProvider provider, object? serviceKey) where T : notnull { throw null; } } diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ServiceProviderKeyedServiceExtensions.cs b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ServiceProviderKeyedServiceExtensions.cs index 4ae52a0e9ed49..03075fbc9c2e6 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ServiceProviderKeyedServiceExtensions.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ServiceProviderKeyedServiceExtensions.cs @@ -31,6 +31,26 @@ public static class ServiceProviderKeyedServiceExtensions throw new InvalidOperationException(SR.KeyedServicesNotSupported); } + /// + /// Get service of type from the . + /// + /// The to retrieve the service object from. + /// An object that specifies the type of service object to get. + /// An object that specifies the key of service object to get. + /// A service object of type or null if there is no such service. + public static object? GetKeyedService(this IServiceProvider provider, Type serviceType, object? serviceKey) + { + ThrowHelper.ThrowIfNull(provider); + ThrowHelper.ThrowIfNull(serviceType); + + if (provider is IKeyedServiceProvider keyedServiceProvider) + { + return keyedServiceProvider.GetKeyedService(serviceType, serviceKey); + } + + throw new InvalidOperationException(SR.KeyedServicesNotSupported); + } + /// /// Get service of type from the . /// diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs b/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs index 8916a906d8572..ee7daad0be7eb 100644 --- a/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs +++ b/src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs @@ -27,6 +27,10 @@ public void ResolveKeyedService() Assert.Null(provider.GetService()); Assert.Same(service1, provider.GetKeyedService("service1")); Assert.Same(service2, provider.GetKeyedService("service2")); + + Assert.Null(provider.GetService(typeof(IService))); + Assert.Same(service1, provider.GetKeyedService(typeof(IService), "service1")); + Assert.Same(service2, provider.GetKeyedService(typeof(IService), "service2")); } [Fact] @@ -39,10 +43,12 @@ public void ResolveNullKeyedService() var provider = CreateServiceProvider(serviceCollection); var nonKeyed = provider.GetService(); - var nullKey = provider.GetKeyedService(null); + var nullKeyOfT = provider.GetKeyedService(null); + var nullKeyOfType = provider.GetKeyedService(typeof(IService), null); Assert.Same(service1, nonKeyed); - Assert.Same(service1, nullKey); + Assert.Same(service1, nullKeyOfT); + Assert.Same(service1, nullKeyOfType); } [Fact] @@ -192,6 +198,7 @@ public void ResolveKeyedServiceSingletonInstance() Assert.Null(provider.GetService()); Assert.Same(service, provider.GetKeyedService("service1")); + Assert.Same(service, provider.GetKeyedService(typeof(IService), "service1")); } [Fact] @@ -355,6 +362,7 @@ public void ResolveKeyedServiceSingletonFactory() Assert.Null(provider.GetService()); Assert.Same(service, provider.GetKeyedService("service1")); + Assert.Same(service, provider.GetKeyedService(typeof(IService), "service1")); } [Fact] @@ -388,6 +396,7 @@ public void ResolveKeyedServiceSingletonFactoryWithAnyKeyIgnoreWrongType() Assert.Null(provider.GetService()); Assert.NotNull(provider.GetKeyedService(87)); Assert.ThrowsAny(() => provider.GetKeyedService(new object())); + Assert.ThrowsAny(() => provider.GetKeyedService(typeof(IService), new object())); } [Fact] @@ -554,6 +563,20 @@ public void ResolveKeyedTransientFromScopeServiceProvider() Assert.NotSame(serviceA1, serviceB1); } + [Fact] + public void ResolveKeyedServiceThrowsIfNotSupported() + { + var provider = new NonKeyedServiceProvider(); + var serviceKey = new object(); + + Assert.Throws(() => provider.GetKeyedService(serviceKey)); + Assert.Throws(() => provider.GetKeyedService(typeof(IService), serviceKey)); + Assert.Throws(() => provider.GetKeyedServices(serviceKey)); + Assert.Throws(() => provider.GetKeyedServices(typeof(IService), serviceKey)); + Assert.Throws(() => provider.GetRequiredKeyedService(serviceKey)); + Assert.Throws(() => provider.GetRequiredKeyedService(typeof(IService), serviceKey)); + } + public interface IService { } public class Service : IService @@ -664,5 +687,10 @@ public interface ISimpleService { } public class SimpleService : ISimpleService { } public class AnotherSimpleService : ISimpleService { } + + public class NonKeyedServiceProvider : IServiceProvider + { + public object GetService(Type serviceType) => throw new NotImplementedException(); + } } }