diff --git a/src/Microsoft.EntityFrameworkCore.InMemory/Extensions/InMemoryServiceCollectionExtensions.cs b/src/Microsoft.EntityFrameworkCore.InMemory/Extensions/InMemoryServiceCollectionExtensions.cs index 0e196114e86..e2ac1779fd2 100644 --- a/src/Microsoft.EntityFrameworkCore.InMemory/Extensions/InMemoryServiceCollectionExtensions.cs +++ b/src/Microsoft.EntityFrameworkCore.InMemory/Extensions/InMemoryServiceCollectionExtensions.cs @@ -56,23 +56,21 @@ public static IServiceCollection AddEntityFrameworkInMemoryDatabase([NotNull] th { Check.NotNull(serviceCollection, nameof(serviceCollection)); - serviceCollection.TryAddEnumerable(ServiceDescriptor - .Singleton>()); + var serviceCollectionMap = new ServiceCollectionMap(serviceCollection) + .TryAddSingletonEnumerable>() + .TryAddSingleton() + .TryAddSingleton() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(p => p.GetService()) + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(); - serviceCollection.TryAdd(new ServiceCollection() - .AddSingleton() - .AddSingleton() - .AddScoped() - .AddScoped() - .AddScoped(p => p.GetService()) - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped()); - - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollectionMap); return serviceCollection; } diff --git a/src/Microsoft.EntityFrameworkCore.Relational/Infrastructure/ServiceCollectionRelationalProviderInfrastructure.cs b/src/Microsoft.EntityFrameworkCore.Relational/Infrastructure/ServiceCollectionRelationalProviderInfrastructure.cs index 32c0185c0f8..e60a53b6564 100644 --- a/src/Microsoft.EntityFrameworkCore.Relational/Infrastructure/ServiceCollectionRelationalProviderInfrastructure.cs +++ b/src/Microsoft.EntityFrameworkCore.Relational/Infrastructure/ServiceCollectionRelationalProviderInfrastructure.cs @@ -20,7 +20,6 @@ using Microsoft.EntityFrameworkCore.Utilities; using Microsoft.EntityFrameworkCore.ValueGeneration; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.DependencyInjection.Extensions; // Intentionally in this namespace since this is for use by other relational providers rather than // by top-level app developers. @@ -37,64 +36,57 @@ public static class ServiceCollectionRelationalProviderInfrastructure /// providers after registering provider-specific services to fill-in the remaining services with /// Entity Framework defaults. /// - /// The to add services to. - public static void TryAddDefaultRelationalServices([NotNull] IServiceCollection serviceCollection) + /// The to add services to. + public static void TryAddDefaultRelationalServices([NotNull] ServiceCollectionMap serviceCollectionMap) { - Check.NotNull(serviceCollection, nameof(serviceCollection)); + Check.NotNull(serviceCollectionMap, nameof(serviceCollectionMap)); - serviceCollection - .TryAdd(new ServiceCollection() - .AddSingleton(s => new DiagnosticListener("Microsoft.EntityFrameworkCore")) - .AddSingleton(s => s.GetService()) - .AddSingleton() - .AddSingleton, ModificationCommandComparer>() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped(p => p.GetService()) - .AddScoped(p => p.GetService()) - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped()); + serviceCollectionMap + .TryAddSingleton(s => new DiagnosticListener("Microsoft.EntityFrameworkCore")) + .TryAddSingleton(s => s.GetService()) + .TryAddSingleton() + .TryAddSingleton, ModificationCommandComparer>() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(p => p.GetService()) + .TryAddScoped(p => p.GetService()) + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(); - // Add service dependencies parameter classes. - // These are added as concrete types because the classes are sealed and the registrations should - // not be changed by provider or application code. - serviceCollection - .TryAdd(new ServiceCollection() - .AddScoped()); - - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollectionMap); } } } diff --git a/src/Microsoft.EntityFrameworkCore.Specification.Tests/EntityFrameworkServiceCollectionExtensionsTest.cs b/src/Microsoft.EntityFrameworkCore.Specification.Tests/EntityFrameworkServiceCollectionExtensionsTest.cs index 71ef457dc76..0c1c0529365 100644 --- a/src/Microsoft.EntityFrameworkCore.Specification.Tests/EntityFrameworkServiceCollectionExtensionsTest.cs +++ b/src/Microsoft.EntityFrameworkCore.Specification.Tests/EntityFrameworkServiceCollectionExtensionsTest.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Linq; +using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.Extensions.DependencyInjection; using Xunit; @@ -17,13 +18,22 @@ protected EntityFrameworkServiceCollectionExtensionsTest(TestHelpers testHelpers } [Fact] - public virtual void Repeated_calls_to_add_do_not_modify_collection() + public void Calling_AddEntityFramework_explicitly_does_not_change_services() { - var expectedCollection = AddServices(new ServiceCollection()); + var services1 = AddServices(new ServiceCollection()); + var services2 = AddServices(new ServiceCollection()); + + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(services2)); - var actualCollection = AddServices(AddServices(new ServiceCollection())); + AssertServicesSame(services1, services2); + } - AssertServicesSame(expectedCollection, actualCollection); + [Fact] + public virtual void Repeated_calls_to_add_do_not_modify_collection() + { + AssertServicesSame( + AddServices(new ServiceCollection()), + AddServices(AddServices(new ServiceCollection()))); } protected virtual void AssertServicesSame(IServiceCollection services1, IServiceCollection services2) diff --git a/src/Microsoft.EntityFrameworkCore.SqlServer/Extensions/SqlServerServiceCollectionExtensions.cs b/src/Microsoft.EntityFrameworkCore.SqlServer/Extensions/SqlServerServiceCollectionExtensions.cs index 43f0f5bfa24..73ce6deb515 100644 --- a/src/Microsoft.EntityFrameworkCore.SqlServer/Extensions/SqlServerServiceCollectionExtensions.cs +++ b/src/Microsoft.EntityFrameworkCore.SqlServer/Extensions/SqlServerServiceCollectionExtensions.cs @@ -22,7 +22,6 @@ using Microsoft.EntityFrameworkCore.Utilities; using Microsoft.EntityFrameworkCore.ValueGeneration; using Microsoft.EntityFrameworkCore.ValueGeneration.Internal; -using Microsoft.Extensions.DependencyInjection.Extensions; // ReSharper disable once CheckNamespace namespace Microsoft.Extensions.DependencyInjection @@ -47,17 +46,17 @@ public static class SqlServerServiceCollectionExtensions /// /// /// - /// public void ConfigureServices(IServiceCollection services) - /// { - /// var connectionString = "connection string to database"; - /// - /// services - /// .AddEntityFrameworkSqlServer() - /// .AddDbContext<MyContext>((serviceProvider, options) => - /// options.UseSqlServer(connectionString) - /// .UseInternalServiceProvider(serviceProvider)); - /// } - /// + /// public void ConfigureServices(IServiceCollection services) + /// { + /// var connectionString = "connection string to database"; + /// + /// services + /// .AddEntityFrameworkSqlServer() + /// .AddDbContext<MyContext>((serviceProvider, options) => + /// options.UseSqlServer(connectionString) + /// .UseInternalServiceProvider(serviceProvider)); + /// } + /// /// /// The to add services to. /// @@ -67,38 +66,36 @@ public static IServiceCollection AddEntityFrameworkSqlServer([NotNull] this ISer { Check.NotNull(serviceCollection, nameof(serviceCollection)); - serviceCollection.TryAddEnumerable( - ServiceDescriptor.Singleton>()); + var serviceCollectionMap = new ServiceCollectionMap(serviceCollection) + .TryAddSingletonEnumerable>() + .TryAddSingleton() + .TryAddSingleton(p => p.GetService()) + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(p => p.GetService()) + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(p => p.GetService()) + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(); - serviceCollection.TryAdd(new ServiceCollection() - .AddSingleton() - .AddSingleton(p => p.GetService()) - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped(p => p.GetService()) - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped(p => p.GetService()) - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped()); - - ServiceCollectionRelationalProviderInfrastructure.TryAddDefaultRelationalServices(serviceCollection); + ServiceCollectionRelationalProviderInfrastructure.TryAddDefaultRelationalServices(serviceCollectionMap); return serviceCollection; } diff --git a/src/Microsoft.EntityFrameworkCore.Sqlite/Infrastructure/SqliteServiceCollectionExtensions.cs b/src/Microsoft.EntityFrameworkCore.Sqlite/Infrastructure/SqliteServiceCollectionExtensions.cs index 195db0738f9..6ee849dbece 100644 --- a/src/Microsoft.EntityFrameworkCore.Sqlite/Infrastructure/SqliteServiceCollectionExtensions.cs +++ b/src/Microsoft.EntityFrameworkCore.Sqlite/Infrastructure/SqliteServiceCollectionExtensions.cs @@ -19,7 +19,6 @@ using Microsoft.EntityFrameworkCore.Update; using Microsoft.EntityFrameworkCore.Update.Internal; using Microsoft.EntityFrameworkCore.Utilities; -using Microsoft.Extensions.DependencyInjection.Extensions; // ReSharper disable once CheckNamespace namespace Microsoft.Extensions.DependencyInjection @@ -64,27 +63,26 @@ public static IServiceCollection AddEntityFrameworkSqlite([NotNull] this IServic { Check.NotNull(serviceCollection, nameof(serviceCollection)); - serviceCollection.TryAddEnumerable( - ServiceDescriptor.Singleton>()); + var serviceCollectionMap = new ServiceCollectionMap(serviceCollection) + .TryAddSingletonEnumerable>() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(p => p.GetService()) + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(); + ; - serviceCollection.TryAdd(new ServiceCollection() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped(p => p.GetService()) - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped()); - - ServiceCollectionRelationalProviderInfrastructure.TryAddDefaultRelationalServices(serviceCollection); + ServiceCollectionRelationalProviderInfrastructure.TryAddDefaultRelationalServices(serviceCollectionMap); return serviceCollection; } diff --git a/src/Microsoft.EntityFrameworkCore/Infrastructure/ServiceCollectionMap.cs b/src/Microsoft.EntityFrameworkCore/Infrastructure/ServiceCollectionMap.cs new file mode 100644 index 00000000000..408a92a7290 --- /dev/null +++ b/src/Microsoft.EntityFrameworkCore/Infrastructure/ServiceCollectionMap.cs @@ -0,0 +1,597 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Utilities; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.EntityFrameworkCore.Infrastructure +{ + /// + /// + /// Prvoides a map over a that allows + /// entries to be conditionally added or re-written without requiring linear scans of the service + /// collection each time this is done. + /// + /// + /// Database providers are expected to create an instance of this around the service collection passed + /// to their 'Add...' method and then use the methods of this class to add services. + /// + /// + /// Note that the collection should not be modified without in other ways while it is being managed + /// by the map. The collection can be used in the normal way after modifications using the map have + /// been completed. + /// + /// + public class ServiceCollectionMap + { + private readonly IServiceCollection _serviceCollection; + private readonly IDictionary> _serviceMap = new Dictionary>(); + + /// + /// Creates a new to operate on the given . + /// + /// The collection to work with. + public ServiceCollectionMap([NotNull] IServiceCollection serviceCollection) + { + Check.NotNull(serviceCollection, nameof(serviceCollection)); + + _serviceCollection = serviceCollection; + + var index = 0; + foreach (var descriptor in serviceCollection) + { + GetOrCreateDescriptorIndexes(descriptor.ServiceType).Add(index++); + } + } + + private IList GetOrCreateDescriptorIndexes(Type serviceType) + { + IList indexes; + if (!_serviceMap.TryGetValue(serviceType, out indexes)) + { + indexes = new List(); + _serviceMap[serviceType] = indexes; + } + return indexes; + } + + /// + /// The underlying . + /// + public virtual IServiceCollection ServiceCollection => _serviceCollection; + + /// + /// Adds a service implemented by the given concrete + /// type if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddTransient() + where TService : class + where TImplementation : class, TService + => TryAdd(typeof(TService), typeof(TImplementation), ServiceLifetime.Transient); + + /// + /// Adds a service implemented by the given concrete + /// type if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddScoped() + where TService : class + where TImplementation : class, TService + => TryAdd(typeof(TService), typeof(TImplementation), ServiceLifetime.Scoped); + + /// + /// Adds a service implemented by the given concrete + /// type if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingleton() + where TService : class + where TImplementation : class, TService + => TryAdd(typeof(TService), typeof(TImplementation), ServiceLifetime.Singleton); + + /// + /// Adds a service implemented by the given concrete + /// type if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddTransient([NotNull] Type serviceType, [NotNull] Type implementationType) + => TryAdd(serviceType, implementationType, ServiceLifetime.Transient); + + /// + /// Adds a service implemented by the given concrete + /// type if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddScoped([NotNull] Type serviceType, [NotNull] Type implementationType) + => TryAdd(serviceType, implementationType, ServiceLifetime.Scoped); + + /// + /// Adds a service implemented by the given concrete + /// type if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingleton([NotNull] Type serviceType, [NotNull] Type implementationType) + => TryAdd(serviceType, implementationType, ServiceLifetime.Singleton); + + private ServiceCollectionMap TryAdd(Type serviceType, Type implementationType, ServiceLifetime lifetime) + { + Check.NotNull(serviceType, nameof(serviceType)); + Check.NotNull(implementationType, nameof(implementationType)); + + var indexes = GetOrCreateDescriptorIndexes(serviceType); + if (!indexes.Any()) + { + AddNewDescriptor(indexes, new ServiceDescriptor(serviceType, implementationType, lifetime)); + } + + return this; + } + + /// + /// Adds a service implemented by the given factory + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The factory that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddTransient([NotNull] Func factory) + where TService : class + => TryAdd(typeof(TService), factory, ServiceLifetime.Transient); + + /// + /// Adds a service implemented by the given factory + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The factory that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddScoped([NotNull] Func factory) + where TService : class + => TryAdd(typeof(TService), factory, ServiceLifetime.Scoped); + + /// + /// Adds a service implemented by the given factory + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The factory that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingleton([NotNull] Func factory) + where TService : class + => TryAdd(typeof(TService), factory, ServiceLifetime.Singleton); + + /// + /// Adds a service implemented by the given factory + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The concrete type that the given factory creates. + /// The factory that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddTransient( + [NotNull] Func factory) + where TService : class + where TImplementation : class, TService + => TryAdd(typeof(TService), factory, ServiceLifetime.Transient); + + /// + /// Adds a service implemented by the given factory + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The concrete type that the given factory creates. + /// The factory that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddScoped( + [NotNull] Func factory) + where TService : class + where TImplementation : class, TService + => TryAdd(typeof(TService), factory, ServiceLifetime.Scoped); + + /// + /// Adds a service implemented by the given factory + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The concrete type that the given factory creates. + /// The factory that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingleton( + [NotNull] Func factory) + where TService : class + where TImplementation : class, TService + => TryAdd(typeof(TService), factory, ServiceLifetime.Singleton); + + /// + /// Adds a service implemented by the given factory + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The factory that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddTransient([NotNull] Type serviceType, [NotNull] Func factory) + => TryAdd(serviceType, factory, ServiceLifetime.Transient); + + /// + /// Adds a service implemented by the given factory + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The factory that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddScoped([NotNull] Type serviceType, [NotNull] Func factory) + => TryAdd(serviceType, factory, ServiceLifetime.Scoped); + + /// + /// Adds a service implemented by the given factory + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The factory that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingleton([NotNull] Type serviceType, [NotNull] Func factory) + => TryAdd(serviceType, factory, ServiceLifetime.Singleton); + + private ServiceCollectionMap TryAdd(Type serviceType, Func factory, ServiceLifetime lifetime) + { + Check.NotNull(serviceType, nameof(serviceType)); + Check.NotNull(factory, nameof(factory)); + + var indexes = GetOrCreateDescriptorIndexes(serviceType); + if (!indexes.Any()) + { + AddNewDescriptor(indexes, new ServiceDescriptor(serviceType, factory, lifetime)); + } + + return this; + } + + /// + /// Adds a service implemented by the given instance + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The object that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingleton([CanBeNull] TService implementation) + where TService : class + => TryAdd(typeof(TService), implementation); + + /// + /// Adds a service implemented by the given instance + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The object that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingleton([NotNull] Type serviceType, [CanBeNull] object implementation) + => TryAdd(serviceType, implementation); + + private ServiceCollectionMap TryAdd(Type serviceType, object implementation) + { + Check.NotNull(serviceType, nameof(serviceType)); + + var indexes = GetOrCreateDescriptorIndexes(serviceType); + if (!indexes.Any()) + { + AddNewDescriptor(indexes, new ServiceDescriptor(serviceType, implementation)); + } + + return this; + } + + /// + /// Adds a service implemented by the given concrete + /// type to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddTransientEnumerable() + where TService : class + where TImplementation : class, TService + => TryAddEnumerable(typeof(TService), typeof(TImplementation), ServiceLifetime.Transient); + + /// + /// Adds a service implemented by the given concrete + /// type to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddScopedEnumerable() + where TService : class + where TImplementation : class, TService + => TryAddEnumerable(typeof(TService), typeof(TImplementation), ServiceLifetime.Scoped); + + /// + /// Adds a service implemented by the given concrete + /// type to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingletonEnumerable() + where TService : class + where TImplementation : class, TService + => TryAddEnumerable(typeof(TService), typeof(TImplementation), ServiceLifetime.Singleton); + + /// + /// Adds a service implemented by the given concrete + /// type to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddTransientEnumerable([NotNull] Type serviceType, [NotNull] Type implementationType) + => TryAddEnumerable(serviceType, implementationType, ServiceLifetime.Transient); + + /// + /// Adds a service implemented by the given concrete + /// type to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddScopedEnumerable([NotNull] Type serviceType, [NotNull] Type implementationType) + => TryAddEnumerable(serviceType, implementationType, ServiceLifetime.Scoped); + + /// + /// Adds a service implemented by the given concrete + /// type to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingletonEnumerable([NotNull] Type serviceType, [NotNull] Type implementationType) + => TryAddEnumerable(serviceType, implementationType, ServiceLifetime.Singleton); + + private ServiceCollectionMap TryAddEnumerable(Type serviceType, Type implementationType, ServiceLifetime lifetime) + { + Check.NotNull(serviceType, nameof(serviceType)); + Check.NotNull(implementationType, nameof(implementationType)); + + var indexes = GetOrCreateDescriptorIndexes(serviceType); + if (indexes.All(i => TryGetImplementationType(_serviceCollection[i]) != implementationType)) + { + AddNewDescriptor(indexes, new ServiceDescriptor(serviceType, implementationType, lifetime)); + } + + return this; + } + + /// + /// Adds a service implemented by the given factory + /// to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The factory that implements this service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddTransientEnumerable( + [NotNull] Func factory) + where TService : class + where TImplementation : class, TService + => TryAddEnumerable(factory, ServiceLifetime.Transient); + + /// + /// Adds a service implemented by the given factory + /// to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The factory that implements this service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddScopedEnumerable( + [NotNull] Func factory) + where TService : class + where TImplementation : class, TService + => TryAddEnumerable(factory, ServiceLifetime.Scoped); + + /// + /// Adds a service implemented by the given factory + /// to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The factory that implements this service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingletonEnumerable( + [NotNull] Func factory) + where TService : class + where TImplementation : class, TService + => TryAddEnumerable(factory, ServiceLifetime.Singleton); + + private ServiceCollectionMap TryAddEnumerable( + [NotNull] Func factory, ServiceLifetime lifetime) + where TService : class + where TImplementation : class, TService + { + Check.NotNull(factory, nameof(factory)); + + var indexes = GetOrCreateDescriptorIndexes(typeof(TService)); + if (indexes.All(i => TryGetImplementationType(_serviceCollection[i]) != typeof(TImplementation))) + { + AddNewDescriptor(indexes, new ServiceDescriptor(typeof(TService), factory, lifetime)); + } + + return this; + } + + /// + /// Adds a service implemented by the given instance + /// to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The object that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingletonEnumerable([NotNull] TService implementation) + where TService : class + => TryAddEnumerable(typeof(TService), implementation); + + /// + /// Adds a service implemented by the given instance + /// to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The object that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingletonEnumerable([NotNull] Type serviceType, [NotNull] object implementation) + => TryAddEnumerable(serviceType, implementation); + + private ServiceCollectionMap TryAddEnumerable(Type serviceType, object implementation) + { + Check.NotNull(serviceType, nameof(serviceType)); + Check.NotNull(implementation, nameof(implementation)); + + var implementationType = implementation.GetType(); + + var indexes = GetOrCreateDescriptorIndexes(serviceType); + if (indexes.All(i => TryGetImplementationType(_serviceCollection[i]) != implementationType)) + { + AddNewDescriptor(indexes, new ServiceDescriptor(serviceType, implementation)); + } + + return this; + } + + private Type TryGetImplementationType(ServiceDescriptor descriptor) + => descriptor.ImplementationType + ?? descriptor.ImplementationInstance?.GetType() + // Generic arg on Func may be obejct, but this is the best we can do and matches logic in D.I. container + ?? descriptor.ImplementationFactory?.GetType().GetTypeInfo().GenericTypeArguments[1]; + + private void AddNewDescriptor(IList indexes, ServiceDescriptor newDescriptor) + { + indexes.Add(_serviceCollection.Count); + _serviceCollection.Add(newDescriptor); + } + + /// + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + /// + /// Re-writes the registration for the gicen service such that if the implementation type + /// implements , then + /// will be called while resolving + /// the service allowing additional services to be injected without breaking the existing + /// constructor. + /// + /// + /// This mechanism should only be used to allow new services to be injected in a patch or + /// point release without making binary breaking changes. + /// + /// + /// The service contract. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap DoPatchInjection() + where TService : class + { + IList indexes; + if (_serviceMap.TryGetValue(typeof(TService), out indexes)) + { + foreach (var index in indexes) + { + var descriptor = _serviceCollection[index]; + var lifetime = descriptor.Lifetime; + var implementationType = descriptor.ImplementationType; + + if (implementationType != null) + { + var implementationIndexes = GetOrCreateDescriptorIndexes(implementationType); + if (!implementationIndexes.Any()) + { + AddNewDescriptor( + implementationIndexes, + new ServiceDescriptor(implementationType, implementationType, lifetime)); + } + + var injectedDescriptor = new ServiceDescriptor( + typeof(TService), + p => InjectServices(p, implementationType), + lifetime); + + _serviceCollection[index] = injectedDescriptor; + } + else if (descriptor.ImplementationFactory != null) + { + var injectedDescriptor = new ServiceDescriptor( + typeof(TService), + p => InjectServices(p, descriptor.ImplementationFactory), + lifetime); + + _serviceCollection[index] = injectedDescriptor; + } + else + { + var injectedDescriptor = new ServiceDescriptor( + typeof(TService), + p => InjectServices(p, descriptor.ImplementationInstance), + lifetime); + + _serviceCollection[index] = injectedDescriptor; + } + } + } + + return this; + } + + private static object InjectServices(IServiceProvider serviceProvider, Type concreteType) + { + var service = serviceProvider.GetService(concreteType); + + (service as IPatchServiceInjectionSite)?.InjectServices(serviceProvider); + + return service; + } + + private static object InjectServices(IServiceProvider serviceProvider, object service) + { + (service as IPatchServiceInjectionSite)?.InjectServices(serviceProvider); + + return service; + } + + private static object InjectServices(IServiceProvider serviceProvider, Func implementationFactory) + { + var service = implementationFactory(serviceProvider); + + (service as IPatchServiceInjectionSite)?.InjectServices(serviceProvider); + + return service; + } + } +} diff --git a/src/Microsoft.EntityFrameworkCore/Infrastructure/ServiceCollectionProviderInfrastructure.cs b/src/Microsoft.EntityFrameworkCore/Infrastructure/ServiceCollectionProviderInfrastructure.cs index eae0bef871b..53764a89d18 100644 --- a/src/Microsoft.EntityFrameworkCore/Infrastructure/ServiceCollectionProviderInfrastructure.cs +++ b/src/Microsoft.EntityFrameworkCore/Infrastructure/ServiceCollectionProviderInfrastructure.cs @@ -16,7 +16,6 @@ using Microsoft.EntityFrameworkCore.Utilities; using Microsoft.EntityFrameworkCore.ValueGeneration; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Logging; // Intentionally in this namespace since this is for use by other relational providers rather than @@ -35,77 +34,77 @@ public static class ServiceCollectionProviderInfrastructure /// Framework defaults. Relational providers should call /// 'ServiceCollectionRelationalProviderInfrastructure.TryAddDefaultRelationalServices' instead. /// - /// The to add services to. - public static void TryAddDefaultEntityFrameworkServices([NotNull] IServiceCollection serviceCollection) + /// The to add services to. + public static void TryAddDefaultEntityFrameworkServices([NotNull] ServiceCollectionMap serviceCollectionMap) { - Check.NotNull(serviceCollection, nameof(serviceCollection)); + Check.NotNull(serviceCollectionMap, nameof(serviceCollectionMap)); - serviceCollection.TryAddEnumerable(new ServiceCollection() - .AddScoped(p => p.GetService()) - .AddScoped(p => p.GetService()) - .AddScoped(p => p.GetService()) - .AddScoped(p => p.GetService()) - .AddScoped(p => p.GetService()) - .AddScoped(p => p.GetService()) - .AddScoped(p => p.GetService()) - .AddScoped(p => p.GetService())); + serviceCollectionMap + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(typeof(ISensitiveDataLogger<>), typeof(SensitiveDataLogger<>)) + .TryAddScoped(typeof(ILogger<>), typeof(InterceptingLogger<>)) + .TryAddScoped(p => GetContextServices(p).Model) + .TryAddScoped(p => GetContextServices(p).CurrentContext) + .TryAddScoped(p => GetContextServices(p).ContextOptions) + .TryAddScopedEnumerable(p => p.GetService()) + .TryAddScopedEnumerable(p => p.GetService()) + .TryAddScopedEnumerable(p => p.GetService()) + .TryAddScopedEnumerable(p => p.GetService()) + .TryAddScopedEnumerable(p => p.GetService()) + .TryAddScopedEnumerable(p => p.GetService()) + .TryAddScopedEnumerable(p => p.GetService()) + .TryAddScopedEnumerable(p => p.GetService()); - serviceCollection.TryAdd(new ServiceCollection() - .AddMemoryCache() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped(typeof(ISensitiveDataLogger<>), typeof(SensitiveDataLogger<>)) - .AddScoped(typeof(ILogger<>), typeof(InterceptingLogger<>)) - .AddScoped(p => GetContextServices(p).Model) - .AddScoped(p => GetContextServices(p).CurrentContext) - .AddScoped(p => GetContextServices(p).ContextOptions) - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped()); + // Note: does TryAdd on all services + serviceCollectionMap.ServiceCollection.AddMemoryCache(); } private static IDbContextServices GetContextServices(IServiceProvider serviceProvider) diff --git a/src/Microsoft.EntityFrameworkCore/Internal/IPatchServiceInjectionSite.cs b/src/Microsoft.EntityFrameworkCore/Internal/IPatchServiceInjectionSite.cs new file mode 100644 index 00000000000..af9cb3755ad --- /dev/null +++ b/src/Microsoft.EntityFrameworkCore/Internal/IPatchServiceInjectionSite.cs @@ -0,0 +1,21 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using JetBrains.Annotations; + +namespace Microsoft.EntityFrameworkCore.Internal +{ + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public interface IPatchServiceInjectionSite + { + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + void InjectServices([NotNull] IServiceProvider serviceProvider); + } +} diff --git a/src/Microsoft.EntityFrameworkCore/Internal/ServiceProviderCache.cs b/src/Microsoft.EntityFrameworkCore/Internal/ServiceProviderCache.cs index fbaba42bed7..ff38acdedfa 100644 --- a/src/Microsoft.EntityFrameworkCore/Internal/ServiceProviderCache.cs +++ b/src/Microsoft.EntityFrameworkCore/Internal/ServiceProviderCache.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Concurrent; +using System.Collections.Generic; using System.Linq; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Infrastructure; @@ -47,37 +48,51 @@ public virtual IServiceProvider GetOrAdd([NotNull] IDbContextOptions options) k => { var services = new ServiceCollection(); - var coreServicesAdded = false; - - foreach (var extension in options.Extensions) - { - if (extension.ApplyServices(services)) - { - coreServicesAdded = true; - } - } - - if (!coreServicesAdded) - { - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(services); - } + ApplyServices(options, services); if (replacedServices != null) { - foreach (var descriptor in services.ToList()) + // For replaced services we use the service collection to obtain the lifetime of + // the service to replace. The replaced services are added to a new collection, after + // which provider and core services are applied. This ensures that any patching happens + // to the replaced service. + var updatedServices = new ServiceCollection(); + foreach (var descriptor in services) { Type replacementType; if (replacedServices.TryGetValue(descriptor.ServiceType, out replacementType)) { - services[services.IndexOf(descriptor)] - = new ServiceDescriptor(descriptor.ServiceType, replacementType, descriptor.Lifetime); + ((IList)updatedServices).Add( + new ServiceDescriptor(descriptor.ServiceType, replacementType, descriptor.Lifetime)); } } + + ApplyServices(options, updatedServices); + services = updatedServices; } return services.BuildServiceProvider(); }); } } + + private static void ApplyServices(IDbContextOptions options, ServiceCollection services) + { + var coreServicesAdded = false; + + foreach (var extension in options.Extensions) + { + if (extension.ApplyServices(services)) + { + coreServicesAdded = true; + } + } + + if (!coreServicesAdded) + { + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices( + new ServiceCollectionMap(services)); + } + } } } diff --git a/src/Microsoft.EntityFrameworkCore/Microsoft.EntityFrameworkCore.csproj b/src/Microsoft.EntityFrameworkCore/Microsoft.EntityFrameworkCore.csproj index 70a0b75530c..b9b6454e087 100644 --- a/src/Microsoft.EntityFrameworkCore/Microsoft.EntityFrameworkCore.csproj +++ b/src/Microsoft.EntityFrameworkCore/Microsoft.EntityFrameworkCore.csproj @@ -151,7 +151,9 @@ + + diff --git a/test/Microsoft.EntityFrameworkCore.InMemory.FunctionalTests/ConfigPatternsInMemoryTest.cs b/test/Microsoft.EntityFrameworkCore.InMemory.FunctionalTests/ConfigPatternsInMemoryTest.cs index 95829b1bdd4..d191a699d57 100644 --- a/test/Microsoft.EntityFrameworkCore.InMemory.FunctionalTests/ConfigPatternsInMemoryTest.cs +++ b/test/Microsoft.EntityFrameworkCore.InMemory.FunctionalTests/ConfigPatternsInMemoryTest.cs @@ -182,7 +182,7 @@ private class NoServicesAndNoConfigBlogContext : DbContext public void Throws_on_attempt_to_use_store_with_no_store_services() { var serviceCollection = new ServiceCollection(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var serviceProvider = serviceCollection.BuildServiceProvider(); Assert.Equal( diff --git a/test/Microsoft.EntityFrameworkCore.InMemory.FunctionalTests/InMemoryServiceCollectionExtensionsTest.cs b/test/Microsoft.EntityFrameworkCore.InMemory.FunctionalTests/InMemoryServiceCollectionExtensionsTest.cs index 19306588902..e66cc200bbc 100644 --- a/test/Microsoft.EntityFrameworkCore.InMemory.FunctionalTests/InMemoryServiceCollectionExtensionsTest.cs +++ b/test/Microsoft.EntityFrameworkCore.InMemory.FunctionalTests/InMemoryServiceCollectionExtensionsTest.cs @@ -1,26 +1,12 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Specification.Tests; -using Microsoft.Extensions.DependencyInjection; -using Xunit; namespace Microsoft.EntityFrameworkCore.InMemory.FunctionalTests { public class InMemoryServiceCollectionExtensionsTest : EntityFrameworkServiceCollectionExtensionsTest { - [Fact] - public void Calling_AddEntityFramework_explicitly_does_not_change_services() - { - var services1 = new ServiceCollection().AddEntityFrameworkInMemoryDatabase(); - var services2 = new ServiceCollection().AddEntityFrameworkInMemoryDatabase(); - - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(services2); - - AssertServicesSame(services1, services2); - } - public InMemoryServiceCollectionExtensionsTest() : base(InMemoryTestHelpers.Instance) { diff --git a/test/Microsoft.EntityFrameworkCore.Relational.Tests/RelationalConnectionTest.cs b/test/Microsoft.EntityFrameworkCore.Relational.Tests/RelationalConnectionTest.cs index 97d80bd12c2..3537c405c83 100644 --- a/test/Microsoft.EntityFrameworkCore.Relational.Tests/RelationalConnectionTest.cs +++ b/test/Microsoft.EntityFrameworkCore.Relational.Tests/RelationalConnectionTest.cs @@ -51,7 +51,7 @@ public void Throws_with_add_when_no_EF_services_use_Database() public void Throws_with_new_when_no_provider_use_Database() { var serviceCollection = new ServiceCollection(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var serviceProvider = serviceCollection.BuildServiceProvider(); var options = new DbContextOptionsBuilder() @@ -70,7 +70,7 @@ public void Throws_with_new_when_no_provider_use_Database() public void Throws_with_add_when_no_provider_use_Database() { var serviceCollection = new ServiceCollection(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var appServiceProivder = serviceCollection .AddDbContext( diff --git a/test/Microsoft.EntityFrameworkCore.Relational.Tests/TestUtilities/FakeProvider/FakeRelationalOptionsExtension.cs b/test/Microsoft.EntityFrameworkCore.Relational.Tests/TestUtilities/FakeProvider/FakeRelationalOptionsExtension.cs index 0dcf8ba0587..47e887f54cd 100644 --- a/test/Microsoft.EntityFrameworkCore.Relational.Tests/TestUtilities/FakeProvider/FakeRelationalOptionsExtension.cs +++ b/test/Microsoft.EntityFrameworkCore.Relational.Tests/TestUtilities/FakeProvider/FakeRelationalOptionsExtension.cs @@ -51,7 +51,7 @@ public static IServiceCollection AddEntityFrameworkRelationalDatabase(IServiceCo .AddScoped(_ => null) .AddScoped(_ => null)); - ServiceCollectionRelationalProviderInfrastructure.TryAddDefaultRelationalServices(serviceCollection); + ServiceCollectionRelationalProviderInfrastructure.TryAddDefaultRelationalServices(new ServiceCollectionMap(serviceCollection)); return serviceCollection; } diff --git a/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/SqlServerConfigPatternsTest.cs b/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/SqlServerConfigPatternsTest.cs index c35bb275df8..766eae6de87 100644 --- a/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/SqlServerConfigPatternsTest.cs +++ b/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/SqlServerConfigPatternsTest.cs @@ -208,7 +208,7 @@ public class ImplicitConfigButNoServices public void Throws_on_attempt_to_use_store_with_no_store_services() { var serviceCollection = new ServiceCollection(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var serviceProvider = serviceCollection.BuildServiceProvider(); using (SqlServerNorthwindContext.GetSharedStore()) diff --git a/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/SqlServerServiceCollectionExtensionsTest.cs b/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/SqlServerServiceCollectionExtensionsTest.cs index 491311507e9..11bdf710238 100644 --- a/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/SqlServerServiceCollectionExtensionsTest.cs +++ b/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/SqlServerServiceCollectionExtensionsTest.cs @@ -1,26 +1,12 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Specification.Tests; -using Microsoft.Extensions.DependencyInjection; -using Xunit; namespace Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests { public class SqlServerServiceCollectionExtensionsTest : EntityFrameworkServiceCollectionExtensionsTest { - [Fact] - public void Calling_AddEntityFramework_explicitly_does_not_change_services() - { - var services1 = new ServiceCollection().AddEntityFrameworkSqlServer(); - var services2 = new ServiceCollection().AddEntityFrameworkSqlServer(); - - ServiceCollectionRelationalProviderInfrastructure.TryAddDefaultRelationalServices(services2); - - AssertServicesSame(services1, services2); - } - public SqlServerServiceCollectionExtensionsTest() : base(SqlServerTestHelpers.Instance) { diff --git a/test/Microsoft.EntityFrameworkCore.Sqlite.FunctionalTests/SqliteServiceCollectionExtensionsTest.cs b/test/Microsoft.EntityFrameworkCore.Sqlite.FunctionalTests/SqliteServiceCollectionExtensionsTest.cs index c6b67b93d8d..3cdc18fbdd7 100644 --- a/test/Microsoft.EntityFrameworkCore.Sqlite.FunctionalTests/SqliteServiceCollectionExtensionsTest.cs +++ b/test/Microsoft.EntityFrameworkCore.Sqlite.FunctionalTests/SqliteServiceCollectionExtensionsTest.cs @@ -1,26 +1,12 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Specification.Tests; -using Microsoft.Extensions.DependencyInjection; -using Xunit; namespace Microsoft.EntityFrameworkCore.Sqlite.FunctionalTests { public class SqliteServiceCollectionExtensionsTest : EntityFrameworkServiceCollectionExtensionsTest { - [Fact] - public void Calling_AddEntityFramework_explicitly_does_not_change_services() - { - var services1 = new ServiceCollection().AddEntityFrameworkSqlite(); - var services2 = new ServiceCollection().AddEntityFrameworkSqlite(); - - ServiceCollectionRelationalProviderInfrastructure.TryAddDefaultRelationalServices(services2); - - AssertServicesSame(services1, services2); - } - public SqliteServiceCollectionExtensionsTest() : base(SqliteTestHelpers.Instance) { diff --git a/test/Microsoft.EntityFrameworkCore.Tests/DbContextTest.cs b/test/Microsoft.EntityFrameworkCore.Tests/DbContextTest.cs index ac9f94bd7a5..cbc503c8914 100644 --- a/test/Microsoft.EntityFrameworkCore.Tests/DbContextTest.cs +++ b/test/Microsoft.EntityFrameworkCore.Tests/DbContextTest.cs @@ -2071,7 +2071,7 @@ public void Can_start_with_custom_services_by_passing_in_base_service_provider() public void Required_low_level_services_are_added_if_needed() { var serviceCollection = new ServiceCollection(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var provider = serviceCollection.BuildServiceProvider(); @@ -2085,7 +2085,7 @@ public void Required_low_level_services_are_not_added_if_already_present() var loggerFactory = new FakeLoggerFactory(); serviceCollection.AddSingleton(loggerFactory); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var provider = serviceCollection.BuildServiceProvider(); @@ -2098,7 +2098,7 @@ public void Low_level_services_can_be_replaced_after_being_added() var serviceCollection = new ServiceCollection(); var loggerFactory = new FakeLoggerFactory(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); serviceCollection.AddSingleton(loggerFactory); @@ -4339,7 +4339,7 @@ public void Throws_with_add_when_no_EF_services_and_no_sets() public void Throws_with_new_when_no_provider() { var serviceCollection = new ServiceCollection(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var serviceProvider = serviceCollection.BuildServiceProvider(); var options = new DbContextOptionsBuilder() @@ -4358,7 +4358,7 @@ public void Throws_with_new_when_no_provider() public void Throws_with_add_when_no_provider() { var serviceCollection = new ServiceCollection(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var appServiceProivder = serviceCollection .AddDbContext( @@ -4381,7 +4381,7 @@ public void Throws_with_add_when_no_provider() public void Throws_with_new_when_no_provider_and_no_sets() { var serviceCollection = new ServiceCollection(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var serviceProvider = serviceCollection.BuildServiceProvider(); var options = new DbContextOptionsBuilder() @@ -4400,7 +4400,7 @@ public void Throws_with_new_when_no_provider_and_no_sets() public void Throws_with_add_when_no_provider_and_no_sets() { var serviceCollection = new ServiceCollection(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var appServiceProivder = serviceCollection .AddDbContext( diff --git a/test/Microsoft.EntityFrameworkCore.Tests/Infrastructure/ServiceCollectionMapTest.cs b/test/Microsoft.EntityFrameworkCore.Tests/Infrastructure/ServiceCollectionMapTest.cs new file mode 100644 index 00000000000..7a14787e215 --- /dev/null +++ b/test/Microsoft.EntityFrameworkCore.Tests/Infrastructure/ServiceCollectionMapTest.cs @@ -0,0 +1,570 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.EntityFrameworkCore.Tests.Infrastructure +{ + public class ServiceCollectionMapTest + { + [Fact] + public void Can_add_delegate_services() + { + Func factory = p => new FakeService(); + + AddServiceDelegateTest(m => m.TryAddTransient(factory), factory, ServiceLifetime.Transient); + AddServiceDelegateTest(m => m.TryAddScoped(factory), factory, ServiceLifetime.Scoped); + AddServiceDelegateTest(m => m.TryAddSingleton(factory), factory, ServiceLifetime.Singleton); + AddServiceDelegateTest(m => m.TryAddTransient(factory), factory, ServiceLifetime.Transient); + AddServiceDelegateTest(m => m.TryAddScoped(factory), factory, ServiceLifetime.Scoped); + AddServiceDelegateTest(m => m.TryAddSingleton(factory), factory, ServiceLifetime.Singleton); + AddServiceDelegateTest(m => m.TryAddTransient(typeof(IFakeService), factory), factory, ServiceLifetime.Transient); + AddServiceDelegateTest(m => m.TryAddScoped(typeof(IFakeService), factory), factory, ServiceLifetime.Scoped); + AddServiceDelegateTest(m => m.TryAddSingleton(typeof(IFakeService), factory), factory, ServiceLifetime.Singleton); + } + + private void AddServiceDelegateTest( + Func adder, + Func factory, + ServiceLifetime lifetime) + { + var serviceCollectionMap = adder(new ServiceCollectionMap(new ServiceCollection())); + + var descriptor = serviceCollectionMap.ServiceCollection.Single(); + + Assert.Same(typeof(IFakeService), descriptor.ServiceType); + Assert.Same(factory, descriptor.ImplementationFactory); + Assert.Equal(lifetime, descriptor.Lifetime); + } + + [Fact] + public void Can_add_concrete_services() + { + AddServiceConcreteTest(m => m.TryAddTransient(), ServiceLifetime.Transient); + AddServiceConcreteTest(m => m.TryAddScoped(), ServiceLifetime.Scoped); + AddServiceConcreteTest(m => m.TryAddSingleton(), ServiceLifetime.Singleton); + AddServiceConcreteTest(m => m.TryAddTransient(typeof(IFakeService), typeof(DerivedFakeService)), ServiceLifetime.Transient); + AddServiceConcreteTest(m => m.TryAddScoped(typeof(IFakeService), typeof(DerivedFakeService)), ServiceLifetime.Scoped); + AddServiceConcreteTest(m => m.TryAddSingleton(typeof(IFakeService), typeof(DerivedFakeService)), ServiceLifetime.Singleton); + } + + private void AddServiceConcreteTest( + Func adder, + ServiceLifetime lifetime) + { + var serviceCollectionMap = adder(new ServiceCollectionMap(new ServiceCollection())); + + var descriptor = serviceCollectionMap.ServiceCollection.Single(); + + Assert.Same(typeof(IFakeService), descriptor.ServiceType); + Assert.Same(typeof(DerivedFakeService), descriptor.ImplementationType); + Assert.Equal(lifetime, descriptor.Lifetime); + } + + [Fact] + public void Can_add_instance_services() + { + var instance = new FakeService(); + + AddServiceInstanceTest(m => m.TryAddSingleton(instance), instance); + AddServiceInstanceTest(m => m.TryAddSingleton(typeof(IFakeService), instance), instance); + } + + private void AddServiceInstanceTest( + Func adder, + object instance) + { + var serviceCollectionMap = adder(new ServiceCollectionMap(new ServiceCollection())); + + var descriptor = serviceCollectionMap.ServiceCollection.Single(); + + Assert.Same(typeof(IFakeService), descriptor.ServiceType); + Assert.Same(instance, descriptor.ImplementationInstance); + Assert.Equal(ServiceLifetime.Singleton, descriptor.Lifetime); + } + + [Fact] + public void Existing_services_are_not_replaced() + { + ExistingServiceTest(m => m.TryAddTransient()); + ExistingServiceTest(m => m.TryAddScoped()); + ExistingServiceTest(m => m.TryAddSingleton()); + ExistingServiceTest(m => m.TryAddTransient(typeof(IFakeService), typeof(FakeService))); + ExistingServiceTest(m => m.TryAddScoped(typeof(IFakeService), typeof(FakeService))); + ExistingServiceTest(m => m.TryAddSingleton(typeof(IFakeService), typeof(FakeService))); + ExistingServiceTest(m => m.TryAddTransient(p => new FakeService())); + ExistingServiceTest(m => m.TryAddScoped(p => new FakeService())); + ExistingServiceTest(m => m.TryAddSingleton(p => new FakeService())); + ExistingServiceTest(m => m.TryAddTransient(p => new FakeService())); + ExistingServiceTest(m => m.TryAddScoped(p => new FakeService())); + ExistingServiceTest(m => m.TryAddSingleton(p => new FakeService())); + ExistingServiceTest(m => m.TryAddTransient(typeof(IFakeService), p => new FakeService())); + ExistingServiceTest(m => m.TryAddScoped(typeof(IFakeService), p => new FakeService())); + ExistingServiceTest(m => m.TryAddSingleton(typeof(IFakeService), p => new FakeService())); + ExistingServiceTest(m => m.TryAddSingleton(new FakeService())); + ExistingServiceTest(m => m.TryAddSingleton(typeof(IFakeService), new FakeService())); + } + + private void ExistingServiceTest(Func adder) + { + var serviceCollection = new ServiceCollection() + .AddSingleton(); + + var descriptor = serviceCollection.Single(); + + var serviceCollectionMap = adder(new ServiceCollectionMap(serviceCollection)); + + Assert.Same(serviceCollection, serviceCollectionMap.ServiceCollection); + Assert.Same(descriptor, serviceCollection.Single()); + } + + [Fact] + public void Can_add_multiple_concrete_services() + { + AddServiceConcreteEnumerableTest( + m => m.TryAddTransientEnumerable(), + m => m.TryAddTransientEnumerable(), + ServiceLifetime.Transient); + + AddServiceConcreteEnumerableTest( + m => m.TryAddScopedEnumerable(), + m => m.TryAddScopedEnumerable(), + ServiceLifetime.Scoped); + + AddServiceConcreteEnumerableTest( + m => m.TryAddSingletonEnumerable(), + m => m.TryAddSingletonEnumerable(), + ServiceLifetime.Singleton); + + AddServiceConcreteEnumerableTest( + m => m.TryAddTransientEnumerable(typeof(IFakeService), typeof(FakeService)), + m => m.TryAddTransientEnumerable(typeof(IFakeService), typeof(DerivedFakeService)), + ServiceLifetime.Transient); + + AddServiceConcreteEnumerableTest( + m => m.TryAddScopedEnumerable(typeof(IFakeService), typeof(FakeService)), + m => m.TryAddScopedEnumerable(typeof(IFakeService), typeof(DerivedFakeService)), + ServiceLifetime.Scoped); + + AddServiceConcreteEnumerableTest( + m => m.TryAddSingletonEnumerable(typeof(IFakeService), typeof(FakeService)), + m => m.TryAddSingletonEnumerable(typeof(IFakeService), typeof(DerivedFakeService)), + ServiceLifetime.Singleton); + } + + private void AddServiceConcreteEnumerableTest( + Func adder1, + Func adder2, + ServiceLifetime lifetime) + { + var serviceCollection = new ServiceCollection(); + adder2(adder1(adder2(adder1(new ServiceCollectionMap(serviceCollection))))); + + Assert.Equal(2, serviceCollection.Count); + + Assert.Same(typeof(IFakeService), serviceCollection[0].ServiceType); + Assert.Same(typeof(FakeService), serviceCollection[0].ImplementationType); + Assert.Equal(lifetime, serviceCollection[0].Lifetime); + + Assert.Same(typeof(IFakeService), serviceCollection[1].ServiceType); + Assert.Same(typeof(DerivedFakeService), serviceCollection[1].ImplementationType); + Assert.Equal(lifetime, serviceCollection[1].Lifetime); + } + + [Fact] + public void Can_add_multiple_delegate_services() + { + Func factory1 = p => new FakeService(); + Func factory2 = p => new DerivedFakeService(); + + AddServiceDelegateEnumerableTest( + m => m.TryAddTransientEnumerable(factory1), + m => m.TryAddTransientEnumerable(factory2), + factory1, factory2, ServiceLifetime.Transient); + + AddServiceDelegateEnumerableTest( + m => m.TryAddScopedEnumerable(factory1), + m => m.TryAddScopedEnumerable(factory2), + factory1, factory2, ServiceLifetime.Scoped); + + AddServiceDelegateEnumerableTest( + m => m.TryAddSingletonEnumerable(factory1), + m => m.TryAddSingletonEnumerable(factory2), + factory1, factory2, ServiceLifetime.Singleton); + } + + private void AddServiceDelegateEnumerableTest( + Func adder1, + Func adder2, + Func factory1, + Func factory2, + ServiceLifetime lifetime) + { + var serviceCollection = new ServiceCollection(); + adder2(adder1(adder2(adder1(new ServiceCollectionMap(serviceCollection))))); + + Assert.Equal(2, serviceCollection.Count); + + Assert.Same(typeof(IFakeService), serviceCollection[0].ServiceType); + Assert.Same(factory1, serviceCollection[0].ImplementationFactory); + Assert.Equal(lifetime, serviceCollection[0].Lifetime); + + Assert.Same(typeof(IFakeService), serviceCollection[1].ServiceType); + Assert.Same(factory2, serviceCollection[1].ImplementationFactory); + Assert.Equal(lifetime, serviceCollection[1].Lifetime); + } + + [Fact] + public void Can_add_multiple_instance_services() + { + var instance1 = new FakeService(); + var instance2 = new DerivedFakeService(); + + AddServiceInstanceEnumerableTest( + m => m.TryAddSingletonEnumerable(instance1), + m => m.TryAddSingletonEnumerable(instance2), + instance1, instance2); + + AddServiceInstanceEnumerableTest( + m => m.TryAddSingletonEnumerable(typeof(IFakeService), instance1), + m => m.TryAddSingletonEnumerable(typeof(IFakeService), instance2), + instance1, instance2); + } + + private void AddServiceInstanceEnumerableTest( + Func adder1, + Func adder2, + object instance1, + object instance2) + { + var serviceCollection = new ServiceCollection(); + adder2(adder1(adder2(adder1(new ServiceCollectionMap(serviceCollection))))); + + Assert.Equal(2, serviceCollection.Count); + + Assert.Same(typeof(IFakeService), serviceCollection[0].ServiceType); + Assert.Same(instance1, serviceCollection[0].ImplementationInstance); + Assert.Equal(ServiceLifetime.Singleton, serviceCollection[0].Lifetime); + + Assert.Same(typeof(IFakeService), serviceCollection[1].ServiceType); + Assert.Same(instance2, serviceCollection[1].ImplementationInstance); + Assert.Equal(ServiceLifetime.Singleton, serviceCollection[1].Lifetime); + } + + [Fact] + public void Can_patch_transient_service_with_concrete_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddTransient(); + serviceMap.DoPatchInjection(); + + Can_patch_transient_service(serviceMap); + } + + [Fact] + public void Can_patch_transient_service_with_delegate_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddTransient(p => new FakeService()); + serviceMap.DoPatchInjection(); + + Can_patch_transient_service(serviceMap); + } + + [Fact] + public void Can_patch_transient_service_with_service_typed_delegate_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddTransient(p => new FakeService()); + serviceMap.DoPatchInjection(); + + Can_patch_transient_service(serviceMap); + } + + [Fact] + public void Can_patch_transient_service_with_untyped_delegate_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddTransient(typeof(IFakeService), p => new FakeService()); + serviceMap.DoPatchInjection(); + + Can_patch_transient_service(serviceMap); + } + + [Fact] + public void Can_patch_transient_service_with_concrete_implementation_already_registered() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddTransient(); + serviceMap.TryAddTransient(); + serviceMap.DoPatchInjection(); + + Assert.IsType(Can_patch_transient_service(serviceMap)); + } + + private static FakeService Can_patch_transient_service(ServiceCollectionMap serviceMap) + { + var serviceProvider = serviceMap.ServiceCollection.BuildServiceProvider(); + + FakeService service; + + using (var context = CreateContext(serviceProvider)) + { + service = (FakeService)context.GetService(); + Assert.Same(context, service.Context); + Assert.NotSame(service, context.GetService()); + } + + using (var context = CreateContext(serviceProvider)) + { + Assert.Same(context, ((FakeService)context.GetService()).Context); + Assert.NotSame(service, context.GetService()); + } + return service; + } + + [Fact] + public void Can_patch_scoped_service_with_concrete_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddScoped(); + serviceMap.DoPatchInjection(); + + Can_patch_scoped_service(serviceMap); + } + + [Fact] + public void Can_patch_scoped_service_with_delegate_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddScoped(p => new FakeService()); + serviceMap.DoPatchInjection(); + + Can_patch_scoped_service(serviceMap); + } + + [Fact] + public void Can_patch_scoped_service_with_service_typed_delegate_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddScoped(p => new FakeService()); + serviceMap.DoPatchInjection(); + + Can_patch_scoped_service(serviceMap); + } + + [Fact] + public void Can_patch_scoped_service_with_untyped_delegate_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddScoped(typeof(IFakeService), p => new FakeService()); + serviceMap.DoPatchInjection(); + + Can_patch_scoped_service(serviceMap); + } + + [Fact] + public void Can_patch_scoped_service_with_concrete_implementation_already_registered() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddScoped(); + serviceMap.TryAddScoped(); + serviceMap.DoPatchInjection(); + + Assert.IsType(Can_patch_scoped_service(serviceMap)); + } + + private static FakeService Can_patch_scoped_service(ServiceCollectionMap serviceMap) + { + var serviceProvider = serviceMap.ServiceCollection.BuildServiceProvider(); + + FakeService service; + + using (var context = CreateContext(serviceProvider)) + { + service = (FakeService)context.GetService(); + Assert.Same(context, service.Context); + Assert.Same(service, context.GetService()); + } + + using (var context = CreateContext(serviceProvider)) + { + Assert.Same(context, ((FakeService)context.GetService()).Context); + Assert.NotSame(service, context.GetService()); + } + + return service; + } + + [Fact] + public void Can_patch_singleton_service_with_concrete_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddSingleton(); + serviceMap.DoPatchInjection(); + + Can_patch_singleton_service(serviceMap); + } + + [Fact] + public void Can_patch_singleton_service_with_delegate_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddSingleton(p => new FakeSingletonService()); + serviceMap.DoPatchInjection(); + + Can_patch_singleton_service(serviceMap); + } + + [Fact] + public void Can_patch_singleton_service_with_service_typed_delegate_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddSingleton(p => new FakeSingletonService()); + serviceMap.DoPatchInjection(); + + Can_patch_singleton_service(serviceMap); + } + + [Fact] + public void Can_patch_singleton_service_with_untyped_delegate_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddSingleton(typeof(IFakeSingletonService), p => new FakeSingletonService()); + serviceMap.DoPatchInjection(); + + Can_patch_singleton_service(serviceMap); + } + + [Fact] + public void Can_patch_singleton_service_with_concrete_implementation_already_registered() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddSingleton(); + serviceMap.TryAddSingleton(); + serviceMap.DoPatchInjection(); + + Assert.IsType(Can_patch_singleton_service(serviceMap)); + } + + [Fact] + public void Can_patch_singleton_service_with_instance_registered() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddSingleton(new DerivedFakeSingletonService()); + serviceMap.DoPatchInjection(); + + Assert.IsType(Can_patch_singleton_service(serviceMap)); + } + + [Fact] + public void Can_patch_singleton_service_with_instance_registered_non_generic() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddSingleton(typeof(IFakeSingletonService), new DerivedFakeSingletonService()); + serviceMap.DoPatchInjection(); + + Assert.IsType(Can_patch_singleton_service(serviceMap)); + } + + [Fact] + public virtual void Same_INavigationFixer_is_returned_for_all_registrations() + { + using (var context = new DbContext(new DbContextOptionsBuilder().UseInMemoryDatabase().Options)) + { + var navFixer = context.GetService(); + + Assert.Contains(navFixer, context.GetService>()); + Assert.Contains(navFixer, context.GetService>()); + Assert.Contains(navFixer, context.GetService>()); + Assert.Contains(navFixer, context.GetService>()); + } + } + + private static FakeSingletonService Can_patch_singleton_service(ServiceCollectionMap serviceMap) + { + var serviceProvider = serviceMap.ServiceCollection.BuildServiceProvider(); + + FakeSingletonService singletonService; + + using (var context = CreateContext(serviceProvider)) + { + singletonService = (FakeSingletonService)context.GetService(); + Assert.Same(context.GetService(), singletonService.ModelSource); + Assert.Same(singletonService, context.GetService()); + } + + using (var context = CreateContext(serviceProvider)) + { + Assert.Same(singletonService, context.GetService()); + } + + return singletonService; + } + + private static ServiceCollectionMap CreateServiceMap() + => new ServiceCollectionMap(new ServiceCollection().AddEntityFrameworkInMemoryDatabase()); + + private static DbContext CreateContext(IServiceProvider serviceProvider) + => new DbContext(new DbContextOptionsBuilder() + .UseInternalServiceProvider(serviceProvider) + .UseInMemoryDatabase() + .Options); + + private interface IFakeService + { + } + + private class FakeService : IFakeService, IPatchServiceInjectionSite + { + public DbContext Context { get; private set; } + + void IPatchServiceInjectionSite.InjectServices(IServiceProvider serviceProvider) + => Context = serviceProvider.GetService().Context; + } + + private class DerivedFakeService : FakeService + { + } + + private interface IFakeSingletonService + { + } + + private class FakeSingletonService : IFakeSingletonService, IPatchServiceInjectionSite + { + public IModelSource ModelSource { get; private set; } + + void IPatchServiceInjectionSite.InjectServices(IServiceProvider serviceProvider) + => ModelSource = serviceProvider.GetService(); + } + + private class DerivedFakeSingletonService : FakeSingletonService + { + } + } +} diff --git a/test/Microsoft.EntityFrameworkCore.Tests/Microsoft.EntityFrameworkCore.Tests.csproj b/test/Microsoft.EntityFrameworkCore.Tests/Microsoft.EntityFrameworkCore.Tests.csproj index d4bff802dcc..19a8b10778d 100644 --- a/test/Microsoft.EntityFrameworkCore.Tests/Microsoft.EntityFrameworkCore.Tests.csproj +++ b/test/Microsoft.EntityFrameworkCore.Tests/Microsoft.EntityFrameworkCore.Tests.csproj @@ -105,6 +105,7 @@ +