Skip to content

Commit cab5809

Browse files
authored
CLI ML.NET version upgrade (dotnet#345)
1 parent 3c492c4 commit cab5809

36 files changed

+489
-499
lines changed

src/Microsoft.ML.Auto/API/ExperimentBase.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public IEnumerable<RunDetails<TMetrics>> Execute(IDataView trainData, ColumnInfo
4848
{
4949
// Cross val threshold for # of dataset rows --
5050
// If dataset has < threshold # of rows, use cross val.
51-
// Else, use run experiment using train-validate split.
51+
// Else, run experiment using train-validate split.
5252
const int crossValRowCountThreshold = 15000;
5353

5454
var rowCount = DatasetDimensionsUtil.CountRows(trainData, crossValRowCountThreshold);

src/Microsoft.ML.Auto/TrainerExtensions/TrainerExtensionUtil.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ internal enum TrainerName
4646

4747
internal static class TrainerExtensionUtil
4848
{
49-
private const string WeightColumn = "WeightColumn";
50-
private const string LabelColumn = "LabelColumn";
49+
private const string WeightColumn = "ExampleWeightColumnName";
50+
private const string LabelColumn = "LabelColumnName";
5151

5252
public static T CreateOptions<T>(IEnumerable<SweepableParam> sweepParams, string labelColumn) where T : TrainerInputBaseWithLabel
5353
{

src/Test/TrainerExtensionsTests.cs

+8-8
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ public void BuildLightGbmPipelineNode()
7373
""L1Regularization"": 0.5
7474
}
7575
},
76-
""LabelColumn"": ""Label""
76+
""LabelColumnName"": ""Label""
7777
}
7878
}";
7979
Util.AssertObjectMatchesJson(expectedJson, pipelineNode);
@@ -105,7 +105,7 @@ public void BuildSdcaPipelineNode()
105105
""MaximumNumberOfIterations"": 10,
106106
""Shuffle"": true,
107107
""BiasLearningRate"": 0.01,
108-
""LabelColumn"": ""Label""
108+
""LabelColumnName"": ""Label""
109109
}
110110
}";
111111
Util.AssertObjectMatchesJson(expectedJson, pipelineNode);
@@ -127,7 +127,7 @@ public void BuildLightGbmPipelineNodeDefaultParams()
127127
""Score""
128128
],
129129
""Properties"": {
130-
""LabelColumn"": ""Label""
130+
""LabelColumnName"": ""Label""
131131
}
132132
}";
133133
Util.AssertObjectMatchesJson(expectedJson, pipelineNode);
@@ -161,8 +161,8 @@ public void BuildPipelineNodeWithCustomColumns()
161161
""NumberOfLeaves"": 1,
162162
""MinimumExampleCountPerLeaf"": 10,
163163
""NumberOfTrees"": 100,
164-
""LabelColumn"": ""L"",
165-
""WeightColumn"": ""W""
164+
""LabelColumnName"": ""L"",
165+
""ExampleWeightColumnName"": ""W""
166166
}
167167
}";
168168
Util.AssertObjectMatchesJson(expectedJson, pipelineNode);
@@ -182,7 +182,7 @@ public void BuildDefaultAveragedPerceptronPipelineNode()
182182
""Score""
183183
],
184184
""Properties"": {
185-
""LabelColumn"": ""L"",
185+
""LabelColumnName"": ""L"",
186186
""NumberOfIterations"": 10
187187
}
188188
}";
@@ -199,7 +199,7 @@ public void BuildOvaPipelineNode()
199199
""InColumns"": null,
200200
""OutColumns"": null,
201201
""Properties"": {
202-
""LabelColumn"": ""Label"",
202+
""LabelColumnName"": ""Label"",
203203
""BinaryTrainer"": {
204204
""Name"": ""FastForestBinary"",
205205
""NodeType"": ""Trainer"",
@@ -210,7 +210,7 @@ public void BuildOvaPipelineNode()
210210
""Score""
211211
],
212212
""Properties"": {
213-
""LabelColumn"": ""Label""
213+
""LabelColumnName"": ""Label""
214214
}
215215
}
216216
}

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleHelperFileContentTest.approved.txt

