Skip to content

Commit

Permalink
initial check in for the issue #23
Browse files Browse the repository at this point in the history
  • Loading branch information
kdcllc committed Mar 25, 2019
1 parent ed0b203 commit c027703
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 0 deletions.
7 changes: 7 additions & 0 deletions Bet.AspNetCore.sln
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AppAuthentication", "src\Ap
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Bet.AspNetCore.ReCapture", "src\Bet.AspNetCore.ReCapture\Bet.AspNetCore.ReCapture.csproj", "{6472D7D3-D27B-4B57-ABE1-1708AB64E1D1}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Bet.Extensions.ML", "src\Bet.Extensions.ML\Bet.Extensions.ML.csproj", "{5FD99A79-7BFB-4112-9B7F-6E308306FC46}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -109,6 +111,10 @@ Global
{6472D7D3-D27B-4B57-ABE1-1708AB64E1D1}.Debug|Any CPU.Build.0 = Debug|Any CPU
{6472D7D3-D27B-4B57-ABE1-1708AB64E1D1}.Release|Any CPU.ActiveCfg = Release|Any CPU
{6472D7D3-D27B-4B57-ABE1-1708AB64E1D1}.Release|Any CPU.Build.0 = Release|Any CPU
{5FD99A79-7BFB-4112-9B7F-6E308306FC46}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{5FD99A79-7BFB-4112-9B7F-6E308306FC46}.Debug|Any CPU.Build.0 = Debug|Any CPU
{5FD99A79-7BFB-4112-9B7F-6E308306FC46}.Release|Any CPU.ActiveCfg = Release|Any CPU
{5FD99A79-7BFB-4112-9B7F-6E308306FC46}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand All @@ -126,6 +132,7 @@ Global
{AA34055E-D487-44B7-BD4C-FEEB29FFDA05} = {8382D62B-2BB8-4406-95FD-63AF167F5D65}
{47756F33-5FF7-4FDB-8E7D-F30CCFBE8F34} = {D5655917-C1A5-44AA-85D2-BF9132205E4E}
{6472D7D3-D27B-4B57-ABE1-1708AB64E1D1} = {8382D62B-2BB8-4406-95FD-63AF167F5D65}
{5FD99A79-7BFB-4112-9B7F-6E308306FC46} = {8382D62B-2BB8-4406-95FD-63AF167F5D65}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {A53F67D4-7CE9-4F03-8DD4-2C00E7AE2F46}
Expand Down
2 changes: 2 additions & 0 deletions build/dependencies.props
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
<PackageReference Update="Microsoft.Extensions.Logging.Configuration" Version="$(NetCoreCommonVersion)"/>
<PackageReference Update="Microsoft.Extensions.Logging.Console" Version="$(NetCoreCommonVersion)" />
<PackageReference Update="Microsoft.Extensions.Logging.Debug" Version="$(NetCoreCommonVersion)" />
<PackageReference Update="Microsoft.Extensions.ObjectPool" Version="$(NetCoreCommonVersion)" />
<PackageReference Update="Microsoft.Extensions.Options" Version="$(NetCoreCommonVersion)" />
<PackageReference Update="Microsoft.Extensions.Options.ConfigurationExtensions" Version="$(NetCoreCommonVersion)" />
<PackageReference Update="Microsoft.Extensions.Options.ConfigurationExtensions" Version="$(NetCoreCommonVersion)" />
Expand Down Expand Up @@ -73,6 +74,7 @@
<PackageReference Update="Colorful.Console" Version="1.2.9" />
<PackageReference Update="McMaster.Extensions.CommandLineUtils" Version="2.3.3"/>
<PackageReference Update="Microsoft.AspNetCore.Diagnostics.HealthChecks" Version="2.2.0" />
<PackageReference Update="Microsoft.ML" Version="0.11.0" />
<PackageReference Update="Microsoft.VisualStudio.Azure.Containers.Tools.Targets" Version="1.4.10" />
<PackageReference Update="Newtonsoft.Json" Version="12.0.1" />
<PackageReference Update="Polly" Version="7.1.0" />
Expand Down
14 changes: 14 additions & 0 deletions src/Bet.Extensions.ML/Bet.Extensions.ML.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFrameworks>netstandard2.0;netcoreapp2.2;netcoreapp3.0</TargetFrameworks>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Logging.Console" />
<PackageReference Include="Microsoft.Extensions.Logging.Debug" />
<PackageReference Include="Microsoft.Extensions.ObjectPool" />
<PackageReference Include="Microsoft.ML" />
</ItemGroup>

</Project>
21 changes: 21 additions & 0 deletions src/Bet.Extensions.ML/Engine/IMLModelEngine.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using Microsoft.ML;

namespace Bet.Extensions.ML.Engine
{
public interface IMLModelEngine<TData, TPrediction> where TData : class where TPrediction : class
{
/// <summary>
/// The transformer is a component that transforms data. It also supports 'schema
/// propagation' to answer the question of 'how will the data with this schema look,
/// after you transform it?'.
/// </summary>
ITransformer MLModel { get; }

/// <summary>
/// Predict based on <see cref="MLModel"/> that was loaded.
/// </summary>
/// <param name="dataSample">The data sample to be predicted on.</param>
/// <returns></returns>
TPrediction Predict(TData dataSample);
}
}
85 changes: 85 additions & 0 deletions src/Bet.Extensions.ML/Engine/MLModelEngineObjectPool.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
using System;
using System.IO;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.ObjectPool;
using Microsoft.ML;

