Skip to content

Commit 125237a

Browse files
daholsteDmitry-A
authored andcommitted
Rev samples towards private preview; ignored columns fix (dotnet#259)
1 parent 75dbb70 commit 125237a

8 files changed

+196
-8
lines changed

src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,26 @@ internal static class ColumnInformationUtil
2323
return ColumnPurpose.Weight;
2424
}
2525

26-
if (columnInfo.CategoricalColumns?.Contains(columnName) == true)
26+
if (columnInfo.CategoricalColumns.Contains(columnName))
2727
{
2828
return ColumnPurpose.CategoricalFeature;
2929
}
3030

31-
if (columnInfo.NumericColumns?.Contains(columnName) == true)
31+
if (columnInfo.NumericColumns.Contains(columnName))
3232
{
3333
return ColumnPurpose.NumericFeature;
3434
}
3535

36-
if (columnInfo.TextColumns?.Contains(columnName) == true)
36+
if (columnInfo.TextColumns.Contains(columnName))
3737
{
3838
return ColumnPurpose.TextFeature;
3939
}
4040

41+
if (columnInfo.IgnoredColumns.Contains(columnName))
42+
{
43+
return ColumnPurpose.Ignore;
44+
}
45+
4146
return null;
4247
}
4348

src/Samples/CustomizeTraining.cs renamed to src/Samples/AdvancedExperimentSettings.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
namespace Samples
1414
{
15-
static class CustomizeTraining
15+
static class AdvancedExperimentSettings
1616
{
1717
private static string BaseDatasetsLocation = Path.Combine("..", "..", "..", "..", "src", "Samples", "Data");
1818
private static string TrainDataPath = Path.Combine(BaseDatasetsLocation, "taxi-fare-train.csv");
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.IO;
8+
using System.Linq;
9+
using Microsoft.Data.DataView;
10+
using Microsoft.ML;
11+
using Microsoft.ML.Auto;
12+
using Microsoft.ML.Data;
13+
14+
namespace Samples
15+
{
16+
static class AdvancedTrainingSettings
17+
{
18+
private static string BaseDatasetsLocation = Path.Combine("..", "..", "..", "..", "src", "Samples", "Data");
19+
private static string TrainDataPath = Path.Combine(BaseDatasetsLocation, "taxi-fare-train.csv");
20+
private static string TestDataPath = Path.Combine(BaseDatasetsLocation, "taxi-fare-test.csv");
21+
private static string ModelPath = Path.Combine(BaseDatasetsLocation, "TaxiFareModel.zip");
22+
private static string LabelColumn = "FareAmount";
23+
private static uint ExperimentTime = 60;
24+
25+
public static void Run()
26+
{
27+
MLContext mlContext = new MLContext();
28+
29+
// STEP 1: Create text loader options
30+
var textLoaderOptions = new TextLoader.Options()
31+
{
32+
Columns = new[]
33+
{
34+
new TextLoader.Column("VendorId", DataKind.String, 0),
35+
new TextLoader.Column("RateCode", DataKind.Single, 1),
36+
new TextLoader.Column("PassengerCount", DataKind.Single, 2),
37+
new TextLoader.Column("TripTimeInSeconds", DataKind.Single, 3),
38+
new TextLoader.Column("TripDistance", DataKind.Single, 4),
39+
new TextLoader.Column("PaymentType", DataKind.String, 5),
40+
new TextLoader.Column("FareAmount", DataKind.Single, 6),
41+
},
42+
HasHeader = true,
43+
Separators = new[] { ',' }
44+
};
45+
46+
// STEP 2: Load data
47+
TextLoader textLoader = mlContext.Data.CreateTextLoader(textLoaderOptions);
48+
IDataView trainDataView = textLoader.Load(TrainDataPath);
49+
IDataView testDataView = textLoader.Load(TestDataPath);
50+
51+
// STEP 3: Build a pre-featurizer for use in our AutoML experiment
52+
IEstimator<ITransformer> preFeaturizer = mlContext.Transforms.Categorical.OneHotEncoding("RateCode");
53+
54+
// STEP 4: Initialize custom column information for use in AutoML experiment
55+
ColumnInformation columnInformation = new ColumnInformation() { LabelColumn = LabelColumn };
56+
columnInformation.CategoricalColumns.Add("VendorId");
57+
columnInformation.IgnoredColumns.Add("PaymentType");
58+
59+
// STEP 5: Run AutoML experiment
60+
Console.WriteLine($"Running AutoML regression experiment for {ExperimentTime} seconds...");
61+
IEnumerable<RunResult<RegressionMetrics>> runResults = mlContext.Auto()
62+
.CreateRegressionExperiment(ExperimentTime)
63+
.Execute(trainDataView, columnInformation, preFeaturizer);
64+
65+
// STEP 6: Print metric from best model
66+
RunResult<RegressionMetrics> best = runResults.Best();
67+
Console.WriteLine($"Total models produced: {runResults.Count()}");
68+
Console.WriteLine($"Best model's trainer: {best.TrainerName}");
69+
Console.WriteLine($"RSquared of best model from validation data: {best.ValidationMetrics.RSquared}");
70+
71+
// STEP 7: Save the best model for later deployment and inferencing
72+
using (FileStream fs = File.Create(ModelPath))
73+
best.Model.SaveTo(mlContext, fs);
74+
75+
Console.WriteLine("Press any key to continue...");
76+
Console.ReadKey();
77+
}
78+
}
79+
}

src/Samples/AutoTrainMulticlassClassification.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public static void Run()
3838
// STEP 3: Auto featurize, auto train and auto hyperparameter tune
3939
Console.WriteLine($"Running AutoML multiclass classification experiment for {ExperimentTime} seconds...");
4040
IEnumerable<RunResult<MultiClassClassifierMetrics>> runResults = mlContext.Auto()
41-
.CreateMulticlassClassificationExperiment(60)
41+
.CreateMulticlassClassificationExperiment(ExperimentTime)
4242
.Execute(trainDataView);
4343

4444
// STEP 4: Print metric from the best model

src/Samples/AutoTrainRegression.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ public static void Run()
3737
IDataView testDataView = textLoader.Load(TestDataPath);
3838

3939
// STEP 3: Auto featurize, auto train and auto hyperparameter tune
40-
Console.WriteLine($"Running AutoML multiclass classification experiment for {ExperimentTime} seconds...");
40+
Console.WriteLine($"Running AutoML regression classification experiment for {ExperimentTime} seconds...");
4141
IEnumerable<RunResult<RegressionMetrics>> runResults = mlContext.Auto()
42-
.CreateRegressionExperiment(60)
42+
.CreateRegressionExperiment(ExperimentTime)
4343
.Execute(trainDataView, LabelColumn);
4444

4545
// STEP 4: Print metric from best model

src/Samples/Program.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public static void Main(string[] args)
2121
AutoTrainMulticlassClassification.Run();
2222
Console.Clear();
2323

24-
CustomizeTraining.Run();
24+
AdvancedExperimentSettings.Run();
2525
Console.Clear();
2626

2727
ObserveProgress.Run();
@@ -30,6 +30,12 @@ public static void Main(string[] args)
3030
Cancellation.Run();
3131
Console.Clear();
3232

33+
AdvancedTrainingSettings.Run();
34+
Console.Clear();
35+
36+
RefitBestModel.Run();
37+
Console.Clear();
38+
3339
Console.WriteLine("Done");
3440
}
3541
catch (Exception ex)

src/Samples/RefitBestModel.cs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.IO;
8+
using System.Linq;
9+
using Microsoft.Data.DataView;
10+
using Microsoft.ML;
11+
using Microsoft.ML.Auto;
12+
using Microsoft.ML.Data;
13+
using Samples.Helpers;
14+
15+
namespace Samples
16+
{
17+
static class RefitBestModel
18+
{
19+
private static string BaseDatasetsLocation = Path.Combine("..", "..", "..", "..", "src", "Samples", "Data");
20+
private static string TrainDataPath = Path.Combine(BaseDatasetsLocation, "taxi-fare-train.csv");
21+
private static string TestDataPath = Path.Combine(BaseDatasetsLocation, "taxi-fare-test.csv");
22+
private static string ModelPath = Path.Combine(BaseDatasetsLocation, "TaxiFareModel.zip");
23+
private static string LabelColumn = "fare_amount";
24+
private static uint ExperimentTime = 60;
25+
26+
public static void Run()
27+
{
28+
MLContext mlContext = new MLContext();
29+
30+
// STEP 1: Infer columns
31+
ColumnInferenceResults columnInference = mlContext.Auto().InferColumns(TrainDataPath, LabelColumn);
32+
ConsoleHelper.Print(columnInference);
33+
34+
// STEP 2: Load data
35+
TextLoader textLoader = mlContext.Data.CreateTextLoader(columnInference.TextLoaderOptions);
36+
IDataView trainDataView = textLoader.Load(TrainDataPath);
37+
IDataView testDataView = textLoader.Load(TestDataPath);
38+
39+
// STEP 3: Subsample training data, for faster AutoML experimentation time
40+
IDataView smallTrainDataView = mlContext.Data.TakeRows(trainDataView, 50000);
41+
42+
// STEP 4: Auto-featurization, model selection, and hyperparameter tuning
43+
Console.WriteLine($"Running AutoML regression classification experiment for {ExperimentTime} seconds...");
44+
IEnumerable<RunResult<RegressionMetrics>> runResults = mlContext.Auto()
45+
.CreateRegressionExperiment(ExperimentTime)
46+
.Execute(smallTrainDataView, LabelColumn);
47+
48+
// STEP 5: Refit best model on entire training data
49+
RunResult<RegressionMetrics> best = runResults.Best();
50+
var refitBestModel = best.Estimator.Fit(trainDataView);
51+
52+
// STEP 6: Evaluate test data
53+
IDataView testDataViewWithBestScore = refitBestModel.Transform(testDataView);
54+
RegressionMetrics testMetrics = mlContext.Regression.Evaluate(testDataViewWithBestScore, label: LabelColumn);
55+
Console.WriteLine($"RSquared of the re-fit model on test data: {testMetrics.RSquared}");
56+
57+
// STEP 7: Save the re-fit best model for later deployment and inferencing
58+
using (FileStream fs = File.Create(ModelPath))
59+
refitBestModel.SaveTo(mlContext, fs);
60+
61+
Console.WriteLine("Press any key to continue...");
62+
Console.ReadKey();
63+
}
64+
}
65+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Microsoft.VisualStudio.TestTools.UnitTesting;
5+
6+
namespace Microsoft.ML.Auto.Test
7+
{
8+
[TestClass]
9+
public class ColumnInformationUtilTests
10+
{
11+
[TestMethod]
12+
public void GetColumnPurpose()
13+
{
14+
var columnInfo = new ColumnInformation()
15+
{
16+
LabelColumn = "Label",
17+
WeightColumn = "Weight",
18+
};
19+
columnInfo.CategoricalColumns.Add("Cat");
20+
columnInfo.NumericColumns.Add("Num");
21+
columnInfo.TextColumns.Add("Text");
22+
columnInfo.IgnoredColumns.Add("Ignored");
23+
24+
Assert.AreEqual(ColumnPurpose.Label, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Label"));
25+
Assert.AreEqual(ColumnPurpose.Weight, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Weight"));
26+
Assert.AreEqual(ColumnPurpose.CategoricalFeature, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Cat"));
27+
Assert.AreEqual(ColumnPurpose.NumericFeature, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Num"));
28+
Assert.AreEqual(ColumnPurpose.TextFeature, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Text"));
29+
Assert.AreEqual(ColumnPurpose.Ignore, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Ignored"));
30+
Assert.AreEqual(null, ColumnInformationUtil.GetColumnPurpose(columnInfo, "NonExistent"));
31+
}
32+
}
33+
}

0 commit comments

Comments
 (0)