+18-18
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,20 @@ namespace TestNamespace.Train
2020
Console.WriteLine($"*************************************************");
2121
Console.WriteLine($"* Metrics for regression model ");
2222
Console.WriteLine($"*------------------------------------------------");
23-
Console.WriteLine($"* LossFn: {metrics.LossFn:0.##}");
23+
Console.WriteLine($"* LossFn: {metrics.LossFunction:0.##}");
2424
Console.WriteLine($"* R2 Score: {metrics.RSquared:0.##}");
25-
Console.WriteLine($"* Absolute loss: {metrics.L1:#.##}");
26-
Console.WriteLine($"* Squared loss: {metrics.L2:#.##}");
27-
Console.WriteLine($"* RMS loss: {metrics.Rms:#.##}");
25+
Console.WriteLine($"* Absolute loss: {metrics.MeanAbsoluteError:#.##}");
26+
Console.WriteLine($"* Squared loss: {metrics.MeanSquaredError:#.##}");
27+
Console.WriteLine($"* RMS loss: {metrics.RootMeanSquaredError:#.##}");
2828
Console.WriteLine($"*************************************************");
2929
}
3030

31-
public static void PrintRegressionFoldsAverageMetrics(TrainCatalogBase.CrossValidationResult<RegressionMetrics>[] crossValidationResults)
31+
public static void PrintRegressionFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<RegressionMetrics>> crossValidationResults)
3232
{
33-
var L1 = crossValidationResults.Select(r => r.Metrics.L1);
34-
var L2 = crossValidationResults.Select(r => r.Metrics.L2);
35-
var RMS = crossValidationResults.Select(r => r.Metrics.L1);
36-
var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFn);
33+
var L1 = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);
34+
var L2 = crossValidationResults.Select(r => r.Metrics.MeanSquaredError);
35+
var RMS = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);
36+
var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFunction);
3737
var R2 = crossValidationResults.Select(r => r.Metrics.RSquared);
3838

3939
Console.WriteLine($"*************************************************************************************************************");
@@ -53,12 +53,12 @@ namespace TestNamespace.Train
5353
Console.WriteLine($"* Metrics for binary classification model ");
5454
Console.WriteLine($"*-----------------------------------------------------------");
5555
Console.WriteLine($"* Accuracy: {metrics.Accuracy:P2}");
56-
Console.WriteLine($"* Auc: {metrics.Auc:P2}");
56+
Console.WriteLine($"* Auc: {metrics.AreaUnderRocCurve:P2}");
5757
Console.WriteLine($"************************************************************");
5858
}
5959

6060

61-
public static void PrintBinaryClassificationFoldsAverageMetrics(TrainCatalogBase.CrossValidationResult<BinaryClassificationMetrics>[] crossValResults)
61+
public static void PrintBinaryClassificationFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<BinaryClassificationMetrics>> crossValResults)
6262
{
6363
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);
6464

@@ -76,31 +76,31 @@ namespace TestNamespace.Train
7676

7777
}
7878

79-
public static void PrintMultiClassClassificationMetrics(MultiClassClassifierMetrics metrics)
79+
public static void PrintMulticlassClassificationMetrics(MulticlassClassificationMetrics metrics)
8080
{
8181
Console.WriteLine($"************************************************************");
8282
Console.WriteLine($"* Metrics for multi-class classification model ");
8383
Console.WriteLine($"*-----------------------------------------------------------");
84-
Console.WriteLine($" AccuracyMacro = {metrics.AccuracyMacro:0.####}, a value between 0 and 1, the closer to 1, the better");
85-
Console.WriteLine($" AccuracyMicro = {metrics.AccuracyMicro:0.####}, a value between 0 and 1, the closer to 1, the better");
84+
Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
85+
Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
8686
Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
87-
for (int i = 0; i < metrics.PerClassLogLoss.Length; i++)
87+
for (int i = 0; i < metrics.PerClassLogLoss.Count; i++)
8888
{
8989
Console.WriteLine($" LogLoss for class {i + 1} = {metrics.PerClassLogLoss[i]:0.####}, the closer to 0, the better");
9090
}
9191
Console.WriteLine($"************************************************************");
9292
}
9393

