Skip to content

Commit 769fc6c

Browse files
srsaggamDmitry-A
authored andcommitted
[AutoML] Early stopping in CLI based on the exploration time (dotnet#3641)
* early stopping in CLI * remove unused variables * change back to thread * remove sleep * fix review comments * remove ununsed usings * format message * collapse declaration * remove unused param * added environment.exit and removal of error message * correction in message * secs-> seconds * exit code * change value to 1 * reverse the declaration
1 parent c19fd08 commit 769fc6c

File tree

8 files changed

+262
-126
lines changed

8 files changed

+262
-126
lines changed

src/Microsoft.ML.AutoML/Utils/BestResultUtil.cs

+25
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,36 @@
44

55
using System.Collections.Generic;
66
using System.Linq;
7+
using Microsoft.ML.Data;
78

89
namespace Microsoft.ML.AutoML
910
{
1011
internal class BestResultUtil
1112
{
13+
public static RunDetail<BinaryClassificationMetrics> GetBestRun(IEnumerable<RunDetail<BinaryClassificationMetrics>> results,
14+
BinaryClassificationMetric metric)
15+
{
16+
var metricsAgent = new BinaryMetricsAgent(null, metric);
17+
var metricInfo = new OptimizingMetricInfo(metric);
18+
return GetBestRun(results, metricsAgent, metricInfo.IsMaximizing);
19+
}
20+
21+
public static RunDetail<RegressionMetrics> GetBestRun(IEnumerable<RunDetail<RegressionMetrics>> results,
22+
RegressionMetric metric)
23+
{
24+
var metricsAgent = new RegressionMetricsAgent(null, metric);
25+
var metricInfo = new OptimizingMetricInfo(metric);
26+
return GetBestRun(results, metricsAgent, metricInfo.IsMaximizing);
27+
}
28+
29+
public static RunDetail<MulticlassClassificationMetrics> GetBestRun(IEnumerable<RunDetail<MulticlassClassificationMetrics>> results,
30+
MulticlassClassificationMetric metric)
31+
{
32+
var metricsAgent = new MultiMetricsAgent(null, metric);
33+
var metricInfo = new OptimizingMetricInfo(metric);
34+
return GetBestRun(results, metricsAgent, metricInfo.IsMaximizing);
35+
}
36+
1237
public static RunDetail<TMetrics> GetBestRun<TMetrics>(IEnumerable<RunDetail<TMetrics>> results,
1338
IMetricsAgent<TMetrics> metricsAgent, bool isMetricMaximizing)
1439
{

src/mlnet/AutoML/AutoMLEngine.cs

+13-16
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System.Collections.Generic;
5+
using System;
66
using Microsoft.ML.AutoML;
77
using Microsoft.ML.CLI.Data;
88
using Microsoft.ML.CLI.ShellProgressBar;
@@ -44,47 +44,44 @@ public ColumnInferenceResults InferColumns(MLContext context, ColumnInformation
4444
return columnInference;
4545
}
4646

47-
ExperimentResult<BinaryClassificationMetrics> IAutoMLEngine.ExploreBinaryClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, BinaryClassificationMetric optimizationMetric, ProgressBar progressBar)
47+
void IAutoMLEngine.ExploreBinaryClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, BinaryClassificationMetric optimizationMetric, ProgressHandlers.BinaryClassificationHandler handler, ProgressBar progressBar)
4848
{
49-
var progressReporter = new ProgressHandlers.BinaryClassificationHandler(optimizationMetric, progressBar);
50-
var result = context.Auto()
49+
ExperimentResult<BinaryClassificationMetrics> result = context.Auto()
5150
.CreateBinaryClassificationExperiment(new BinaryExperimentSettings()
5251
{
5352
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
5453
CacheBeforeTrainer = this.cacheBeforeTrainer,
5554
OptimizingMetric = optimizationMetric
5655
})
57-
.Execute(trainData, validationData, columnInformation, progressHandler: progressReporter);
56+
.Execute(trainData, validationData, columnInformation, progressHandler: handler);
57+
5858
logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline);
59-
return result;
6059
}
6160

62-
ExperimentResult<RegressionMetrics> IAutoMLEngine.ExploreRegressionModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, RegressionMetric optimizationMetric, ProgressBar progressBar)
61+
void IAutoMLEngine.ExploreRegressionModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, RegressionMetric optimizationMetric, ProgressHandlers.RegressionHandler handler, ProgressBar progressBar)
6362
{
64-
var progressReporter = new ProgressHandlers.RegressionHandler(optimizationMetric, progressBar);
65-
var result = context.Auto()
63+
ExperimentResult<RegressionMetrics> result = context.Auto()
6664
.CreateRegressionExperiment(new RegressionExperimentSettings()
6765
{
6866
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
6967
OptimizingMetric = optimizationMetric,
7068
CacheBeforeTrainer = this.cacheBeforeTrainer
71-
}).Execute(trainData, validationData, columnInformation, progressHandler: progressReporter);
69+
}).Execute(trainData, validationData, columnInformation, progressHandler: handler);
70+
7271
logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline);
73-
return result;
7472
}
7573

