Skip to content

Commit 2674426

Browse files
Merge pull request #184 from CESARDELATORRE/features/objectpooling
Added object pooling for PredictionFunction/PredictionEngine to eShopDashboardML sample
2 parents db4f723 + 651dee7 commit 2674426

18 files changed

+399
-148
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
using Microsoft.ML.Core.Data;
2+
using Microsoft.ML.Runtime.Data;
3+
using System.IO;
4+
using Microsoft.ML;
5+
//using Microsoft.Extensions.Configuration;
6+
7+
namespace Common
8+
{
9+
public class MLModelEngine<TData, TPrediction>
10+
where TData : class
11+
where TPrediction : class, new()
12+
{
13+
private readonly MLContext _mlContext;
14+
private readonly ITransformer _model;
15+
private readonly ObjectPool<PredictionFunction<TData, TPrediction>> _predictionEnginePool;
16+
private readonly int _minPredictionEngineObjectsInPool;
17+
private readonly int _maxPredictionEngineObjectsInPool;
18+
19+
public int CurrentPredictionEnginePoolSize
20+
{
21+
get { return _predictionEnginePool.CurrentPoolSize; }
22+
}
23+
24+
//Constructor with modelFilePathName to load
25+
public MLModelEngine(MLContext mlContext, string modelFilePathName, int minPredictionEngineObjectsInPool = 5, int maxPredictionEngineObjectsInPool = 1000)
26+
{
27+
_mlContext = mlContext;
28+
29+
//Load the ProductSalesForecast model from the .ZIP file
30+
using (var fileStream = File.OpenRead(modelFilePathName))
31+
{
32+
_model = mlContext.Model.Load(fileStream);
33+
}
34+
35+
_minPredictionEngineObjectsInPool = minPredictionEngineObjectsInPool;
36+
_maxPredictionEngineObjectsInPool = maxPredictionEngineObjectsInPool;
37+
38+
//Create PredictionEngine Object Pool
39+
_predictionEnginePool = CreatePredictionEngineObjectPool();
40+
}
41+
42+
//Constructor with ITransformer model already created
43+
public MLModelEngine(MLContext mlContext, ITransformer model, int minPredictionEngineObjectsInPool = 5, int maxPredictionEngineObjectsInPool = 1000)
44+
{
45+
_mlContext = mlContext;
46+
_model = model;
47+
_minPredictionEngineObjectsInPool = minPredictionEngineObjectsInPool;
48+
_maxPredictionEngineObjectsInPool = maxPredictionEngineObjectsInPool;
49+
50+
//Create PredictionEngine Object Pool
51+
_predictionEnginePool = CreatePredictionEngineObjectPool();
52+
}
53+
54+
private ObjectPool<PredictionFunction<TData, TPrediction>> CreatePredictionEngineObjectPool()
55+
{
56+
return new ObjectPool<PredictionFunction<TData, TPrediction>>(() => _model.MakePredictionFunction<TData, TPrediction>(_mlContext),
57+
_minPredictionEngineObjectsInPool,
58+
_maxPredictionEngineObjectsInPool);
59+
}
60+
61+
public TPrediction Predict(TData dataSample)
62+
{
63+
//Get PredictionEngine object from the Object Pool
64+
PredictionFunction<TData, TPrediction> predictionEngine = _predictionEnginePool.GetObject();
65+
66+
//Predict
67+
TPrediction prediction = predictionEngine.Predict(dataSample);
68+
69+
//Release used PredictionEngine object into the Object Pool
70+
_predictionEnginePool.PutObject(predictionEngine);
71+
72+
return prediction;
73+
}
74+
75+
}
76+
}

samples/csharp/common/ObjectPool.cs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
using System;
2+
using System.Collections.Concurrent;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Threading.Tasks;
6+
7+
namespace Common
8+
{
9+
public class ObjectPool<T>
10+
{
11+
private ConcurrentBag<T> _objects;
12+
private Func<T> _objectGenerator;
13+
private int _maxPoolSize;
14+
15+
public int CurrentPoolSize
16+
{
17+
get { return _objects.Count; }
18+
}
19+
20+
public ObjectPool(Func<T> objectGenerator, int minPoolSize = 5, int maxPoolSize = 50000)
21+
{
22+
if (objectGenerator == null) throw new ArgumentNullException("objectGenerator");
23+
_objects = new ConcurrentBag<T>();
24+
_objectGenerator = objectGenerator;
25+
_maxPoolSize = maxPoolSize;
26+
27+
//Measure total time of minimum objects creation
28+
var watch = System.Diagnostics.Stopwatch.StartNew();
29+
30+
//Create minimum number of objects in pool
31+
for (int i = 0; i < minPoolSize; i++)
32+
{
33+
_objects.Add(_objectGenerator());
34+
}
35+
36+
//Stop measuring time
37+
watch.Stop();
38+
long elapsedMs = watch.ElapsedMilliseconds;
39+
}
40+
41+
public T GetObject()
42+
{
43+
T item;
44+
if (_objects.TryTake(out item))
45+
{
46+
return item;
47+
}
48+
else
49+
{
50+
if(_objects.Count <= _maxPoolSize)
51+
return _objectGenerator();
52+
else
53+
throw new InvalidOperationException("MaxPoolSize reached");
54+
}
55+
}
56+
57+
public void PutObject(T item)
58+
{
59+
_objects.Add(item);
60+
}
61+
}
62+
}