94-
public static void PrintMulticlassClassificationFoldsAverageMetrics(TrainCatalogBase.CrossValidationResult<MultiClassClassifierMetrics>[] crossValResults)
94+
public static void PrintMulticlassClassificationFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<MulticlassClassificationMetrics>> crossValResults)
9595
{
9696
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);
9797

98-
var microAccuracyValues = metricsInMultipleFolds.Select(m => m.AccuracyMicro);
98+
var microAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy);
9999
var microAccuracyAverage = microAccuracyValues.Average();
100100
var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues);
101101
var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues);
102102

103-
var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.AccuracyMacro);
103+
var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.MacroAccuracy);
104104
var macroAccuracyAverage = macroAccuracyValues.Average();
105105
var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues);
106106
var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues);

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.GeneratedTrainCodeBinaryClassificationTest.approved.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ namespace MyNamespace
7272
.AppendCacheCheckpoint(mlContext);
7373

7474
// Set the training algorithm, then create and config the modelBuilder
75-
var trainer = mlContext.BinaryClassification.Trainers.LightGbm(new Options() { NumLeaves = 2, Booster = new Options.TreeBooster.Options() { }, LabelColumn = "Label", FeatureColumn = "Features" });
75+
var trainer = mlContext.BinaryClassification.Trainers.LightGbm(new Options() { NumLeaves = 2, Booster = new Options.TreeBooster.Options() { }, LabelColumnName = "Label", FeatureColumnName = "Features" });
7676
var trainingPipeline = dataProcessPipeline.Append(trainer);
7777

7878
// Train the model fitting to the DataSet

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.GeneratedTrainCodeMulticlassTest.approved.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ namespace MyNamespace
7272
.AppendCacheCheckpoint(mlContext);
7373

7474
// Set the training algorithm, then create and config the modelBuilder
75-
var trainer = mlContext.MulticlassClassification.Trainers.LightGbm(new Options() { NumLeaves = 2, Booster = new Options.TreeBooster.Options() { }, LabelColumn = "Label", FeatureColumn = "Features" });
75+
var trainer = mlContext.MulticlassClassification.Trainers.LightGbm(new Options() { NumLeaves = 2, Booster = new Options.TreeBooster.Options() { }, LabelColumnName = "Label", FeatureColumnName = "Features" });
7676
var trainingPipeline = dataProcessPipeline.Append(trainer);
7777

7878
// Train the model fitting to the DataSet

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.GeneratedTrainCodeRegressionTest.approved.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ namespace MyNamespace
7272
.AppendCacheCheckpoint(mlContext);
7373

7474
// Set the training algorithm, then create and config the modelBuilder
75-
var trainer = mlContext.Regression.Trainers.LightGbm(new Options() { NumLeaves = 2, Booster = new Options.TreeBooster.Options() { }, LabelColumn = "Label", FeatureColumn = "Features" });
75+
var trainer = mlContext.Regression.Trainers.LightGbm(new Options() { NumLeaves = 2, Booster = new Options.TreeBooster.Options() { }, LabelColumnName = "Label", FeatureColumnName = "Features" });
7676
var trainingPipeline = dataProcessPipeline.Append(trainer);
7777

7878
// Train the model fitting to the DataSet

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.ModelProjectFileContentTest.approved.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
</RestoreSources>
1010
</PropertyGroup>
1111
<ItemGroup>
12-
<PackageReference Include="Microsoft.ML" Version="0.11.0" />
12+
<PackageReference Include="Microsoft.ML" Version="1.0.0-preview" />
1313
</ItemGroup>
1414

1515
<ItemGroup>

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.PredictProgramCSFileContentTest.approved.txt

