diff --git a/src/GitHubVulnerabilities2Db/Gallery/ThrowingTelemetryService.cs b/src/GitHubVulnerabilities2Db/Gallery/ThrowingTelemetryService.cs index 49fe225d07..6eb5a27ecc 100644 --- a/src/GitHubVulnerabilities2Db/Gallery/ThrowingTelemetryService.cs +++ b/src/GitHubVulnerabilities2Db/Gallery/ThrowingTelemetryService.cs @@ -66,7 +66,7 @@ public void TrackDownloadCountDecreasedDuringRefresh(string packageId, string pa throw new NotImplementedException(); } - public void TrackDownloadJsonRefreshDuration(long milliseconds) + public void TrackDownloadJsonRefreshDuration(TimeSpan duration) { throw new NotImplementedException(); } @@ -370,5 +370,10 @@ public void TrackVerifyPackageKeyEvent(string packageId, string packageVersion, { throw new NotImplementedException(); } + + public void TrackVulnerabilitiesCacheRefreshDuration(TimeSpan duration) + { + throw new NotImplementedException(); + } } } \ No newline at end of file diff --git a/src/NuGetGallery.Services/PackageManagement/IPackageVulnerabilitiesManagementService.cs b/src/NuGetGallery.Services/PackageManagement/IPackageVulnerabilitiesManagementService.cs index e114bc37af..73770dbe00 100644 --- a/src/NuGetGallery.Services/PackageManagement/IPackageVulnerabilitiesManagementService.cs +++ b/src/NuGetGallery.Services/PackageManagement/IPackageVulnerabilitiesManagementService.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Threading.Tasks; +using System.Linq; using NuGet.Services.Entities; namespace NuGetGallery diff --git a/src/NuGetGallery.Services/Telemetry/ITelemetryService.cs b/src/NuGetGallery.Services/Telemetry/ITelemetryService.cs index 54c852733d..14b9695318 100644 --- a/src/NuGetGallery.Services/Telemetry/ITelemetryService.cs +++ b/src/NuGetGallery.Services/Telemetry/ITelemetryService.cs @@ -16,7 +16,7 @@ public interface ITelemetryService void TrackGetPackageRegistrationDownloadCountFailed(string packageId); - void TrackDownloadJsonRefreshDuration(long milliseconds); + void TrackDownloadJsonRefreshDuration(TimeSpan duration); void TrackDownloadCountDecreasedDuringRefresh(string packageId, string packageVersion, long oldCount, long newCount); @@ -404,5 +404,11 @@ void TrackABTestEvaluated( bool isAuthenticated, int testBucket, int testPercentage); + + /// + /// Track how long it takes to populate the vulnerabilities cache + /// + /// Refresh duration for vulnerabilities cache + void TrackVulnerabilitiesCacheRefreshDuration(TimeSpan duration); } } \ No newline at end of file diff --git a/src/NuGetGallery.Services/Telemetry/TelemetryService.cs b/src/NuGetGallery.Services/Telemetry/TelemetryService.cs index 7d49f8b335..c31c7c29b4 100644 --- a/src/NuGetGallery.Services/Telemetry/TelemetryService.cs +++ b/src/NuGetGallery.Services/Telemetry/TelemetryService.cs @@ -91,6 +91,7 @@ public class Events public const string ABTestEvaluated = "ABTestEvaluated"; public const string PackagePushDisconnect = "PackagePushDisconnect"; public const string SymbolPackagePushDisconnect = "SymbolPackagePushDisconnect"; + public const string VulnerabilitiesCacheRefreshDurationMs = "VulnerabilitiesCacheRefreshDurationMs"; } private readonly IDiagnosticsSource _diagnosticsSource; @@ -260,9 +261,9 @@ public void TrackGetPackageRegistrationDownloadCountFailed(string packageId) }); } - public void TrackDownloadJsonRefreshDuration(long milliseconds) + public void TrackDownloadJsonRefreshDuration(TimeSpan duration) { - TrackMetric(Events.DownloadJsonRefreshDuration, milliseconds, properties => { }); + TrackMetric(Events.DownloadJsonRefreshDuration, duration.TotalMilliseconds, properties => { }); } public void TrackDownloadCountDecreasedDuringRefresh(string packageId, string packageVersion, long oldCount, long newCount) @@ -1103,6 +1104,11 @@ public void TrackSymbolPackagePushDisconnectEvent() TrackMetric(Events.SymbolPackagePushDisconnect, 1, p => { }); } + public void TrackVulnerabilitiesCacheRefreshDuration(TimeSpan duration) + { + TrackMetric(Events.VulnerabilitiesCacheRefreshDurationMs, duration.TotalMilliseconds, properties => { }); + } + /// /// We use instead of /// diff --git a/src/NuGetGallery/App_Start/AppActivator.cs b/src/NuGetGallery/App_Start/AppActivator.cs index 1a437e7fe0..20de1a9c2a 100644 --- a/src/NuGetGallery/App_Start/AppActivator.cs +++ b/src/NuGetGallery/App_Start/AppActivator.cs @@ -15,7 +15,7 @@ using System.Web.Routing; using System.Web.UI; using Elmah; -using Microsoft.WindowsAzure.ServiceRuntime; +using Microsoft.Extensions.DependencyInjection; using NuGetGallery; using NuGetGallery.Configuration; using NuGetGallery.Diagnostics; @@ -270,10 +270,17 @@ private static void BackgroundJobsPostStart(IAppConfiguration configuration) { // Perform initial refresh + schedule new refreshes every 15 minutes HostingEnvironment.QueueBackgroundWorkItem(_ => cloudDownloadCountService.RefreshAsync()); - jobs.Add(new CloudDownloadCountServiceRefreshJob(TimeSpan.FromMinutes(15), cloudDownloadCountService)); + jobs.Add(new CloudDownloadCountServiceRefreshJob(TimeSpan.FromMinutes(15), + cloudDownloadCountService)); } } + // Perform initial refresh for vulnerabilities cache + schedule new refreshes every 30 minutes + var packageVulnerabilitiesCacheService = DependencyResolver.Current.GetService(); + var serviceScopeFactory = DependencyResolver.Current.GetService(); + HostingEnvironment.QueueBackgroundWorkItem(_ => packageVulnerabilitiesCacheService.RefreshCache(serviceScopeFactory)); + jobs.Add(new PackageVulnerabilitiesCacheRefreshJob(TimeSpan.FromMinutes(30), packageVulnerabilitiesCacheService, serviceScopeFactory)); + if (jobs.AnySafe()) { var jobCoordinator = new NuGetJobCoordinator(); diff --git a/src/NuGetGallery/App_Start/DefaultDependenciesModule.cs b/src/NuGetGallery/App_Start/DefaultDependenciesModule.cs index af69ee6d18..96950c4ca3 100644 --- a/src/NuGetGallery/App_Start/DefaultDependenciesModule.cs +++ b/src/NuGetGallery/App_Start/DefaultDependenciesModule.cs @@ -460,6 +460,11 @@ protected override void Load(ContainerBuilder builder) .As() .InstancePerLifetimeScope(); + builder.RegisterType() + .AsSelf() + .As() + .SingleInstance(); + services.AddHttpClient(); services.AddScoped(); diff --git a/src/NuGetGallery/Infrastructure/Jobs/PackageVulnerabilitiesCacheRefreshJob.cs b/src/NuGetGallery/Infrastructure/Jobs/PackageVulnerabilitiesCacheRefreshJob.cs new file mode 100644 index 0000000000..3c866437a8 --- /dev/null +++ b/src/NuGetGallery/Infrastructure/Jobs/PackageVulnerabilitiesCacheRefreshJob.cs @@ -0,0 +1,30 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using WebBackgrounder; + +namespace NuGetGallery +{ + public class PackageVulnerabilitiesCacheRefreshJob : Job + { + private readonly IPackageVulnerabilitiesCacheService _packageVulnerabilitiesCacheService; + private IServiceScopeFactory _serviceScopeFactory; + + public PackageVulnerabilitiesCacheRefreshJob(TimeSpan interval, + IPackageVulnerabilitiesCacheService packageVulnerabilitiesCacheService, + IServiceScopeFactory serviceScopeFactory) + : base("", interval) + { + _packageVulnerabilitiesCacheService = packageVulnerabilitiesCacheService; + _serviceScopeFactory = serviceScopeFactory; + } + + public override Task Execute() + { + return new Task(() => _packageVulnerabilitiesCacheService.RefreshCache(_serviceScopeFactory)); + } + } +} \ No newline at end of file diff --git a/src/NuGetGallery/NuGetGallery.csproj b/src/NuGetGallery/NuGetGallery.csproj index 6c07fd3f05..e574f390f5 100644 --- a/src/NuGetGallery/NuGetGallery.csproj +++ b/src/NuGetGallery/NuGetGallery.csproj @@ -227,6 +227,7 @@ + @@ -314,8 +315,10 @@ + + diff --git a/src/NuGetGallery/Services/CloudDownloadCountService.cs b/src/NuGetGallery/Services/CloudDownloadCountService.cs index 4380558d19..12306a2b7c 100644 --- a/src/NuGetGallery/Services/CloudDownloadCountService.cs +++ b/src/NuGetGallery/Services/CloudDownloadCountService.cs @@ -106,7 +106,7 @@ public async Task RefreshAsync() var stopwatch = Stopwatch.StartNew(); await RefreshCoreAsync(); stopwatch.Stop(); - _telemetryService.TrackDownloadJsonRefreshDuration(stopwatch.ElapsedMilliseconds); + _telemetryService.TrackDownloadJsonRefreshDuration(TimeSpan.FromMilliseconds(stopwatch.ElapsedMilliseconds)); } catch (WebException ex) diff --git a/src/NuGetGallery/Services/IPackageVulnerabilitiesCacheService.cs b/src/NuGetGallery/Services/IPackageVulnerabilitiesCacheService.cs new file mode 100644 index 0000000000..8f10fe113d --- /dev/null +++ b/src/NuGetGallery/Services/IPackageVulnerabilitiesCacheService.cs @@ -0,0 +1,27 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.DependencyInjection; +using NuGet.Services.Entities; + +namespace NuGetGallery +{ + /// + /// This interface is used to implement a basic caching for vulnerabilities querying. + /// /// + public interface IPackageVulnerabilitiesCacheService + { + /// + /// This function is used to get the packages by id dictionary from the cache + /// + IReadOnlyDictionary> GetVulnerabilitiesById(string id); + + /// + /// This function will refresh the cache from the database, to be called at regular intervals + /// + /// The factory which will provide a new service scope for each refresh + void RefreshCache(IServiceScopeFactory serviceScopeFactory); + } +} diff --git a/src/NuGetGallery/Services/PackageVulnerabilitiesCacheService.cs b/src/NuGetGallery/Services/PackageVulnerabilitiesCacheService.cs new file mode 100644 index 0000000000..b305a65542 --- /dev/null +++ b/src/NuGetGallery/Services/PackageVulnerabilitiesCacheService.cs @@ -0,0 +1,107 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Data.Entity; +using System.Diagnostics; +using System.Linq; +using Microsoft.Extensions.DependencyInjection; +using NuGet.Services.Entities; + +namespace NuGetGallery +{ + public class PackageVulnerabilitiesCacheService : IPackageVulnerabilitiesCacheService + { + private IDictionary>> _vulnerabilitiesByIdCache + = new Dictionary>>(); + private readonly object _refreshLock = new object(); + private bool _isRefreshing; + + private readonly ITelemetryService _telemetryService; + + public PackageVulnerabilitiesCacheService(ITelemetryService telemetryService) + { + _telemetryService = telemetryService ?? throw new ArgumentNullException(nameof(telemetryService)); + } + + public IReadOnlyDictionary> GetVulnerabilitiesById(string id) + { + if (string.IsNullOrEmpty(id)) + { + throw new ArgumentException("Must have a value.", nameof(id)); + } + + if (_vulnerabilitiesByIdCache.TryGetValue(id, out var result)) + { + return result; + } + + return null; + } + + public void RefreshCache(IServiceScopeFactory serviceScopeFactory) + { + if (serviceScopeFactory == null) + { + throw new ArgumentNullException(nameof(serviceScopeFactory)); + } + + bool shouldRefresh = false; + lock (_refreshLock) + { + if (!_isRefreshing) + { + _isRefreshing = true; + shouldRefresh = true; + } + } + + if (shouldRefresh) + { + try + { + var stopwatch = Stopwatch.StartNew(); + + // Create a unique service scope for each refresh to ensure a fresh entities context + using (var serviceScope = serviceScopeFactory.CreateScope()) + { + var serviceProvider = serviceScope.ServiceProvider; + var entitiesContext = serviceProvider.GetService(); + + // We need to build a dictionary of dictionaries. Breaking it down: + // - this give us a list of all vulnerable package version ranges + _vulnerabilitiesByIdCache = entitiesContext.Set() + .Include(x => x.Vulnerability) + // - from these we want a list in this format: (, (, )) + // which will allow us to look up the dictionary by id, and return a dictionary of version -> vulnerability + .SelectMany(x => x.Packages.Select(p => new + { PackageId = x.PackageId ?? string.Empty, KeyVulnerability = new { PackageKey = p.Key, x.Vulnerability } })) + .GroupBy(ikv => ikv.PackageId, ikv => ikv.KeyVulnerability) + // - build the outer dictionary, keyed by - each inner dictionary is paired with a time of creation (for cache invalidation) + .ToDictionary(ikv => ikv.Key, + ikv => + ikv.GroupBy(kv => kv.PackageKey, kv => kv.Vulnerability) + // - build the inner dictionaries, all under the same , each keyed by + .ToDictionary(kv => kv.Key, + kv => kv.ToList().AsReadOnly() as IReadOnlyList)); + } + + stopwatch.Stop(); + + _telemetryService.TrackVulnerabilitiesCacheRefreshDuration(TimeSpan.FromMilliseconds(stopwatch.ElapsedMilliseconds)); + } + catch (Exception ex) + { + _telemetryService.TraceException(ex); + } + finally + { + _isRefreshing = false; + } + } + } + } +} \ No newline at end of file diff --git a/src/NuGetGallery/Services/PackageVulnerabilitiesService.cs b/src/NuGetGallery/Services/PackageVulnerabilitiesService.cs index d8b335c019..e0d7759e8c 100644 --- a/src/NuGetGallery/Services/PackageVulnerabilitiesService.cs +++ b/src/NuGetGallery/Services/PackageVulnerabilitiesService.cs @@ -12,36 +12,17 @@ namespace NuGetGallery { public class PackageVulnerabilitiesService : IPackageVulnerabilitiesService { - private readonly IEntitiesContext _entitiesContext; + private readonly IPackageVulnerabilitiesCacheService _packageVulnerabilitiesCacheService; - public PackageVulnerabilitiesService(IEntitiesContext entitiesContext) + public PackageVulnerabilitiesService(IPackageVulnerabilitiesCacheService packageVulnerabilitiesCacheService) { - _entitiesContext = entitiesContext ?? throw new ArgumentNullException(nameof(entitiesContext)); + _packageVulnerabilitiesCacheService = packageVulnerabilitiesCacheService ?? + throw new ArgumentNullException( + nameof(packageVulnerabilitiesCacheService)); } - public IReadOnlyDictionary> GetVulnerabilitiesById(string id) - { - var result = new Dictionary>(); - var packagesMatchingId = _entitiesContext.Packages - .Where(p => p.PackageRegistration != null && p.PackageRegistration.Id == id) - .Include($"{nameof(Package.VulnerablePackageRanges)}.{nameof(VulnerablePackageVersionRange.Vulnerability)}"); - foreach (var package in packagesMatchingId) - { - if (package.VulnerablePackageRanges == null) - { - continue; - } - - if (package.VulnerablePackageRanges.Any()) - { - result.Add(package.Key, - package.VulnerablePackageRanges.Select(vr => vr.Vulnerability).ToList()); - } - } - - return !result.Any() ? null : - result.ToDictionary(kv => kv.Key, kv => kv.Value as IReadOnlyList); - } + public IReadOnlyDictionary> GetVulnerabilitiesById(string id) => + _packageVulnerabilitiesCacheService.GetVulnerabilitiesById(id); public bool IsPackageVulnerable(Package package) { diff --git a/src/VerifyMicrosoftPackage/Fakes/FakeTelemetryService.cs b/src/VerifyMicrosoftPackage/Fakes/FakeTelemetryService.cs index 47f3a21d39..3451277f76 100644 --- a/src/VerifyMicrosoftPackage/Fakes/FakeTelemetryService.cs +++ b/src/VerifyMicrosoftPackage/Fakes/FakeTelemetryService.cs @@ -63,7 +63,7 @@ public void TrackDownloadCountDecreasedDuringRefresh(string packageId, string pa throw new NotImplementedException(); } - public void TrackDownloadJsonRefreshDuration(long milliseconds) + public void TrackDownloadJsonRefreshDuration(TimeSpan duration) { throw new NotImplementedException(); } @@ -372,5 +372,10 @@ public void TrackVerifyPackageKeyEvent(string packageId, string packageVersion, { throw new NotImplementedException(); } + + public void TrackVulnerabilitiesCacheRefreshDuration(TimeSpan duration) + { + throw new NotImplementedException(); + } } } diff --git a/tests/NuGetGallery.Facts/NuGetGallery.Facts.csproj b/tests/NuGetGallery.Facts/NuGetGallery.Facts.csproj index 4c87efaa41..5a13c85526 100644 --- a/tests/NuGetGallery.Facts/NuGetGallery.Facts.csproj +++ b/tests/NuGetGallery.Facts/NuGetGallery.Facts.csproj @@ -103,6 +103,7 @@ + diff --git a/tests/NuGetGallery.Facts/Services/PackageVulnerabilitiesCacheServiceFacts.cs b/tests/NuGetGallery.Facts/Services/PackageVulnerabilitiesCacheServiceFacts.cs new file mode 100644 index 0000000000..04d9dc043b --- /dev/null +++ b/tests/NuGetGallery.Facts/Services/PackageVulnerabilitiesCacheServiceFacts.cs @@ -0,0 +1,184 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Data.Entity; +using System.Linq; +using Microsoft.Extensions.DependencyInjection; +using Moq; +using NuGet.Services.Entities; +using NuGetGallery.Framework; +using Xunit; + +namespace NuGetGallery.Services +{ + public class PackageVulnerabilitiesCacheServiceFacts : TestContainer + { + [Fact] + public void RefreshesVulnerabilitiesCache() + { + // Arrange + var entitiesContext = new Mock(); + entitiesContext.Setup(x => x.Set()).Returns(GetVulnerableRanges()); + var serviceProvider = new Mock(); + serviceProvider.Setup(x => x.GetService(typeof(IEntitiesContext))).Returns(entitiesContext.Object); + var serviceScope = new Mock(); + serviceScope.Setup(x => x.ServiceProvider).Returns(serviceProvider.Object); + serviceScope.Setup(x => x.Dispose()).Verifiable(); + var serviceScopeFactory = new Mock(); + serviceScopeFactory.Setup(x => x.CreateScope()).Returns(serviceScope.Object); + var telemetryService = new Mock(); + telemetryService.Setup(x => x.TrackVulnerabilitiesCacheRefreshDuration(It.IsAny())).Verifiable(); + var cacheService = new PackageVulnerabilitiesCacheService(telemetryService.Object); + cacheService.RefreshCache(serviceScopeFactory.Object); + + // Act + var vulnerabilitiesFoo = cacheService.GetVulnerabilitiesById("Foo"); + var vulnerabilitiesBar = cacheService.GetVulnerabilitiesById("Bar"); + + // Assert + // - ensure telemetry is sent + telemetryService.Verify(x => x.TrackVulnerabilitiesCacheRefreshDuration(It.IsAny()), Times.Once); + // - ensure scope is disposed + serviceScope.Verify(x => x.Dispose(), Times.AtLeastOnce); + // - ensure contants of cache are correct + Assert.Equal(4, vulnerabilitiesFoo.Count); + Assert.Equal(1, vulnerabilitiesFoo[0].Count); + Assert.Equal(1, vulnerabilitiesFoo[1].Count); + Assert.Equal(2, vulnerabilitiesFoo[2].Count); + Assert.Equal(1234, vulnerabilitiesFoo[2][0].GitHubDatabaseKey); + Assert.Equal(5678, vulnerabilitiesFoo[2][1].GitHubDatabaseKey); + Assert.Equal(1, vulnerabilitiesFoo[3].Count); + Assert.Equal(2, vulnerabilitiesBar.Count); + Assert.Equal(1, vulnerabilitiesBar[5].Count); + Assert.Equal(9012, vulnerabilitiesBar[5][0].GitHubDatabaseKey); + Assert.Equal(1, vulnerabilitiesBar[6].Count); + } + + DbSet GetVulnerableRanges() + { + var registrationFoo = new PackageRegistration { Id = "Foo" }; + var registrationBar = new PackageRegistration { Id = "Bar" }; + + var vulnerabilityCriticalFoo = new PackageVulnerability + { + AdvisoryUrl = "http://theurl/1234", + GitHubDatabaseKey = 1234, + Severity = PackageVulnerabilitySeverity.Critical + }; + var vulnerabilityModerateFoo = new PackageVulnerability + { + AdvisoryUrl = "http://theurl/5678", + GitHubDatabaseKey = 5678, + Severity = PackageVulnerabilitySeverity.Moderate + }; + var vulnerabilityCriticalBar = new PackageVulnerability + { + AdvisoryUrl = "http://theurl/9012", + GitHubDatabaseKey = 9012, + Severity = PackageVulnerabilitySeverity.Critical + }; + + var versionRangeCriticalFoo = new VulnerablePackageVersionRange + { + Vulnerability = vulnerabilityCriticalFoo, + PackageId = "Foo", + PackageVersionRange = "1.1.1", + FirstPatchedPackageVersion = "1.1.2" + }; + var versionRangeModerateFoo = new VulnerablePackageVersionRange + { + Vulnerability = vulnerabilityModerateFoo, + PackageId = "Foo", + PackageVersionRange = "<=1.1.2", + FirstPatchedPackageVersion = "1.1.3" + }; + var versionRangeCriticalBar = new VulnerablePackageVersionRange + { + Vulnerability = vulnerabilityCriticalBar, + PackageId = "Bar", + PackageVersionRange = "<=1.1.0", + FirstPatchedPackageVersion = "1.1.1" + }; + + var packageFoo100 = new Package + { + Key = 0, + PackageRegistration = registrationFoo, + Version = "1.0.0", + VulnerablePackageRanges = new List + { + versionRangeModerateFoo + } + }; + var packageFoo110 = new Package + { + Key = 1, + PackageRegistration = registrationFoo, + Version = "1.1.0", + VulnerablePackageRanges = new List + { + versionRangeModerateFoo + } + }; + var packageFoo111 = new Package + { + Key = 2, + PackageRegistration = registrationFoo, + Version = "1.1.1", + VulnerablePackageRanges = new List + { + versionRangeModerateFoo, + versionRangeCriticalFoo + } + }; + var packageFoo112 = new Package + { + Key = 3, + PackageRegistration = registrationFoo, + Version = "1.1.2", + VulnerablePackageRanges = new List + { + versionRangeModerateFoo + } + }; + var packageBar100 = new Package + { + Key = 5, + Version = "1.0.0", + PackageRegistration = registrationBar, + VulnerablePackageRanges = new List + { + versionRangeCriticalBar + } + }; + var packageBar110 = new Package + { + Key = 6, + PackageRegistration = registrationBar, + Version = "1.1.0", + VulnerablePackageRanges = new List + { + versionRangeCriticalBar + } + }; + + versionRangeCriticalFoo.Packages = new List { packageFoo111 }; + versionRangeModerateFoo.Packages = new List { packageFoo100, packageFoo110, packageFoo111, packageFoo112 }; + versionRangeCriticalBar.Packages = new List { packageBar100, packageBar110 }; + + var vulnerableRangeList = new List { versionRangeCriticalFoo, versionRangeModerateFoo, versionRangeCriticalBar }.AsQueryable(); + var vulnerableRangeDbSet = new Mock>(); + + // boilerplate mock DbSet redirects: + vulnerableRangeDbSet.As().Setup(x => x.Provider).Returns(vulnerableRangeList.Provider); + vulnerableRangeDbSet.As().Setup(x => x.Expression).Returns(vulnerableRangeList.Expression); + vulnerableRangeDbSet.As().Setup(x => x.ElementType).Returns(vulnerableRangeList.ElementType); + vulnerableRangeDbSet.As().Setup(x => x.GetEnumerator()).Returns(vulnerableRangeList.GetEnumerator()); + vulnerableRangeDbSet.Setup(x => x.Include(It.IsAny())).Returns(vulnerableRangeDbSet.Object); // bypass includes (which break the test) + + return vulnerableRangeDbSet.Object; + } + } +} diff --git a/tests/NuGetGallery.Facts/Services/PackageVulnerabilitiesServiceFacts.cs b/tests/NuGetGallery.Facts/Services/PackageVulnerabilitiesServiceFacts.cs index 77121a2ab9..4324a2a5b3 100644 --- a/tests/NuGetGallery.Facts/Services/PackageVulnerabilitiesServiceFacts.cs +++ b/tests/NuGetGallery.Facts/Services/PackageVulnerabilitiesServiceFacts.cs @@ -11,146 +11,49 @@ namespace NuGetGallery.Services { public class PackageVulnerabilitiesServiceFacts : TestContainer { - private PackageRegistration _registrationVulnerable; - - private PackageVulnerability _vulnerabilityCritical; - private PackageVulnerability _vulnerabilityModerate; - - private VulnerablePackageVersionRange _versionRangeCritical; - private VulnerablePackageVersionRange _versionRangeModerate; - - private Package _packageVulnerable100; - private Package _packageVulnerable110; - private Package _packageVulnerable111; - private Package _packageVulnerable112; - - private Package _packageNotVulnerable; - - [Fact] - public void GetsVulnerabilitiesOfPackage() - { - // Arrange - SetUp(); - var packages = new[] - { - _packageVulnerable100, - _packageVulnerable110, - _packageVulnerable111, - _packageVulnerable112, - _packageNotVulnerable - }; - var context = GetFakeContext(); - context.Packages.AddRange(packages); - var target = Get(); - - // Act - var vulnerableResult = target.GetVulnerabilitiesById("Vulnerable"); - var notVulnerableResult = target.GetVulnerabilitiesById("NotVulnerable"); - - // Assert - Assert.Equal(3, vulnerableResult.Count); - var vulnerabilitiesFor100 = vulnerableResult[_packageVulnerable100.Key]; - var vulnerabilitiesFor110 = vulnerableResult[_packageVulnerable110.Key]; - var vulnerabilitiesFor111 = vulnerableResult[_packageVulnerable111.Key]; - Assert.Equal(_vulnerabilityModerate, vulnerabilitiesFor100[0]); - Assert.Equal(_vulnerabilityModerate, vulnerabilitiesFor110[0]); - Assert.Equal(_vulnerabilityModerate, vulnerabilitiesFor111[0]); - Assert.Equal(_vulnerabilityCritical, vulnerabilitiesFor111[1]); - - Assert.Null(notVulnerableResult); - } - [Fact] public void GetsVulnerableStatusOfPackage() { // Arrange - SetUp(); - var context = GetFakeContext(); - var target = Get(); - - // Act - var shouldBeVulnerable = target.IsPackageVulnerable(_packageVulnerable100); - var shouldNotBeVulnerable = target.IsPackageVulnerable(_packageNotVulnerable); - - // Assert - Assert.True(shouldBeVulnerable); - Assert.False(shouldNotBeVulnerable); - } - - private void SetUp() - { - _registrationVulnerable = new PackageRegistration { Id = "Vulnerable" }; + var registrationVulnerable = new PackageRegistration { Id = "Vulnerable" }; - _vulnerabilityCritical = new PackageVulnerability - { - AdvisoryUrl = "http://theurl/1234", - GitHubDatabaseKey = 1234, - Severity = PackageVulnerabilitySeverity.Critical - }; - _vulnerabilityModerate = new PackageVulnerability + var vulnerabilityModerate = new PackageVulnerability { AdvisoryUrl = "http://theurl/5678", GitHubDatabaseKey = 5678, Severity = PackageVulnerabilitySeverity.Moderate }; - _versionRangeCritical = new VulnerablePackageVersionRange + var versionRangeModerate = new VulnerablePackageVersionRange { - Vulnerability = _vulnerabilityCritical, - PackageVersionRange = "1.1.1", - FirstPatchedPackageVersion = "1.1.2" - }; - _versionRangeModerate = new VulnerablePackageVersionRange - { - Vulnerability = _vulnerabilityModerate, + Vulnerability = vulnerabilityModerate, PackageVersionRange = "<=1.1.1", FirstPatchedPackageVersion = "1.1.2" }; - _packageVulnerable100 = new Package + var packageVulnerable = new Package { Key = 0, - PackageRegistration = _registrationVulnerable, + PackageRegistration = registrationVulnerable, Version = "1.0.0", - VulnerablePackageRanges = new List - { - _versionRangeModerate - } - }; - _packageVulnerable110 = new Package - { - Key = 1, - PackageRegistration = _registrationVulnerable, - Version = "1.1.0", - VulnerablePackageRanges = new List - { - _versionRangeModerate - } - }; - _packageVulnerable111 = new Package - { - Key = 3, // simulate a different order in db - create a non-contiguous range of rows, even if the range is contiguous - PackageRegistration = _registrationVulnerable, - Version = "1.1.1", - VulnerablePackageRanges = new List - { - _versionRangeModerate, - _versionRangeCritical - } - }; - _packageVulnerable112 = new Package - { - Key = 2, // simulate a different order in db - create a non-contiguous range of rows, even if the range is contiguous - PackageRegistration = _registrationVulnerable, - Version = "1.1.2", - VulnerablePackageRanges = new List() + VulnerablePackageRanges = new List {versionRangeModerate} }; - _packageNotVulnerable = new Package + var packageNotVulnerable = new Package { Key = 4, PackageRegistration = new PackageRegistration { Id = "NotVulnerable" }, VulnerablePackageRanges = new List() - }; + }; + + var target = Get(); + + // Act + var shouldBeVulnerable = target.IsPackageVulnerable(packageVulnerable); + var shouldNotBeVulnerable = target.IsPackageVulnerable(packageNotVulnerable); + + // Assert + Assert.True(shouldBeVulnerable); + Assert.False(shouldNotBeVulnerable); } } } diff --git a/tests/NuGetGallery.Facts/Services/TelemetryServiceFacts.cs b/tests/NuGetGallery.Facts/Services/TelemetryServiceFacts.cs index 7cc3362273..2fc8ea2e84 100644 --- a/tests/NuGetGallery.Facts/Services/TelemetryServiceFacts.cs +++ b/tests/NuGetGallery.Facts/Services/TelemetryServiceFacts.cs @@ -61,7 +61,7 @@ public static IEnumerable TrackMetricNames_Data }; yield return new object[] { "DownloadJsonRefreshDuration", - (TrackAction)(s => s.TrackDownloadJsonRefreshDuration(0)) + (TrackAction)(s => s.TrackDownloadJsonRefreshDuration(TimeSpan.FromMilliseconds(0))) }; yield return new object[] { "DownloadCountDecreasedDuringRefresh", @@ -343,6 +343,10 @@ public static IEnumerable TrackMetricNames_Data yield return new object[] { "SymbolPackagePushDisconnect", (TrackAction)(s => s.TrackSymbolPackagePushDisconnectEvent()) }; + + yield return new object[] { "VulnerabilitiesCacheRefreshDurationMs", + (TrackAction)(s => s.TrackVulnerabilitiesCacheRefreshDuration(TimeSpan.FromMilliseconds(0))) + }; } }