-
-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
229 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
<Project Sdk="Microsoft.NET.Sdk"> | ||
|
||
<PropertyGroup> | ||
<TargetFrameworks>netstandard2.0;netcoreapp2.2;netcoreapp3.0</TargetFrameworks> | ||
</PropertyGroup> | ||
|
||
<ItemGroup> | ||
<PackageReference Include="Microsoft.Extensions.Logging.Console" /> | ||
<PackageReference Include="Microsoft.Extensions.Logging.Debug" /> | ||
<PackageReference Include="Microsoft.Extensions.ObjectPool" /> | ||
<PackageReference Include="Microsoft.ML" /> | ||
</ItemGroup> | ||
|
||
</Project> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
using Microsoft.ML; | ||
|
||
namespace Bet.Extensions.ML.Engine | ||
{ | ||
public interface IMLModelEngine<TData, TPrediction> where TData : class where TPrediction : class | ||
{ | ||
/// <summary> | ||
/// The transformer is a component that transforms data. It also supports 'schema | ||
/// propagation' to answer the question of 'how will the data with this schema look, | ||
/// after you transform it?'. | ||
/// </summary> | ||
ITransformer MLModel { get; } | ||
|
||
/// <summary> | ||
/// Predict based on <see cref="MLModel"/> that was loaded. | ||
/// </summary> | ||
/// <param name="dataSample">The data sample to be predicted on.</param> | ||
/// <returns></returns> | ||
TPrediction Predict(TData dataSample); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
using System; | ||
using System.IO; | ||
using Microsoft.Extensions.Logging; | ||
using Microsoft.Extensions.ObjectPool; | ||
using Microsoft.ML; | ||
|
||
namespace Bet.Extensions.ML.Engine | ||
{ | ||
public class MLModelEngineObjectPool<TData, TPrediction> | ||
: IMLModelEngine<TData, TPrediction> | ||
where TData : class | ||
where TPrediction : class, new() | ||
{ | ||
private readonly MLContext _mlContext; | ||
private readonly ILogger<PredictionEnginePooledObjectPolicy<TData, TPrediction>> _logger; | ||
private readonly int _maximumObjectsRetained; | ||
private readonly ObjectPool<PredictionEngine<TData, TPrediction>> _predictionEnginePool; | ||
|
||
public ITransformer MLModel { get; private set; } | ||
|
||
public MLModelEngineObjectPool( | ||
string modelFilePathName, | ||
ILogger<PredictionEnginePooledObjectPolicy<TData, TPrediction>> logger, | ||
int maximumObjectsRetained = -1) | ||
{ | ||
//Create the MLContext object to use under the scope of this class | ||
_mlContext = new MLContext(); | ||
|
||
_logger = logger; | ||
//Load the ProductSalesForecast model from the .ZIP file | ||
using (var fileStream = File.OpenRead(modelFilePathName)) | ||
{ | ||
MLModel = _mlContext.Model.Load(fileStream); | ||
} | ||
|
||
_maximumObjectsRetained = maximumObjectsRetained; | ||
|
||
// create PredictionEngine Object Pool | ||
_predictionEnginePool = CreatePredictionEngineObjectPool(); | ||
} | ||
|
||
public TPrediction Predict(TData dataSample) | ||
{ | ||
// get instance of PredictionEngine from the object pool | ||
|
||
var predictionEngine = _predictionEnginePool.Get(); | ||
|
||
try | ||
{ | ||
return predictionEngine.Predict(dataSample); | ||
} | ||
catch (Exception ex) | ||
{ | ||
_logger.LogError("PredictionEngine failed: {ex}", ex.ToString()); | ||
} | ||
finally | ||
{ | ||
// release used PredictionEngine object into the Object pool. | ||
_predictionEnginePool.Return(predictionEngine); | ||
} | ||
|
||
// all other cases return null prediction. | ||
return null; | ||
} | ||
|
||
private ObjectPool<PredictionEngine<TData,TPrediction>> CreatePredictionEngineObjectPool() | ||
{ | ||
var pooledObjectPolicy = new PredictionEnginePooledObjectPolicy<TData, TPrediction>(_mlContext, MLModel, _logger); | ||
|
||
DefaultObjectPool<PredictionEngine<TData, TPrediction>> pool; | ||
|
||
if (_maximumObjectsRetained != -1) | ||
{ | ||
pool = new DefaultObjectPool<PredictionEngine<TData, TPrediction>>(pooledObjectPolicy, _maximumObjectsRetained); | ||
} | ||
else | ||
{ | ||
//default maximumRetained is Environment.ProcessorCount * 2, if not explicitly provided | ||
pool = new DefaultObjectPool<PredictionEngine<TData, TPrediction>>(pooledObjectPolicy); | ||
} | ||
|
||
return pool; | ||
} | ||
} | ||
} |
57 changes: 57 additions & 0 deletions
57
src/Bet.Extensions.ML/Engine/PredictionEnginePooledObjectPolicy.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
using Microsoft.Extensions.Logging; | ||
using Microsoft.Extensions.ObjectPool; | ||
using Microsoft.ML; | ||
using System; | ||
using System.Diagnostics; | ||
|
||
namespace Bet.Extensions.ML.Engine | ||
{ | ||
/// <summary> | ||
/// Creates instance of the <see cref="PredictionEngine{TSrc, TDst}"/> for the dataset. | ||
/// </summary> | ||
/// <typeparam name="TData"></typeparam> | ||
/// <typeparam name="TPrediction"></typeparam> | ||
public class PredictionEnginePooledObjectPolicy<TData, TPrediction> | ||
: IPooledObjectPolicy<PredictionEngine<TData, TPrediction>> | ||
where TData : class | ||
where TPrediction : class, new() | ||
{ | ||
private readonly MLContext _mlContext; | ||
|
||
private readonly ITransformer _model; | ||
|
||
private readonly ILogger<PredictionEnginePooledObjectPolicy<TData,TPrediction>> _logger; | ||
|
||
public PredictionEnginePooledObjectPolicy( | ||
MLContext mlContext, | ||
ITransformer model, | ||
ILogger<PredictionEnginePooledObjectPolicy<TData,TPrediction>> logger) | ||
{ | ||
_mlContext = mlContext ?? throw new ArgumentNullException(nameof(mlContext)); | ||
_model = model ?? throw new ArgumentNullException(nameof(model)); | ||
_logger = logger; | ||
} | ||
|
||
public PredictionEngine<TData, TPrediction> Create() | ||
{ | ||
var watch = Stopwatch.StartNew(); | ||
|
||
var predictionEngine = _model.CreatePredictionEngine<TData, TPrediction>(_mlContext); | ||
|
||
watch.Stop(); | ||
_logger.LogDebug("Time took to create the prediction engine: {elapsed}", watch.ElapsedMilliseconds); | ||
|
||
return predictionEngine; | ||
} | ||
|
||
public bool Return(PredictionEngine<TData, TPrediction> obj) | ||
{ | ||
if(obj == null) | ||
{ | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
using Bet.Extensions.ML.Engine; | ||
using Microsoft.Extensions.DependencyInjection.Extensions; | ||
using Microsoft.Extensions.Logging; | ||
using Microsoft.Extensions.ObjectPool; | ||
using System; | ||
using System.IO; | ||
|
||
namespace Microsoft.Extensions.DependencyInjection | ||
{ | ||
public static class ServiceCollectionExtensions | ||
{ | ||
/// <summary> | ||
/// Add <see cref="IMLModelEngine{TData, TPrediction}"/> prediction engine with specified model. | ||
/// </summary> | ||
/// <typeparam name="TData"></typeparam> | ||
/// <typeparam name="TPrediction"></typeparam> | ||
/// <param name="services"></param> | ||
/// <param name="modelFilePath"></param> | ||
/// <param name="maximumObjectRetained"></param> | ||
/// <returns></returns> | ||
public static IServiceCollection AddMLModelEngine<TData, TPrediction>( | ||
this IServiceCollection services, | ||
string modelFilePath, | ||
int maximumObjectRetained = -1) where TData : class where TPrediction : class, new() | ||
{ | ||
services.TryAddSingleton<ObjectPoolProvider, DefaultObjectPoolProvider>(); | ||
|
||
services.AddSingleton<IMLModelEngine<TData, TPrediction>>(sp => | ||
{ | ||
if (!File.Exists(modelFilePath)) | ||
{ | ||
throw new ArgumentException($"File: {modelFilePath} doesn't exist"); | ||
} | ||
var logger = sp.GetRequiredService<ILogger<PredictionEnginePooledObjectPolicy<TData, TPrediction>>>(); | ||
return new MLModelEngineObjectPool<TData, TPrediction>(modelFilePath,logger,maximumObjectRetained); | ||
}); | ||
|
||
return services; | ||
} | ||
} | ||
} |