+2-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ using System.Linq;
1010
using System.Collections.Generic;
1111
using Microsoft.ML;
1212
using Microsoft.ML.Data;
13-
using Microsoft.Data.DataView;
1413
using TestNamespace.Model.DataModels;
1514

1615

@@ -44,7 +43,7 @@ namespace TestNamespace.Predict
4443
private static void Predict(MLContext mlContext, ITransformer mlModel, SampleObservation sampleData)
4544
{
4645
// Create prediction engine related to the loaded ML model
47-
var predEngine = mlModel.CreatePredictionEngine<SampleObservation, SamplePrediction>(mlContext);
46+
var predEngine = mlContext.Model.CreatePredictionEngine<SampleObservation, SamplePrediction>(mlModel);
4847

4948
// Try a single prediction
5049
var predictionResult = predEngine.Predict(sampleData);
@@ -56,7 +55,7 @@ namespace TestNamespace.Predict
5655
ITransformer mlModel;
5756
using (var stream = new FileStream(modelFilePath, FileMode.Open, FileAccess.Read, FileShare.Read))
5857
{
59-
mlModel = mlContext.Model.Load(stream);
58+
mlModel = mlContext.Model.Load(stream, out var modelInputSchema);
6059
}
6160

6261
return mlModel;

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.PredictProjectFileContentTest.approved.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
<TargetFramework>netcoreapp2.1</TargetFramework>
66
</PropertyGroup>
77
<ItemGroup>
8-
<PackageReference Include="Microsoft.ML" Version="0.11.0" />
9-
<PackageReference Include="Microsoft.ML.LightGBM" Version="0.11.0" />
10-
<PackageReference Include="Microsoft.ML.HalLearners" Version="0.11.0" />
8+
<PackageReference Include="Microsoft.ML" Version="1.0.0-preview" />
9+
<PackageReference Include="Microsoft.ML.LightGBM" Version="1.0.0-preview" />
10+
<PackageReference Include="Microsoft.ML.Mkl.Components" Version="1.0.0-preview" />
1111
</ItemGroup>
1212
<ItemGroup>
1313
<ProjectReference Include="..\TestNamespace.Model\TestNamespace.Model.csproj" />

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.TrainProgramCSFileContentOvaTest.approved.txt

+5-7
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ using System;
88
using System.IO;
99
using System.Linq;
1010
using Microsoft.ML;
11-
using Microsoft.ML.Data;
12-
using Microsoft.Data.DataView;
1311
using TestNamespace.Model.DataModels;
1412

1513
namespace TestNamespace.Train
@@ -50,7 +48,7 @@ namespace TestNamespace.Train
5048
EvaluateModel(mlContext, mlModel, testDataView);
5149

5250
// Save model
53-
SaveModel(mlContext, mlModel, MODEL_FILEPATH);
51+
SaveModel(mlContext, mlModel, MODEL_FILEPATH, trainingDataView.Schema);
5452

5553
Console.WriteLine("=============== End of process, hit any key to finish ===============");
5654
Console.ReadKey();
@@ -63,7 +61,7 @@ namespace TestNamespace.Train
6361
.AppendCacheCheckpoint(mlContext);
6462

6563
// Set the training algorithm
66-
var trainer = mlContext.MulticlassClassification.Trainers.OneVersusAll(mlContext.BinaryClassification.Trainers.FastForest(numLeaves: 2, labelColumnName: "Label", featureColumnName: "Features"), labelColumnName: "Label");
64+
var trainer = mlContext.MulticlassClassification.Trainers.OneVersusAll(mlContext.BinaryClassification.Trainers.FastForest(labelColumnName: "Label", featureColumnName: "Features"), labelColumnName: "Label");
6765
var trainingPipeline = dataProcessPipeline.Append(trainer);
6866

6967
return trainingPipeline;
@@ -85,14 +83,14 @@ namespace TestNamespace.Train
8583
Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");
8684
IDataView predictions = mlModel.Transform(testDataView);
8785
var metrics = mlContext.MulticlassClassification.Evaluate(predictions, "Label", "Score");
88-
ConsoleHelper.PrintMultiClassClassificationMetrics(metrics);
86+
ConsoleHelper.PrintMulticlassClassificationMetrics(metrics);
8987
}
90-
private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath)
88+
private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema)
9189
{
9290
// Save/persist the trained model to a .ZIP file
9391
Console.WriteLine($"=============== Saving the model ===============");
9492
using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write))
95-
mlContext.Model.Save(mlModel, fs);
93+
mlContext.Model.Save(mlModel, modelInputSchema, fs);
9694