samples/csharp/end-to-end-apps/Regression-SalesForecast/eShopDashboardML.sln

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution
2121
.editorconfig = .editorconfig
2222
EndProjectSection
2323
EndProject
24+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TestObjectPoolingConsoleApp", "src\TestObjectPoolingConsoleApp\TestObjectPoolingConsoleApp.csproj", "{CF3DE8C7-81D6-4B2B-A2F0-82D15701F10A}"
25+
EndProject
2426
Global
2527
GlobalSection(SolutionConfigurationPlatforms) = preSolution
2628
Debug|Any CPU = Debug|Any CPU
@@ -39,6 +41,10 @@ Global
3941
{F5DC33CF-35B3-45DD-A4A2-977DEA38060A}.Debug|Any CPU.Build.0 = Debug|Any CPU
4042
{F5DC33CF-35B3-45DD-A4A2-977DEA38060A}.Release|Any CPU.ActiveCfg = Release|Any CPU
4143
{F5DC33CF-35B3-45DD-A4A2-977DEA38060A}.Release|Any CPU.Build.0 = Release|Any CPU
44+
{CF3DE8C7-81D6-4B2B-A2F0-82D15701F10A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
45+
{CF3DE8C7-81D6-4B2B-A2F0-82D15701F10A}.Debug|Any CPU.Build.0 = Debug|Any CPU
46+
{CF3DE8C7-81D6-4B2B-A2F0-82D15701F10A}.Release|Any CPU.ActiveCfg = Release|Any CPU
47+
{CF3DE8C7-81D6-4B2B-A2F0-82D15701F10A}.Release|Any CPU.Build.0 = Release|Any CPU
4248
EndGlobalSection
4349
GlobalSection(SolutionProperties) = preSolution
4450
HideSolutionNode = FALSE
@@ -47,6 +53,7 @@ Global
4753
{29DB8569-F5D6-4190-9DF4-8D18CA0AABA8} = {F395612F-24C7-4666-90B2-62E417033B4B}
4854
{5AB1C510-FEF6-4930-AE05-D16AF802084D} = {F395612F-24C7-4666-90B2-62E417033B4B}
4955
{F5DC33CF-35B3-45DD-A4A2-977DEA38060A} = {B3AF01E5-D172-47F9-991E-A85504958F43}
56+
{CF3DE8C7-81D6-4B2B-A2F0-82D15701F10A} = {F395612F-24C7-4666-90B2-62E417033B4B}
5057
EndGlobalSection
5158
GlobalSection(ExtensibilityGlobals) = postSolution
5259
SolutionGuid = {1E47A71B-4F99-48EA-9267-DEE93B23BA31}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+

2+
namespace TestObjectPoolingConsoleApp.DataStructures
3+
{
4+
/// <summary>
5+
/// This is the input to the trained model.
6+
/// </summary>
7+
public class CountryData
8+
{
9+
// next,country,year,month,max,min,std,count,sales,med,prev
10+
public CountryData(string country, int year, int month, float max, float min, float std, int count, float sales, float med, float prev)
11+
{
12+
this.country = country;
13+
14+
this.year = year;
15+
this.month = month;
16+
this.max = max;
17+
this.min = min;
18+
this.std = std;
19+
this.count = count;
20+
this.sales = sales;
21+
this.med = med;
22+
this.prev = prev;
23+
}
24+
25+
public float next;
26+
27+
public string country;
28+
29+
public float year;
30+
public float month;
31+
public float max;
32+
public float min;
33+
public float std;
34+
public float count;
35+
public float sales;
36+
public float med;
37+
public float prev;
38+
}
39+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+

2+
namespace TestObjectPoolingConsoleApp.DataStructures
3+
{
4+
/// <summary>
5+
/// This is the output of the scored model, the prediction.
6+
/// </summary>
7+
public class CountrySalesPrediction
8+
{
9+
// Below columns are produced by the model's predictor.
10+
public float Score;
11+
}
12+
13+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+

2+
namespace TestObjectPoolingConsoleApp.DataStructures
3+
{
4+
/// <summary>
5+
/// This is the input to the trained model.
6+
/// </summary>
7+
public class ProductData
8+
{
9+
// next,productId,year,month,units,avg,count,max,min,prev
10+
public ProductData(string productId, int year, int month, float units, float avg,
11+
int count, float max, float min, float prev)
12+
{
13+
this.productId = productId;
14+
this.year = year;
15+
this.month = month;
16+
this.units = units;
17+
this.avg = avg;
18+
this.count = count;
19+
this.max = max;
20+
this.min = min;
21+
this.prev = prev;
22+
}
23+
24+
public float next;
25+
26+
public string productId;
27+
28+
public float year;
29+
public float month;
30+
public float units;
31+
public float avg;
32+
public float count;
33+
public float max;
34+
public float min;
35+
public float prev;
36+
}
37+
38+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+

2+
namespace TestObjectPoolingConsoleApp.DataStructures
3+
{
4+
/// <summary>
5+
/// This is the output of the scored model, the prediction.
6+
/// </summary>
7+
public class ProductUnitPrediction
8+
{
9+
// Below columns are produced by the model's predictor.
10+
public float Score;
11+
}
12+
13+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
using Common;
2+
using Microsoft.ML;
3+
using System;
4+
using System.Threading;
5+
using System.Threading.Tasks;
6+
7+
using TestObjectPoolingConsoleApp.DataStructures;
8+
9+
namespace TestObjectPoolingConsoleApp
10+
{
11+
class Program
12+
{
13+
static void Main(string[] args)
14+
{
15+
CancellationTokenSource cts = new CancellationTokenSource();
16+
17+
// Create an opportunity for the user to cancel.
18+
Task.Run(() =>
19+
{
20+
if (Console.ReadKey().KeyChar == 'c' || Console.ReadKey().KeyChar == 'C')
21+
cts.Cancel();
22+
});
23+
24+
MLContext mlContext = new MLContext(seed:1);
25+
string modelFolder = $"Forecast/ModelFiles";
26+
string modelFilePathName = $"ModelFiles/country_month_fastTreeTweedie.zip";
27+
var countrySalesModel = new MLModelEngine<CountryData, CountrySalesPrediction>(mlContext,
28+
modelFilePathName,
29+
minPredictionEngineObjectsInPool: 2);
30+
31+
Console.WriteLine("Current number of objects in pool: {0:####.####}", countrySalesModel.CurrentPredictionEnginePoolSize);
32+
33+
//Single Prediction
34+
var singleCountrySample = new CountryData("Australia", 2017, 1, 477, 164, 2486, 9, 10345, 281, 1029);
35+
var singleNextMonthPrediction = countrySalesModel.Predict(singleCountrySample);
36+
37+
Console.WriteLine("Prediction: {0:####.####}", singleNextMonthPrediction.Score);
38+
39+
// Create a high demand for the modelEngine objects.
40+
Parallel.For(0, 1000000, (i, loopState) =>
41+
{
42+
//Sample country data
43+
//next,country,year,month,max,min,std,count,sales,med,prev
44+
//4.23056080166201,Australia,2017,1,477.34,164.916,2486.1346772137,9,10345.71,281.7,1029.11
45+
46+
var countrySample = new CountryData("Australia", 2017, 1, 477, 164, 2486, 9, 10345, 281, i);
47+
48+
// This is the bottleneck in our application. All threads in this loop
49+
// must serialize their access to the static Console class.
50+
Console.CursorLeft = 0;
51+
var nextMonthPrediction = countrySalesModel.Predict(countrySample);
52+
53+
Console.WriteLine("Prediction: {0:####.####}", nextMonthPrediction.Score);
54+
Console.WriteLine("-----------------------------------------");
55+
Console.WriteLine("Current number of objects in pool: {0:####.####}", countrySalesModel.CurrentPredictionEnginePoolSize);
56+
57+
if (cts.Token.IsCancellationRequested)
58+
loopState.Stop();
59+
60+
});
61+
62+
Console.WriteLine("-----------------------------------------");
63+
Console.WriteLine("Current number of objects in pool: {0:####.####}", countrySalesModel.CurrentPredictionEnginePoolSize);
64+
65+
66+
Console.WriteLine("Press the Enter key to exit.");
67+
Console.ReadLine();
68+
cts.Dispose();
69+
}
70+
71+
}
72+
73+
74+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<OutputType>Exe</OutputType>
5+
<TargetFramework>netcoreapp2.1</TargetFramework>
6+
</PropertyGroup>
7+
8+
<ItemGroup>
9+
<Compile Include="..\..\..\..\common\MLModelEngine.cs" Link="Common\MLModelEngine.cs" />
10+
<Compile Include="..\..\..\..\common\ObjectPool.cs" Link="Common\ObjectPool.cs" />
11+
</ItemGroup>
12+
13+
<ItemGroup>
14+
<Folder Include="Common\" />
15+
</ItemGroup>
16+
17+
<ItemGroup>
18+
<PackageReference Include="Microsoft.ML" Version="$(MicrosoftMLVersion)" />
19+
</ItemGroup>
20+
21+
<ItemGroup>
22+
<None Update="ModelFiles\country_month_fastTreeTweedie.zip">
23+
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
24+
</None>
25+
<None Update="ModelFiles\product_month_fastTreeTweedie.zip">
26+
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
27+
</None>
28+
</ItemGroup>
29+
30+
</Project>

0 commit comments

Comments
 (0)