Skip to content

Commit 9603fc4

Browse files
authored
Print winning iteration and runtime in CLI (dotnet#288)
* Print best metric and runtime * Print best metric and runtime * Line endings in AutoMLEngine.cs * Rename time column to duration to match Python SDK * Revert to MicroAccuracy and MacroAccuracy spellings * Revert spelling of BinaryClassificationMetricsAgent to BinaryMetricsAgent to reduce merge conflicts * Revert spelling of MulticlassMetricsAgent to MultiMetricsAgent to reduce merge conflicts * missed some files * Fix merge conflict * Update AutoMLEngine.cs
1 parent 4a6921d commit 9603fc4

File tree

9 files changed

+98
-32
lines changed

9 files changed

+98
-32
lines changed

src/Microsoft.ML.Auto/API/RunResult.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ public sealed class RunResult<T>
1515
public ITransformer Model { get { return _modelContainer.GetModel(); } }
1616
public Exception Exception { get; private set; }
1717
public string TrainerName { get; private set; }
18-
public int RuntimeInSeconds { get; private set; }
18+
public double RuntimeInSeconds { get; private set; }
1919
public IEstimator<ITransformer> Estimator { get; private set; }
2020

2121
internal Pipeline Pipeline { get; private set; }
22-
internal int PipelineInferenceTimeInSeconds { get; private set; }
22+
internal double PipelineInferenceTimeInSeconds { get; private set; }
2323

2424
private readonly ModelContainer _modelContainer;
2525

@@ -28,8 +28,8 @@ internal RunResult(ModelContainer modelContainer,
2828
IEstimator<ITransformer> estimator,
2929
Pipeline pipeline,
3030
Exception exception,
31-
int runtimeInSeconds,
32-
int pipelineInferenceTimeInSeconds)
31+
double runtimeInSeconds,
32+
double pipelineInferenceTimeInSeconds)
3333
{
3434
_modelContainer = modelContainer;
3535
ValidationMetrics = metrics;

src/Microsoft.ML.Auto/Experiment/Experiment.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ public List<RunResult<T>> Execute()
101101
// evaluate pipeline
102102
runResult = ProcessPipeline(pipeline);
103103

104-
runResult.RuntimeInSeconds = (int)iterationStopwatch.Elapsed.TotalSeconds;
105-
runResult.PipelineInferenceTimeInSeconds = (int)getPiplelineStopwatch.Elapsed.TotalSeconds;
104+
runResult.RuntimeInSeconds = iterationStopwatch.Elapsed.TotalSeconds;
105+
runResult.PipelineInferenceTimeInSeconds = getPiplelineStopwatch.Elapsed.TotalSeconds;
106106
}
107107
catch (Exception ex)
108108
{

src/Microsoft.ML.Auto/Experiment/MetricsAgents/BinaryMetricsAgent.cs

+5-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
65
using Microsoft.ML.Data;
76

87
namespace Microsoft.ML.Auto
@@ -18,6 +17,11 @@ public BinaryMetricsAgent(BinaryClassificationMetric optimizingMetric)
1817

1918
public double GetScore(BinaryClassificationMetrics metrics)
2019
{
20+
if (metrics == null)
21+
{
22+
return double.NaN;
23+
}
24+
2125
switch (_optimizingMetric)
2226
{
2327
case BinaryClassificationMetric.Accuracy:

src/Microsoft.ML.Auto/Experiment/MetricsAgents/MultiMetricsAgent.cs

+5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ public MultiMetricsAgent(MulticlassClassificationMetric optimizingMetric)
1717

1818
public double GetScore(MultiClassClassifierMetrics metrics)
1919
{
20+
if (metrics == null)
21+
{
22+
return double.NaN;
23+
}
24+
2025
switch (_optimizingMetric)
2126
{
2227
case MulticlassClassificationMetric.MacroAccuracy:

src/Microsoft.ML.Auto/Experiment/MetricsAgents/RegressionMetricsAgent.cs

+5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ public RegressionMetricsAgent(RegressionMetric optimizingMetric)
1717

1818
public double GetScore(RegressionMetrics metrics)
1919
{
20+
if (metrics == null)
21+
{
22+
return double.NaN;
23+
}
24+
2025
switch (_optimizingMetric)
2126
{
2227
case RegressionMetric.MeanAbsoluteError:

src/Microsoft.ML.Auto/Experiment/SuggestedPipelineResult.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ internal class SuggestedPipelineResult<T> : SuggestedPipelineResult
3838
public ModelContainer ModelContainer { get; set; }
3939
public Exception Exception { get; set; }
4040

41-
public int RuntimeInSeconds { get; set; }
42-
public int PipelineInferenceTimeInSeconds { get; set; }
41+
public double RuntimeInSeconds { get; set; }
42+
public double PipelineInferenceTimeInSeconds { get; set; }
4343

4444
public SuggestedPipelineResult(T evaluatedMetrics, IEstimator<ITransformer> estimator,
4545
ModelContainer modelContainer, SuggestedPipeline pipeline, double score, Exception exception)

src/mlnet/AutoML/AutoMLEngine.cs

+11-5
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,15 @@ public ColumnInferenceResults InferColumns(MLContext context, ColumnInformation
5050

5151
if (taskKind == TaskKind.BinaryClassification)
5252
{
53-
var progressReporter = new ProgressHandlers.BinaryClassificationHandler();
53+
var optimizationMetric = new BinaryExperimentSettings().OptimizingMetric;
54+
var progressReporter = new ProgressHandlers.BinaryClassificationHandler(optimizationMetric);
5455
var result = context.Auto()
5556
.CreateBinaryClassificationExperiment(new BinaryExperimentSettings()
5657
{
5758
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
5859
ProgressHandler = progressReporter,
59-
EnableCaching = this.enableCaching
60+
EnableCaching = this.enableCaching,
61+
OptimizingMetric = optimizationMetric
6062
})
6163
.Execute(trainData, validationData, columnInformation);
6264
logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline);
@@ -67,12 +69,14 @@ public ColumnInferenceResults InferColumns(MLContext context, ColumnInformation
6769

6870
if (taskKind == TaskKind.Regression)
6971
{
70-
var progressReporter = new ProgressHandlers.RegressionHandler();
72+
var optimizationMetric = new RegressionExperimentSettings().OptimizingMetric;
73+
var progressReporter = new ProgressHandlers.RegressionHandler(optimizationMetric);
7174
var result = context.Auto()
7275
.CreateRegressionExperiment(new RegressionExperimentSettings()
7376
{
7477
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
7578
ProgressHandler = progressReporter,
79+
OptimizingMetric = optimizationMetric,
7680
EnableCaching = this.enableCaching
7781
}).Execute(trainData, validationData, columnInformation);
7882
logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline);
@@ -83,13 +87,15 @@ public ColumnInferenceResults InferColumns(MLContext context, ColumnInformation
8387

8488
if (taskKind == TaskKind.MulticlassClassification)
8589
{
86-
var progressReporter = new ProgressHandlers.MulticlassClassificationHandler();
90+
var optimizationMetric = new MulticlassExperimentSettings().OptimizingMetric;
91+
var progressReporter = new ProgressHandlers.MulticlassClassificationHandler(optimizationMetric);
8792

8893
var experimentSettings = new MulticlassExperimentSettings()
8994
{
9095
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
9196
ProgressHandler = progressReporter,
92-
EnableCaching = this.enableCaching
97+
EnableCaching = this.enableCaching,
98+
OptimizingMetric = optimizationMetric
9399
};
94100

95101
// Inclusion list for currently supported learners. Need to remove once we have codegen support for all other learners.

src/mlnet/Utilities/ConsolePrinter.cs

+10-9
Original file line numberDiff line numberDiff line change
@@ -12,43 +12,44 @@ internal class ConsolePrinter
1212
private static NLog.Logger logger = NLog.LogManager.GetCurrentClassLogger();
1313

1414

15-
internal static void PrintBinaryClassificationMetrics(int iteration, string trainerName, BinaryClassificationMetrics metrics)
15+
internal static void PrintMetrics(int iteration, string trainerName, BinaryClassificationMetrics metrics, double bestMetric, double runtimeInSeconds)
1616
{
17-
logger.Log(LogLevel.Info, $"{iteration,-4} {trainerName,-35} {metrics?.Accuracy ?? double.NaN,9:F4} {metrics?.Auc ?? double.NaN,8:F4} {metrics?.Auprc ?? double.NaN,8:F4} {metrics?.F1Score ?? double.NaN,9:F4}");
17+
logger.Log(LogLevel.Info, $"{iteration,-4} {trainerName,-35} {metrics?.Accuracy ?? double.NaN,9:F4} {metrics?.Auc ?? double.NaN,8:F4} {metrics?.Auprc ?? double.NaN,8:F4} {metrics?.F1Score ?? double.NaN,9:F4} {bestMetric,8:F4} {runtimeInSeconds,9:F1}");
1818
}
1919

20-
internal static void PrintMulticlassClassificationMetrics(int iteration, string trainerName, MultiClassClassifierMetrics metrics)
20+
internal static void PrintMetrics(int iteration, string trainerName, MultiClassClassifierMetrics metrics, double bestMetric, double runtimeInSeconds)
2121
{
22-
logger.Log(LogLevel.Info, $"{iteration,-4} {trainerName,-35} {metrics?.AccuracyMicro ?? double.NaN,14:F4} {metrics?.AccuracyMacro ?? double.NaN,14:F4}");
22+
logger.Log(LogLevel.Info, $"{iteration,-4} {trainerName,-35} {metrics?.AccuracyMicro ?? double.NaN,14:F4} {metrics?.AccuracyMacro ?? double.NaN,14:F4} {bestMetric,14:F4} {runtimeInSeconds,9:F1}");
2323
}
2424

25-
internal static void PrintRegressionMetrics(int iteration, string trainerName, RegressionMetrics metrics)
25+
internal static void PrintMetrics(int iteration, string trainerName, RegressionMetrics metrics, double bestMetric, double runtimeInSeconds)
2626
{
27-
logger.Log(LogLevel.Info, $"{iteration,-4} {trainerName,-35} {metrics?.RSquared ?? double.NaN,9:F4} {metrics?.LossFn ?? double.NaN,12:F2} {metrics?.L1 ?? double.NaN,15:F2} {metrics?.L2 ?? double.NaN,15:F2} {metrics?.Rms ?? double.NaN,12:F2}");
27+
logger.Log(LogLevel.Info, $"{iteration,-4} {trainerName,-35} {metrics?.RSquared ?? double.NaN,9:F4} {metrics?.LossFn ?? double.NaN,12:F2} {metrics?.L1 ?? double.NaN,15:F2} {metrics?.L2 ?? double.NaN,15:F2} {metrics?.Rms ?? double.NaN,12:F2} {bestMetric,12:F4} {runtimeInSeconds,9:F1}");
2828
}
2929

30+
3031
internal static void PrintBinaryClassificationMetricsHeader()
3132
{
3233
logger.Log(LogLevel.Info, $"*************************************************");
3334
logger.Log(LogLevel.Info, $"* {Strings.MetricsForBinaryClassModels} ");
3435
logger.Log(LogLevel.Info, $"*------------------------------------------------");
35-
logger.Log(LogLevel.Info, $"{" ",-4} {"Trainer",-35} {"Accuracy",9} {"AUC",8} {"AUPRC",8} {"F1-score",9}");
36+
logger.Log(LogLevel.Info, $"{" ",-4} {"Trainer",-35} {"Accuracy",9} {"AUC",8} {"AUPRC",8} {"F1-score",9} {"Best",8} {"Duration",9}");
3637
}
3738

3839
internal static void PrintMulticlassClassificationMetricsHeader()
3940
{
4041
logger.Log(LogLevel.Info, $"*************************************************");
4142
logger.Log(LogLevel.Info, $"* {Strings.MetricsForMulticlassModels} ");
4243
logger.Log(LogLevel.Info, $"*------------------------------------------------");
43-
logger.Log(LogLevel.Info, $"{" ",-4} {"Trainer",-35} {"AccuracyMicro",14} {"AccuracyMacro",14}");
44+
logger.Log(LogLevel.Info, $"{" ",-4} {"Trainer",-35} {"AccuracyMicro",14} {"AccuracyMacro",14} {"Best",14} {"Duration",9}");
4445
}
4546

4647
internal static void PrintRegressionMetricsHeader()
4748
{
4849
logger.Log(LogLevel.Info, $"*************************************************");
4950
logger.Log(LogLevel.Info, $"* {Strings.MetricsForRegressionModels} ");
5051
logger.Log(LogLevel.Info, $"*------------------------------------------------");
51-
logger.Log(LogLevel.Info, $"{" ",-4} {"Trainer",-35} {"R2-Score",9} {"LossFn",12} {"Absolute-loss",15} {"Squared-loss",15} {"RMS-loss",12}");
52+
logger.Log(LogLevel.Info, $"{" ",-4} {"Trainer",-35} {"R2-Score",9} {"LossFn",12} {"Absolute-loss",15} {"Squared-loss",15} {"RMS-loss",12} {"Best",12} {"Duration",9}");
5253
}
5354

5455
internal static void PrintBestPipelineHeader()

src/mlnet/Utilities/ProgressHandlers.cs

+54-9
Original file line numberDiff line numberDiff line change
@@ -10,49 +10,94 @@ namespace Microsoft.ML.CLI.Utilities
1010
{
1111
internal class ProgressHandlers
1212
{
13+
private static int MetricComparator(double a, double b, bool isMaximizing)
14+
{
15+
return (isMaximizing ? a.CompareTo(b) : -a.CompareTo(b));
16+
}
17+
1318
internal class RegressionHandler : IProgress<RunResult<RegressionMetrics>>
1419
{
15-
int iterationIndex;
16-
public RegressionHandler()
20+
private readonly bool isMaximizing;
21+
private readonly Func<RunResult<RegressionMetrics>, double> GetScore;
22+
private RunResult<RegressionMetrics> bestResult;
23+
private int iterationIndex;
24+
25+
public RegressionHandler(RegressionMetric optimizationMetric)
1726
{
27+
isMaximizing = new OptimizingMetricInfo(optimizationMetric).IsMaximizing;
28+
GetScore = (RunResult<RegressionMetrics> result) => new RegressionMetricsAgent(optimizationMetric).GetScore(result?.ValidationMetrics);
1829
ConsolePrinter.PrintRegressionMetricsHeader();
1930
}
2031

2132
public void Report(RunResult<RegressionMetrics> iterationResult)
2233
{
2334
iterationIndex++;
24-
ConsolePrinter.PrintRegressionMetrics(iterationIndex, iterationResult.TrainerName, iterationResult.ValidationMetrics);
35+
UpdateBestResult(iterationResult);
36+
ConsolePrinter.PrintMetrics(iterationIndex, iterationResult.TrainerName, iterationResult.ValidationMetrics, GetScore(bestResult), iterationResult.RuntimeInSeconds);
37+
}
38+
39+
private void UpdateBestResult(RunResult<RegressionMetrics> iterationResult)
40+
{
41+
if (MetricComparator(GetScore(iterationResult), GetScore(bestResult), isMaximizing) > 0)
42+
bestResult = iterationResult;
2543
}
2644
}
2745

2846
internal class BinaryClassificationHandler : IProgress<RunResult<BinaryClassificationMetrics>>
2947
{
30-
int iterationIndex;
31-
internal BinaryClassificationHandler()
48+
private readonly bool isMaximizing;
49+
private readonly Func<RunResult<BinaryClassificationMetrics>, double> GetScore;
50+
private RunResult<BinaryClassificationMetrics> bestResult;
51+
private int iterationIndex;
52+
53+
public BinaryClassificationHandler(BinaryClassificationMetric optimizationMetric)
3254
{
55+
isMaximizing = new OptimizingMetricInfo(optimizationMetric).IsMaximizing;
56+
GetScore = (RunResult<BinaryClassificationMetrics> result) => new BinaryMetricsAgent(optimizationMetric).GetScore(result?.ValidationMetrics);
3357
ConsolePrinter.PrintBinaryClassificationMetricsHeader();
3458
}
3559

3660
public void Report(RunResult<BinaryClassificationMetrics> iterationResult)
3761
{
3862
iterationIndex++;
39-
ConsolePrinter.PrintBinaryClassificationMetrics(iterationIndex, iterationResult.TrainerName, iterationResult.ValidationMetrics);
63+
UpdateBestResult(iterationResult);
64+
ConsolePrinter.PrintMetrics(iterationIndex, iterationResult.TrainerName, iterationResult.ValidationMetrics, GetScore(bestResult), iterationResult.RuntimeInSeconds);
65+
}
66+
67+
private void UpdateBestResult(RunResult<BinaryClassificationMetrics> iterationResult)
68+
{
69+
if (MetricComparator(GetScore(iterationResult), GetScore(bestResult), isMaximizing) > 0)
70+
bestResult = iterationResult;
4071
}
4172
}
4273

4374
internal class MulticlassClassificationHandler : IProgress<RunResult<MultiClassClassifierMetrics>>
4475
{
45-
int iterationIndex;
46-
internal MulticlassClassificationHandler()
76+
private readonly bool isMaximizing;
77+
private readonly Func<RunResult<MultiClassClassifierMetrics>, double> GetScore;
78+
private RunResult<MultiClassClassifierMetrics> bestResult;
79+
private int iterationIndex;
80+
81+
public MulticlassClassificationHandler(MulticlassClassificationMetric optimizationMetric)
4782
{
83+
isMaximizing = new OptimizingMetricInfo(optimizationMetric).IsMaximizing;
84+
GetScore = (RunResult<MultiClassClassifierMetrics> result) => new MultiMetricsAgent(optimizationMetric).GetScore(result?.ValidationMetrics);
4885
ConsolePrinter.PrintMulticlassClassificationMetricsHeader();
4986
}
5087

5188
public void Report(RunResult<MultiClassClassifierMetrics> iterationResult)
5289
{
5390
iterationIndex++;
54-
ConsolePrinter.PrintMulticlassClassificationMetrics(iterationIndex, iterationResult.TrainerName, iterationResult.ValidationMetrics);
91+
UpdateBestResult(iterationResult);
92+
ConsolePrinter.PrintMetrics(iterationIndex, iterationResult.TrainerName, iterationResult.ValidationMetrics, GetScore(bestResult), iterationResult.RuntimeInSeconds);
93+
}
94+
95+
private void UpdateBestResult(RunResult<MultiClassClassifierMetrics> iterationResult)
96+
{
97+
if (MetricComparator(GetScore(iterationResult), GetScore(bestResult), isMaximizing) > 0)
98+
bestResult = iterationResult;
5599
}
56100
}
101+
57102
}
58103
}

0 commit comments

Comments
 (0)