From c0277033fb52405a18d574192da9b63976d62e41 Mon Sep 17 00:00:00 2001 From: kdcllc Date: Mon, 25 Mar 2019 19:57:03 -0400 Subject: [PATCH] initial check in for the issue #23 --- Bet.AspNetCore.sln | 7 ++ build/dependencies.props | 2 + .../Bet.Extensions.ML.csproj | 14 +++ .../Engine/IMLModelEngine.cs | 21 +++++ .../Engine/MLModelEngineObjectPool.cs | 85 +++++++++++++++++++ .../PredictionEnginePooledObjectPolicy.cs | 57 +++++++++++++ .../ServiceCollectionExtensions.cs | 43 ++++++++++ 7 files changed, 229 insertions(+) create mode 100644 src/Bet.Extensions.ML/Bet.Extensions.ML.csproj create mode 100644 src/Bet.Extensions.ML/Engine/IMLModelEngine.cs create mode 100644 src/Bet.Extensions.ML/Engine/MLModelEngineObjectPool.cs create mode 100644 src/Bet.Extensions.ML/Engine/PredictionEnginePooledObjectPolicy.cs create mode 100644 src/Bet.Extensions.ML/ServiceCollectionExtensions.cs diff --git a/Bet.AspNetCore.sln b/Bet.AspNetCore.sln index 1586225..3695484 100644 --- a/Bet.AspNetCore.sln +++ b/Bet.AspNetCore.sln @@ -55,6 +55,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AppAuthentication", "src\Ap EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Bet.AspNetCore.ReCapture", "src\Bet.AspNetCore.ReCapture\Bet.AspNetCore.ReCapture.csproj", "{6472D7D3-D27B-4B57-ABE1-1708AB64E1D1}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Bet.Extensions.ML", "src\Bet.Extensions.ML\Bet.Extensions.ML.csproj", "{5FD99A79-7BFB-4112-9B7F-6E308306FC46}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -109,6 +111,10 @@ Global {6472D7D3-D27B-4B57-ABE1-1708AB64E1D1}.Debug|Any CPU.Build.0 = Debug|Any CPU {6472D7D3-D27B-4B57-ABE1-1708AB64E1D1}.Release|Any CPU.ActiveCfg = Release|Any CPU {6472D7D3-D27B-4B57-ABE1-1708AB64E1D1}.Release|Any CPU.Build.0 = Release|Any CPU + {5FD99A79-7BFB-4112-9B7F-6E308306FC46}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {5FD99A79-7BFB-4112-9B7F-6E308306FC46}.Debug|Any CPU.Build.0 = Debug|Any CPU + {5FD99A79-7BFB-4112-9B7F-6E308306FC46}.Release|Any CPU.ActiveCfg = Release|Any CPU + {5FD99A79-7BFB-4112-9B7F-6E308306FC46}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -126,6 +132,7 @@ Global {AA34055E-D487-44B7-BD4C-FEEB29FFDA05} = {8382D62B-2BB8-4406-95FD-63AF167F5D65} {47756F33-5FF7-4FDB-8E7D-F30CCFBE8F34} = {D5655917-C1A5-44AA-85D2-BF9132205E4E} {6472D7D3-D27B-4B57-ABE1-1708AB64E1D1} = {8382D62B-2BB8-4406-95FD-63AF167F5D65} + {5FD99A79-7BFB-4112-9B7F-6E308306FC46} = {8382D62B-2BB8-4406-95FD-63AF167F5D65} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {A53F67D4-7CE9-4F03-8DD4-2C00E7AE2F46} diff --git a/build/dependencies.props b/build/dependencies.props index f21bf58..1ea9ecf 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -40,6 +40,7 @@ + @@ -73,6 +74,7 @@ + diff --git a/src/Bet.Extensions.ML/Bet.Extensions.ML.csproj b/src/Bet.Extensions.ML/Bet.Extensions.ML.csproj new file mode 100644 index 0000000..660661e --- /dev/null +++ b/src/Bet.Extensions.ML/Bet.Extensions.ML.csproj @@ -0,0 +1,14 @@ + + + + netstandard2.0;netcoreapp2.2;netcoreapp3.0 + + + + + + + + + + diff --git a/src/Bet.Extensions.ML/Engine/IMLModelEngine.cs b/src/Bet.Extensions.ML/Engine/IMLModelEngine.cs new file mode 100644 index 0000000..9882ef9 --- /dev/null +++ b/src/Bet.Extensions.ML/Engine/IMLModelEngine.cs @@ -0,0 +1,21 @@ +using Microsoft.ML; + +namespace Bet.Extensions.ML.Engine +{ + public interface IMLModelEngine where TData : class where TPrediction : class + { + /// + /// The transformer is a component that transforms data. It also supports 'schema + /// propagation' to answer the question of 'how will the data with this schema look, + /// after you transform it?'. + /// + ITransformer MLModel { get; } + + /// + /// Predict based on that was loaded. + /// + /// The data sample to be predicted on. + /// + TPrediction Predict(TData dataSample); + } +} diff --git a/src/Bet.Extensions.ML/Engine/MLModelEngineObjectPool.cs b/src/Bet.Extensions.ML/Engine/MLModelEngineObjectPool.cs new file mode 100644 index 0000000..7f5e496 --- /dev/null +++ b/src/Bet.Extensions.ML/Engine/MLModelEngineObjectPool.cs @@ -0,0 +1,85 @@ +using System; +using System.IO; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.ObjectPool; +using Microsoft.ML; + +namespace Bet.Extensions.ML.Engine +{ + public class MLModelEngineObjectPool + : IMLModelEngine + where TData : class + where TPrediction : class, new() + { + private readonly MLContext _mlContext; + private readonly ILogger> _logger; + private readonly int _maximumObjectsRetained; + private readonly ObjectPool> _predictionEnginePool; + + public ITransformer MLModel { get; private set; } + + public MLModelEngineObjectPool( + string modelFilePathName, + ILogger> logger, + int maximumObjectsRetained = -1) + { + //Create the MLContext object to use under the scope of this class + _mlContext = new MLContext(); + + _logger = logger; + //Load the ProductSalesForecast model from the .ZIP file + using (var fileStream = File.OpenRead(modelFilePathName)) + { + MLModel = _mlContext.Model.Load(fileStream); + } + + _maximumObjectsRetained = maximumObjectsRetained; + + // create PredictionEngine Object Pool + _predictionEnginePool = CreatePredictionEngineObjectPool(); + } + + public TPrediction Predict(TData dataSample) + { + // get instance of PredictionEngine from the object pool + + var predictionEngine = _predictionEnginePool.Get(); + + try + { + return predictionEngine.Predict(dataSample); + } + catch (Exception ex) + { + _logger.LogError("PredictionEngine failed: {ex}", ex.ToString()); + } + finally + { + // release used PredictionEngine object into the Object pool. + _predictionEnginePool.Return(predictionEngine); + } + + // all other cases return null prediction. + return null; + } + + private ObjectPool> CreatePredictionEngineObjectPool() + { + var pooledObjectPolicy = new PredictionEnginePooledObjectPolicy(_mlContext, MLModel, _logger); + + DefaultObjectPool> pool; + + if (_maximumObjectsRetained != -1) + { + pool = new DefaultObjectPool>(pooledObjectPolicy, _maximumObjectsRetained); + } + else + { + //default maximumRetained is Environment.ProcessorCount * 2, if not explicitly provided + pool = new DefaultObjectPool>(pooledObjectPolicy); + } + + return pool; + } + } +} diff --git a/src/Bet.Extensions.ML/Engine/PredictionEnginePooledObjectPolicy.cs b/src/Bet.Extensions.ML/Engine/PredictionEnginePooledObjectPolicy.cs new file mode 100644 index 0000000..36965e3 --- /dev/null +++ b/src/Bet.Extensions.ML/Engine/PredictionEnginePooledObjectPolicy.cs @@ -0,0 +1,57 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.ObjectPool; +using Microsoft.ML; +using System; +using System.Diagnostics; + +namespace Bet.Extensions.ML.Engine +{ + /// + /// Creates instance of the for the dataset. + /// + /// + /// + public class PredictionEnginePooledObjectPolicy + : IPooledObjectPolicy> + where TData : class + where TPrediction : class, new() + { + private readonly MLContext _mlContext; + + private readonly ITransformer _model; + + private readonly ILogger> _logger; + + public PredictionEnginePooledObjectPolicy( + MLContext mlContext, + ITransformer model, + ILogger> logger) + { + _mlContext = mlContext ?? throw new ArgumentNullException(nameof(mlContext)); + _model = model ?? throw new ArgumentNullException(nameof(model)); + _logger = logger; + } + + public PredictionEngine Create() + { + var watch = Stopwatch.StartNew(); + + var predictionEngine = _model.CreatePredictionEngine(_mlContext); + + watch.Stop(); + _logger.LogDebug("Time took to create the prediction engine: {elapsed}", watch.ElapsedMilliseconds); + + return predictionEngine; + } + + public bool Return(PredictionEngine obj) + { + if(obj == null) + { + return false; + } + + return true; + } + } +} diff --git a/src/Bet.Extensions.ML/ServiceCollectionExtensions.cs b/src/Bet.Extensions.ML/ServiceCollectionExtensions.cs new file mode 100644 index 0000000..4113185 --- /dev/null +++ b/src/Bet.Extensions.ML/ServiceCollectionExtensions.cs @@ -0,0 +1,43 @@ +using Bet.Extensions.ML.Engine; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.ObjectPool; +using System; +using System.IO; + +namespace Microsoft.Extensions.DependencyInjection +{ + public static class ServiceCollectionExtensions + { + /// + /// Add prediction engine with specified model. + /// + /// + /// + /// + /// + /// + /// + public static IServiceCollection AddMLModelEngine( + this IServiceCollection services, + string modelFilePath, + int maximumObjectRetained = -1) where TData : class where TPrediction : class, new() + { + services.TryAddSingleton(); + + services.AddSingleton>(sp => + { + if (!File.Exists(modelFilePath)) + { + throw new ArgumentException($"File: {modelFilePath} doesn't exist"); + } + + var logger = sp.GetRequiredService>>(); + + return new MLModelEngineObjectPool(modelFilePath,logger,maximumObjectRetained); + }); + + return services; + } + } +}