namespace Bet.Extensions.ML.Engine
{
public class MLModelEngineObjectPool<TData, TPrediction>
: IMLModelEngine<TData, TPrediction>
where TData : class
where TPrediction : class, new()
{
private readonly MLContext _mlContext;
private readonly ILogger<PredictionEnginePooledObjectPolicy<TData, TPrediction>> _logger;
private readonly int _maximumObjectsRetained;
private readonly ObjectPool<PredictionEngine<TData, TPrediction>> _predictionEnginePool;

public ITransformer MLModel { get; private set; }

public MLModelEngineObjectPool(
string modelFilePathName,
ILogger<PredictionEnginePooledObjectPolicy<TData, TPrediction>> logger,
int maximumObjectsRetained = -1)
{
//Create the MLContext object to use under the scope of this class
_mlContext = new MLContext();

_logger = logger;
//Load the ProductSalesForecast model from the .ZIP file
using (var fileStream = File.OpenRead(modelFilePathName))
{
MLModel = _mlContext.Model.Load(fileStream);
}

_maximumObjectsRetained = maximumObjectsRetained;

// create PredictionEngine Object Pool
_predictionEnginePool = CreatePredictionEngineObjectPool();
}

public TPrediction Predict(TData dataSample)
{
// get instance of PredictionEngine from the object pool

var predictionEngine = _predictionEnginePool.Get();

try
{
return predictionEngine.Predict(dataSample);
}
catch (Exception ex)
{
_logger.LogError("PredictionEngine failed: {ex}", ex.ToString());
}
finally
{
// release used PredictionEngine object into the Object pool.
_predictionEnginePool.Return(predictionEngine);
}

// all other cases return null prediction.
return null;
}

private ObjectPool<PredictionEngine<TData,TPrediction>> CreatePredictionEngineObjectPool()
{
var pooledObjectPolicy = new PredictionEnginePooledObjectPolicy<TData, TPrediction>(_mlContext, MLModel, _logger);

DefaultObjectPool<PredictionEngine<TData, TPrediction>> pool;

if (_maximumObjectsRetained != -1)
{
pool = new DefaultObjectPool<PredictionEngine<TData, TPrediction>>(pooledObjectPolicy, _maximumObjectsRetained);
}
else
{
//default maximumRetained is Environment.ProcessorCount * 2, if not explicitly provided
pool = new DefaultObjectPool<PredictionEngine<TData, TPrediction>>(pooledObjectPolicy);
}

return pool;
}
}
}
57 changes: 57 additions & 0 deletions src/Bet.Extensions.ML/Engine/PredictionEnginePooledObjectPolicy.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.ObjectPool;
using Microsoft.ML;
using System;
using System.Diagnostics;

namespace Bet.Extensions.ML.Engine
{
/// <summary>
/// Creates instance of the <see cref="PredictionEngine{TSrc, TDst}"/> for the dataset.
/// </summary>
/// <typeparam name="TData"></typeparam>
/// <typeparam name="TPrediction"></typeparam>
public class PredictionEnginePooledObjectPolicy<TData, TPrediction>
: IPooledObjectPolicy<PredictionEngine<TData, TPrediction>>
where TData : class
where TPrediction : class, new()
{
private readonly MLContext _mlContext;

private readonly ITransformer _model;

private readonly ILogger<PredictionEnginePooledObjectPolicy<TData,TPrediction>> _logger;

public PredictionEnginePooledObjectPolicy(
MLContext mlContext,
ITransformer model,
ILogger<PredictionEnginePooledObjectPolicy<TData,TPrediction>> logger)
{
_mlContext = mlContext ?? throw new ArgumentNullException(nameof(mlContext));
_model = model ?? throw new ArgumentNullException(nameof(model));
_logger = logger;
}

public PredictionEngine<TData, TPrediction> Create()
{
var watch = Stopwatch.StartNew();

var predictionEngine = _model.CreatePredictionEngine<TData, TPrediction>(_mlContext);

watch.Stop();
_logger.LogDebug("Time took to create the prediction engine: {elapsed}", watch.ElapsedMilliseconds);

return predictionEngine;
}

public bool Return(PredictionEngine<TData, TPrediction> obj)
{
if(obj == null)
{
return false;
}

return true;
}
}
}
43 changes: 43 additions & 0 deletions src/Bet.Extensions.ML/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
using Bet.Extensions.ML.Engine;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.ObjectPool;
using System;
using System.IO;

namespace Microsoft.Extensions.DependencyInjection
{
public static class ServiceCollectionExtensions
{
/// <summary>
/// Add <see cref="IMLModelEngine{TData, TPrediction}"/> prediction engine with specified model.
/// </summary>
/// <typeparam name="TData"></typeparam>
/// <typeparam name="TPrediction"></typeparam>
/// <param name="services"></param>
/// <param name="modelFilePath"></param>
/// <param name="maximumObjectRetained"></param>
/// <returns></returns>
public static IServiceCollection AddMLModelEngine<TData, TPrediction>(
this IServiceCollection services,
string modelFilePath,
int maximumObjectRetained = -1) where TData : class where TPrediction : class, new()
{
services.TryAddSingleton<ObjectPoolProvider, DefaultObjectPoolProvider>();

services.AddSingleton<IMLModelEngine<TData, TPrediction>>(sp =>
{
if (!File.Exists(modelFilePath))
{
throw new ArgumentException($"File: {modelFilePath} doesn't exist");
}
var logger = sp.GetRequiredService<ILogger<PredictionEnginePooledObjectPolicy<TData, TPrediction>>>();
return new MLModelEngineObjectPool<TData, TPrediction>(modelFilePath,logger,maximumObjectRetained);
});

return services;
}
}
}

0 comments on commit c027703

Please sign in to comment.