76-
ExperimentResult<MulticlassClassificationMetrics> IAutoMLEngine.ExploreMultiClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, MulticlassClassificationMetric optimizationMetric, ProgressBar progressBar)
74+
void IAutoMLEngine.ExploreMultiClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, MulticlassClassificationMetric optimizationMetric, ProgressHandlers.MulticlassClassificationHandler handler, ProgressBar progressBar)
7775
{
78-
var progressReporter = new ProgressHandlers.MulticlassClassificationHandler(optimizationMetric, progressBar);
79-
var result = context.Auto()
76+
ExperimentResult<MulticlassClassificationMetrics> result = context.Auto()
8077
.CreateMulticlassClassificationExperiment(new MulticlassExperimentSettings()
8178
{
8279
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
8380
CacheBeforeTrainer = this.cacheBeforeTrainer,
8481
OptimizingMetric = optimizationMetric
85-
}).Execute(trainData, validationData, columnInformation, progressHandler: progressReporter);
82+
}).Execute(trainData, validationData, columnInformation, progressHandler: handler);
83+
8684
logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline);
87-
return result;
8885
}
8986

9087
}

src/mlnet/AutoML/IAutoMLEngine.cs

+5-3
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
56
using System.Collections.Generic;
67
using Microsoft.ML.AutoML;
78
using Microsoft.ML.CLI.ShellProgressBar;
9+
using Microsoft.ML.CLI.Utilities;
810
using Microsoft.ML.Data;
911

1012
namespace Microsoft.ML.CLI.CodeGenerator
@@ -13,11 +15,11 @@ internal interface IAutoMLEngine
1315
{
1416
ColumnInferenceResults InferColumns(MLContext context, ColumnInformation columnInformation);
1517

16-
ExperimentResult<BinaryClassificationMetrics> ExploreBinaryClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, BinaryClassificationMetric optimizationMetric, ProgressBar progressBar = null);
18+
void ExploreBinaryClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, BinaryClassificationMetric optimizationMetric, ProgressHandlers.BinaryClassificationHandler handler, ProgressBar progressBar = null);
1719

18-
ExperimentResult<MulticlassClassificationMetrics> ExploreMultiClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, MulticlassClassificationMetric optimizationMetric, ProgressBar progressBar = null);
20+
void ExploreMultiClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, MulticlassClassificationMetric optimizationMetric, ProgressHandlers.MulticlassClassificationHandler handler, ProgressBar progressBar = null);
1921

20-
ExperimentResult<RegressionMetrics> ExploreRegressionModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, RegressionMetric optimizationMetric, ProgressBar progressBar = null);
22+
void ExploreRegressionModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, RegressionMetric optimizationMetric, ProgressHandlers.RegressionHandler handler, ProgressBar progressBar = null);
2123

2224
}
2325
}

0 commit comments

Comments
 (0)