From 4112c260e423a4a0a2ce910440d037c3eef9ed1f Mon Sep 17 00:00:00 2001 From: Arthur Vickers Date: Thu, 2 Feb 2017 14:26:36 -0800 Subject: [PATCH] Implement service-injection mechanism for patch/point releases Issue #7465 This change introduces a new class ServiceCollectionMap that is used by providers and by our service-adding methods. The new class builds a map from service type to indexes of the ServiceDescriptors in the list for the given type. This allows all the TryAdd calls that we make to work without scanning the list every time. In addition, this type has a method for re-writing services to do property injection when we need to add a service dependency without breaking a constructor in a patch or point release. This is used by making a call at the end of our service registrations like so: ```C# ... .TryAddScoped() .DoPatchInjection(); ``` Also, ReplaceService code has been updated to ensure that service-injection is also run on any replaced services, since these services may also inherit from our base classes. --- .../InMemoryServiceCollectionExtensions.cs | 30 +- ...lectionRelationalProviderInfrastructure.cs | 104 ++- ...rameworkServiceCollectionExtensionsTest.cs | 18 +- .../SqlServerServiceCollectionExtensions.cs | 83 ++- .../SqliteServiceCollectionExtensions.cs | 40 +- .../Infrastructure/ServiceCollectionMap.cs | 597 ++++++++++++++++++ ...ServiceCollectionProviderInfrastructure.cs | 137 ++-- .../Internal/IPatchServiceInjectionSite.cs | 21 + .../Internal/ServiceProviderCache.cs | 49 +- .../Microsoft.EntityFrameworkCore.csproj | 2 + .../ConfigPatternsInMemoryTest.cs | 2 +- ...InMemoryServiceCollectionExtensionsTest.cs | 14 - .../RelationalConnectionTest.cs | 4 +- .../FakeRelationalOptionsExtension.cs | 2 +- .../SqlServerConfigPatternsTest.cs | 2 +- ...qlServerServiceCollectionExtensionsTest.cs | 14 - .../SqliteServiceCollectionExtensionsTest.cs | 14 - .../DbContextTest.cs | 14 +- .../ServiceCollectionMapTest.cs | 570 +++++++++++++++++ ...Microsoft.EntityFrameworkCore.Tests.csproj | 1 + 20 files changed, 1438 insertions(+), 280 deletions(-) create mode 100644 src/Microsoft.EntityFrameworkCore/Infrastructure/ServiceCollectionMap.cs create mode 100644 src/Microsoft.EntityFrameworkCore/Internal/IPatchServiceInjectionSite.cs create mode 100644 test/Microsoft.EntityFrameworkCore.Tests/Infrastructure/ServiceCollectionMapTest.cs diff --git a/src/Microsoft.EntityFrameworkCore.InMemory/Extensions/InMemoryServiceCollectionExtensions.cs b/src/Microsoft.EntityFrameworkCore.InMemory/Extensions/InMemoryServiceCollectionExtensions.cs index 0e196114e86..e2ac1779fd2 100644 --- a/src/Microsoft.EntityFrameworkCore.InMemory/Extensions/InMemoryServiceCollectionExtensions.cs +++ b/src/Microsoft.EntityFrameworkCore.InMemory/Extensions/InMemoryServiceCollectionExtensions.cs @@ -56,23 +56,21 @@ public static IServiceCollection AddEntityFrameworkInMemoryDatabase([NotNull] th { Check.NotNull(serviceCollection, nameof(serviceCollection)); - serviceCollection.TryAddEnumerable(ServiceDescriptor - .Singleton>()); + var serviceCollectionMap = new ServiceCollectionMap(serviceCollection) + .TryAddSingletonEnumerable>() + .TryAddSingleton() + .TryAddSingleton() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(p => p.GetService()) + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(); - serviceCollection.TryAdd(new ServiceCollection() - .AddSingleton() - .AddSingleton() - .AddScoped() - .AddScoped() - .AddScoped(p => p.GetService()) - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped()); - - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollectionMap); return serviceCollection; } diff --git a/src/Microsoft.EntityFrameworkCore.Relational/Infrastructure/ServiceCollectionRelationalProviderInfrastructure.cs b/src/Microsoft.EntityFrameworkCore.Relational/Infrastructure/ServiceCollectionRelationalProviderInfrastructure.cs index 32c0185c0f8..e60a53b6564 100644 --- a/src/Microsoft.EntityFrameworkCore.Relational/Infrastructure/ServiceCollectionRelationalProviderInfrastructure.cs +++ b/src/Microsoft.EntityFrameworkCore.Relational/Infrastructure/ServiceCollectionRelationalProviderInfrastructure.cs @@ -20,7 +20,6 @@ using Microsoft.EntityFrameworkCore.Utilities; using Microsoft.EntityFrameworkCore.ValueGeneration; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.DependencyInjection.Extensions; // Intentionally in this namespace since this is for use by other relational providers rather than // by top-level app developers. @@ -37,64 +36,57 @@ public static class ServiceCollectionRelationalProviderInfrastructure /// providers after registering provider-specific services to fill-in the remaining services with /// Entity Framework defaults. /// - /// The to add services to. - public static void TryAddDefaultRelationalServices([NotNull] IServiceCollection serviceCollection) + /// The to add services to. + public static void TryAddDefaultRelationalServices([NotNull] ServiceCollectionMap serviceCollectionMap) { - Check.NotNull(serviceCollection, nameof(serviceCollection)); + Check.NotNull(serviceCollectionMap, nameof(serviceCollectionMap)); - serviceCollection - .TryAdd(new ServiceCollection() - .AddSingleton(s => new DiagnosticListener("Microsoft.EntityFrameworkCore")) - .AddSingleton(s => s.GetService()) - .AddSingleton() - .AddSingleton, ModificationCommandComparer>() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped(p => p.GetService()) - .AddScoped(p => p.GetService()) - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped()); + serviceCollectionMap + .TryAddSingleton(s => new DiagnosticListener("Microsoft.EntityFrameworkCore")) + .TryAddSingleton(s => s.GetService()) + .TryAddSingleton() + .TryAddSingleton, ModificationCommandComparer>() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(p => p.GetService()) + .TryAddScoped(p => p.GetService()) + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(); - // Add service dependencies parameter classes. - // These are added as concrete types because the classes are sealed and the registrations should - // not be changed by provider or application code. - serviceCollection - .TryAdd(new ServiceCollection() - .AddScoped()); - - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollectionMap); } } } diff --git a/src/Microsoft.EntityFrameworkCore.Specification.Tests/EntityFrameworkServiceCollectionExtensionsTest.cs b/src/Microsoft.EntityFrameworkCore.Specification.Tests/EntityFrameworkServiceCollectionExtensionsTest.cs index 71ef457dc76..0c1c0529365 100644 --- a/src/Microsoft.EntityFrameworkCore.Specification.Tests/EntityFrameworkServiceCollectionExtensionsTest.cs +++ b/src/Microsoft.EntityFrameworkCore.Specification.Tests/EntityFrameworkServiceCollectionExtensionsTest.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Linq; +using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.Extensions.DependencyInjection; using Xunit; @@ -17,13 +18,22 @@ protected EntityFrameworkServiceCollectionExtensionsTest(TestHelpers testHelpers } [Fact] - public virtual void Repeated_calls_to_add_do_not_modify_collection() + public void Calling_AddEntityFramework_explicitly_does_not_change_services() { - var expectedCollection = AddServices(new ServiceCollection()); + var services1 = AddServices(new ServiceCollection()); + var services2 = AddServices(new ServiceCollection()); + + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(services2)); - var actualCollection = AddServices(AddServices(new ServiceCollection())); + AssertServicesSame(services1, services2); + } - AssertServicesSame(expectedCollection, actualCollection); + [Fact] + public virtual void Repeated_calls_to_add_do_not_modify_collection() + { + AssertServicesSame( + AddServices(new ServiceCollection()), + AddServices(AddServices(new ServiceCollection()))); } protected virtual void AssertServicesSame(IServiceCollection services1, IServiceCollection services2) diff --git a/src/Microsoft.EntityFrameworkCore.SqlServer/Extensions/SqlServerServiceCollectionExtensions.cs b/src/Microsoft.EntityFrameworkCore.SqlServer/Extensions/SqlServerServiceCollectionExtensions.cs index 43f0f5bfa24..73ce6deb515 100644 --- a/src/Microsoft.EntityFrameworkCore.SqlServer/Extensions/SqlServerServiceCollectionExtensions.cs +++ b/src/Microsoft.EntityFrameworkCore.SqlServer/Extensions/SqlServerServiceCollectionExtensions.cs @@ -22,7 +22,6 @@ using Microsoft.EntityFrameworkCore.Utilities; using Microsoft.EntityFrameworkCore.ValueGeneration; using Microsoft.EntityFrameworkCore.ValueGeneration.Internal; -using Microsoft.Extensions.DependencyInjection.Extensions; // ReSharper disable once CheckNamespace namespace Microsoft.Extensions.DependencyInjection @@ -47,17 +46,17 @@ public static class SqlServerServiceCollectionExtensions /// /// /// - /// public void ConfigureServices(IServiceCollection services) - /// { - /// var connectionString = "connection string to database"; - /// - /// services - /// .AddEntityFrameworkSqlServer() - /// .AddDbContext<MyContext>((serviceProvider, options) => - /// options.UseSqlServer(connectionString) - /// .UseInternalServiceProvider(serviceProvider)); - /// } - /// + /// public void ConfigureServices(IServiceCollection services) + /// { + /// var connectionString = "connection string to database"; + /// + /// services + /// .AddEntityFrameworkSqlServer() + /// .AddDbContext<MyContext>((serviceProvider, options) => + /// options.UseSqlServer(connectionString) + /// .UseInternalServiceProvider(serviceProvider)); + /// } + /// /// /// The to add services to. /// @@ -67,38 +66,36 @@ public static IServiceCollection AddEntityFrameworkSqlServer([NotNull] this ISer { Check.NotNull(serviceCollection, nameof(serviceCollection)); - serviceCollection.TryAddEnumerable( - ServiceDescriptor.Singleton>()); + var serviceCollectionMap = new ServiceCollectionMap(serviceCollection) + .TryAddSingletonEnumerable>() + .TryAddSingleton() + .TryAddSingleton(p => p.GetService()) + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(p => p.GetService()) + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(p => p.GetService()) + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(); - serviceCollection.TryAdd(new ServiceCollection() - .AddSingleton() - .AddSingleton(p => p.GetService()) - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped(p => p.GetService()) - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped(p => p.GetService()) - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped()); - - ServiceCollectionRelationalProviderInfrastructure.TryAddDefaultRelationalServices(serviceCollection); + ServiceCollectionRelationalProviderInfrastructure.TryAddDefaultRelationalServices(serviceCollectionMap); return serviceCollection; } diff --git a/src/Microsoft.EntityFrameworkCore.Sqlite/Infrastructure/SqliteServiceCollectionExtensions.cs b/src/Microsoft.EntityFrameworkCore.Sqlite/Infrastructure/SqliteServiceCollectionExtensions.cs index 195db0738f9..6ee849dbece 100644 --- a/src/Microsoft.EntityFrameworkCore.Sqlite/Infrastructure/SqliteServiceCollectionExtensions.cs +++ b/src/Microsoft.EntityFrameworkCore.Sqlite/Infrastructure/SqliteServiceCollectionExtensions.cs @@ -19,7 +19,6 @@ using Microsoft.EntityFrameworkCore.Update; using Microsoft.EntityFrameworkCore.Update.Internal; using Microsoft.EntityFrameworkCore.Utilities; -using Microsoft.Extensions.DependencyInjection.Extensions; // ReSharper disable once CheckNamespace namespace Microsoft.Extensions.DependencyInjection @@ -64,27 +63,26 @@ public static IServiceCollection AddEntityFrameworkSqlite([NotNull] this IServic { Check.NotNull(serviceCollection, nameof(serviceCollection)); - serviceCollection.TryAddEnumerable( - ServiceDescriptor.Singleton>()); + var serviceCollectionMap = new ServiceCollectionMap(serviceCollection) + .TryAddSingletonEnumerable>() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(p => p.GetService()) + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(); + ; - serviceCollection.TryAdd(new ServiceCollection() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped(p => p.GetService()) - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped()); - - ServiceCollectionRelationalProviderInfrastructure.TryAddDefaultRelationalServices(serviceCollection); + ServiceCollectionRelationalProviderInfrastructure.TryAddDefaultRelationalServices(serviceCollectionMap); return serviceCollection; } diff --git a/src/Microsoft.EntityFrameworkCore/Infrastructure/ServiceCollectionMap.cs b/src/Microsoft.EntityFrameworkCore/Infrastructure/ServiceCollectionMap.cs new file mode 100644 index 00000000000..408a92a7290 --- /dev/null +++ b/src/Microsoft.EntityFrameworkCore/Infrastructure/ServiceCollectionMap.cs @@ -0,0 +1,597 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Utilities; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.EntityFrameworkCore.Infrastructure +{ + /// + /// + /// Prvoides a map over a that allows + /// entries to be conditionally added or re-written without requiring linear scans of the service + /// collection each time this is done. + /// + /// + /// Database providers are expected to create an instance of this around the service collection passed + /// to their 'Add...' method and then use the methods of this class to add services. + /// + /// + /// Note that the collection should not be modified without in other ways while it is being managed + /// by the map. The collection can be used in the normal way after modifications using the map have + /// been completed. + /// + /// + public class ServiceCollectionMap + { + private readonly IServiceCollection _serviceCollection; + private readonly IDictionary> _serviceMap = new Dictionary>(); + + /// + /// Creates a new to operate on the given . + /// + /// The collection to work with. + public ServiceCollectionMap([NotNull] IServiceCollection serviceCollection) + { + Check.NotNull(serviceCollection, nameof(serviceCollection)); + + _serviceCollection = serviceCollection; + + var index = 0; + foreach (var descriptor in serviceCollection) + { + GetOrCreateDescriptorIndexes(descriptor.ServiceType).Add(index++); + } + } + + private IList GetOrCreateDescriptorIndexes(Type serviceType) + { + IList indexes; + if (!_serviceMap.TryGetValue(serviceType, out indexes)) + { + indexes = new List(); + _serviceMap[serviceType] = indexes; + } + return indexes; + } + + /// + /// The underlying . + /// + public virtual IServiceCollection ServiceCollection => _serviceCollection; + + /// + /// Adds a service implemented by the given concrete + /// type if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddTransient() + where TService : class + where TImplementation : class, TService + => TryAdd(typeof(TService), typeof(TImplementation), ServiceLifetime.Transient); + + /// + /// Adds a service implemented by the given concrete + /// type if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddScoped() + where TService : class + where TImplementation : class, TService + => TryAdd(typeof(TService), typeof(TImplementation), ServiceLifetime.Scoped); + + /// + /// Adds a service implemented by the given concrete + /// type if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingleton() + where TService : class + where TImplementation : class, TService + => TryAdd(typeof(TService), typeof(TImplementation), ServiceLifetime.Singleton); + + /// + /// Adds a service implemented by the given concrete + /// type if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddTransient([NotNull] Type serviceType, [NotNull] Type implementationType) + => TryAdd(serviceType, implementationType, ServiceLifetime.Transient); + + /// + /// Adds a service implemented by the given concrete + /// type if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddScoped([NotNull] Type serviceType, [NotNull] Type implementationType) + => TryAdd(serviceType, implementationType, ServiceLifetime.Scoped); + + /// + /// Adds a service implemented by the given concrete + /// type if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingleton([NotNull] Type serviceType, [NotNull] Type implementationType) + => TryAdd(serviceType, implementationType, ServiceLifetime.Singleton); + + private ServiceCollectionMap TryAdd(Type serviceType, Type implementationType, ServiceLifetime lifetime) + { + Check.NotNull(serviceType, nameof(serviceType)); + Check.NotNull(implementationType, nameof(implementationType)); + + var indexes = GetOrCreateDescriptorIndexes(serviceType); + if (!indexes.Any()) + { + AddNewDescriptor(indexes, new ServiceDescriptor(serviceType, implementationType, lifetime)); + } + + return this; + } + + /// + /// Adds a service implemented by the given factory + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The factory that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddTransient([NotNull] Func factory) + where TService : class + => TryAdd(typeof(TService), factory, ServiceLifetime.Transient); + + /// + /// Adds a service implemented by the given factory + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The factory that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddScoped([NotNull] Func factory) + where TService : class + => TryAdd(typeof(TService), factory, ServiceLifetime.Scoped); + + /// + /// Adds a service implemented by the given factory + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The factory that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingleton([NotNull] Func factory) + where TService : class + => TryAdd(typeof(TService), factory, ServiceLifetime.Singleton); + + /// + /// Adds a service implemented by the given factory + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The concrete type that the given factory creates. + /// The factory that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddTransient( + [NotNull] Func factory) + where TService : class + where TImplementation : class, TService + => TryAdd(typeof(TService), factory, ServiceLifetime.Transient); + + /// + /// Adds a service implemented by the given factory + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The concrete type that the given factory creates. + /// The factory that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddScoped( + [NotNull] Func factory) + where TService : class + where TImplementation : class, TService + => TryAdd(typeof(TService), factory, ServiceLifetime.Scoped); + + /// + /// Adds a service implemented by the given factory + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The concrete type that the given factory creates. + /// The factory that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingleton( + [NotNull] Func factory) + where TService : class + where TImplementation : class, TService + => TryAdd(typeof(TService), factory, ServiceLifetime.Singleton); + + /// + /// Adds a service implemented by the given factory + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The factory that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddTransient([NotNull] Type serviceType, [NotNull] Func factory) + => TryAdd(serviceType, factory, ServiceLifetime.Transient); + + /// + /// Adds a service implemented by the given factory + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The factory that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddScoped([NotNull] Type serviceType, [NotNull] Func factory) + => TryAdd(serviceType, factory, ServiceLifetime.Scoped); + + /// + /// Adds a service implemented by the given factory + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The factory that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingleton([NotNull] Type serviceType, [NotNull] Func factory) + => TryAdd(serviceType, factory, ServiceLifetime.Singleton); + + private ServiceCollectionMap TryAdd(Type serviceType, Func factory, ServiceLifetime lifetime) + { + Check.NotNull(serviceType, nameof(serviceType)); + Check.NotNull(factory, nameof(factory)); + + var indexes = GetOrCreateDescriptorIndexes(serviceType); + if (!indexes.Any()) + { + AddNewDescriptor(indexes, new ServiceDescriptor(serviceType, factory, lifetime)); + } + + return this; + } + + /// + /// Adds a service implemented by the given instance + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The object that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingleton([CanBeNull] TService implementation) + where TService : class + => TryAdd(typeof(TService), implementation); + + /// + /// Adds a service implemented by the given instance + /// if no service for the given service type has already been registered. + /// + /// The contract for the service. + /// The object that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingleton([NotNull] Type serviceType, [CanBeNull] object implementation) + => TryAdd(serviceType, implementation); + + private ServiceCollectionMap TryAdd(Type serviceType, object implementation) + { + Check.NotNull(serviceType, nameof(serviceType)); + + var indexes = GetOrCreateDescriptorIndexes(serviceType); + if (!indexes.Any()) + { + AddNewDescriptor(indexes, new ServiceDescriptor(serviceType, implementation)); + } + + return this; + } + + /// + /// Adds a service implemented by the given concrete + /// type to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddTransientEnumerable() + where TService : class + where TImplementation : class, TService + => TryAddEnumerable(typeof(TService), typeof(TImplementation), ServiceLifetime.Transient); + + /// + /// Adds a service implemented by the given concrete + /// type to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddScopedEnumerable() + where TService : class + where TImplementation : class, TService + => TryAddEnumerable(typeof(TService), typeof(TImplementation), ServiceLifetime.Scoped); + + /// + /// Adds a service implemented by the given concrete + /// type to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingletonEnumerable() + where TService : class + where TImplementation : class, TService + => TryAddEnumerable(typeof(TService), typeof(TImplementation), ServiceLifetime.Singleton); + + /// + /// Adds a service implemented by the given concrete + /// type to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddTransientEnumerable([NotNull] Type serviceType, [NotNull] Type implementationType) + => TryAddEnumerable(serviceType, implementationType, ServiceLifetime.Transient); + + /// + /// Adds a service implemented by the given concrete + /// type to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddScopedEnumerable([NotNull] Type serviceType, [NotNull] Type implementationType) + => TryAddEnumerable(serviceType, implementationType, ServiceLifetime.Scoped); + + /// + /// Adds a service implemented by the given concrete + /// type to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingletonEnumerable([NotNull] Type serviceType, [NotNull] Type implementationType) + => TryAddEnumerable(serviceType, implementationType, ServiceLifetime.Singleton); + + private ServiceCollectionMap TryAddEnumerable(Type serviceType, Type implementationType, ServiceLifetime lifetime) + { + Check.NotNull(serviceType, nameof(serviceType)); + Check.NotNull(implementationType, nameof(implementationType)); + + var indexes = GetOrCreateDescriptorIndexes(serviceType); + if (indexes.All(i => TryGetImplementationType(_serviceCollection[i]) != implementationType)) + { + AddNewDescriptor(indexes, new ServiceDescriptor(serviceType, implementationType, lifetime)); + } + + return this; + } + + /// + /// Adds a service implemented by the given factory + /// to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The factory that implements this service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddTransientEnumerable( + [NotNull] Func factory) + where TService : class + where TImplementation : class, TService + => TryAddEnumerable(factory, ServiceLifetime.Transient); + + /// + /// Adds a service implemented by the given factory + /// to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The factory that implements this service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddScopedEnumerable( + [NotNull] Func factory) + where TService : class + where TImplementation : class, TService + => TryAddEnumerable(factory, ServiceLifetime.Scoped); + + /// + /// Adds a service implemented by the given factory + /// to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The concrete type that implements the service. + /// The factory that implements this service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingletonEnumerable( + [NotNull] Func factory) + where TService : class + where TImplementation : class, TService + => TryAddEnumerable(factory, ServiceLifetime.Singleton); + + private ServiceCollectionMap TryAddEnumerable( + [NotNull] Func factory, ServiceLifetime lifetime) + where TService : class + where TImplementation : class, TService + { + Check.NotNull(factory, nameof(factory)); + + var indexes = GetOrCreateDescriptorIndexes(typeof(TService)); + if (indexes.All(i => TryGetImplementationType(_serviceCollection[i]) != typeof(TImplementation))) + { + AddNewDescriptor(indexes, new ServiceDescriptor(typeof(TService), factory, lifetime)); + } + + return this; + } + + /// + /// Adds a service implemented by the given instance + /// to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The object that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingletonEnumerable([NotNull] TService implementation) + where TService : class + => TryAddEnumerable(typeof(TService), implementation); + + /// + /// Adds a service implemented by the given instance + /// to ths list of services that implement the given contract. The service is only added + /// if the collection contains no other registration for the same service and implementation type. + /// + /// The contract for the service. + /// The object that implements the service. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap TryAddSingletonEnumerable([NotNull] Type serviceType, [NotNull] object implementation) + => TryAddEnumerable(serviceType, implementation); + + private ServiceCollectionMap TryAddEnumerable(Type serviceType, object implementation) + { + Check.NotNull(serviceType, nameof(serviceType)); + Check.NotNull(implementation, nameof(implementation)); + + var implementationType = implementation.GetType(); + + var indexes = GetOrCreateDescriptorIndexes(serviceType); + if (indexes.All(i => TryGetImplementationType(_serviceCollection[i]) != implementationType)) + { + AddNewDescriptor(indexes, new ServiceDescriptor(serviceType, implementation)); + } + + return this; + } + + private Type TryGetImplementationType(ServiceDescriptor descriptor) + => descriptor.ImplementationType + ?? descriptor.ImplementationInstance?.GetType() + // Generic arg on Func may be obejct, but this is the best we can do and matches logic in D.I. container + ?? descriptor.ImplementationFactory?.GetType().GetTypeInfo().GenericTypeArguments[1]; + + private void AddNewDescriptor(IList indexes, ServiceDescriptor newDescriptor) + { + indexes.Add(_serviceCollection.Count); + _serviceCollection.Add(newDescriptor); + } + + /// + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + /// + /// Re-writes the registration for the gicen service such that if the implementation type + /// implements , then + /// will be called while resolving + /// the service allowing additional services to be injected without breaking the existing + /// constructor. + /// + /// + /// This mechanism should only be used to allow new services to be injected in a patch or + /// point release without making binary breaking changes. + /// + /// + /// The service contract. + /// The map, such that further calls can be chained. + public virtual ServiceCollectionMap DoPatchInjection() + where TService : class + { + IList indexes; + if (_serviceMap.TryGetValue(typeof(TService), out indexes)) + { + foreach (var index in indexes) + { + var descriptor = _serviceCollection[index]; + var lifetime = descriptor.Lifetime; + var implementationType = descriptor.ImplementationType; + + if (implementationType != null) + { + var implementationIndexes = GetOrCreateDescriptorIndexes(implementationType); + if (!implementationIndexes.Any()) + { + AddNewDescriptor( + implementationIndexes, + new ServiceDescriptor(implementationType, implementationType, lifetime)); + } + + var injectedDescriptor = new ServiceDescriptor( + typeof(TService), + p => InjectServices(p, implementationType), + lifetime); + + _serviceCollection[index] = injectedDescriptor; + } + else if (descriptor.ImplementationFactory != null) + { + var injectedDescriptor = new ServiceDescriptor( + typeof(TService), + p => InjectServices(p, descriptor.ImplementationFactory), + lifetime); + + _serviceCollection[index] = injectedDescriptor; + } + else + { + var injectedDescriptor = new ServiceDescriptor( + typeof(TService), + p => InjectServices(p, descriptor.ImplementationInstance), + lifetime); + + _serviceCollection[index] = injectedDescriptor; + } + } + } + + return this; + } + + private static object InjectServices(IServiceProvider serviceProvider, Type concreteType) + { + var service = serviceProvider.GetService(concreteType); + + (service as IPatchServiceInjectionSite)?.InjectServices(serviceProvider); + + return service; + } + + private static object InjectServices(IServiceProvider serviceProvider, object service) + { + (service as IPatchServiceInjectionSite)?.InjectServices(serviceProvider); + + return service; + } + + private static object InjectServices(IServiceProvider serviceProvider, Func implementationFactory) + { + var service = implementationFactory(serviceProvider); + + (service as IPatchServiceInjectionSite)?.InjectServices(serviceProvider); + + return service; + } + } +} diff --git a/src/Microsoft.EntityFrameworkCore/Infrastructure/ServiceCollectionProviderInfrastructure.cs b/src/Microsoft.EntityFrameworkCore/Infrastructure/ServiceCollectionProviderInfrastructure.cs index eae0bef871b..53764a89d18 100644 --- a/src/Microsoft.EntityFrameworkCore/Infrastructure/ServiceCollectionProviderInfrastructure.cs +++ b/src/Microsoft.EntityFrameworkCore/Infrastructure/ServiceCollectionProviderInfrastructure.cs @@ -16,7 +16,6 @@ using Microsoft.EntityFrameworkCore.Utilities; using Microsoft.EntityFrameworkCore.ValueGeneration; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Logging; // Intentionally in this namespace since this is for use by other relational providers rather than @@ -35,77 +34,77 @@ public static class ServiceCollectionProviderInfrastructure /// Framework defaults. Relational providers should call /// 'ServiceCollectionRelationalProviderInfrastructure.TryAddDefaultRelationalServices' instead. /// - /// The to add services to. - public static void TryAddDefaultEntityFrameworkServices([NotNull] IServiceCollection serviceCollection) + /// The to add services to. + public static void TryAddDefaultEntityFrameworkServices([NotNull] ServiceCollectionMap serviceCollectionMap) { - Check.NotNull(serviceCollection, nameof(serviceCollection)); + Check.NotNull(serviceCollectionMap, nameof(serviceCollectionMap)); - serviceCollection.TryAddEnumerable(new ServiceCollection() - .AddScoped(p => p.GetService()) - .AddScoped(p => p.GetService()) - .AddScoped(p => p.GetService()) - .AddScoped(p => p.GetService()) - .AddScoped(p => p.GetService()) - .AddScoped(p => p.GetService()) - .AddScoped(p => p.GetService()) - .AddScoped(p => p.GetService())); + serviceCollectionMap + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddSingleton() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped() + .TryAddScoped(typeof(ISensitiveDataLogger<>), typeof(SensitiveDataLogger<>)) + .TryAddScoped(typeof(ILogger<>), typeof(InterceptingLogger<>)) + .TryAddScoped(p => GetContextServices(p).Model) + .TryAddScoped(p => GetContextServices(p).CurrentContext) + .TryAddScoped(p => GetContextServices(p).ContextOptions) + .TryAddScopedEnumerable(p => p.GetService()) + .TryAddScopedEnumerable(p => p.GetService()) + .TryAddScopedEnumerable(p => p.GetService()) + .TryAddScopedEnumerable(p => p.GetService()) + .TryAddScopedEnumerable(p => p.GetService()) + .TryAddScopedEnumerable(p => p.GetService()) + .TryAddScopedEnumerable(p => p.GetService()) + .TryAddScopedEnumerable(p => p.GetService()); - serviceCollection.TryAdd(new ServiceCollection() - .AddMemoryCache() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddSingleton() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped(typeof(ISensitiveDataLogger<>), typeof(SensitiveDataLogger<>)) - .AddScoped(typeof(ILogger<>), typeof(InterceptingLogger<>)) - .AddScoped(p => GetContextServices(p).Model) - .AddScoped(p => GetContextServices(p).CurrentContext) - .AddScoped(p => GetContextServices(p).ContextOptions) - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped() - .AddScoped()); + // Note: does TryAdd on all services + serviceCollectionMap.ServiceCollection.AddMemoryCache(); } private static IDbContextServices GetContextServices(IServiceProvider serviceProvider) diff --git a/src/Microsoft.EntityFrameworkCore/Internal/IPatchServiceInjectionSite.cs b/src/Microsoft.EntityFrameworkCore/Internal/IPatchServiceInjectionSite.cs new file mode 100644 index 00000000000..af9cb3755ad --- /dev/null +++ b/src/Microsoft.EntityFrameworkCore/Internal/IPatchServiceInjectionSite.cs @@ -0,0 +1,21 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using JetBrains.Annotations; + +namespace Microsoft.EntityFrameworkCore.Internal +{ + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public interface IPatchServiceInjectionSite + { + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + void InjectServices([NotNull] IServiceProvider serviceProvider); + } +} diff --git a/src/Microsoft.EntityFrameworkCore/Internal/ServiceProviderCache.cs b/src/Microsoft.EntityFrameworkCore/Internal/ServiceProviderCache.cs index fbaba42bed7..ff38acdedfa 100644 --- a/src/Microsoft.EntityFrameworkCore/Internal/ServiceProviderCache.cs +++ b/src/Microsoft.EntityFrameworkCore/Internal/ServiceProviderCache.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Concurrent; +using System.Collections.Generic; using System.Linq; using JetBrains.Annotations; using Microsoft.EntityFrameworkCore.Infrastructure; @@ -47,37 +48,51 @@ public virtual IServiceProvider GetOrAdd([NotNull] IDbContextOptions options) k => { var services = new ServiceCollection(); - var coreServicesAdded = false; - - foreach (var extension in options.Extensions) - { - if (extension.ApplyServices(services)) - { - coreServicesAdded = true; - } - } - - if (!coreServicesAdded) - { - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(services); - } + ApplyServices(options, services); if (replacedServices != null) { - foreach (var descriptor in services.ToList()) + // For replaced services we use the service collection to obtain the lifetime of + // the service to replace. The replaced services are added to a new collection, after + // which provider and core services are applied. This ensures that any patching happens + // to the replaced service. + var updatedServices = new ServiceCollection(); + foreach (var descriptor in services) { Type replacementType; if (replacedServices.TryGetValue(descriptor.ServiceType, out replacementType)) { - services[services.IndexOf(descriptor)] - = new ServiceDescriptor(descriptor.ServiceType, replacementType, descriptor.Lifetime); + ((IList)updatedServices).Add( + new ServiceDescriptor(descriptor.ServiceType, replacementType, descriptor.Lifetime)); } } + + ApplyServices(options, updatedServices); + services = updatedServices; } return services.BuildServiceProvider(); }); } } + + private static void ApplyServices(IDbContextOptions options, ServiceCollection services) + { + var coreServicesAdded = false; + + foreach (var extension in options.Extensions) + { + if (extension.ApplyServices(services)) + { + coreServicesAdded = true; + } + } + + if (!coreServicesAdded) + { + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices( + new ServiceCollectionMap(services)); + } + } } } diff --git a/src/Microsoft.EntityFrameworkCore/Microsoft.EntityFrameworkCore.csproj b/src/Microsoft.EntityFrameworkCore/Microsoft.EntityFrameworkCore.csproj index 70a0b75530c..b9b6454e087 100644 --- a/src/Microsoft.EntityFrameworkCore/Microsoft.EntityFrameworkCore.csproj +++ b/src/Microsoft.EntityFrameworkCore/Microsoft.EntityFrameworkCore.csproj @@ -151,7 +151,9 @@ + + diff --git a/test/Microsoft.EntityFrameworkCore.InMemory.FunctionalTests/ConfigPatternsInMemoryTest.cs b/test/Microsoft.EntityFrameworkCore.InMemory.FunctionalTests/ConfigPatternsInMemoryTest.cs index 95829b1bdd4..d191a699d57 100644 --- a/test/Microsoft.EntityFrameworkCore.InMemory.FunctionalTests/ConfigPatternsInMemoryTest.cs +++ b/test/Microsoft.EntityFrameworkCore.InMemory.FunctionalTests/ConfigPatternsInMemoryTest.cs @@ -182,7 +182,7 @@ private class NoServicesAndNoConfigBlogContext : DbContext public void Throws_on_attempt_to_use_store_with_no_store_services() { var serviceCollection = new ServiceCollection(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var serviceProvider = serviceCollection.BuildServiceProvider(); Assert.Equal( diff --git a/test/Microsoft.EntityFrameworkCore.InMemory.FunctionalTests/InMemoryServiceCollectionExtensionsTest.cs b/test/Microsoft.EntityFrameworkCore.InMemory.FunctionalTests/InMemoryServiceCollectionExtensionsTest.cs index 19306588902..e66cc200bbc 100644 --- a/test/Microsoft.EntityFrameworkCore.InMemory.FunctionalTests/InMemoryServiceCollectionExtensionsTest.cs +++ b/test/Microsoft.EntityFrameworkCore.InMemory.FunctionalTests/InMemoryServiceCollectionExtensionsTest.cs @@ -1,26 +1,12 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Specification.Tests; -using Microsoft.Extensions.DependencyInjection; -using Xunit; namespace Microsoft.EntityFrameworkCore.InMemory.FunctionalTests { public class InMemoryServiceCollectionExtensionsTest : EntityFrameworkServiceCollectionExtensionsTest { - [Fact] - public void Calling_AddEntityFramework_explicitly_does_not_change_services() - { - var services1 = new ServiceCollection().AddEntityFrameworkInMemoryDatabase(); - var services2 = new ServiceCollection().AddEntityFrameworkInMemoryDatabase(); - - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(services2); - - AssertServicesSame(services1, services2); - } - public InMemoryServiceCollectionExtensionsTest() : base(InMemoryTestHelpers.Instance) { diff --git a/test/Microsoft.EntityFrameworkCore.Relational.Tests/RelationalConnectionTest.cs b/test/Microsoft.EntityFrameworkCore.Relational.Tests/RelationalConnectionTest.cs index 97d80bd12c2..3537c405c83 100644 --- a/test/Microsoft.EntityFrameworkCore.Relational.Tests/RelationalConnectionTest.cs +++ b/test/Microsoft.EntityFrameworkCore.Relational.Tests/RelationalConnectionTest.cs @@ -51,7 +51,7 @@ public void Throws_with_add_when_no_EF_services_use_Database() public void Throws_with_new_when_no_provider_use_Database() { var serviceCollection = new ServiceCollection(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var serviceProvider = serviceCollection.BuildServiceProvider(); var options = new DbContextOptionsBuilder() @@ -70,7 +70,7 @@ public void Throws_with_new_when_no_provider_use_Database() public void Throws_with_add_when_no_provider_use_Database() { var serviceCollection = new ServiceCollection(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var appServiceProivder = serviceCollection .AddDbContext( diff --git a/test/Microsoft.EntityFrameworkCore.Relational.Tests/TestUtilities/FakeProvider/FakeRelationalOptionsExtension.cs b/test/Microsoft.EntityFrameworkCore.Relational.Tests/TestUtilities/FakeProvider/FakeRelationalOptionsExtension.cs index 0dcf8ba0587..47e887f54cd 100644 --- a/test/Microsoft.EntityFrameworkCore.Relational.Tests/TestUtilities/FakeProvider/FakeRelationalOptionsExtension.cs +++ b/test/Microsoft.EntityFrameworkCore.Relational.Tests/TestUtilities/FakeProvider/FakeRelationalOptionsExtension.cs @@ -51,7 +51,7 @@ public static IServiceCollection AddEntityFrameworkRelationalDatabase(IServiceCo .AddScoped(_ => null) .AddScoped(_ => null)); - ServiceCollectionRelationalProviderInfrastructure.TryAddDefaultRelationalServices(serviceCollection); + ServiceCollectionRelationalProviderInfrastructure.TryAddDefaultRelationalServices(new ServiceCollectionMap(serviceCollection)); return serviceCollection; } diff --git a/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/SqlServerConfigPatternsTest.cs b/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/SqlServerConfigPatternsTest.cs index c35bb275df8..766eae6de87 100644 --- a/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/SqlServerConfigPatternsTest.cs +++ b/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/SqlServerConfigPatternsTest.cs @@ -208,7 +208,7 @@ public class ImplicitConfigButNoServices public void Throws_on_attempt_to_use_store_with_no_store_services() { var serviceCollection = new ServiceCollection(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var serviceProvider = serviceCollection.BuildServiceProvider(); using (SqlServerNorthwindContext.GetSharedStore()) diff --git a/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/SqlServerServiceCollectionExtensionsTest.cs b/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/SqlServerServiceCollectionExtensionsTest.cs index 491311507e9..11bdf710238 100644 --- a/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/SqlServerServiceCollectionExtensionsTest.cs +++ b/test/Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests/SqlServerServiceCollectionExtensionsTest.cs @@ -1,26 +1,12 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Specification.Tests; -using Microsoft.Extensions.DependencyInjection; -using Xunit; namespace Microsoft.EntityFrameworkCore.SqlServer.FunctionalTests { public class SqlServerServiceCollectionExtensionsTest : EntityFrameworkServiceCollectionExtensionsTest { - [Fact] - public void Calling_AddEntityFramework_explicitly_does_not_change_services() - { - var services1 = new ServiceCollection().AddEntityFrameworkSqlServer(); - var services2 = new ServiceCollection().AddEntityFrameworkSqlServer(); - - ServiceCollectionRelationalProviderInfrastructure.TryAddDefaultRelationalServices(services2); - - AssertServicesSame(services1, services2); - } - public SqlServerServiceCollectionExtensionsTest() : base(SqlServerTestHelpers.Instance) { diff --git a/test/Microsoft.EntityFrameworkCore.Sqlite.FunctionalTests/SqliteServiceCollectionExtensionsTest.cs b/test/Microsoft.EntityFrameworkCore.Sqlite.FunctionalTests/SqliteServiceCollectionExtensionsTest.cs index c6b67b93d8d..3cdc18fbdd7 100644 --- a/test/Microsoft.EntityFrameworkCore.Sqlite.FunctionalTests/SqliteServiceCollectionExtensionsTest.cs +++ b/test/Microsoft.EntityFrameworkCore.Sqlite.FunctionalTests/SqliteServiceCollectionExtensionsTest.cs @@ -1,26 +1,12 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Specification.Tests; -using Microsoft.Extensions.DependencyInjection; -using Xunit; namespace Microsoft.EntityFrameworkCore.Sqlite.FunctionalTests { public class SqliteServiceCollectionExtensionsTest : EntityFrameworkServiceCollectionExtensionsTest { - [Fact] - public void Calling_AddEntityFramework_explicitly_does_not_change_services() - { - var services1 = new ServiceCollection().AddEntityFrameworkSqlite(); - var services2 = new ServiceCollection().AddEntityFrameworkSqlite(); - - ServiceCollectionRelationalProviderInfrastructure.TryAddDefaultRelationalServices(services2); - - AssertServicesSame(services1, services2); - } - public SqliteServiceCollectionExtensionsTest() : base(SqliteTestHelpers.Instance) { diff --git a/test/Microsoft.EntityFrameworkCore.Tests/DbContextTest.cs b/test/Microsoft.EntityFrameworkCore.Tests/DbContextTest.cs index ac9f94bd7a5..cbc503c8914 100644 --- a/test/Microsoft.EntityFrameworkCore.Tests/DbContextTest.cs +++ b/test/Microsoft.EntityFrameworkCore.Tests/DbContextTest.cs @@ -2071,7 +2071,7 @@ public void Can_start_with_custom_services_by_passing_in_base_service_provider() public void Required_low_level_services_are_added_if_needed() { var serviceCollection = new ServiceCollection(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var provider = serviceCollection.BuildServiceProvider(); @@ -2085,7 +2085,7 @@ public void Required_low_level_services_are_not_added_if_already_present() var loggerFactory = new FakeLoggerFactory(); serviceCollection.AddSingleton(loggerFactory); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var provider = serviceCollection.BuildServiceProvider(); @@ -2098,7 +2098,7 @@ public void Low_level_services_can_be_replaced_after_being_added() var serviceCollection = new ServiceCollection(); var loggerFactory = new FakeLoggerFactory(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); serviceCollection.AddSingleton(loggerFactory); @@ -4339,7 +4339,7 @@ public void Throws_with_add_when_no_EF_services_and_no_sets() public void Throws_with_new_when_no_provider() { var serviceCollection = new ServiceCollection(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var serviceProvider = serviceCollection.BuildServiceProvider(); var options = new DbContextOptionsBuilder() @@ -4358,7 +4358,7 @@ public void Throws_with_new_when_no_provider() public void Throws_with_add_when_no_provider() { var serviceCollection = new ServiceCollection(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var appServiceProivder = serviceCollection .AddDbContext( @@ -4381,7 +4381,7 @@ public void Throws_with_add_when_no_provider() public void Throws_with_new_when_no_provider_and_no_sets() { var serviceCollection = new ServiceCollection(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var serviceProvider = serviceCollection.BuildServiceProvider(); var options = new DbContextOptionsBuilder() @@ -4400,7 +4400,7 @@ public void Throws_with_new_when_no_provider_and_no_sets() public void Throws_with_add_when_no_provider_and_no_sets() { var serviceCollection = new ServiceCollection(); - ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(serviceCollection); + ServiceCollectionProviderInfrastructure.TryAddDefaultEntityFrameworkServices(new ServiceCollectionMap(serviceCollection)); var appServiceProivder = serviceCollection .AddDbContext( diff --git a/test/Microsoft.EntityFrameworkCore.Tests/Infrastructure/ServiceCollectionMapTest.cs b/test/Microsoft.EntityFrameworkCore.Tests/Infrastructure/ServiceCollectionMapTest.cs new file mode 100644 index 00000000000..7a14787e215 --- /dev/null +++ b/test/Microsoft.EntityFrameworkCore.Tests/Infrastructure/ServiceCollectionMapTest.cs @@ -0,0 +1,570 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.EntityFrameworkCore.ChangeTracking.Internal; +using Microsoft.EntityFrameworkCore.Infrastructure; +using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.EntityFrameworkCore.Tests.Infrastructure +{ + public class ServiceCollectionMapTest + { + [Fact] + public void Can_add_delegate_services() + { + Func factory = p => new FakeService(); + + AddServiceDelegateTest(m => m.TryAddTransient(factory), factory, ServiceLifetime.Transient); + AddServiceDelegateTest(m => m.TryAddScoped(factory), factory, ServiceLifetime.Scoped); + AddServiceDelegateTest(m => m.TryAddSingleton(factory), factory, ServiceLifetime.Singleton); + AddServiceDelegateTest(m => m.TryAddTransient(factory), factory, ServiceLifetime.Transient); + AddServiceDelegateTest(m => m.TryAddScoped(factory), factory, ServiceLifetime.Scoped); + AddServiceDelegateTest(m => m.TryAddSingleton(factory), factory, ServiceLifetime.Singleton); + AddServiceDelegateTest(m => m.TryAddTransient(typeof(IFakeService), factory), factory, ServiceLifetime.Transient); + AddServiceDelegateTest(m => m.TryAddScoped(typeof(IFakeService), factory), factory, ServiceLifetime.Scoped); + AddServiceDelegateTest(m => m.TryAddSingleton(typeof(IFakeService), factory), factory, ServiceLifetime.Singleton); + } + + private void AddServiceDelegateTest( + Func adder, + Func factory, + ServiceLifetime lifetime) + { + var serviceCollectionMap = adder(new ServiceCollectionMap(new ServiceCollection())); + + var descriptor = serviceCollectionMap.ServiceCollection.Single(); + + Assert.Same(typeof(IFakeService), descriptor.ServiceType); + Assert.Same(factory, descriptor.ImplementationFactory); + Assert.Equal(lifetime, descriptor.Lifetime); + } + + [Fact] + public void Can_add_concrete_services() + { + AddServiceConcreteTest(m => m.TryAddTransient(), ServiceLifetime.Transient); + AddServiceConcreteTest(m => m.TryAddScoped(), ServiceLifetime.Scoped); + AddServiceConcreteTest(m => m.TryAddSingleton(), ServiceLifetime.Singleton); + AddServiceConcreteTest(m => m.TryAddTransient(typeof(IFakeService), typeof(DerivedFakeService)), ServiceLifetime.Transient); + AddServiceConcreteTest(m => m.TryAddScoped(typeof(IFakeService), typeof(DerivedFakeService)), ServiceLifetime.Scoped); + AddServiceConcreteTest(m => m.TryAddSingleton(typeof(IFakeService), typeof(DerivedFakeService)), ServiceLifetime.Singleton); + } + + private void AddServiceConcreteTest( + Func adder, + ServiceLifetime lifetime) + { + var serviceCollectionMap = adder(new ServiceCollectionMap(new ServiceCollection())); + + var descriptor = serviceCollectionMap.ServiceCollection.Single(); + + Assert.Same(typeof(IFakeService), descriptor.ServiceType); + Assert.Same(typeof(DerivedFakeService), descriptor.ImplementationType); + Assert.Equal(lifetime, descriptor.Lifetime); + } + + [Fact] + public void Can_add_instance_services() + { + var instance = new FakeService(); + + AddServiceInstanceTest(m => m.TryAddSingleton(instance), instance); + AddServiceInstanceTest(m => m.TryAddSingleton(typeof(IFakeService), instance), instance); + } + + private void AddServiceInstanceTest( + Func adder, + object instance) + { + var serviceCollectionMap = adder(new ServiceCollectionMap(new ServiceCollection())); + + var descriptor = serviceCollectionMap.ServiceCollection.Single(); + + Assert.Same(typeof(IFakeService), descriptor.ServiceType); + Assert.Same(instance, descriptor.ImplementationInstance); + Assert.Equal(ServiceLifetime.Singleton, descriptor.Lifetime); + } + + [Fact] + public void Existing_services_are_not_replaced() + { + ExistingServiceTest(m => m.TryAddTransient()); + ExistingServiceTest(m => m.TryAddScoped()); + ExistingServiceTest(m => m.TryAddSingleton()); + ExistingServiceTest(m => m.TryAddTransient(typeof(IFakeService), typeof(FakeService))); + ExistingServiceTest(m => m.TryAddScoped(typeof(IFakeService), typeof(FakeService))); + ExistingServiceTest(m => m.TryAddSingleton(typeof(IFakeService), typeof(FakeService))); + ExistingServiceTest(m => m.TryAddTransient(p => new FakeService())); + ExistingServiceTest(m => m.TryAddScoped(p => new FakeService())); + ExistingServiceTest(m => m.TryAddSingleton(p => new FakeService())); + ExistingServiceTest(m => m.TryAddTransient(p => new FakeService())); + ExistingServiceTest(m => m.TryAddScoped(p => new FakeService())); + ExistingServiceTest(m => m.TryAddSingleton(p => new FakeService())); + ExistingServiceTest(m => m.TryAddTransient(typeof(IFakeService), p => new FakeService())); + ExistingServiceTest(m => m.TryAddScoped(typeof(IFakeService), p => new FakeService())); + ExistingServiceTest(m => m.TryAddSingleton(typeof(IFakeService), p => new FakeService())); + ExistingServiceTest(m => m.TryAddSingleton(new FakeService())); + ExistingServiceTest(m => m.TryAddSingleton(typeof(IFakeService), new FakeService())); + } + + private void ExistingServiceTest(Func adder) + { + var serviceCollection = new ServiceCollection() + .AddSingleton(); + + var descriptor = serviceCollection.Single(); + + var serviceCollectionMap = adder(new ServiceCollectionMap(serviceCollection)); + + Assert.Same(serviceCollection, serviceCollectionMap.ServiceCollection); + Assert.Same(descriptor, serviceCollection.Single()); + } + + [Fact] + public void Can_add_multiple_concrete_services() + { + AddServiceConcreteEnumerableTest( + m => m.TryAddTransientEnumerable(), + m => m.TryAddTransientEnumerable(), + ServiceLifetime.Transient); + + AddServiceConcreteEnumerableTest( + m => m.TryAddScopedEnumerable(), + m => m.TryAddScopedEnumerable(), + ServiceLifetime.Scoped); + + AddServiceConcreteEnumerableTest( + m => m.TryAddSingletonEnumerable(), + m => m.TryAddSingletonEnumerable(), + ServiceLifetime.Singleton); + + AddServiceConcreteEnumerableTest( + m => m.TryAddTransientEnumerable(typeof(IFakeService), typeof(FakeService)), + m => m.TryAddTransientEnumerable(typeof(IFakeService), typeof(DerivedFakeService)), + ServiceLifetime.Transient); + + AddServiceConcreteEnumerableTest( + m => m.TryAddScopedEnumerable(typeof(IFakeService), typeof(FakeService)), + m => m.TryAddScopedEnumerable(typeof(IFakeService), typeof(DerivedFakeService)), + ServiceLifetime.Scoped); + + AddServiceConcreteEnumerableTest( + m => m.TryAddSingletonEnumerable(typeof(IFakeService), typeof(FakeService)), + m => m.TryAddSingletonEnumerable(typeof(IFakeService), typeof(DerivedFakeService)), + ServiceLifetime.Singleton); + } + + private void AddServiceConcreteEnumerableTest( + Func adder1, + Func adder2, + ServiceLifetime lifetime) + { + var serviceCollection = new ServiceCollection(); + adder2(adder1(adder2(adder1(new ServiceCollectionMap(serviceCollection))))); + + Assert.Equal(2, serviceCollection.Count); + + Assert.Same(typeof(IFakeService), serviceCollection[0].ServiceType); + Assert.Same(typeof(FakeService), serviceCollection[0].ImplementationType); + Assert.Equal(lifetime, serviceCollection[0].Lifetime); + + Assert.Same(typeof(IFakeService), serviceCollection[1].ServiceType); + Assert.Same(typeof(DerivedFakeService), serviceCollection[1].ImplementationType); + Assert.Equal(lifetime, serviceCollection[1].Lifetime); + } + + [Fact] + public void Can_add_multiple_delegate_services() + { + Func factory1 = p => new FakeService(); + Func factory2 = p => new DerivedFakeService(); + + AddServiceDelegateEnumerableTest( + m => m.TryAddTransientEnumerable(factory1), + m => m.TryAddTransientEnumerable(factory2), + factory1, factory2, ServiceLifetime.Transient); + + AddServiceDelegateEnumerableTest( + m => m.TryAddScopedEnumerable(factory1), + m => m.TryAddScopedEnumerable(factory2), + factory1, factory2, ServiceLifetime.Scoped); + + AddServiceDelegateEnumerableTest( + m => m.TryAddSingletonEnumerable(factory1), + m => m.TryAddSingletonEnumerable(factory2), + factory1, factory2, ServiceLifetime.Singleton); + } + + private void AddServiceDelegateEnumerableTest( + Func adder1, + Func adder2, + Func factory1, + Func factory2, + ServiceLifetime lifetime) + { + var serviceCollection = new ServiceCollection(); + adder2(adder1(adder2(adder1(new ServiceCollectionMap(serviceCollection))))); + + Assert.Equal(2, serviceCollection.Count); + + Assert.Same(typeof(IFakeService), serviceCollection[0].ServiceType); + Assert.Same(factory1, serviceCollection[0].ImplementationFactory); + Assert.Equal(lifetime, serviceCollection[0].Lifetime); + + Assert.Same(typeof(IFakeService), serviceCollection[1].ServiceType); + Assert.Same(factory2, serviceCollection[1].ImplementationFactory); + Assert.Equal(lifetime, serviceCollection[1].Lifetime); + } + + [Fact] + public void Can_add_multiple_instance_services() + { + var instance1 = new FakeService(); + var instance2 = new DerivedFakeService(); + + AddServiceInstanceEnumerableTest( + m => m.TryAddSingletonEnumerable(instance1), + m => m.TryAddSingletonEnumerable(instance2), + instance1, instance2); + + AddServiceInstanceEnumerableTest( + m => m.TryAddSingletonEnumerable(typeof(IFakeService), instance1), + m => m.TryAddSingletonEnumerable(typeof(IFakeService), instance2), + instance1, instance2); + } + + private void AddServiceInstanceEnumerableTest( + Func adder1, + Func adder2, + object instance1, + object instance2) + { + var serviceCollection = new ServiceCollection(); + adder2(adder1(adder2(adder1(new ServiceCollectionMap(serviceCollection))))); + + Assert.Equal(2, serviceCollection.Count); + + Assert.Same(typeof(IFakeService), serviceCollection[0].ServiceType); + Assert.Same(instance1, serviceCollection[0].ImplementationInstance); + Assert.Equal(ServiceLifetime.Singleton, serviceCollection[0].Lifetime); + + Assert.Same(typeof(IFakeService), serviceCollection[1].ServiceType); + Assert.Same(instance2, serviceCollection[1].ImplementationInstance); + Assert.Equal(ServiceLifetime.Singleton, serviceCollection[1].Lifetime); + } + + [Fact] + public void Can_patch_transient_service_with_concrete_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddTransient(); + serviceMap.DoPatchInjection(); + + Can_patch_transient_service(serviceMap); + } + + [Fact] + public void Can_patch_transient_service_with_delegate_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddTransient(p => new FakeService()); + serviceMap.DoPatchInjection(); + + Can_patch_transient_service(serviceMap); + } + + [Fact] + public void Can_patch_transient_service_with_service_typed_delegate_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddTransient(p => new FakeService()); + serviceMap.DoPatchInjection(); + + Can_patch_transient_service(serviceMap); + } + + [Fact] + public void Can_patch_transient_service_with_untyped_delegate_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddTransient(typeof(IFakeService), p => new FakeService()); + serviceMap.DoPatchInjection(); + + Can_patch_transient_service(serviceMap); + } + + [Fact] + public void Can_patch_transient_service_with_concrete_implementation_already_registered() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddTransient(); + serviceMap.TryAddTransient(); + serviceMap.DoPatchInjection(); + + Assert.IsType(Can_patch_transient_service(serviceMap)); + } + + private static FakeService Can_patch_transient_service(ServiceCollectionMap serviceMap) + { + var serviceProvider = serviceMap.ServiceCollection.BuildServiceProvider(); + + FakeService service; + + using (var context = CreateContext(serviceProvider)) + { + service = (FakeService)context.GetService(); + Assert.Same(context, service.Context); + Assert.NotSame(service, context.GetService()); + } + + using (var context = CreateContext(serviceProvider)) + { + Assert.Same(context, ((FakeService)context.GetService()).Context); + Assert.NotSame(service, context.GetService()); + } + return service; + } + + [Fact] + public void Can_patch_scoped_service_with_concrete_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddScoped(); + serviceMap.DoPatchInjection(); + + Can_patch_scoped_service(serviceMap); + } + + [Fact] + public void Can_patch_scoped_service_with_delegate_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddScoped(p => new FakeService()); + serviceMap.DoPatchInjection(); + + Can_patch_scoped_service(serviceMap); + } + + [Fact] + public void Can_patch_scoped_service_with_service_typed_delegate_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddScoped(p => new FakeService()); + serviceMap.DoPatchInjection(); + + Can_patch_scoped_service(serviceMap); + } + + [Fact] + public void Can_patch_scoped_service_with_untyped_delegate_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddScoped(typeof(IFakeService), p => new FakeService()); + serviceMap.DoPatchInjection(); + + Can_patch_scoped_service(serviceMap); + } + + [Fact] + public void Can_patch_scoped_service_with_concrete_implementation_already_registered() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddScoped(); + serviceMap.TryAddScoped(); + serviceMap.DoPatchInjection(); + + Assert.IsType(Can_patch_scoped_service(serviceMap)); + } + + private static FakeService Can_patch_scoped_service(ServiceCollectionMap serviceMap) + { + var serviceProvider = serviceMap.ServiceCollection.BuildServiceProvider(); + + FakeService service; + + using (var context = CreateContext(serviceProvider)) + { + service = (FakeService)context.GetService(); + Assert.Same(context, service.Context); + Assert.Same(service, context.GetService()); + } + + using (var context = CreateContext(serviceProvider)) + { + Assert.Same(context, ((FakeService)context.GetService()).Context); + Assert.NotSame(service, context.GetService()); + } + + return service; + } + + [Fact] + public void Can_patch_singleton_service_with_concrete_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddSingleton(); + serviceMap.DoPatchInjection(); + + Can_patch_singleton_service(serviceMap); + } + + [Fact] + public void Can_patch_singleton_service_with_delegate_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddSingleton(p => new FakeSingletonService()); + serviceMap.DoPatchInjection(); + + Can_patch_singleton_service(serviceMap); + } + + [Fact] + public void Can_patch_singleton_service_with_service_typed_delegate_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddSingleton(p => new FakeSingletonService()); + serviceMap.DoPatchInjection(); + + Can_patch_singleton_service(serviceMap); + } + + [Fact] + public void Can_patch_singleton_service_with_untyped_delegate_implementation() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddSingleton(typeof(IFakeSingletonService), p => new FakeSingletonService()); + serviceMap.DoPatchInjection(); + + Can_patch_singleton_service(serviceMap); + } + + [Fact] + public void Can_patch_singleton_service_with_concrete_implementation_already_registered() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddSingleton(); + serviceMap.TryAddSingleton(); + serviceMap.DoPatchInjection(); + + Assert.IsType(Can_patch_singleton_service(serviceMap)); + } + + [Fact] + public void Can_patch_singleton_service_with_instance_registered() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddSingleton(new DerivedFakeSingletonService()); + serviceMap.DoPatchInjection(); + + Assert.IsType(Can_patch_singleton_service(serviceMap)); + } + + [Fact] + public void Can_patch_singleton_service_with_instance_registered_non_generic() + { + var serviceMap = CreateServiceMap(); + + serviceMap.TryAddSingleton(typeof(IFakeSingletonService), new DerivedFakeSingletonService()); + serviceMap.DoPatchInjection(); + + Assert.IsType(Can_patch_singleton_service(serviceMap)); + } + + [Fact] + public virtual void Same_INavigationFixer_is_returned_for_all_registrations() + { + using (var context = new DbContext(new DbContextOptionsBuilder().UseInMemoryDatabase().Options)) + { + var navFixer = context.GetService(); + + Assert.Contains(navFixer, context.GetService>()); + Assert.Contains(navFixer, context.GetService>()); + Assert.Contains(navFixer, context.GetService>()); + Assert.Contains(navFixer, context.GetService>()); + } + } + + private static FakeSingletonService Can_patch_singleton_service(ServiceCollectionMap serviceMap) + { + var serviceProvider = serviceMap.ServiceCollection.BuildServiceProvider(); + + FakeSingletonService singletonService; + + using (var context = CreateContext(serviceProvider)) + { + singletonService = (FakeSingletonService)context.GetService(); + Assert.Same(context.GetService(), singletonService.ModelSource); + Assert.Same(singletonService, context.GetService()); + } + + using (var context = CreateContext(serviceProvider)) + { + Assert.Same(singletonService, context.GetService()); + } + + return singletonService; + } + + private static ServiceCollectionMap CreateServiceMap() + => new ServiceCollectionMap(new ServiceCollection().AddEntityFrameworkInMemoryDatabase()); + + private static DbContext CreateContext(IServiceProvider serviceProvider) + => new DbContext(new DbContextOptionsBuilder() + .UseInternalServiceProvider(serviceProvider) + .UseInMemoryDatabase() + .Options); + + private interface IFakeService + { + } + + private class FakeService : IFakeService, IPatchServiceInjectionSite + { + public DbContext Context { get; private set; } + + void IPatchServiceInjectionSite.InjectServices(IServiceProvider serviceProvider) + => Context = serviceProvider.GetService().Context; + } + + private class DerivedFakeService : FakeService + { + } + + private interface IFakeSingletonService + { + } + + private class FakeSingletonService : IFakeSingletonService, IPatchServiceInjectionSite + { + public IModelSource ModelSource { get; private set; } + + void IPatchServiceInjectionSite.InjectServices(IServiceProvider serviceProvider) + => ModelSource = serviceProvider.GetService(); + } + + private class DerivedFakeSingletonService : FakeSingletonService + { + } + } +} diff --git a/test/Microsoft.EntityFrameworkCore.Tests/Microsoft.EntityFrameworkCore.Tests.csproj b/test/Microsoft.EntityFrameworkCore.Tests/Microsoft.EntityFrameworkCore.Tests.csproj index d4bff802dcc..19a8b10778d 100644 --- a/test/Microsoft.EntityFrameworkCore.Tests/Microsoft.EntityFrameworkCore.Tests.csproj +++ b/test/Microsoft.EntityFrameworkCore.Tests/Microsoft.EntityFrameworkCore.Tests.csproj @@ -105,6 +105,7 @@ +