Skip to content

Commit 95b7382

Browse files
daholsteDmitry-A
authored andcommitted
Rev Samples (dotnet#334)
1 parent 3cacc45 commit 95b7382

18 files changed

+10377
-65
lines changed

src/Samples/AdvancedExperimentSettings.cs

+20-8
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,40 @@
88
using Microsoft.ML;
99
using Microsoft.ML.Auto;
1010
using Microsoft.ML.Data;
11-
using Samples.Helpers;
1211

1312
namespace Samples
1413
{
1514
static class AdvancedExperimentSettings
1615
{
17-
private static string BaseDatasetsLocation = Path.Combine("..", "..", "..", "..", "src", "Samples", "Data");
16+
private static string BaseDatasetsLocation = "Data";
1817
private static string TrainDataPath = Path.Combine(BaseDatasetsLocation, "taxi-fare-train.csv");
1918
private static string TestDataPath = Path.Combine(BaseDatasetsLocation, "taxi-fare-test.csv");
2019
private static string ModelPath = Path.Combine(BaseDatasetsLocation, "TaxiFareModel.zip");
21-
private static string LabelColumn = "fare_amount";
20+
private static string LabelColumn = "FareAmount";
2221

2322
public static void Run()
2423
{
2524
MLContext mlContext = new MLContext();
26-
27-
// STEP 1: Infer columns
28-
ColumnInferenceResults columnInference = mlContext.Auto().InferColumns(TrainDataPath, LabelColumn, ',');
29-
ConsoleHelper.Print(columnInference);
25+
26+
// STEP 1: Create text loader options
27+
var textLoaderOptions = new TextLoader.Options()
28+
{
29+
Columns = new[]
30+
{
31+
new TextLoader.Column("VendorId", DataKind.String, 0),
32+
new TextLoader.Column("RateCode", DataKind.Single, 1),
33+
new TextLoader.Column("PassengerCount", DataKind.Single, 2),
34+
new TextLoader.Column("TripTimeInSeconds", DataKind.Single, 3),
35+
new TextLoader.Column("TripDistance", DataKind.Single, 4),
36+
new TextLoader.Column("PaymentType", DataKind.String, 5),
37+
new TextLoader.Column("FareAmount", DataKind.Single, 6),
38+
},
39+
HasHeader = true,
40+
Separators = new[] { ',' }
41+
};
3042

3143
// STEP 2: Load data
32-
TextLoader textLoader = mlContext.Data.CreateTextLoader(columnInference.TextLoaderOptions);
44+
TextLoader textLoader = mlContext.Data.CreateTextLoader(textLoaderOptions);
3345
IDataView trainDataView = textLoader.Load(TrainDataPath);
3446
IDataView testDataView = textLoader.Load(TestDataPath);
3547

src/Samples/AdvancedTrainingSettings.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace Samples
1515
{
1616
static class AdvancedTrainingSettings
1717
{
18-
private static string BaseDatasetsLocation = Path.Combine("..", "..", "..", "..", "src", "Samples", "Data");
18+
private static string BaseDatasetsLocation = "Data";
1919
private static string TrainDataPath = Path.Combine(BaseDatasetsLocation, "taxi-fare-train.csv");
2020
private static string TestDataPath = Path.Combine(BaseDatasetsLocation, "taxi-fare-test.csv");
2121
private static string ModelPath = Path.Combine(BaseDatasetsLocation, "TaxiFareModel.zip");

src/Samples/AutoTrainBinaryClassification.cs

+27-9
Original file line numberDiff line numberDiff line change
@@ -10,37 +10,44 @@
1010
using Microsoft.ML;
1111
using Microsoft.ML.Auto;
1212
using Microsoft.ML.Data;
13-
using Samples.Helpers;
13+
using Samples.DataStructures;
1414

1515
namespace Samples
1616
{
1717
public class AutoTrainBinaryClassification
1818
{
19-
private static string BaseDatasetsLocation = Path.Combine("..", "..", "..", "..", "src", "Samples", "Data");
19+
private static string BaseDatasetsLocation = "Data";
2020
private static string TrainDataPath = Path.Combine(BaseDatasetsLocation, "wikipedia-detox-250-line-data.tsv");
2121
private static string TestDataPath = Path.Combine(BaseDatasetsLocation, "wikipedia-detox-250-line-test.tsv");
2222
private static string ModelPath = Path.Combine(BaseDatasetsLocation, "SentimentModel.zip");
23-
private static string LabelColumn = "Sentiment";
2423
private static uint ExperimentTime = 60;
2524

2625
public static void Run()
2726
{
2827
MLContext mlContext = new MLContext();
2928

30-
// STEP 1: Infer columns
31-
ColumnInferenceResults columnInference = mlContext.Auto().InferColumns(TrainDataPath, LabelColumn);
32-
ConsoleHelper.Print(columnInference);
29+
// STEP 1: Create text loader options
30+
var textLoaderOptions = new TextLoader.Options()
31+
{
32+
Columns = new[]
33+
{
34+
new TextLoader.Column("Label", DataKind.Boolean, 0),
35+
new TextLoader.Column("Text", DataKind.String, 1),
36+
},
37+
HasHeader = true,
38+
Separators = new[] { '\t' }
39+
};
3340

3441
// STEP 2: Load data
35-
TextLoader textLoader = mlContext.Data.CreateTextLoader(columnInference.TextLoaderOptions);
42+
TextLoader textLoader = mlContext.Data.CreateTextLoader(textLoaderOptions);
3643
IDataView trainDataView = textLoader.Load(TrainDataPath);
3744
IDataView testDataView = textLoader.Load(TestDataPath);
3845

3946
// STEP 3: Auto featurize, auto train and auto hyperparameter tune
4047
Console.WriteLine($"Running AutoML binary classification experiment for {ExperimentTime} seconds...");
4148
IEnumerable<RunResult<BinaryClassificationMetrics>> runResults = mlContext.Auto()
4249
.CreateBinaryClassificationExperiment(ExperimentTime)
43-
.Execute(trainDataView, LabelColumn);
50+
.Execute(trainDataView);
4451

4552
// STEP 4: Print metric from the best model
4653
RunResult<BinaryClassificationMetrics> best = runResults.Best();
@@ -50,13 +57,24 @@ public static void Run()
5057

5158
// STEP 5: Evaluate test data
5259
IDataView testDataViewWithBestScore = best.Model.Transform(testDataView);
53-
BinaryClassificationMetrics testMetrics = mlContext.BinaryClassification.EvaluateNonCalibrated(testDataViewWithBestScore, label: LabelColumn);
60+
BinaryClassificationMetrics testMetrics = mlContext.BinaryClassification.EvaluateNonCalibrated(testDataViewWithBestScore);
5461
Console.WriteLine($"Accuracy of best model on test data: {testMetrics.Accuracy}");
5562

5663
// STEP 6: Save the best model for later deployment and inferencing
5764
using (FileStream fs = File.Create(ModelPath))
5865
best.Model.SaveTo(mlContext, fs);
5966

67+
// STEP 7: Create prediction engine from the best trained model
68+
var predictionEngine = best.Model.CreatePredictionEngine<SentimentIssue, SentimentPrediction>(mlContext);
69+
70+
// STEP 8: Initialize a new sentiment issue, and get the predicted sentiment
71+
var testSentimentIssue = new SentimentIssue
72+
{
73+
Text = "I hope this helps."
74+
};
75+
var prediction = predictionEngine.Predict(testSentimentIssue);
76+
Console.WriteLine($"Predicted sentiment for test issue: {prediction.Prediction}");
77+
6078
Console.WriteLine("Press any key to continue...");
6179
Console.ReadKey();
6280
}

src/Samples/AutoTrainMulticlassClassification.cs

+25-6
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
using Microsoft.ML;
1111
using Microsoft.ML.Auto;
1212
using Microsoft.ML.Data;
13-
using Samples.Helpers;
13+
using Samples.DataStructures;
1414

1515
namespace Samples
1616
{
1717
public class AutoTrainMulticlassClassification
1818
{
19-
private static string BaseDatasetsLocation = Path.Combine("..", "..", "..", "..", "src", "Samples", "Data");
19+
private static string BaseDatasetsLocation = "Data";
2020
private static string TrainDataPath = Path.Combine(BaseDatasetsLocation, "optdigits-train.csv");
2121
private static string TestDataPath = Path.Combine(BaseDatasetsLocation, "optdigits-test.csv");
2222
private static string ModelPath = Path.Combine(BaseDatasetsLocation, "OptDigits.zip");
@@ -26,12 +26,20 @@ public static void Run()
2626
{
2727
MLContext mlContext = new MLContext();
2828

29-
// STEP 1: Infer columns
30-
ColumnInferenceResults columnInference = mlContext.Auto().InferColumns(TrainDataPath);
31-
ConsoleHelper.Print(columnInference);
29+
// STEP 1: Create text loader options
30+
var textLoaderOptions = new TextLoader.Options()
31+
{
32+
Columns = new[]
33+
{
34+
new TextLoader.Column("PixelValues", DataKind.Single, 0, 63),
35+
new TextLoader.Column("Label", DataKind.Single, 64),
36+
},
37+
HasHeader = true,
38+
Separators = new[] { ',' }
39+
};
3240

3341
// STEP 2: Load data
34-
TextLoader textLoader = mlContext.Data.CreateTextLoader(columnInference.TextLoaderOptions);
42+
TextLoader textLoader = mlContext.Data.CreateTextLoader(textLoaderOptions);
3543
IDataView trainDataView = textLoader.Load(TrainDataPath);
3644
IDataView testDataView = textLoader.Load(TestDataPath);
3745

@@ -56,6 +64,17 @@ public static void Run()
5664
using (FileStream fs = File.Create(ModelPath))
5765
best.Model.SaveTo(mlContext, fs);
5866

67+
// STEP 7: Create prediction engine from the best trained model
68+
var predictionEngine = best.Model.CreatePredictionEngine<PixelData, PixelPrediction>(mlContext);
69+
70+
// STEP 8: Initialize new pixel data, and get the predicted number
71+
var testPixelData = new PixelData
72+
{
73+
PixelValues = new float[] { 0, 0, 1, 8, 15, 10, 0, 0, 0, 3, 13, 15, 14, 14, 0, 0, 0, 5, 10, 0, 10, 12, 0, 0, 0, 0, 3, 5, 15, 10, 2, 0, 0, 0, 16, 16, 16, 16, 12, 0, 0, 1, 8, 12, 14, 8, 3, 0, 0, 0, 0, 10, 13, 0, 0, 0, 0, 0, 0, 11, 9, 0, 0, 0 }
74+
};
75+
var prediction = predictionEngine.Predict(testPixelData);
76+
Console.WriteLine($"Predicted number for test pixels: {prediction.Prediction}");
77+
5978
Console.WriteLine("Press any key to continue...");
6079
Console.ReadKey();
6180
}

src/Samples/AutoTrainRegression.cs

+36-7
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,42 @@
1010
using Microsoft.ML;
1111
using Microsoft.ML.Auto;
1212
using Microsoft.ML.Data;
13-
using Samples.Helpers;
13+
using Samples.DataStructures;
1414

1515
namespace Samples
1616
{
1717
static class AutoTrainRegression
1818
{
19-
private static string BaseDatasetsLocation = Path.Combine("..", "..", "..", "..", "src", "Samples", "Data");
19+
private static string BaseDatasetsLocation = "Data";
2020
private static string TrainDataPath = Path.Combine(BaseDatasetsLocation, "taxi-fare-train.csv");
2121
private static string TestDataPath = Path.Combine(BaseDatasetsLocation, "taxi-fare-test.csv");
2222
private static string ModelPath = Path.Combine(BaseDatasetsLocation, "TaxiFareModel.zip");
23-
private static string LabelColumn = "fare_amount";
23+
private static string LabelColumn = "FareAmount";
2424
private static uint ExperimentTime = 60;
2525

2626
public static void Run()
2727
{
2828
MLContext mlContext = new MLContext();
2929

30-
// STEP 1: Infer columns
31-
ColumnInferenceResults columnInference = mlContext.Auto().InferColumns(TrainDataPath, LabelColumn);
32-
ConsoleHelper.Print(columnInference);
30+
// STEP 1: Create text loader options
31+
var textLoaderOptions = new TextLoader.Options()
32+
{
33+
Columns = new[]
34+
{
35+
new TextLoader.Column("VendorId", DataKind.String, 0),
36+
new TextLoader.Column("RateCode", DataKind.Single, 1),
37+
new TextLoader.Column("PassengerCount", DataKind.Single, 2),
38+
new TextLoader.Column("TripTimeInSeconds", DataKind.Single, 3),
39+
new TextLoader.Column("TripDistance", DataKind.Single, 4),
40+
new TextLoader.Column("PaymentType", DataKind.String, 5),
41+
new TextLoader.Column("FareAmount", DataKind.Single, 6),
42+
},
43+
HasHeader = true,
44+
Separators = new[] { ',' }
45+
};
3346

3447
// STEP 2: Load data
35-
TextLoader textLoader = mlContext.Data.CreateTextLoader(columnInference.TextLoaderOptions);
48+
TextLoader textLoader = mlContext.Data.CreateTextLoader(textLoaderOptions);
3649
IDataView trainDataView = textLoader.Load(TrainDataPath);
3750
IDataView testDataView = textLoader.Load(TestDataPath);
3851

@@ -57,6 +70,22 @@ public static void Run()
5770
using (FileStream fs = File.Create(ModelPath))
5871
best.Model.SaveTo(mlContext, fs);
5972

73+
// STEP 7: Create prediction engine from the best trained model
74+
var predictionEngine = best.Model.CreatePredictionEngine<TaxiTrip, TaxiTripFarePrediction>(mlContext);
75+
76+
// STEP 8: Initialize a new test taxi trip, and get the predicted fare
77+
var testTaxiTrip = new TaxiTrip
78+
{
79+
VendorId = "VTS",
80+
RateCode = 1,
81+
PassengerCount = 1,
82+
TripTimeInSeconds = 1140,
83+
TripDistance = 3.75f,
84+
PaymentType = "CRD"
85+
};
86+
var prediction = predictionEngine.Predict(testTaxiTrip);
87+
Console.WriteLine($"Predicted fare for test taxi trip: {prediction.FareAmount}");
88+
6089
Console.WriteLine("Press any key to continue...");
6190
Console.ReadKey();
6291
}

src/Samples/Cancellation.cs

+39-20
Original file line numberDiff line numberDiff line change
@@ -3,57 +3,76 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.IO;
76
using System.Collections.Generic;
87
using System.Diagnostics;
8+
using System.IO;
99
using System.Linq;
1010
using System.Threading;
11+
using System.Threading.Tasks;
1112
using Microsoft.Data.DataView;
1213
using Microsoft.ML;
1314
using Microsoft.ML.Auto;
1415
using Microsoft.ML.Data;
15-
using Samples.Helpers;
1616

1717
namespace Samples
1818
{
1919
static class Cancellation
2020
{
21-
private static string BaseDatasetsLocation = Path.Combine("..", "..", "..", "..", "src", "Samples", "Data");
21+
private static string BaseDatasetsLocation = "Data";
2222
private static string TrainDataPath = Path.Combine(BaseDatasetsLocation, "taxi-fare-train.csv");
2323
private static string TestDataPath = Path.Combine(BaseDatasetsLocation, "taxi-fare-test.csv");
2424
private static string ModelPath = Path.Combine(BaseDatasetsLocation, "TaxiFareModel.zip");
25-
private static string LabelColumn = "fare_amount";
25+
private static string LabelColumn = "FareAmount";
2626

2727
public static void Run()
2828
{
2929
MLContext mlContext = new MLContext();
3030

31-
// STEP 1: Infer columns
32-
ColumnInferenceResults columnInference = mlContext.Auto().InferColumns(TrainDataPath, LabelColumn, ',');
33-
ConsoleHelper.Print(columnInference);
31+
// STEP 1: Create text loader options
32+
var textLoaderOptions = new TextLoader.Options()
33+
{
34+
Columns = new[]
35+
{
36+
new TextLoader.Column("VendorId", DataKind.String, 0),
37+
new TextLoader.Column("RateCode", DataKind.Single, 1),
38+
new TextLoader.Column("PassengerCount", DataKind.Single, 2),
39+
new TextLoader.Column("TripTimeInSeconds", DataKind.Single, 3),
40+
new TextLoader.Column("TripDistance", DataKind.Single, 4),
41+
new TextLoader.Column("PaymentType", DataKind.String, 5),
42+
new TextLoader.Column("FareAmount", DataKind.Single, 6),
43+
},
44+
HasHeader = true,
45+
Separators = new[] { ',' }
46+
};
3447

3548
// STEP 2: Load data
36-
TextLoader textLoader = mlContext.Data.CreateTextLoader(columnInference.TextLoaderOptions);
49+
TextLoader textLoader = mlContext.Data.CreateTextLoader(textLoaderOptions);
3750
IDataView trainDataView = textLoader.Load(TrainDataPath);
3851
IDataView testDataView = textLoader.Load(TestDataPath);
3952

40-
int cancelAfterInSeconds = 20;
53+
// STEP 3: Auto inference with a cancellation token in a new task
54+
Stopwatch stopwatch = Stopwatch.StartNew();
4155
CancellationTokenSource cts = new CancellationTokenSource();
42-
cts.CancelAfter(cancelAfterInSeconds * 1000);
43-
44-
Stopwatch watch = Stopwatch.StartNew();
45-
46-
// STEP 3: Auto inference with a cancellation token
47-
Console.WriteLine($"Invoking an experiment that will be cancelled after {cancelAfterInSeconds} seconds");
48-
IEnumerable<RunResult<RegressionMetrics>> runResults = mlContext.Auto()
56+
var experiment = mlContext.Auto()
4957
.CreateRegressionExperiment(new RegressionExperimentSettings()
5058
{
51-
MaxExperimentTimeInSeconds = 60,
59+
MaxExperimentTimeInSeconds = 3600,
5260
CancellationToken = cts.Token
53-
})
54-
.Execute(trainDataView, LabelColumn);
61+
});
62+
IEnumerable<RunResult<RegressionMetrics>> runResults = new List<RunResult<RegressionMetrics>>();
63+
Console.WriteLine($"Running AutoML experiment...");
64+
Task experimentTask = Task.Run(() =>
65+
{
66+
runResults = experiment.Execute(trainDataView, LabelColumn);
67+
});
68+
69+
// STEP 4: Stop the experiment run after any key is pressed
70+
Console.WriteLine($"Press any key to stop the experiment run...");
71+
Console.ReadKey();
72+
cts.Cancel();
73+
experimentTask.Wait();
5574

56-
Console.WriteLine($"{runResults.Count()} models were returned after {cancelAfterInSeconds} seconds");
75+
Console.WriteLine($"{runResults.Count()} models were returned after {stopwatch.Elapsed.TotalSeconds:0.00} seconds");
5776

5877
Console.WriteLine("Press any key to continue...");
5978
Console.ReadKey();

0 commit comments

Comments
 (0)