9795
Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath));
9896
}

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.TrainProgramCSFileContentTest.approved.txt

+4-7
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@ using System;
88
using System.IO;
99
using System.Linq;
1010
using Microsoft.ML;
11-
using Microsoft.ML.Data;
12-
using Microsoft.Data.DataView;
1311
using TestNamespace.Model.DataModels;
14-
using Microsoft.ML.LightGBM;
1512

1613
namespace TestNamespace.Train
1714
{
@@ -51,7 +48,7 @@ namespace TestNamespace.Train
5148
EvaluateModel(mlContext, mlModel, testDataView);
5249

5350
// Save model
54-
SaveModel(mlContext, mlModel, MODEL_FILEPATH);
51+
SaveModel(mlContext, mlModel, MODEL_FILEPATH, trainingDataView.Schema);
5552

5653
Console.WriteLine("=============== End of process, hit any key to finish ===============");
5754
Console.ReadKey();
@@ -64,7 +61,7 @@ namespace TestNamespace.Train
6461
.AppendCacheCheckpoint(mlContext);
6562

6663
// Set the training algorithm
67-
var trainer = mlContext.BinaryClassification.Trainers.LightGbm(new Options() { NumLeaves = 2, Booster = new Options.TreeBooster.Options() { }, LabelColumn = "Label", FeatureColumn = "Features" });
64+
var trainer = mlContext.BinaryClassification.Trainers.LightGbm(labelColumnName: "Label", featureColumnName: "Features");
6865
var trainingPipeline = dataProcessPipeline.Append(trainer);
6966

7067
return trainingPipeline;
@@ -88,12 +85,12 @@ namespace TestNamespace.Train
8885
var metrics = mlContext.BinaryClassification.EvaluateNonCalibrated(predictions, "Label", "Score");
8986
ConsoleHelper.PrintBinaryClassificationMetrics(metrics);
9087
}
91-
private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath)
88+
private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema)
9289
{
9390
// Save/persist the trained model to a .ZIP file
9491
Console.WriteLine($"=============== Saving the model ===============");
9592
using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write))
96-
mlContext.Model.Save(mlModel, fs);
93+
mlContext.Model.Save(mlModel, modelInputSchema, fs);
9794

9895
Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath));
9996
}

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.TrainProjectFileContentTest.approved.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
<TargetFramework>netcoreapp2.1</TargetFramework>
66
</PropertyGroup>
77
<ItemGroup>
8-
<PackageReference Include="Microsoft.ML" Version="0.11.0" />
9-
<PackageReference Include="Microsoft.ML.LightGBM" Version="0.11.0" />
10-
<PackageReference Include="Microsoft.ML.HalLearners" Version="0.11.0" />
8+
<PackageReference Include="Microsoft.ML" Version="1.0.0-preview" />
9+
<PackageReference Include="Microsoft.ML.LightGBM" Version="1.0.0-preview" />
10+
<PackageReference Include="Microsoft.ML.Mkl.Components" Version="1.0.0-preview" />
1111
</ItemGroup>
1212
<ItemGroup>
1313
<ProjectReference Include="..\TestNamespace.Model\TestNamespace.Model.csproj" />

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.cs

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
namespace mlnet.Test
1717
{
18-
[Ignore]
1918
[TestClass]
2019
[UseReporter(typeof(DiffReporter))]
2120
public class ConsoleCodeGeneratorTests

0 commit comments

Comments
 (0)