Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added object pooling for PredictionFunction/PredictionEngine to eShopDashboardML sample #184

Merged
merged 5 commits into from
Dec 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions samples/csharp/common/MLModelEngine.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.Data;
using System.IO;
using Microsoft.ML;
//using Microsoft.Extensions.Configuration;

namespace Common
{
public class MLModelEngine<TData, TPrediction>
where TData : class
where TPrediction : class, new()
{
private readonly MLContext _mlContext;
private readonly ITransformer _model;
private readonly ObjectPool<PredictionFunction<TData, TPrediction>> _predictionEnginePool;
private readonly int _minPredictionEngineObjectsInPool;
private readonly int _maxPredictionEngineObjectsInPool;

public int CurrentPredictionEnginePoolSize
{
get { return _predictionEnginePool.CurrentPoolSize; }
}

//Constructor with modelFilePathName to load
public MLModelEngine(MLContext mlContext, string modelFilePathName, int minPredictionEngineObjectsInPool = 5, int maxPredictionEngineObjectsInPool = 1000)
{
_mlContext = mlContext;

//Load the ProductSalesForecast model from the .ZIP file
using (var fileStream = File.OpenRead(modelFilePathName))
{
_model = mlContext.Model.Load(fileStream);
}

_minPredictionEngineObjectsInPool = minPredictionEngineObjectsInPool;
_maxPredictionEngineObjectsInPool = maxPredictionEngineObjectsInPool;

//Create PredictionEngine Object Pool
_predictionEnginePool = CreatePredictionEngineObjectPool();
}

//Constructor with ITransformer model already created
public MLModelEngine(MLContext mlContext, ITransformer model, int minPredictionEngineObjectsInPool = 5, int maxPredictionEngineObjectsInPool = 1000)
{
_mlContext = mlContext;
_model = model;
_minPredictionEngineObjectsInPool = minPredictionEngineObjectsInPool;
_maxPredictionEngineObjectsInPool = maxPredictionEngineObjectsInPool;

//Create PredictionEngine Object Pool
_predictionEnginePool = CreatePredictionEngineObjectPool();
}

private ObjectPool<PredictionFunction<TData, TPrediction>> CreatePredictionEngineObjectPool()
{
return new ObjectPool<PredictionFunction<TData, TPrediction>>(() => _model.MakePredictionFunction<TData, TPrediction>(_mlContext),
_minPredictionEngineObjectsInPool,
_maxPredictionEngineObjectsInPool);
}

public TPrediction Predict(TData dataSample)
{
//Get PredictionEngine object from the Object Pool
PredictionFunction<TData, TPrediction> predictionEngine = _predictionEnginePool.GetObject();

//Predict
TPrediction prediction = predictionEngine.Predict(dataSample);

//Release used PredictionEngine object into the Object Pool
_predictionEnginePool.PutObject(predictionEngine);

return prediction;
}

}
}
62 changes: 62 additions & 0 deletions samples/csharp/common/ObjectPool.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;

