Skip to content

Commit e50c4d2

Browse files
authored
Merge pull request #3882 from Dmitry-A/master
[AutoML] bring AutoML API library to master
2 parents a997cd6 + 3db7d98 commit e50c4d2

File tree

135 files changed

+13618
-88
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

135 files changed

+13618
-88
lines changed

Microsoft.ML.sln

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Microsoft Visual Studio Solution File, Format Version 12.00
2-
# Visual Studio 15
3-
VisualStudioVersion = 15.0.27130.2026
2+
# Visual Studio Version 16
3+
VisualStudioVersion = 16.0.29209.152
44
MinimumVisualStudioVersion = 10.0.40219.1
55
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Core", "src\Microsoft.ML.Core\Microsoft.ML.Core.csproj", "{A6CA6CC6-5D7C-4D7F-A0F5-35E14B383B0A}"
66
EndProject
@@ -264,6 +264,18 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StableApi", "t
264264
EndProject
265265
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Dnn", "src\Microsoft.ML.Dnn\Microsoft.ML.Dnn.csproj", "{4C2D1A8F-7AC1-4036-B5E3-4B31769D73B8}"
266266
EndProject
267+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.AutoML.Tests", "test\Microsoft.ML.AutoML.Tests\Microsoft.ML.AutoML.Tests.csproj", "{C2652287-CD6D-40FB-B042-95FB56D09DB8}"
268+
EndProject
269+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.AutoML", "src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj", "{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}"
270+
EndProject
271+
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.AutoML", "Microsoft.ML.AutoML", "{F5D11F71-2D61-4AE9-99D7-0F0B54649B15}"
272+
ProjectSection(SolutionItems) = preProject
273+
pkg\Microsoft.ML.AutoML\Microsoft.ML.AutoML.nupkgproj = pkg\Microsoft.ML.AutoML\Microsoft.ML.AutoML.nupkgproj
274+
pkg\Microsoft.ML.AutoML\Microsoft.ML.AutoML.symbols.nupkgproj = pkg\Microsoft.ML.AutoML\Microsoft.ML.AutoML.symbols.nupkgproj
275+
EndProjectSection
276+
EndProject
277+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.AutoML.Samples", "docs\samples\Microsoft.ML.AutoML.Samples\Microsoft.ML.AutoML.Samples.csproj", "{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}"
278+
EndProject
267279
Global
268280
GlobalSection(SolutionConfigurationPlatforms) = preSolution
269281
Debug|Any CPU = Debug|Any CPU
@@ -1528,6 +1540,77 @@ Global
15281540
{4C2D1A8F-7AC1-4036-B5E3-4B31769D73B8}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
15291541
{4C2D1A8F-7AC1-4036-B5E3-4B31769D73B8}.Release-netfx|x64.ActiveCfg = Release-netfx|Any CPU
15301542
{4C2D1A8F-7AC1-4036-B5E3-4B31769D73B8}.Release-netfx|x64.Build.0 = Release-netfx|Any CPU
1543+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
1544+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Debug|Any CPU.Build.0 = Debug|Any CPU
1545+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Debug|x64.ActiveCfg = Debug|Any CPU
1546+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Debug|x64.Build.0 = Debug|Any CPU
1547+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Debug-netcoreapp3_0|Any CPU.ActiveCfg = Debug-netcoreapp3_0|Any CPU
1548+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Debug-netcoreapp3_0|Any CPU.Build.0 = Debug-netcoreapp3_0|Any CPU
1549+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Debug-netcoreapp3_0|x64.ActiveCfg = Debug-netcoreapp3_0|Any CPU
1550+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Debug-netcoreapp3_0|x64.Build.0 = Debug-netcoreapp3_0|Any CPU
1551+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU
1552+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU
1553+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Debug-netfx|x64.ActiveCfg = Debug-netfx|Any CPU
1554+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Debug-netfx|x64.Build.0 = Debug-netfx|Any CPU
1555+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Release|Any CPU.ActiveCfg = Release|Any CPU
1556+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Release|Any CPU.Build.0 = Release|Any CPU
1557+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Release|x64.ActiveCfg = Release|Any CPU
1558+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Release|x64.Build.0 = Release|Any CPU
1559+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Release-netcoreapp3_0|Any CPU.ActiveCfg = Release-netcoreapp3_0|Any CPU
1560+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Release-netcoreapp3_0|Any CPU.Build.0 = Release-netcoreapp3_0|Any CPU
1561+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Release-netcoreapp3_0|x64.ActiveCfg = Release-netcoreapp3_0|Any CPU
1562+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Release-netcoreapp3_0|x64.Build.0 = Release-netcoreapp3_0|Any CPU
1563+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
1564+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
1565+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Release-netfx|x64.ActiveCfg = Release-netfx|Any CPU
1566+
{C2652287-CD6D-40FB-B042-95FB56D09DB8}.Release-netfx|x64.Build.0 = Release-netfx|Any CPU
1567+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
1568+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Debug|Any CPU.Build.0 = Debug|Any CPU
1569+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Debug|x64.ActiveCfg = Debug|Any CPU
1570+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Debug|x64.Build.0 = Debug|Any CPU
1571+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Debug-netcoreapp3_0|Any CPU.ActiveCfg = Debug-netcoreapp3_0|Any CPU
1572+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Debug-netcoreapp3_0|Any CPU.Build.0 = Debug-netcoreapp3_0|Any CPU
1573+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Debug-netcoreapp3_0|x64.ActiveCfg = Debug-netcoreapp3_0|Any CPU
1574+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Debug-netcoreapp3_0|x64.Build.0 = Debug-netcoreapp3_0|Any CPU
1575+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU
1576+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU
1577+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Debug-netfx|x64.ActiveCfg = Debug-netfx|Any CPU
1578+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Debug-netfx|x64.Build.0 = Debug-netfx|Any CPU
1579+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Release|Any CPU.ActiveCfg = Release|Any CPU
1580+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Release|Any CPU.Build.0 = Release|Any CPU
1581+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Release|x64.ActiveCfg = Release|Any CPU
1582+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Release|x64.Build.0 = Release|Any CPU
1583+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Release-netcoreapp3_0|Any CPU.ActiveCfg = Release-netcoreapp3_0|Any CPU
1584+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Release-netcoreapp3_0|Any CPU.Build.0 = Release-netcoreapp3_0|Any CPU
1585+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Release-netcoreapp3_0|x64.ActiveCfg = Release-netcoreapp3_0|Any CPU
1586+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Release-netcoreapp3_0|x64.Build.0 = Release-netcoreapp3_0|Any CPU
1587+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
1588+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
1589+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Release-netfx|x64.ActiveCfg = Release-netfx|Any CPU
1590+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}.Release-netfx|x64.Build.0 = Release-netfx|Any CPU
1591+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
1592+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Debug|Any CPU.Build.0 = Debug|Any CPU
1593+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Debug|x64.ActiveCfg = Debug|Any CPU
1594+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Debug|x64.Build.0 = Debug|Any CPU
1595+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Debug-netcoreapp3_0|Any CPU.ActiveCfg = Debug-netcoreapp3_0|Any CPU
1596+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Debug-netcoreapp3_0|Any CPU.Build.0 = Debug-netcoreapp3_0|Any CPU
1597+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Debug-netcoreapp3_0|x64.ActiveCfg = Debug-netcoreapp3_0|Any CPU
1598+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Debug-netcoreapp3_0|x64.Build.0 = Debug-netcoreapp3_0|Any CPU
1599+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU
1600+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU
1601+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Debug-netfx|x64.ActiveCfg = Debug-netfx|Any CPU
1602+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Debug-netfx|x64.Build.0 = Debug-netfx|Any CPU
1603+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Release|Any CPU.ActiveCfg = Release|Any CPU
1604+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Release|Any CPU.Build.0 = Release|Any CPU
1605+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Release|x64.ActiveCfg = Release|Any CPU
1606+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Release|x64.Build.0 = Release|Any CPU
1607+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Release-netcoreapp3_0|Any CPU.ActiveCfg = Release-netcoreapp3_0|Any CPU
1608+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Release-netcoreapp3_0|Any CPU.Build.0 = Release-netcoreapp3_0|Any CPU
1609+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Release-netcoreapp3_0|x64.ActiveCfg = Release-netcoreapp3_0|Any CPU
1610+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Release-netcoreapp3_0|x64.Build.0 = Release-netcoreapp3_0|Any CPU
1611+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
1612+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
1613+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}.Release-netfx|x64.ActiveCfg = Release-netfx|Any CPU
15311614
EndGlobalSection
15321615
GlobalSection(SolutionProperties) = preSolution
15331616
HideSolutionNode = FALSE
@@ -1610,6 +1693,10 @@ Global
16101693
{AE4F7569-26F3-4160-8A8B-7A57D0DA3350} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
16111694
{F308DC6B-7E59-40D7-A581-834E8CD99CFE} = {7F13E156-3EBA-4021-84A5-CD56BA72F99E}
16121695
{4C2D1A8F-7AC1-4036-B5E3-4B31769D73B8} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
1696+
{C2652287-CD6D-40FB-B042-95FB56D09DB8} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
1697+
{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
1698+
{F5D11F71-2D61-4AE9-99D7-0F0B54649B15} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
1699+
{A6924919-9E37-4023-8B7F-E85C8E3CC9B3} = {DA452A53-2E94-4433-B08C-041EDEC729E6}
16131700
EndGlobalSection
16141701
GlobalSection(ExtensibilityGlobals) = postSolution
16151702
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
using System;
2+
using System.IO;
3+
using System.Linq;
4+
using Microsoft.ML.AutoML;
5+
using Microsoft.ML.Data;
6+
7+
namespace Microsoft.ML.AutoML.Samples
8+
{
9+
public static class BinaryClassificationExperiment
10+
{
11+
private static string TrainDataPath = "<Path to your train dataset goes here>";
12+
private static string TestDataPath = "<Path to your test dataset goes here>";
13+
private static string ModelPath = @"<Desired model output directory goes here>\SentimentModel.zip";
14+
private static uint ExperimentTime = 60;
15+
16+
public static void Run()
17+
{
18+
MLContext mlContext = new MLContext();
19+
20+
// STEP 1: Load data
21+
IDataView trainDataView = mlContext.Data.LoadFromTextFile<SentimentIssue>(TrainDataPath, hasHeader: true);
22+
IDataView testDataView = mlContext.Data.LoadFromTextFile<SentimentIssue>(TestDataPath, hasHeader: true);
23+
24+
// STEP 2: Run AutoML experiment
25+
Console.WriteLine($"Running AutoML binary classification experiment for {ExperimentTime} seconds...");
26+
ExperimentResult<BinaryClassificationMetrics> experimentResult = mlContext.Auto()
27+
.CreateBinaryClassificationExperiment(ExperimentTime)
28+
.Execute(trainDataView);
29+
30+
// STEP 3: Print metric from the best model
31+
RunDetail<BinaryClassificationMetrics> bestRun = experimentResult.BestRun;
32+
Console.WriteLine($"Total models produced: {experimentResult.RunDetails.Count()}");
33+
Console.WriteLine($"Best model's trainer: {bestRun.TrainerName}");
34+
Console.WriteLine($"Metrics of best model from validation data --");
35+
PrintMetrics(bestRun.ValidationMetrics);
36+
37+
// STEP 4: Evaluate test data
38+
IDataView testDataViewWithBestScore = bestRun.Model.Transform(testDataView);
39+
BinaryClassificationMetrics testMetrics = mlContext.BinaryClassification.EvaluateNonCalibrated(testDataViewWithBestScore);
40+
Console.WriteLine($"Metrics of best model on test data --");
41+
PrintMetrics(testMetrics);
42+
43+
// STEP 5: Save the best model for later deployment and inferencing
44+
using (FileStream fs = File.Create(ModelPath))
45+
mlContext.Model.Save(bestRun.Model, trainDataView.Schema, fs);
46+
47+
// STEP 6: Create prediction engine from the best trained model
48+
var predictionEngine = mlContext.Model.CreatePredictionEngine<SentimentIssue, SentimentPrediction>(bestRun.Model);
49+
50+
// STEP 7: Initialize a new sentiment issue, and get the predicted sentiment
51+
var testSentimentIssue = new SentimentIssue
52+
{
53+
Text = "I hope this helps."
54+
};
55+
var prediction = predictionEngine.Predict(testSentimentIssue);
56+
Console.WriteLine($"Predicted sentiment for test issue: {prediction.Prediction}");
57+
58+
Console.WriteLine("Press any key to continue...");
59+
Console.ReadKey();
60+
}
61+
62+
private static void PrintMetrics(BinaryClassificationMetrics metrics)
63+
{
64+
Console.WriteLine($"Accuracy: {metrics.Accuracy}");
65+
Console.WriteLine($"AreaUnderPrecisionRecallCurve: {metrics.AreaUnderPrecisionRecallCurve}");
66+
Console.WriteLine($"AreaUnderRocCurve: {metrics.AreaUnderRocCurve}");
67+
Console.WriteLine($"F1Score: {metrics.F1Score}");
68+
Console.WriteLine($"NegativePrecision: {metrics.NegativePrecision}");
69+
Console.WriteLine($"NegativeRecall: {metrics.NegativeRecall}");
70+
Console.WriteLine($"PositivePrecision: {metrics.PositivePrecision}");
71+
Console.WriteLine($"PositiveRecall: {metrics.PositiveRecall}");
72+
}
73+
}
74+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using Microsoft.ML.Data;
2+
3+
namespace Microsoft.ML.AutoML.Samples
4+
{
5+
public class PixelData
6+
{
7+
[LoadColumn(0, 63)]
8+
[VectorType(64)]
9+
public float[] PixelValues;
10+
11+
[LoadColumn(64)]
12+
public float Number;
13+
}
14+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using Microsoft.ML.Data;
2+
3+
namespace Microsoft.ML.AutoML.Samples
4+
{
5+
public class PixelPrediction
6+
{
7+
[ColumnName("PredictedLabel")]
8+
public float Prediction;
9+
}
10+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using Microsoft.ML.Data;
2+
3+
namespace Microsoft.ML.AutoML.Samples
4+
{
5+
public class SentimentIssue
6+
{
7+
[LoadColumn(0)]
8+
public bool Label { get; set; }
9+
10+
[LoadColumn(1)]
11+
public string Text { get; set; }
12+
}
13+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using Microsoft.ML.Data;
2+
3+
namespace Microsoft.ML.AutoML.Samples
4+
{
5+
public class SentimentPrediction
6+
{
7+
// ColumnName attribute is used to change the column name from
8+
// its default value, which is the name of the field.
9+
[ColumnName("PredictedLabel")]
10+
public bool Prediction { get; set; }
11+
12+
public float Score { get; set; }
13+
}
14+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using Microsoft.ML.Data;
2+
3+
namespace Microsoft.ML.AutoML.Samples
4+
{
5+
public class TaxiTrip
6+
{
7+
[LoadColumn(0)]
8+
public string VendorId;
9+
10+
[LoadColumn(1)]
11+
public float RateCode;
12+
13+
[LoadColumn(2)]
14+
public float PassengerCount;
15+
16+
[LoadColumn(3)]
17+
public float TripTimeInSeconds;
18+
19+
[LoadColumn(4)]
20+
public float TripDistance;
21+
22+
[LoadColumn(5)]
23+
public string PaymentType;
24+
25+
[LoadColumn(6)]
26+
public float FareAmount;
27+
}
28+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using Microsoft.ML.Data;
2+
3+
namespace Microsoft.ML.AutoML.Samples
4+
{
5+
public class TaxiTripFarePrediction
6+
{
7+
[ColumnName("Score")]
8+
public float FareAmount;
9+
}
10+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<OutputType>Exe</OutputType>
5+
<TargetFramework>netcoreapp2.1</TargetFramework>
6+
</PropertyGroup>
7+
8+
<ItemGroup>
9+
<ProjectReference Include="..\..\..\src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj" />
10+
</ItemGroup>
11+
12+
</Project>
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
using System;
2+
using System.IO;
3+
using System.Linq;
4+
using Microsoft.ML.AutoML;
5+
using Microsoft.ML.Data;
6+
7+
namespace Microsoft.ML.AutoML.Samples
8+
{
9+
public static class MulticlassClassificationExperiment
10+
{
11+
private static string TrainDataPath = "<Path to your train dataset goes here>";
12+
private static string TestDataPath = "<Path to your test dataset goes here>";
13+
private static string ModelPath = @"<Desired model output directory goes here>\OptDigitsModel.zip";
14+
private static string LabelColumnName = "Number";
15+
private static uint ExperimentTime = 60;
16+
17+
public static void Run()
18+
{
19+
MLContext mlContext = new MLContext();
20+
21+
// STEP 1: Load data
22+
IDataView trainDataView = mlContext.Data.LoadFromTextFile<PixelData>(TrainDataPath, separatorChar: ',');
23+
IDataView testDataView = mlContext.Data.LoadFromTextFile<PixelData>(TestDataPath, separatorChar: ',');
24+
25+
// STEP 2: Run AutoML experiment
26+
Console.WriteLine($"Running AutoML multiclass classification experiment for {ExperimentTime} seconds...");
27+
ExperimentResult<MulticlassClassificationMetrics> experimentResult = mlContext.Auto()
28+
.CreateMulticlassClassificationExperiment(ExperimentTime)
29+
.Execute(trainDataView, LabelColumnName);
30+
31+
// STEP 3: Print metric from the best model
32+
RunDetail<MulticlassClassificationMetrics> bestRun = experimentResult.BestRun;
33+
Console.WriteLine($"Total models produced: {experimentResult.RunDetails.Count()}");
34+
Console.WriteLine($"Best model's trainer: {bestRun.TrainerName}");
35+
Console.WriteLine($"Metrics of best model from validation data --");
36+
PrintMetrics(bestRun.ValidationMetrics);
37+
38+
// STEP 4: Evaluate test data
39+
IDataView testDataViewWithBestScore = bestRun.Model.Transform(testDataView);
40+
MulticlassClassificationMetrics testMetrics = mlContext.MulticlassClassification.Evaluate(testDataViewWithBestScore, labelColumnName: LabelColumnName);
41+
Console.WriteLine($"Metrics of best model on test data --");
42+
PrintMetrics(testMetrics);
43+
44+
// STEP 5: Save the best model for later deployment and inferencing
45+
using (FileStream fs = File.Create(ModelPath))
46+
mlContext.Model.Save(bestRun.Model, trainDataView.Schema, fs);
47+
48+
// STEP 6: Create prediction engine from the best trained model
49+
var predictionEngine = mlContext.Model.CreatePredictionEngine<PixelData, PixelPrediction>(bestRun.Model);
50+
51+
// STEP 7: Initialize new pixel data, and get the predicted number
52+
var testPixelData = new PixelData
53+
{
54+
PixelValues = new float[] { 0, 0, 1, 8, 15, 10, 0, 0, 0, 3, 13, 15, 14, 14, 0, 0, 0, 5, 10, 0, 10, 12, 0, 0, 0, 0, 3, 5, 15, 10, 2, 0, 0, 0, 16, 16, 16, 16, 12, 0, 0, 1, 8, 12, 14, 8, 3, 0, 0, 0, 0, 10, 13, 0, 0, 0, 0, 0, 0, 11, 9, 0, 0, 0 }
55+
};
56+
var prediction = predictionEngine.Predict(testPixelData);
57+
Console.WriteLine($"Predicted number for test pixels: {prediction.Prediction}");
58+
59+
Console.WriteLine("Press any key to continue...");
60+
Console.ReadKey();
61+
}
62+
63+
private static void PrintMetrics(MulticlassClassificationMetrics metrics)
64+
{
65+
Console.WriteLine($"LogLoss: {metrics.LogLoss}");
66+
Console.WriteLine($"LogLossReduction: {metrics.LogLossReduction}");
67+
Console.WriteLine($"MacroAccuracy: {metrics.MacroAccuracy}");
68+
Console.WriteLine($"MicroAccuracy: {metrics.MicroAccuracy}");
69+
}
70+
}
71+
}

0 commit comments

Comments
 (0)