diff --git a/src/GitHubVulnerabilities2Db/Gallery/ThrowingTelemetryService.cs b/src/GitHubVulnerabilities2Db/Gallery/ThrowingTelemetryService.cs
index 49fe225d07..6eb5a27ecc 100644
--- a/src/GitHubVulnerabilities2Db/Gallery/ThrowingTelemetryService.cs
+++ b/src/GitHubVulnerabilities2Db/Gallery/ThrowingTelemetryService.cs
@@ -66,7 +66,7 @@ public void TrackDownloadCountDecreasedDuringRefresh(string packageId, string pa
throw new NotImplementedException();
}
- public void TrackDownloadJsonRefreshDuration(long milliseconds)
+ public void TrackDownloadJsonRefreshDuration(TimeSpan duration)
{
throw new NotImplementedException();
}
@@ -370,5 +370,10 @@ public void TrackVerifyPackageKeyEvent(string packageId, string packageVersion,
{
throw new NotImplementedException();
}
+
+ public void TrackVulnerabilitiesCacheRefreshDuration(TimeSpan duration)
+ {
+ throw new NotImplementedException();
+ }
}
}
\ No newline at end of file
diff --git a/src/NuGetGallery.Services/PackageManagement/IPackageVulnerabilitiesManagementService.cs b/src/NuGetGallery.Services/PackageManagement/IPackageVulnerabilitiesManagementService.cs
index e114bc37af..73770dbe00 100644
--- a/src/NuGetGallery.Services/PackageManagement/IPackageVulnerabilitiesManagementService.cs
+++ b/src/NuGetGallery.Services/PackageManagement/IPackageVulnerabilitiesManagementService.cs
@@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Threading.Tasks;
+using System.Linq;
using NuGet.Services.Entities;
namespace NuGetGallery
diff --git a/src/NuGetGallery.Services/Telemetry/ITelemetryService.cs b/src/NuGetGallery.Services/Telemetry/ITelemetryService.cs
index 54c852733d..14b9695318 100644
--- a/src/NuGetGallery.Services/Telemetry/ITelemetryService.cs
+++ b/src/NuGetGallery.Services/Telemetry/ITelemetryService.cs
@@ -16,7 +16,7 @@ public interface ITelemetryService
void TrackGetPackageRegistrationDownloadCountFailed(string packageId);
- void TrackDownloadJsonRefreshDuration(long milliseconds);
+ void TrackDownloadJsonRefreshDuration(TimeSpan duration);
void TrackDownloadCountDecreasedDuringRefresh(string packageId, string packageVersion, long oldCount, long newCount);
@@ -404,5 +404,11 @@ void TrackABTestEvaluated(
bool isAuthenticated,
int testBucket,
int testPercentage);
+
+ ///
+ /// Track how long it takes to populate the vulnerabilities cache
+ ///
+ /// Refresh duration for vulnerabilities cache
+ void TrackVulnerabilitiesCacheRefreshDuration(TimeSpan duration);
}
}
\ No newline at end of file
diff --git a/src/NuGetGallery.Services/Telemetry/TelemetryService.cs b/src/NuGetGallery.Services/Telemetry/TelemetryService.cs
index 7d49f8b335..c31c7c29b4 100644
--- a/src/NuGetGallery.Services/Telemetry/TelemetryService.cs
+++ b/src/NuGetGallery.Services/Telemetry/TelemetryService.cs
@@ -91,6 +91,7 @@ public class Events
public const string ABTestEvaluated = "ABTestEvaluated";
public const string PackagePushDisconnect = "PackagePushDisconnect";
public const string SymbolPackagePushDisconnect = "SymbolPackagePushDisconnect";
+ public const string VulnerabilitiesCacheRefreshDurationMs = "VulnerabilitiesCacheRefreshDurationMs";
}
private readonly IDiagnosticsSource _diagnosticsSource;
@@ -260,9 +261,9 @@ public void TrackGetPackageRegistrationDownloadCountFailed(string packageId)
});
}
- public void TrackDownloadJsonRefreshDuration(long milliseconds)
+ public void TrackDownloadJsonRefreshDuration(TimeSpan duration)
{
- TrackMetric(Events.DownloadJsonRefreshDuration, milliseconds, properties => { });
+ TrackMetric(Events.DownloadJsonRefreshDuration, duration.TotalMilliseconds, properties => { });
}
public void TrackDownloadCountDecreasedDuringRefresh(string packageId, string packageVersion, long oldCount, long newCount)
@@ -1103,6 +1104,11 @@ public void TrackSymbolPackagePushDisconnectEvent()
TrackMetric(Events.SymbolPackagePushDisconnect, 1, p => { });
}
+ public void TrackVulnerabilitiesCacheRefreshDuration(TimeSpan duration)
+ {
+ TrackMetric(Events.VulnerabilitiesCacheRefreshDurationMs, duration.TotalMilliseconds, properties => { });
+ }
+
///
/// We use instead of
///
diff --git a/src/NuGetGallery/App_Start/AppActivator.cs b/src/NuGetGallery/App_Start/AppActivator.cs
index 1a437e7fe0..20de1a9c2a 100644
--- a/src/NuGetGallery/App_Start/AppActivator.cs
+++ b/src/NuGetGallery/App_Start/AppActivator.cs
@@ -15,7 +15,7 @@
using System.Web.Routing;
using System.Web.UI;
using Elmah;
-using Microsoft.WindowsAzure.ServiceRuntime;
+using Microsoft.Extensions.DependencyInjection;
using NuGetGallery;
using NuGetGallery.Configuration;
using NuGetGallery.Diagnostics;
@@ -270,10 +270,17 @@ private static void BackgroundJobsPostStart(IAppConfiguration configuration)
{
// Perform initial refresh + schedule new refreshes every 15 minutes
HostingEnvironment.QueueBackgroundWorkItem(_ => cloudDownloadCountService.RefreshAsync());
- jobs.Add(new CloudDownloadCountServiceRefreshJob(TimeSpan.FromMinutes(15), cloudDownloadCountService));
+ jobs.Add(new CloudDownloadCountServiceRefreshJob(TimeSpan.FromMinutes(15),
+ cloudDownloadCountService));
}
}
+ // Perform initial refresh for vulnerabilities cache + schedule new refreshes every 30 minutes
+ var packageVulnerabilitiesCacheService = DependencyResolver.Current.GetService();
+ var serviceScopeFactory = DependencyResolver.Current.GetService();
+ HostingEnvironment.QueueBackgroundWorkItem(_ => packageVulnerabilitiesCacheService.RefreshCache(serviceScopeFactory));
+ jobs.Add(new PackageVulnerabilitiesCacheRefreshJob(TimeSpan.FromMinutes(30), packageVulnerabilitiesCacheService, serviceScopeFactory));
+
if (jobs.AnySafe())
{
var jobCoordinator = new NuGetJobCoordinator();
diff --git a/src/NuGetGallery/App_Start/DefaultDependenciesModule.cs b/src/NuGetGallery/App_Start/DefaultDependenciesModule.cs
index af69ee6d18..96950c4ca3 100644
--- a/src/NuGetGallery/App_Start/DefaultDependenciesModule.cs
+++ b/src/NuGetGallery/App_Start/DefaultDependenciesModule.cs
@@ -460,6 +460,11 @@ protected override void Load(ContainerBuilder builder)
.As()
.InstancePerLifetimeScope();
+ builder.RegisterType()
+ .AsSelf()
+ .As()
+ .SingleInstance();
+
services.AddHttpClient();
services.AddScoped();
diff --git a/src/NuGetGallery/Infrastructure/Jobs/PackageVulnerabilitiesCacheRefreshJob.cs b/src/NuGetGallery/Infrastructure/Jobs/PackageVulnerabilitiesCacheRefreshJob.cs
new file mode 100644
index 0000000000..3c866437a8
--- /dev/null
+++ b/src/NuGetGallery/Infrastructure/Jobs/PackageVulnerabilitiesCacheRefreshJob.cs
@@ -0,0 +1,30 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.Threading.Tasks;
+using Microsoft.Extensions.DependencyInjection;
+using WebBackgrounder;
+
+namespace NuGetGallery
+{
+ public class PackageVulnerabilitiesCacheRefreshJob : Job
+ {
+ private readonly IPackageVulnerabilitiesCacheService _packageVulnerabilitiesCacheService;
+ private IServiceScopeFactory _serviceScopeFactory;
+
+ public PackageVulnerabilitiesCacheRefreshJob(TimeSpan interval,
+ IPackageVulnerabilitiesCacheService packageVulnerabilitiesCacheService,
+ IServiceScopeFactory serviceScopeFactory)
+ : base("", interval)
+ {
+ _packageVulnerabilitiesCacheService = packageVulnerabilitiesCacheService;
+ _serviceScopeFactory = serviceScopeFactory;
+ }
+
+ public override Task Execute()
+ {
+ return new Task(() => _packageVulnerabilitiesCacheService.RefreshCache(_serviceScopeFactory));
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/NuGetGallery/NuGetGallery.csproj b/src/NuGetGallery/NuGetGallery.csproj
index 6c07fd3f05..e574f390f5 100644
--- a/src/NuGetGallery/NuGetGallery.csproj
+++ b/src/NuGetGallery/NuGetGallery.csproj
@@ -227,6 +227,7 @@
+
@@ -314,8 +315,10 @@
+
+
diff --git a/src/NuGetGallery/Services/CloudDownloadCountService.cs b/src/NuGetGallery/Services/CloudDownloadCountService.cs
index 4380558d19..12306a2b7c 100644
--- a/src/NuGetGallery/Services/CloudDownloadCountService.cs
+++ b/src/NuGetGallery/Services/CloudDownloadCountService.cs
@@ -106,7 +106,7 @@ public async Task RefreshAsync()
var stopwatch = Stopwatch.StartNew();
await RefreshCoreAsync();
stopwatch.Stop();
- _telemetryService.TrackDownloadJsonRefreshDuration(stopwatch.ElapsedMilliseconds);
+ _telemetryService.TrackDownloadJsonRefreshDuration(TimeSpan.FromMilliseconds(stopwatch.ElapsedMilliseconds));
}
catch (WebException ex)
diff --git a/src/NuGetGallery/Services/IPackageVulnerabilitiesCacheService.cs b/src/NuGetGallery/Services/IPackageVulnerabilitiesCacheService.cs
new file mode 100644
index 0000000000..8f10fe113d
--- /dev/null
+++ b/src/NuGetGallery/Services/IPackageVulnerabilitiesCacheService.cs
@@ -0,0 +1,27 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.Collections.Generic;
+using Microsoft.Extensions.DependencyInjection;
+using NuGet.Services.Entities;
+
+namespace NuGetGallery
+{
+ ///
+ /// This interface is used to implement a basic caching for vulnerabilities querying.
+ /// ///
+ public interface IPackageVulnerabilitiesCacheService
+ {
+ ///
+ /// This function is used to get the packages by id dictionary from the cache
+ ///
+ IReadOnlyDictionary> GetVulnerabilitiesById(string id);
+
+ ///
+ /// This function will refresh the cache from the database, to be called at regular intervals
+ ///
+ /// The factory which will provide a new service scope for each refresh
+ void RefreshCache(IServiceScopeFactory serviceScopeFactory);
+ }
+}
diff --git a/src/NuGetGallery/Services/PackageVulnerabilitiesCacheService.cs b/src/NuGetGallery/Services/PackageVulnerabilitiesCacheService.cs
new file mode 100644
index 0000000000..b305a65542
--- /dev/null
+++ b/src/NuGetGallery/Services/PackageVulnerabilitiesCacheService.cs
@@ -0,0 +1,107 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Data.Entity;
+using System.Diagnostics;
+using System.Linq;
+using Microsoft.Extensions.DependencyInjection;
+using NuGet.Services.Entities;
+
+namespace NuGetGallery
+{
+ public class PackageVulnerabilitiesCacheService : IPackageVulnerabilitiesCacheService
+ {
+ private IDictionary>> _vulnerabilitiesByIdCache
+ = new Dictionary>>();
+ private readonly object _refreshLock = new object();
+ private bool _isRefreshing;
+
+ private readonly ITelemetryService _telemetryService;
+
+ public PackageVulnerabilitiesCacheService(ITelemetryService telemetryService)
+ {
+ _telemetryService = telemetryService ?? throw new ArgumentNullException(nameof(telemetryService));
+ }
+
+ public IReadOnlyDictionary> GetVulnerabilitiesById(string id)
+ {
+ if (string.IsNullOrEmpty(id))
+ {
+ throw new ArgumentException("Must have a value.", nameof(id));
+ }
+
+ if (_vulnerabilitiesByIdCache.TryGetValue(id, out var result))
+ {
+ return result;
+ }
+
+ return null;
+ }
+
+ public void RefreshCache(IServiceScopeFactory serviceScopeFactory)
+ {
+ if (serviceScopeFactory == null)
+ {
+ throw new ArgumentNullException(nameof(serviceScopeFactory));
+ }
+
+ bool shouldRefresh = false;
+ lock (_refreshLock)
+ {
+ if (!_isRefreshing)
+ {
+ _isRefreshing = true;
+ shouldRefresh = true;
+ }
+ }
+
+ if (shouldRefresh)
+ {
+ try
+ {
+ var stopwatch = Stopwatch.StartNew();
+
+ // Create a unique service scope for each refresh to ensure a fresh entities context
+ using (var serviceScope = serviceScopeFactory.CreateScope())
+ {
+ var serviceProvider = serviceScope.ServiceProvider;
+ var entitiesContext = serviceProvider.GetService();
+
+ // We need to build a dictionary of dictionaries. Breaking it down:
+ // - this give us a list of all vulnerable package version ranges
+ _vulnerabilitiesByIdCache = entitiesContext.Set()
+ .Include(x => x.Vulnerability)
+ // - from these we want a list in this format: (, (, ))
+ // which will allow us to look up the dictionary by id, and return a dictionary of version -> vulnerability
+ .SelectMany(x => x.Packages.Select(p => new
+ { PackageId = x.PackageId ?? string.Empty, KeyVulnerability = new { PackageKey = p.Key, x.Vulnerability } }))
+ .GroupBy(ikv => ikv.PackageId, ikv => ikv.KeyVulnerability)
+ // - build the outer dictionary, keyed by - each inner dictionary is paired with a time of creation (for cache invalidation)
+ .ToDictionary(ikv => ikv.Key,
+ ikv =>
+ ikv.GroupBy(kv => kv.PackageKey, kv => kv.Vulnerability)
+ // - build the inner dictionaries, all under the same , each keyed by
+ .ToDictionary(kv => kv.Key,
+ kv => kv.ToList().AsReadOnly() as IReadOnlyList));
+ }
+
+ stopwatch.Stop();
+
+ _telemetryService.TrackVulnerabilitiesCacheRefreshDuration(TimeSpan.FromMilliseconds(stopwatch.ElapsedMilliseconds));
+ }
+ catch (Exception ex)
+ {
+ _telemetryService.TraceException(ex);
+ }
+ finally
+ {
+ _isRefreshing = false;
+ }
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/NuGetGallery/Services/PackageVulnerabilitiesService.cs b/src/NuGetGallery/Services/PackageVulnerabilitiesService.cs
index d8b335c019..e0d7759e8c 100644
--- a/src/NuGetGallery/Services/PackageVulnerabilitiesService.cs
+++ b/src/NuGetGallery/Services/PackageVulnerabilitiesService.cs
@@ -12,36 +12,17 @@ namespace NuGetGallery
{
public class PackageVulnerabilitiesService : IPackageVulnerabilitiesService
{
- private readonly IEntitiesContext _entitiesContext;
+ private readonly IPackageVulnerabilitiesCacheService _packageVulnerabilitiesCacheService;
- public PackageVulnerabilitiesService(IEntitiesContext entitiesContext)
+ public PackageVulnerabilitiesService(IPackageVulnerabilitiesCacheService packageVulnerabilitiesCacheService)
{
- _entitiesContext = entitiesContext ?? throw new ArgumentNullException(nameof(entitiesContext));
+ _packageVulnerabilitiesCacheService = packageVulnerabilitiesCacheService ??
+ throw new ArgumentNullException(
+ nameof(packageVulnerabilitiesCacheService));
}
- public IReadOnlyDictionary> GetVulnerabilitiesById(string id)
- {
- var result = new Dictionary>();
- var packagesMatchingId = _entitiesContext.Packages
- .Where(p => p.PackageRegistration != null && p.PackageRegistration.Id == id)
- .Include($"{nameof(Package.VulnerablePackageRanges)}.{nameof(VulnerablePackageVersionRange.Vulnerability)}");
- foreach (var package in packagesMatchingId)
- {
- if (package.VulnerablePackageRanges == null)
- {
- continue;
- }
-
- if (package.VulnerablePackageRanges.Any())
- {
- result.Add(package.Key,
- package.VulnerablePackageRanges.Select(vr => vr.Vulnerability).ToList());
- }
- }
-
- return !result.Any() ? null :
- result.ToDictionary(kv => kv.Key, kv => kv.Value as IReadOnlyList);
- }
+ public IReadOnlyDictionary> GetVulnerabilitiesById(string id) =>
+ _packageVulnerabilitiesCacheService.GetVulnerabilitiesById(id);
public bool IsPackageVulnerable(Package package)
{
diff --git a/src/VerifyMicrosoftPackage/Fakes/FakeTelemetryService.cs b/src/VerifyMicrosoftPackage/Fakes/FakeTelemetryService.cs
index 47f3a21d39..3451277f76 100644
--- a/src/VerifyMicrosoftPackage/Fakes/FakeTelemetryService.cs
+++ b/src/VerifyMicrosoftPackage/Fakes/FakeTelemetryService.cs
@@ -63,7 +63,7 @@ public void TrackDownloadCountDecreasedDuringRefresh(string packageId, string pa
throw new NotImplementedException();
}
- public void TrackDownloadJsonRefreshDuration(long milliseconds)
+ public void TrackDownloadJsonRefreshDuration(TimeSpan duration)
{
throw new NotImplementedException();
}
@@ -372,5 +372,10 @@ public void TrackVerifyPackageKeyEvent(string packageId, string packageVersion,
{
throw new NotImplementedException();
}
+
+ public void TrackVulnerabilitiesCacheRefreshDuration(TimeSpan duration)
+ {
+ throw new NotImplementedException();
+ }
}
}
diff --git a/tests/NuGetGallery.Facts/NuGetGallery.Facts.csproj b/tests/NuGetGallery.Facts/NuGetGallery.Facts.csproj
index 4c87efaa41..5a13c85526 100644
--- a/tests/NuGetGallery.Facts/NuGetGallery.Facts.csproj
+++ b/tests/NuGetGallery.Facts/NuGetGallery.Facts.csproj
@@ -103,6 +103,7 @@
+
diff --git a/tests/NuGetGallery.Facts/Services/PackageVulnerabilitiesCacheServiceFacts.cs b/tests/NuGetGallery.Facts/Services/PackageVulnerabilitiesCacheServiceFacts.cs
new file mode 100644
index 0000000000..04d9dc043b
--- /dev/null
+++ b/tests/NuGetGallery.Facts/Services/PackageVulnerabilitiesCacheServiceFacts.cs
@@ -0,0 +1,184 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.Collections.Generic;
+using System.Data.Entity;
+using System.Linq;
+using Microsoft.Extensions.DependencyInjection;
+using Moq;
+using NuGet.Services.Entities;
+using NuGetGallery.Framework;
+using Xunit;
+
+namespace NuGetGallery.Services
+{
+ public class PackageVulnerabilitiesCacheServiceFacts : TestContainer
+ {
+ [Fact]
+ public void RefreshesVulnerabilitiesCache()
+ {
+ // Arrange
+ var entitiesContext = new Mock();
+ entitiesContext.Setup(x => x.Set()).Returns(GetVulnerableRanges());
+ var serviceProvider = new Mock();
+ serviceProvider.Setup(x => x.GetService(typeof(IEntitiesContext))).Returns(entitiesContext.Object);
+ var serviceScope = new Mock();
+ serviceScope.Setup(x => x.ServiceProvider).Returns(serviceProvider.Object);
+ serviceScope.Setup(x => x.Dispose()).Verifiable();
+ var serviceScopeFactory = new Mock();
+ serviceScopeFactory.Setup(x => x.CreateScope()).Returns(serviceScope.Object);
+ var telemetryService = new Mock();
+ telemetryService.Setup(x => x.TrackVulnerabilitiesCacheRefreshDuration(It.IsAny())).Verifiable();
+ var cacheService = new PackageVulnerabilitiesCacheService(telemetryService.Object);
+ cacheService.RefreshCache(serviceScopeFactory.Object);
+
+ // Act
+ var vulnerabilitiesFoo = cacheService.GetVulnerabilitiesById("Foo");
+ var vulnerabilitiesBar = cacheService.GetVulnerabilitiesById("Bar");
+
+ // Assert
+ // - ensure telemetry is sent
+ telemetryService.Verify(x => x.TrackVulnerabilitiesCacheRefreshDuration(It.IsAny()), Times.Once);
+ // - ensure scope is disposed
+ serviceScope.Verify(x => x.Dispose(), Times.AtLeastOnce);
+ // - ensure contants of cache are correct
+ Assert.Equal(4, vulnerabilitiesFoo.Count);
+ Assert.Equal(1, vulnerabilitiesFoo[0].Count);
+ Assert.Equal(1, vulnerabilitiesFoo[1].Count);
+ Assert.Equal(2, vulnerabilitiesFoo[2].Count);
+ Assert.Equal(1234, vulnerabilitiesFoo[2][0].GitHubDatabaseKey);
+ Assert.Equal(5678, vulnerabilitiesFoo[2][1].GitHubDatabaseKey);
+ Assert.Equal(1, vulnerabilitiesFoo[3].Count);
+ Assert.Equal(2, vulnerabilitiesBar.Count);
+ Assert.Equal(1, vulnerabilitiesBar[5].Count);
+ Assert.Equal(9012, vulnerabilitiesBar[5][0].GitHubDatabaseKey);
+ Assert.Equal(1, vulnerabilitiesBar[6].Count);
+ }
+
+ DbSet GetVulnerableRanges()
+ {
+ var registrationFoo = new PackageRegistration { Id = "Foo" };
+ var registrationBar = new PackageRegistration { Id = "Bar" };
+
+ var vulnerabilityCriticalFoo = new PackageVulnerability
+ {
+ AdvisoryUrl = "http://theurl/1234",
+ GitHubDatabaseKey = 1234,
+ Severity = PackageVulnerabilitySeverity.Critical
+ };
+ var vulnerabilityModerateFoo = new PackageVulnerability
+ {
+ AdvisoryUrl = "http://theurl/5678",
+ GitHubDatabaseKey = 5678,
+ Severity = PackageVulnerabilitySeverity.Moderate
+ };
+ var vulnerabilityCriticalBar = new PackageVulnerability
+ {
+ AdvisoryUrl = "http://theurl/9012",
+ GitHubDatabaseKey = 9012,
+ Severity = PackageVulnerabilitySeverity.Critical
+ };
+
+ var versionRangeCriticalFoo = new VulnerablePackageVersionRange
+ {
+ Vulnerability = vulnerabilityCriticalFoo,
+ PackageId = "Foo",
+ PackageVersionRange = "1.1.1",
+ FirstPatchedPackageVersion = "1.1.2"
+ };
+ var versionRangeModerateFoo = new VulnerablePackageVersionRange
+ {
+ Vulnerability = vulnerabilityModerateFoo,
+ PackageId = "Foo",
+ PackageVersionRange = "<=1.1.2",
+ FirstPatchedPackageVersion = "1.1.3"
+ };
+ var versionRangeCriticalBar = new VulnerablePackageVersionRange
+ {
+ Vulnerability = vulnerabilityCriticalBar,
+ PackageId = "Bar",
+ PackageVersionRange = "<=1.1.0",
+ FirstPatchedPackageVersion = "1.1.1"
+ };
+
+ var packageFoo100 = new Package
+ {
+ Key = 0,
+ PackageRegistration = registrationFoo,
+ Version = "1.0.0",
+ VulnerablePackageRanges = new List
+ {
+ versionRangeModerateFoo
+ }
+ };
+ var packageFoo110 = new Package
+ {
+ Key = 1,
+ PackageRegistration = registrationFoo,
+ Version = "1.1.0",
+ VulnerablePackageRanges = new List
+ {
+ versionRangeModerateFoo
+ }
+ };
+ var packageFoo111 = new Package
+ {
+ Key = 2,
+ PackageRegistration = registrationFoo,
+ Version = "1.1.1",
+ VulnerablePackageRanges = new List
+ {
+ versionRangeModerateFoo,
+ versionRangeCriticalFoo
+ }
+ };
+ var packageFoo112 = new Package
+ {
+ Key = 3,
+ PackageRegistration = registrationFoo,
+ Version = "1.1.2",
+ VulnerablePackageRanges = new List
+ {
+ versionRangeModerateFoo
+ }
+ };
+ var packageBar100 = new Package
+ {
+ Key = 5,
+ Version = "1.0.0",
+ PackageRegistration = registrationBar,
+ VulnerablePackageRanges = new List
+ {
+ versionRangeCriticalBar
+ }
+ };
+ var packageBar110 = new Package
+ {
+ Key = 6,
+ PackageRegistration = registrationBar,
+ Version = "1.1.0",
+ VulnerablePackageRanges = new List
+ {
+ versionRangeCriticalBar
+ }
+ };
+
+ versionRangeCriticalFoo.Packages = new List { packageFoo111 };
+ versionRangeModerateFoo.Packages = new List { packageFoo100, packageFoo110, packageFoo111, packageFoo112 };
+ versionRangeCriticalBar.Packages = new List { packageBar100, packageBar110 };
+
+ var vulnerableRangeList = new List { versionRangeCriticalFoo, versionRangeModerateFoo, versionRangeCriticalBar }.AsQueryable();
+ var vulnerableRangeDbSet = new Mock>();
+
+ // boilerplate mock DbSet redirects:
+ vulnerableRangeDbSet.As().Setup(x => x.Provider).Returns(vulnerableRangeList.Provider);
+ vulnerableRangeDbSet.As().Setup(x => x.Expression).Returns(vulnerableRangeList.Expression);
+ vulnerableRangeDbSet.As().Setup(x => x.ElementType).Returns(vulnerableRangeList.ElementType);
+ vulnerableRangeDbSet.As().Setup(x => x.GetEnumerator()).Returns(vulnerableRangeList.GetEnumerator());
+ vulnerableRangeDbSet.Setup(x => x.Include(It.IsAny())).Returns(vulnerableRangeDbSet.Object); // bypass includes (which break the test)
+
+ return vulnerableRangeDbSet.Object;
+ }
+ }
+}
diff --git a/tests/NuGetGallery.Facts/Services/PackageVulnerabilitiesServiceFacts.cs b/tests/NuGetGallery.Facts/Services/PackageVulnerabilitiesServiceFacts.cs
index 77121a2ab9..4324a2a5b3 100644
--- a/tests/NuGetGallery.Facts/Services/PackageVulnerabilitiesServiceFacts.cs
+++ b/tests/NuGetGallery.Facts/Services/PackageVulnerabilitiesServiceFacts.cs
@@ -11,146 +11,49 @@ namespace NuGetGallery.Services
{
public class PackageVulnerabilitiesServiceFacts : TestContainer
{
- private PackageRegistration _registrationVulnerable;
-
- private PackageVulnerability _vulnerabilityCritical;
- private PackageVulnerability _vulnerabilityModerate;
-
- private VulnerablePackageVersionRange _versionRangeCritical;
- private VulnerablePackageVersionRange _versionRangeModerate;
-
- private Package _packageVulnerable100;
- private Package _packageVulnerable110;
- private Package _packageVulnerable111;
- private Package _packageVulnerable112;
-
- private Package _packageNotVulnerable;
-
- [Fact]
- public void GetsVulnerabilitiesOfPackage()
- {
- // Arrange
- SetUp();
- var packages = new[]
- {
- _packageVulnerable100,
- _packageVulnerable110,
- _packageVulnerable111,
- _packageVulnerable112,
- _packageNotVulnerable
- };
- var context = GetFakeContext();
- context.Packages.AddRange(packages);
- var target = Get();
-
- // Act
- var vulnerableResult = target.GetVulnerabilitiesById("Vulnerable");
- var notVulnerableResult = target.GetVulnerabilitiesById("NotVulnerable");
-
- // Assert
- Assert.Equal(3, vulnerableResult.Count);
- var vulnerabilitiesFor100 = vulnerableResult[_packageVulnerable100.Key];
- var vulnerabilitiesFor110 = vulnerableResult[_packageVulnerable110.Key];
- var vulnerabilitiesFor111 = vulnerableResult[_packageVulnerable111.Key];
- Assert.Equal(_vulnerabilityModerate, vulnerabilitiesFor100[0]);
- Assert.Equal(_vulnerabilityModerate, vulnerabilitiesFor110[0]);
- Assert.Equal(_vulnerabilityModerate, vulnerabilitiesFor111[0]);
- Assert.Equal(_vulnerabilityCritical, vulnerabilitiesFor111[1]);
-
- Assert.Null(notVulnerableResult);
- }
-
[Fact]
public void GetsVulnerableStatusOfPackage()
{
// Arrange
- SetUp();
- var context = GetFakeContext();
- var target = Get();
-
- // Act
- var shouldBeVulnerable = target.IsPackageVulnerable(_packageVulnerable100);
- var shouldNotBeVulnerable = target.IsPackageVulnerable(_packageNotVulnerable);
-
- // Assert
- Assert.True(shouldBeVulnerable);
- Assert.False(shouldNotBeVulnerable);
- }
-
- private void SetUp()
- {
- _registrationVulnerable = new PackageRegistration { Id = "Vulnerable" };
+ var registrationVulnerable = new PackageRegistration { Id = "Vulnerable" };
- _vulnerabilityCritical = new PackageVulnerability
- {
- AdvisoryUrl = "http://theurl/1234",
- GitHubDatabaseKey = 1234,
- Severity = PackageVulnerabilitySeverity.Critical
- };
- _vulnerabilityModerate = new PackageVulnerability
+ var vulnerabilityModerate = new PackageVulnerability
{
AdvisoryUrl = "http://theurl/5678",
GitHubDatabaseKey = 5678,
Severity = PackageVulnerabilitySeverity.Moderate
};
- _versionRangeCritical = new VulnerablePackageVersionRange
+ var versionRangeModerate = new VulnerablePackageVersionRange
{
- Vulnerability = _vulnerabilityCritical,
- PackageVersionRange = "1.1.1",
- FirstPatchedPackageVersion = "1.1.2"
- };
- _versionRangeModerate = new VulnerablePackageVersionRange
- {
- Vulnerability = _vulnerabilityModerate,
+ Vulnerability = vulnerabilityModerate,
PackageVersionRange = "<=1.1.1",
FirstPatchedPackageVersion = "1.1.2"
};
- _packageVulnerable100 = new Package
+ var packageVulnerable = new Package
{
Key = 0,
- PackageRegistration = _registrationVulnerable,
+ PackageRegistration = registrationVulnerable,
Version = "1.0.0",
- VulnerablePackageRanges = new List
- {
- _versionRangeModerate
- }
- };
- _packageVulnerable110 = new Package
- {
- Key = 1,
- PackageRegistration = _registrationVulnerable,
- Version = "1.1.0",
- VulnerablePackageRanges = new List
- {
- _versionRangeModerate
- }
- };
- _packageVulnerable111 = new Package
- {
- Key = 3, // simulate a different order in db - create a non-contiguous range of rows, even if the range is contiguous
- PackageRegistration = _registrationVulnerable,
- Version = "1.1.1",
- VulnerablePackageRanges = new List
- {
- _versionRangeModerate,
- _versionRangeCritical
- }
- };
- _packageVulnerable112 = new Package
- {
- Key = 2, // simulate a different order in db - create a non-contiguous range of rows, even if the range is contiguous
- PackageRegistration = _registrationVulnerable,
- Version = "1.1.2",
- VulnerablePackageRanges = new List()
+ VulnerablePackageRanges = new List {versionRangeModerate}
};
- _packageNotVulnerable = new Package
+ var packageNotVulnerable = new Package
{
Key = 4,
PackageRegistration = new PackageRegistration { Id = "NotVulnerable" },
VulnerablePackageRanges = new List()
- };
+ };
+
+ var target = Get();
+
+ // Act
+ var shouldBeVulnerable = target.IsPackageVulnerable(packageVulnerable);
+ var shouldNotBeVulnerable = target.IsPackageVulnerable(packageNotVulnerable);
+
+ // Assert
+ Assert.True(shouldBeVulnerable);
+ Assert.False(shouldNotBeVulnerable);
}
}
}
diff --git a/tests/NuGetGallery.Facts/Services/TelemetryServiceFacts.cs b/tests/NuGetGallery.Facts/Services/TelemetryServiceFacts.cs
index 7cc3362273..2fc8ea2e84 100644
--- a/tests/NuGetGallery.Facts/Services/TelemetryServiceFacts.cs
+++ b/tests/NuGetGallery.Facts/Services/TelemetryServiceFacts.cs
@@ -61,7 +61,7 @@ public static IEnumerable