Skip to content

Commit db06f02

Browse files
Moved MLModelEngine to Common and created a sample Console app for testing object pooling
1 parent ac4664c commit db06f02

15 files changed

+254
-15
lines changed
Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
using Microsoft.ML.Runtime.Data;
33
using System.IO;
44
using Microsoft.ML;
5-
using Microsoft.Extensions.Configuration;
5+
//using Microsoft.Extensions.Configuration;
66

7-
namespace eShopDashboard.Forecast
7+
namespace Common
88
{
9-
public class MLModel<TData, TPrediction>
9+
public class MLModelEngine<TData, TPrediction>
1010
where TData : class
1111
where TPrediction : class, new()
1212
{
@@ -16,8 +16,13 @@ public class MLModel<TData, TPrediction>
1616
private readonly int _minPredictionEngineObjectsInPool;
1717
private readonly int _maxPredictionEngineObjectsInPool;
1818

19+
public int CurrentPredictionEnginePoolSize
20+
{
21+
get { return _predictionEnginePool.CurrentPoolSize; }
22+
}
23+
1924
//Constructor with modelFilePathName to load
20-
public MLModel(MLContext mlContext, string modelFilePathName, int minPredictionEngineObjectsInPool = 10, int maxPredictionEngineObjectsInPool = 1000)
25+
public MLModelEngine(MLContext mlContext, string modelFilePathName, int minPredictionEngineObjectsInPool = 5, int maxPredictionEngineObjectsInPool = 1000)
2126
{
2227
_mlContext = mlContext;
2328

@@ -35,7 +40,7 @@ public MLModel(MLContext mlContext, string modelFilePathName, int minPredictionE
3540
}
3641

3742
//Constructor with ITransformer model already created
38-
public MLModel(MLContext mlContext, ITransformer model, int minPredictionEngineObjectsInPool = 10, int maxPredictionEngineObjectsInPool = 1000)
43+
public MLModelEngine(MLContext mlContext, ITransformer model, int minPredictionEngineObjectsInPool = 5, int maxPredictionEngineObjectsInPool = 1000)
3944
{
4045
_mlContext = mlContext;
4146
_model = model;
Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,20 @@
44
using System.Linq;
55
using System.Threading.Tasks;
66

7-
namespace eShopDashboard.Forecast
7+
namespace Common
88
{
99
public class ObjectPool<T>
1010
{
1111
private ConcurrentBag<T> _objects;
1212
private Func<T> _objectGenerator;
1313
private int _maxPoolSize;
1414

15-
public ObjectPool(Func<T> objectGenerator, int minPoolSize = 5, int maxPoolSize = 5000)
15+
public int CurrentPoolSize
16+
{
17+
get { return _objects.Count; }
18+
}
19+
20+
public ObjectPool(Func<T> objectGenerator, int minPoolSize = 5, int maxPoolSize = 50000)
1621
{
1722
if (objectGenerator == null) throw new ArgumentNullException("objectGenerator");
1823
_objects = new ConcurrentBag<T>();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+

2+
namespace TestObjectPoolingConsoleApp.DataStructures
3+
{
4+
/// <summary>
5+
/// This is the input to the trained model.
6+
/// </summary>
7+
public class CountryData
8+
{
9+
// next,country,year,month,max,min,std,count,sales,med,prev
10+
public CountryData(string country, int year, int month, float max, float min, float std, int count, float sales, float med, float prev)
11+
{
12+
this.country = country;
13+
14+
this.year = year;
15+
this.month = month;
16+
this.max = max;
17+
this.min = min;
18+
this.std = std;
19+
this.count = count;
20+
this.sales = sales;
21+
this.med = med;
22+
this.prev = prev;
23+
}
24+
25+
public float next;
26+
27+
public string country;
28+
29+
public float year;
30+
public float month;
31+
public float max;
32+
public float min;
33+
public float std;
34+
public float count;
35+
public float sales;
36+
public float med;
37+
public float prev;
38+
}
39+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+

2+
namespace TestObjectPoolingConsoleApp.DataStructures
3+
{
4+
/// <summary>
5+
/// This is the output of the scored model, the prediction.
6+
/// </summary>
7+
public class CountrySalesPrediction
8+
{
9+
// Below columns are produced by the model's predictor.
10+
public float Score;
11+
}
12+
13+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+

2+
namespace TestObjectPoolingConsoleApp.DataStructures
3+
{
4+
/// <summary>
5+
/// This is the input to the trained model.
6+
/// </summary>
7+
public class ProductData
8+
{
9+
// next,productId,year,month,units,avg,count,max,min,prev
10+
public ProductData(string productId, int year, int month, float units, float avg,
11+
int count, float max, float min, float prev)
12+
{
13+
this.productId = productId;
14+
this.year = year;
15+
this.month = month;
16+
this.units = units;
17+
this.avg = avg;
18+
this.count = count;
19+
this.max = max;
20+
this.min = min;
21+
this.prev = prev;
22+
}
23+
24+
public float next;
25+
26+
public string productId;
27+
28+
public float year;
29+
public float month;
30+
public float units;
31+
public float avg;
32+
public float count;
33+
public float max;
34+
public float min;
35+
public float prev;
36+
}
37+
38+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+

2+
namespace TestObjectPoolingConsoleApp.DataStructures
3+
{
4+
/// <summary>
5+
/// This is the output of the scored model, the prediction.
6+
/// </summary>
7+
public class ProductUnitPrediction
8+
{
9+
// Below columns are produced by the model's predictor.
10+
public float Score;
11+
}
12+
13+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
using Common;
2+
using Microsoft.ML;
3+
using System;
4+
using System.Threading;
5+
using System.Threading.Tasks;
6+
7+
using TestObjectPoolingConsoleApp.DataStructures;
8+
9+
namespace TestObjectPoolingConsoleApp
10+
{
11+
class Program
12+
{
13+
static void Main(string[] args)
14+
{
15+
CancellationTokenSource cts = new CancellationTokenSource();
16+
17+
// Create an opportunity for the user to cancel.
18+
Task.Run(() =>
19+
{
20+
if (Console.ReadKey().KeyChar == 'c' || Console.ReadKey().KeyChar == 'C')
21+
cts.Cancel();
22+
});
23+
24+
MLContext mlContext = new MLContext(seed:1);
25+
string modelFolder = $"Forecast/ModelFiles";
26+
string modelFilePathName = $"ModelFiles/country_month_fastTreeTweedie.zip";
27+
var countrySalesModel = new MLModelEngine<CountryData, CountrySalesPrediction>(mlContext,
28+
modelFilePathName,
29+
minPredictionEngineObjectsInPool: 2);
30+
31+
Console.WriteLine("Current number of objects in pool: {0:####.####}", countrySalesModel.CurrentPredictionEnginePoolSize);
32+
33+
//Single Prediction
34+
var singleCountrySample = new CountryData("Australia", 2017, 1, 477, 164, 2486, 9, 10345, 281, 1029);
35+
var singleNextMonthPrediction = countrySalesModel.Predict(singleCountrySample);
36+
37+
Console.WriteLine("Prediction: {0:####.####}", singleNextMonthPrediction.Score);
38+
39+
// Create a high demand for the modelEngine objects.
40+
Parallel.For(0, 1000000, (i, loopState) =>
41+
{
42+
//Sample country data
43+
//next,country,year,month,max,min,std,count,sales,med,prev
44+
//4.23056080166201,Australia,2017,1,477.34,164.916,2486.1346772137,9,10345.71,281.7,1029.11
45+
46+
var countrySample = new CountryData("Australia", 2017, 1, 477, 164, 2486, 9, 10345, 281, i);
47+
48+
// This is the bottleneck in our application. All threads in this loop
49+
// must serialize their access to the static Console class.
50+
Console.CursorLeft = 0;
51+
var nextMonthPrediction = countrySalesModel.Predict(countrySample);
52+
53+
Console.WriteLine("Prediction: {0:####.####}", nextMonthPrediction.Score);
54+
Console.WriteLine("-----------------------------------------");
55+
Console.WriteLine("Current number of objects in pool: {0:####.####}", countrySalesModel.CurrentPredictionEnginePoolSize);
56+
57+
if (cts.Token.IsCancellationRequested)
58+
loopState.Stop();
59+
60+
});
61+
62+
Console.WriteLine("-----------------------------------------");
63+
Console.WriteLine("Current number of objects in pool: {0:####.####}", countrySalesModel.CurrentPredictionEnginePoolSize);
64+
65+
66+
Console.WriteLine("Press the Enter key to exit.");
67+
Console.ReadLine();
68+
cts.Dispose();
69+
}
70+
71+
}
72+
73+
74+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<OutputType>Exe</OutputType>
5+
<TargetFramework>netcoreapp2.1</TargetFramework>
6+
</PropertyGroup>
7+
8+
<ItemGroup>
9+
<Compile Include="..\..\..\..\common\MLModelEngine.cs" Link="Common\MLModelEngine.cs" />
10+
<Compile Include="..\..\..\..\common\ObjectPool.cs" Link="Common\ObjectPool.cs" />
11+
</ItemGroup>
12+
13+
<ItemGroup>
14+
<Folder Include="Common\" />
15+
</ItemGroup>
16+
17+
<ItemGroup>
18+
<PackageReference Include="Microsoft.ML" Version="$(MicrosoftMLVersion)" />
19+
</ItemGroup>
20+
21+
<ItemGroup>
22+
<None Update="ModelFiles\country_month_fastTreeTweedie.zip">
23+
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
24+
</None>
25+
<None Update="ModelFiles\product_month_fastTreeTweedie.zip">
26+
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
27+
</None>
28+
</ItemGroup>
29+
30+
</Project>

samples/csharp/end-to-end-apps/Regression-SalesForecast/src/eShopDashboard/Controllers/CountrySalesForecastController.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,20 @@
1111
using Microsoft.ML.Runtime.Data;
1212
using Serilog;
1313

14+
using Common;
15+
1416
namespace eShopDashboard.Controllers
1517
{
1618
[Produces("application/json")]
1719
[Route("api/countrysalesforecast")]
1820
public class CountrySalesForecastController : Controller
1921
{
2022
private readonly AppSettings appSettings;
21-
private readonly MLModel<CountryData, CountrySalesPrediction> countrySalesModel;
23+
private readonly MLModelEngine<CountryData, CountrySalesPrediction> countrySalesModel;
2224
private readonly ILogger<CountrySalesForecastController> logger;
2325

2426
public CountrySalesForecastController(IOptionsSnapshot<AppSettings> appSettings,
25-
MLModel<CountryData, CountrySalesPrediction> countrySalesModel,
27+
MLModelEngine<CountryData, CountrySalesPrediction> countrySalesModel,
2628
ILogger<CountrySalesForecastController> logger)
2729
{
2830
this.appSettings = appSettings.Value;

samples/csharp/end-to-end-apps/Regression-SalesForecast/src/eShopDashboard/Controllers/ProductDemandForecastController.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,19 @@
1010
using Microsoft.ML.Runtime.Data;
1111
using Serilog;
1212

13+
using Common;
14+
1315
namespace eShopDashboard.Controllers
1416
{
1517
[Produces("application/json")]
1618
[Route("api/productdemandforecast")]
1719
public class ProductDemandForecastController : Controller
1820
{
1921
private readonly AppSettings appSettings;
20-
private readonly MLModel<ProductData, ProductUnitPrediction> productSalesModel;
22+
private readonly MLModelEngine<ProductData, ProductUnitPrediction> productSalesModel;
2123

2224
public ProductDemandForecastController(IOptionsSnapshot<AppSettings> appSettings,
23-
MLModel<ProductData, ProductUnitPrediction> productSalesModel)
25+
MLModelEngine<ProductData, ProductUnitPrediction> productSalesModel)
2426
{
2527
this.appSettings = appSettings.Value;
2628

samples/csharp/end-to-end-apps/Regression-SalesForecast/src/eShopDashboard/Startup.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
using Microsoft.ML.Runtime.Data;
1414
using Serilog;
1515

16+
using Common;
17+
1618
namespace eShopDashboard
1719
{
1820
public class Startup
@@ -48,20 +50,20 @@ public void ConfigureServices(IServiceCollection services)
4850
return new MLContext(seed: 1);
4951
});
5052

51-
services.AddSingleton <MLModel<ProductData, ProductUnitPrediction>>((ctx) =>
53+
services.AddSingleton <MLModelEngine<ProductData, ProductUnitPrediction>>((ctx) =>
5254
{
5355
MLContext mlContext = ctx.GetRequiredService<MLContext>();
5456
string modelFolder = Configuration["ForecastModelsPath"];
5557
string modelFilePathName = $"{modelFolder}/product_month_fastTreeTweedie.zip";
56-
return new MLModel<ProductData, ProductUnitPrediction>(mlContext, modelFilePathName);
58+
return new MLModelEngine<ProductData, ProductUnitPrediction>(mlContext, modelFilePathName);
5759
});
5860

59-
services.AddSingleton<MLModel<CountryData, CountrySalesPrediction>>((ctx) =>
61+
services.AddSingleton<MLModelEngine<CountryData, CountrySalesPrediction>>((ctx) =>
6062
{
6163
MLContext mlContext = ctx.GetRequiredService<MLContext>();
6264
string modelFolder = Configuration["ForecastModelsPath"];
6365
string modelFilePathName = $"{modelFolder}/country_month_fastTreeTweedie.zip";
64-
return new MLModel<CountryData, CountrySalesPrediction>(mlContext, modelFilePathName, minPredictionEngineObjectsInPool:50);
66+
return new MLModelEngine<CountryData, CountrySalesPrediction>(mlContext, modelFilePathName, minPredictionEngineObjectsInPool:50);
6567
});
6668

6769
services.Configure<CatalogSettings>(Configuration.GetSection("CatalogSettings"));

samples/csharp/end-to-end-apps/Regression-SalesForecast/src/eShopDashboard/eShopDashboard.csproj

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
<ItemGroup>
2525
<Folder Include="Infrastructure\Migrations\Ordering\" />
26+
<Folder Include="Common\" />
2627
<Folder Include="ReportsContext\" />
2728
</ItemGroup>
2829
<ItemGroup>
@@ -32,6 +33,10 @@
3233
<ItemGroup>
3334
<None Remove="ProductImages\coming_soon.png" />
3435
</ItemGroup>
36+
<ItemGroup>
37+
<Compile Include="..\..\..\..\common\MLModelEngine.cs" Link="Common\MLModelEngine.cs" />
38+
<Compile Include="..\..\..\..\common\ObjectPool.cs" Link="Common\ObjectPool.cs" />
39+
</ItemGroup>
3540
<ItemGroup>
3641
<None Update="Forecast\ModelFiles\country_month_fastTreeTweedie.zip">
3742
<CopyToOutputDirectory>Always</CopyToOutputDirectory>

0 commit comments

Comments
 (0)