|
2 | 2 | // The .NET Foundation licenses this file to you under the MIT license.
|
3 | 3 | // See the LICENSE file in the project root for more information.
|
4 | 4 |
|
5 |
| -using System.Collections.Generic; |
| 5 | +using System; |
6 | 6 | using Microsoft.ML.AutoML;
|
7 | 7 | using Microsoft.ML.CLI.Data;
|
8 | 8 | using Microsoft.ML.CLI.ShellProgressBar;
|
@@ -44,47 +44,44 @@ public ColumnInferenceResults InferColumns(MLContext context, ColumnInformation
|
44 | 44 | return columnInference;
|
45 | 45 | }
|
46 | 46 |
|
47 |
| - ExperimentResult<BinaryClassificationMetrics> IAutoMLEngine.ExploreBinaryClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, BinaryClassificationMetric optimizationMetric, ProgressBar progressBar) |
| 47 | + void IAutoMLEngine.ExploreBinaryClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, BinaryClassificationMetric optimizationMetric, ProgressHandlers.BinaryClassificationHandler handler, ProgressBar progressBar) |
48 | 48 | {
|
49 |
| - var progressReporter = new ProgressHandlers.BinaryClassificationHandler(optimizationMetric, progressBar); |
50 |
| - var result = context.Auto() |
| 49 | + ExperimentResult<BinaryClassificationMetrics> result = context.Auto() |
51 | 50 | .CreateBinaryClassificationExperiment(new BinaryExperimentSettings()
|
52 | 51 | {
|
53 | 52 | MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
|
54 | 53 | CacheBeforeTrainer = this.cacheBeforeTrainer,
|
55 | 54 | OptimizingMetric = optimizationMetric
|
56 | 55 | })
|
57 |
| - .Execute(trainData, validationData, columnInformation, progressHandler: progressReporter); |
| 56 | + .Execute(trainData, validationData, columnInformation, progressHandler: handler); |
| 57 | + |
58 | 58 | logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline);
|
59 |
| - return result; |
60 | 59 | }
|
61 | 60 |
|
62 |
| - ExperimentResult<RegressionMetrics> IAutoMLEngine.ExploreRegressionModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, RegressionMetric optimizationMetric, ProgressBar progressBar) |
| 61 | + void IAutoMLEngine.ExploreRegressionModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, RegressionMetric optimizationMetric, ProgressHandlers.RegressionHandler handler, ProgressBar progressBar) |
63 | 62 | {
|
64 |
| - var progressReporter = new ProgressHandlers.RegressionHandler(optimizationMetric, progressBar); |
65 |
| - var result = context.Auto() |
| 63 | + ExperimentResult<RegressionMetrics> result = context.Auto() |
66 | 64 | .CreateRegressionExperiment(new RegressionExperimentSettings()
|
67 | 65 | {
|
68 | 66 | MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
|
69 | 67 | OptimizingMetric = optimizationMetric,
|
70 | 68 | CacheBeforeTrainer = this.cacheBeforeTrainer
|
71 |
| - }).Execute(trainData, validationData, columnInformation, progressHandler: progressReporter); |
| 69 | + }).Execute(trainData, validationData, columnInformation, progressHandler: handler); |
| 70 | + |
72 | 71 | logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline);
|
73 |
| - return result; |
74 | 72 | }
|
75 | 73 |
|
76 |
| - ExperimentResult<MulticlassClassificationMetrics> IAutoMLEngine.ExploreMultiClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, MulticlassClassificationMetric optimizationMetric, ProgressBar progressBar) |
| 74 | + void IAutoMLEngine.ExploreMultiClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, MulticlassClassificationMetric optimizationMetric, ProgressHandlers.MulticlassClassificationHandler handler, ProgressBar progressBar) |
77 | 75 | {
|
78 |
| - var progressReporter = new ProgressHandlers.MulticlassClassificationHandler(optimizationMetric, progressBar); |
79 |
| - var result = context.Auto() |
| 76 | + ExperimentResult<MulticlassClassificationMetrics> result = context.Auto() |
80 | 77 | .CreateMulticlassClassificationExperiment(new MulticlassExperimentSettings()
|
81 | 78 | {
|
82 | 79 | MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
|
83 | 80 | CacheBeforeTrainer = this.cacheBeforeTrainer,
|
84 | 81 | OptimizingMetric = optimizationMetric
|
85 |
| - }).Execute(trainData, validationData, columnInformation, progressHandler: progressReporter); |
| 82 | + }).Execute(trainData, validationData, columnInformation, progressHandler: handler); |
| 83 | + |
86 | 84 | logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline);
|
87 |
| - return result; |
88 | 85 | }
|
89 | 86 |
|
90 | 87 | }
|
|
0 commit comments