namespace Common
{
public class ObjectPool<T>
{
private ConcurrentBag<T> _objects;
private Func<T> _objectGenerator;
private int _maxPoolSize;

public int CurrentPoolSize
{
get { return _objects.Count; }
}

public ObjectPool(Func<T> objectGenerator, int minPoolSize = 5, int maxPoolSize = 50000)
{
if (objectGenerator == null) throw new ArgumentNullException("objectGenerator");
_objects = new ConcurrentBag<T>();
_objectGenerator = objectGenerator;
_maxPoolSize = maxPoolSize;

//Measure total time of minimum objects creation
var watch = System.Diagnostics.Stopwatch.StartNew();

//Create minimum number of objects in pool
for (int i = 0; i < minPoolSize; i++)
{
_objects.Add(_objectGenerator());
}

//Stop measuring time
watch.Stop();
long elapsedMs = watch.ElapsedMilliseconds;
}

public T GetObject()
{
T item;
if (_objects.TryTake(out item))
{
return item;
}
else
{
if(_objects.Count <= _maxPoolSize)
return _objectGenerator();
else
throw new InvalidOperationException("MaxPoolSize reached");
}
}

public void PutObject(T item)
{
_objects.Add(item);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution
.editorconfig = .editorconfig
EndProjectSection
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TestObjectPoolingConsoleApp", "src\TestObjectPoolingConsoleApp\TestObjectPoolingConsoleApp.csproj", "{CF3DE8C7-81D6-4B2B-A2F0-82D15701F10A}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand All @@ -39,6 +41,10 @@ Global
{F5DC33CF-35B3-45DD-A4A2-977DEA38060A}.Debug|Any CPU.Build.0 = Debug|Any CPU
{F5DC33CF-35B3-45DD-A4A2-977DEA38060A}.Release|Any CPU.ActiveCfg = Release|Any CPU
{F5DC33CF-35B3-45DD-A4A2-977DEA38060A}.Release|Any CPU.Build.0 = Release|Any CPU
{CF3DE8C7-81D6-4B2B-A2F0-82D15701F10A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{CF3DE8C7-81D6-4B2B-A2F0-82D15701F10A}.Debug|Any CPU.Build.0 = Debug|Any CPU
{CF3DE8C7-81D6-4B2B-A2F0-82D15701F10A}.Release|Any CPU.ActiveCfg = Release|Any CPU
{CF3DE8C7-81D6-4B2B-A2F0-82D15701F10A}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand All @@ -47,6 +53,7 @@ Global
{29DB8569-F5D6-4190-9DF4-8D18CA0AABA8} = {F395612F-24C7-4666-90B2-62E417033B4B}
{5AB1C510-FEF6-4930-AE05-D16AF802084D} = {F395612F-24C7-4666-90B2-62E417033B4B}
{F5DC33CF-35B3-45DD-A4A2-977DEA38060A} = {B3AF01E5-D172-47F9-991E-A85504958F43}
{CF3DE8C7-81D6-4B2B-A2F0-82D15701F10A} = {F395612F-24C7-4666-90B2-62E417033B4B}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {1E47A71B-4F99-48EA-9267-DEE93B23BA31}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

namespace TestObjectPoolingConsoleApp.DataStructures
{
/// <summary>
/// This is the input to the trained model.
/// </summary>
public class CountryData
{
// next,country,year,month,max,min,std,count,sales,med,prev
public CountryData(string country, int year, int month, float max, float min, float std, int count, float sales, float med, float prev)
{
this.country = country;

this.year = year;
this.month = month;
this.max = max;
this.min = min;
this.std = std;
this.count = count;
this.sales = sales;
this.med = med;
this.prev = prev;
}

public float next;

public string country;

public float year;
public float month;
public float max;
public float min;
public float std;
public float count;
public float sales;
public float med;
public float prev;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

namespace TestObjectPoolingConsoleApp.DataStructures
{
/// <summary>
/// This is the output of the scored model, the prediction.
/// </summary>
public class CountrySalesPrediction
{
// Below columns are produced by the model's predictor.
public float Score;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

namespace TestObjectPoolingConsoleApp.DataStructures
{
/// <summary>
/// This is the input to the trained model.
/// </summary>
public class ProductData
{
// next,productId,year,month,units,avg,count,max,min,prev
public ProductData(string productId, int year, int month, float units, float avg,
int count, float max, float min, float prev)
{
this.productId = productId;
this.year = year;
this.month = month;
this.units = units;
this.avg = avg;
this.count = count;
this.max = max;
this.min = min;
this.prev = prev;
}

public float next;

public string productId;

public float year;
public float month;
public float units;
public float avg;
public float count;
public float max;
public float min;
public float prev;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

namespace TestObjectPoolingConsoleApp.DataStructures
{
/// <summary>
/// This is the output of the scored model, the prediction.
/// </summary>
public class ProductUnitPrediction
{
// Below columns are produced by the model's predictor.
public float Score;
}

}
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
using Common;
using Microsoft.ML;
using System;
using System.Threading;
using System.Threading.Tasks;

using TestObjectPoolingConsoleApp.DataStructures;

namespace TestObjectPoolingConsoleApp
{
class Program
{
static void Main(string[] args)
{
CancellationTokenSource cts = new CancellationTokenSource();

// Create an opportunity for the user to cancel.
Task.Run(() =>
{
if (Console.ReadKey().KeyChar == 'c' || Console.ReadKey().KeyChar == 'C')
cts.Cancel();
});

MLContext mlContext = new MLContext(seed:1);
string modelFolder = $"Forecast/ModelFiles";
string modelFilePathName = $"ModelFiles/country_month_fastTreeTweedie.zip";
var countrySalesModel = new MLModelEngine<CountryData, CountrySalesPrediction>(mlContext,
modelFilePathName,
minPredictionEngineObjectsInPool: 2);

Console.WriteLine("Current number of objects in pool: {0:####.####}", countrySalesModel.CurrentPredictionEnginePoolSize);

//Single Prediction
var singleCountrySample = new CountryData("Australia", 2017, 1, 477, 164, 2486, 9, 10345, 281, 1029);
var singleNextMonthPrediction = countrySalesModel.Predict(singleCountrySample);

Console.WriteLine("Prediction: {0:####.####}", singleNextMonthPrediction.Score);

// Create a high demand for the modelEngine objects.
Parallel.For(0, 1000000, (i, loopState) =>
{
//Sample country data
//next,country,year,month,max,min,std,count,sales,med,prev
//4.23056080166201,Australia,2017,1,477.34,164.916,2486.1346772137,9,10345.71,281.7,1029.11

var countrySample = new CountryData("Australia", 2017, 1, 477, 164, 2486, 9, 10345, 281, i);

// This is the bottleneck in our application. All threads in this loop
// must serialize their access to the static Console class.
Console.CursorLeft = 0;
var nextMonthPrediction = countrySalesModel.Predict(countrySample);

Console.WriteLine("Prediction: {0:####.####}", nextMonthPrediction.Score);
Console.WriteLine("-----------------------------------------");
Console.WriteLine("Current number of objects in pool: {0:####.####}", countrySalesModel.CurrentPredictionEnginePoolSize);

if (cts.Token.IsCancellationRequested)
loopState.Stop();

});

Console.WriteLine("-----------------------------------------");
Console.WriteLine("Current number of objects in pool: {0:####.####}", countrySalesModel.CurrentPredictionEnginePoolSize);


Console.WriteLine("Press the Enter key to exit.");
Console.ReadLine();
cts.Dispose();
}

}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>netcoreapp2.1</TargetFramework>
</PropertyGroup>

<ItemGroup>
<Compile Include="..\..\..\..\common\MLModelEngine.cs" Link="Common\MLModelEngine.cs" />
<Compile Include="..\..\..\..\common\ObjectPool.cs" Link="Common\ObjectPool.cs" />
</ItemGroup>

<ItemGroup>
<Folder Include="Common\" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.ML" Version="$(MicrosoftMLVersion)" />
</ItemGroup>

<ItemGroup>
<None Update="ModelFiles\country_month_fastTreeTweedie.zip">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
<None Update="ModelFiles\product_month_fastTreeTweedie.zip">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
</ItemGroup>

</Project>
Loading