|
3 | 3 | // See the LICENSE file in the project root for more information.
|
4 | 4 |
|
5 | 5 | using System;
|
| 6 | +using System.Collections.Generic; |
6 | 7 | using System.IO;
|
7 | 8 | using System.Linq;
|
8 | 9 | using Microsoft.Data.DataView;
|
9 | 10 | using Microsoft.ML;
|
10 | 11 | using Microsoft.ML.Auto;
|
11 | 12 | using Microsoft.ML.Data;
|
| 13 | +using Samples.Helpers; |
12 | 14 |
|
13 | 15 | namespace Samples
|
14 | 16 | {
|
15 | 17 | public class AutoTrainMulticlassClassification
|
16 | 18 | {
|
17 | 19 | private static string BaseDatasetsLocation = @"../../../../src/Samples/Data";
|
18 |
| - private static string TrainDataPath = $"{BaseDatasetsLocation}/iris-train.txt"; |
19 |
| - private static string TestDataPath = $"{BaseDatasetsLocation}/iris-test.txt"; |
20 |
| - private static string ModelPath = $"{BaseDatasetsLocation}/IrisClassificationModel.zip"; |
| 20 | + private static string TrainDataPath = $"{BaseDatasetsLocation}/optdigits-train.csv"; |
| 21 | + private static string TestDataPath = $"{BaseDatasetsLocation}/optdigits-test.csv"; |
| 22 | + private static string ModelPath = $"{BaseDatasetsLocation}/OptDigits.zip"; |
21 | 23 | private static uint ExperimentTime = 60;
|
22 | 24 |
|
23 | 25 | public static void Run()
|
24 | 26 | {
|
25 | 27 | MLContext mlContext = new MLContext();
|
26 | 28 |
|
27 | 29 | // STEP 1: Infer columns
|
28 |
| - var columnInference = mlContext.Auto().InferColumns(TrainDataPath); |
| 30 | + ColumnInferenceResults columnInference = mlContext.Auto().InferColumns(TrainDataPath); |
| 31 | + ConsoleHelper.Print(columnInference); |
29 | 32 |
|
30 | 33 | // STEP 2: Load data
|
31 |
| - var textLoader = mlContext.Data.CreateTextLoader(columnInference.TextLoaderArgs); |
32 |
| - var trainDataView = textLoader.Read(TrainDataPath); |
33 |
| - var testDataView = textLoader.Read(TestDataPath); |
| 34 | + TextLoader textLoader = mlContext.Data.CreateTextLoader(columnInference.TextLoaderArgs); |
| 35 | + IDataView trainDataView = textLoader.Read(TrainDataPath); |
| 36 | + IDataView testDataView = textLoader.Read(TestDataPath); |
34 | 37 |
|
35 | 38 | // STEP 3: Auto featurize, auto train and auto hyperparameter tune
|
36 | 39 | Console.WriteLine($"Running AutoML multiclass classification experiment for {ExperimentTime} seconds...");
|
37 |
| - var runResults = mlContext.Auto() |
38 |
| - .CreateMulticlassClassificationExperiment(60) |
39 |
| - .Execute(trainDataView); |
| 40 | + IEnumerable<RunResult<MultiClassClassifierMetrics>> runResults = mlContext.Auto() |
| 41 | + .CreateMulticlassClassificationExperiment(60) |
| 42 | + .Execute(trainDataView); |
40 | 43 |
|
41 | 44 | // STEP 4: Print metric from the best model
|
42 |
| - var best = runResults.Best(); |
| 45 | + RunResult<MultiClassClassifierMetrics> best = runResults.Best(); |
43 | 46 | Console.WriteLine($"Total models produced: {runResults.Count()}");
|
44 | 47 | Console.WriteLine($"Best model's trainer: {best.TrainerName}");
|
45 | 48 | Console.WriteLine($"AccuracyMacro of best model from validation data: {best.ValidationMetrics.AccuracyMacro}");
|
46 | 49 |
|
47 | 50 | // STEP 5: Evaluate test data
|
48 |
| - var testDataViewWithBestScore = best.Model.Transform(testDataView); |
49 |
| - var testMetrics = mlContext.MulticlassClassification.Evaluate(testDataViewWithBestScore); |
| 51 | + IDataView testDataViewWithBestScore = best.Model.Transform(testDataView); |
| 52 | + MultiClassClassifierMetrics testMetrics = mlContext.MulticlassClassification.Evaluate(testDataViewWithBestScore); |
50 | 53 | Console.WriteLine($"AccuracyMacro of best model on test data: {testMetrics.AccuracyMacro}");
|
51 | 54 |
|
52 | 55 | // STEP 6: Save the best model for later deployment and inferencing
|
53 |
| - using (var fs = File.Create(ModelPath)) |
| 56 | + using (FileStream fs = File.Create(ModelPath)) |
54 | 57 | best.Model.SaveTo(mlContext, fs);
|
55 | 58 |
|
56 | 59 | Console.WriteLine("Press any key to continue...");
|
|
0 commit comments