|
2 | 2 | // The .NET Foundation licenses this file to you under the MIT license.
|
3 | 3 | // See the LICENSE file in the project root for more information.
|
4 | 4 |
|
| 5 | +using System; |
| 6 | +using System.Collections.Generic; |
5 | 7 | using Microsoft.Data.DataView;
|
6 | 8 | using Microsoft.ML.Auto;
|
7 | 9 | using Microsoft.ML.CLI.Data;
|
| 10 | +using Microsoft.ML.CLI.ShellProgressBar; |
8 | 11 | using Microsoft.ML.CLI.Utilities;
|
| 12 | +using Microsoft.ML.Data; |
9 | 13 | using NLog;
|
10 | 14 |
|
11 | 15 | namespace Microsoft.ML.CLI.CodeGenerator
|
@@ -42,68 +46,51 @@ public ColumnInferenceResults InferColumns(MLContext context, ColumnInformation
|
42 | 46 | return columnInference;
|
43 | 47 | }
|
44 | 48 |
|
45 |
| - (Pipeline, ITransformer) IAutoMLEngine.ExploreModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation) |
| 49 | + IEnumerable<RunResult<BinaryClassificationMetrics>> IAutoMLEngine.ExploreBinaryClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, BinaryClassificationMetric optimizationMetric, ProgressBar progressBar) |
46 | 50 | {
|
47 |
| - ITransformer model = null; |
48 |
| - |
49 |
| - Pipeline pipeline = null; |
50 |
| - |
51 |
| - if (taskKind == TaskKind.BinaryClassification) |
52 |
| - { |
53 |
| - var optimizationMetric = new BinaryExperimentSettings().OptimizingMetric; |
54 |
| - var progressReporter = new ProgressHandlers.BinaryClassificationHandler(optimizationMetric); |
55 |
| - var result = context.Auto() |
56 |
| - .CreateBinaryClassificationExperiment(new BinaryExperimentSettings() |
57 |
| - { |
58 |
| - MaxExperimentTimeInSeconds = settings.MaxExplorationTime, |
59 |
| - ProgressHandler = progressReporter, |
60 |
| - EnableCaching = this.enableCaching, |
61 |
| - OptimizingMetric = optimizationMetric |
62 |
| - }) |
63 |
| - .Execute(trainData, validationData, columnInformation); |
64 |
| - logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline); |
65 |
| - var bestIteration = result.Best(); |
66 |
| - pipeline = bestIteration.Pipeline; |
67 |
| - model = bestIteration.Model; |
68 |
| - } |
69 |
| - |
70 |
| - if (taskKind == TaskKind.Regression) |
71 |
| - { |
72 |
| - var optimizationMetric = new RegressionExperimentSettings().OptimizingMetric; |
73 |
| - var progressReporter = new ProgressHandlers.RegressionHandler(optimizationMetric); |
74 |
| - var result = context.Auto() |
75 |
| - .CreateRegressionExperiment(new RegressionExperimentSettings() |
76 |
| - { |
77 |
| - MaxExperimentTimeInSeconds = settings.MaxExplorationTime, |
78 |
| - ProgressHandler = progressReporter, |
79 |
| - OptimizingMetric = optimizationMetric, |
80 |
| - EnableCaching = this.enableCaching |
81 |
| - }).Execute(trainData, validationData, columnInformation); |
82 |
| - logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline); |
83 |
| - var bestIteration = result.Best(); |
84 |
| - pipeline = bestIteration.Pipeline; |
85 |
| - model = bestIteration.Model; |
86 |
| - } |
| 51 | + var progressReporter = new ProgressHandlers.BinaryClassificationHandler(optimizationMetric, progressBar); |
| 52 | + var result = context.Auto() |
| 53 | + .CreateBinaryClassificationExperiment(new BinaryExperimentSettings() |
| 54 | + { |
| 55 | + MaxExperimentTimeInSeconds = settings.MaxExplorationTime, |
| 56 | + ProgressHandler = progressReporter, |
| 57 | + EnableCaching = this.enableCaching, |
| 58 | + OptimizingMetric = optimizationMetric |
| 59 | + }) |
| 60 | + .Execute(trainData, validationData, columnInformation); |
| 61 | + logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline); |
| 62 | + return result; |
| 63 | + } |
87 | 64 |
|
88 |
| - if (taskKind == TaskKind.MulticlassClassification) |
89 |
| - { |
90 |
| - var optimizationMetric = new MulticlassExperimentSettings().OptimizingMetric; |
91 |
| - var progressReporter = new ProgressHandlers.MulticlassClassificationHandler(optimizationMetric); |
92 |
| - var result = context.Auto() |
93 |
| - .CreateMulticlassClassificationExperiment(new MulticlassExperimentSettings() |
94 |
| - { |
95 |
| - MaxExperimentTimeInSeconds = settings.MaxExplorationTime, |
96 |
| - ProgressHandler = progressReporter, |
97 |
| - EnableCaching = this.enableCaching, |
98 |
| - OptimizingMetric = optimizationMetric |
99 |
| - }).Execute(trainData, validationData, columnInformation); |
100 |
| - logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline); |
101 |
| - var bestIteration = result.Best(); |
102 |
| - pipeline = bestIteration.Pipeline; |
103 |
| - model = bestIteration.Model; |
104 |
| - } |
| 65 | + IEnumerable<RunResult<RegressionMetrics>> IAutoMLEngine.ExploreRegressionModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, RegressionMetric optimizationMetric, ProgressBar progressBar) |
| 66 | + { |
| 67 | + var progressReporter = new ProgressHandlers.RegressionHandler(optimizationMetric, progressBar); |
| 68 | + var result = context.Auto() |
| 69 | + .CreateRegressionExperiment(new RegressionExperimentSettings() |
| 70 | + { |
| 71 | + MaxExperimentTimeInSeconds = settings.MaxExplorationTime, |
| 72 | + ProgressHandler = progressReporter, |
| 73 | + OptimizingMetric = optimizationMetric, |
| 74 | + EnableCaching = this.enableCaching |
| 75 | + }).Execute(trainData, validationData, columnInformation); |
| 76 | + logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline); |
| 77 | + return result; |
| 78 | + } |
105 | 79 |
|
106 |
| - return (pipeline, model); |
| 80 | + IEnumerable<RunResult<MultiClassClassifierMetrics>> IAutoMLEngine.ExploreMultiClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, MulticlassClassificationMetric optimizationMetric, ProgressBar progressBar) |
| 81 | + { |
| 82 | + var progressReporter = new ProgressHandlers.MulticlassClassificationHandler(optimizationMetric, progressBar); |
| 83 | + var result = context.Auto() |
| 84 | + .CreateMulticlassClassificationExperiment(new MulticlassExperimentSettings() |
| 85 | + { |
| 86 | + MaxExperimentTimeInSeconds = settings.MaxExplorationTime, |
| 87 | + ProgressHandler = progressReporter, |
| 88 | + EnableCaching = this.enableCaching, |
| 89 | + OptimizingMetric = optimizationMetric |
| 90 | + }).Execute(trainData, validationData, columnInformation); |
| 91 | + logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline); |
| 92 | + return result; |
107 | 93 | }
|
| 94 | + |
108 | 95 | }
|
109 | 96 | }
|
0 commit comments