Skip to content

Commit 43c62ae

Browse files
srsaggamDmitry-A
authored andcommitted
Initial version of Progress bar impl and CLI UI experience (dotnet#325)
* progressbar * added progressbar and refactoring * reverted * revert sign assembly * added headers and removed exception rethrow
1 parent 6162944 commit 43c62ae

21 files changed

+1072
-133
lines changed

src/Microsoft.ML.Auto/API/RegressionExperiment.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ internal IEnumerable<RunResult<RegressionMetrics>> Execute(MLContext context,
9191
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, columnInfo, validationData);
9292

9393
// run autofit & get all pipelines run in that process
94-
var experiment = new Experiment<RegressionMetrics>(context, TaskKind.Regression, trainData, columnInfo,
95-
validationData, preFeaturizers, new OptimizingMetricInfo(_settings.OptimizingMetric),
94+
var experiment = new Experiment<RegressionMetrics>(context, TaskKind.Regression, trainData, columnInfo,
95+
validationData, preFeaturizers, new OptimizingMetricInfo(_settings.OptimizingMetric),
9696
_settings.ProgressHandler, _settings, new RegressionMetricsAgent(_settings.OptimizingMetric),
9797
TrainerExtensionUtil.GetTrainerNames(_settings.Trainers));
9898

src/Microsoft.ML.Auto/Utils/RunResultUtil.cs

+11
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,16 @@ public static RunResult<T> GetBestRunResult<T>(IEnumerable<RunResult<T>> results
1818
double maxScore = results.Select(r => metricsAgent.GetScore(r.ValidationMetrics)).Max();
1919
return results.First(r => Math.Abs(metricsAgent.GetScore(r.ValidationMetrics) - maxScore) < 1E-20);
2020
}
21+
22+
public static IEnumerable<RunResult<T>> GetTopNRunResults<T>(IEnumerable<RunResult<T>> results,
23+
IMetricsAgent<T> metricsAgent, int n)
24+
{
25+
results = results.Where(r => r.ValidationMetrics != null);
26+
if (!results.Any()) { return null; }
27+
28+
var orderedResults = results.OrderByDescending(t => metricsAgent.GetScore(t.ValidationMetrics));
29+
30+
return orderedResults.Take(n);
31+
}
2132
}
2233
}

src/mlnet/AutoML/AutoMLEngine.cs

+46-59
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
6+
using System.Collections.Generic;
57
using Microsoft.Data.DataView;
68
using Microsoft.ML.Auto;
79
using Microsoft.ML.CLI.Data;
10+
using Microsoft.ML.CLI.ShellProgressBar;
811
using Microsoft.ML.CLI.Utilities;
12+
using Microsoft.ML.Data;
913
using NLog;
1014

1115
namespace Microsoft.ML.CLI.CodeGenerator
@@ -42,68 +46,51 @@ public ColumnInferenceResults InferColumns(MLContext context, ColumnInformation
4246
return columnInference;
4347
}
4448

45-
(Pipeline, ITransformer) IAutoMLEngine.ExploreModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation)
49+
IEnumerable<RunResult<BinaryClassificationMetrics>> IAutoMLEngine.ExploreBinaryClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, BinaryClassificationMetric optimizationMetric, ProgressBar progressBar)
4650
{
47-
ITransformer model = null;
48-
49-
Pipeline pipeline = null;
50-
51-
if (taskKind == TaskKind.BinaryClassification)
52-
{
53-
var optimizationMetric = new BinaryExperimentSettings().OptimizingMetric;
54-
var progressReporter = new ProgressHandlers.BinaryClassificationHandler(optimizationMetric);
55-
var result = context.Auto()
56-
.CreateBinaryClassificationExperiment(new BinaryExperimentSettings()
57-
{
58-
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
59-
ProgressHandler = progressReporter,
60-
EnableCaching = this.enableCaching,
61-
OptimizingMetric = optimizationMetric
62-
})
63-
.Execute(trainData, validationData, columnInformation);
64-
logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline);
65-
var bestIteration = result.Best();
66-
pipeline = bestIteration.Pipeline;
67-
model = bestIteration.Model;
68-
}
69-
70-
if (taskKind == TaskKind.Regression)
71-
{
72-
var optimizationMetric = new RegressionExperimentSettings().OptimizingMetric;
73-
var progressReporter = new ProgressHandlers.RegressionHandler(optimizationMetric);
74-
var result = context.Auto()
75-
.CreateRegressionExperiment(new RegressionExperimentSettings()
76-
{
77-
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
78-
ProgressHandler = progressReporter,
79-
OptimizingMetric = optimizationMetric,
80-
EnableCaching = this.enableCaching
81-
}).Execute(trainData, validationData, columnInformation);
82-
logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline);
83-
var bestIteration = result.Best();
84-
pipeline = bestIteration.Pipeline;
85-
model = bestIteration.Model;
86-
}
51+
var progressReporter = new ProgressHandlers.BinaryClassificationHandler(optimizationMetric, progressBar);
52+
var result = context.Auto()
53+
.CreateBinaryClassificationExperiment(new BinaryExperimentSettings()
54+
{
55+
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
56+
ProgressHandler = progressReporter,
57+
EnableCaching = this.enableCaching,
58+
OptimizingMetric = optimizationMetric
59+
})
60+
.Execute(trainData, validationData, columnInformation);
61+
logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline);
62+
return result;
63+
}
8764

88-
if (taskKind == TaskKind.MulticlassClassification)
89-
{
90-
var optimizationMetric = new MulticlassExperimentSettings().OptimizingMetric;
91-
var progressReporter = new ProgressHandlers.MulticlassClassificationHandler(optimizationMetric);
92-
var result = context.Auto()
93-
.CreateMulticlassClassificationExperiment(new MulticlassExperimentSettings()
94-
{
95-
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
96-
ProgressHandler = progressReporter,
97-
EnableCaching = this.enableCaching,
98-
OptimizingMetric = optimizationMetric
99-
}).Execute(trainData, validationData, columnInformation);
100-
logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline);
101-
var bestIteration = result.Best();
102-
pipeline = bestIteration.Pipeline;
103-
model = bestIteration.Model;
104-
}
65+
IEnumerable<RunResult<RegressionMetrics>> IAutoMLEngine.ExploreRegressionModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, RegressionMetric optimizationMetric, ProgressBar progressBar)
66+
{
67+
var progressReporter = new ProgressHandlers.RegressionHandler(optimizationMetric, progressBar);
68+
var result = context.Auto()
69+
.CreateRegressionExperiment(new RegressionExperimentSettings()
70+
{
71+
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
72+
ProgressHandler = progressReporter,
73+
OptimizingMetric = optimizationMetric,
74+
EnableCaching = this.enableCaching
75+
}).Execute(trainData, validationData, columnInformation);
76+
logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline);
77+
return result;
78+
}
10579

106-
return (pipeline, model);
80+
IEnumerable<RunResult<MultiClassClassifierMetrics>> IAutoMLEngine.ExploreMultiClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, MulticlassClassificationMetric optimizationMetric, ProgressBar progressBar)
81+
{
82+
var progressReporter = new ProgressHandlers.MulticlassClassificationHandler(optimizationMetric, progressBar);
83+
var result = context.Auto()
84+
.CreateMulticlassClassificationExperiment(new MulticlassExperimentSettings()
85+
{
86+
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
87+
ProgressHandler = progressReporter,
88+
EnableCaching = this.enableCaching,
89+
OptimizingMetric = optimizationMetric
90+
}).Execute(trainData, validationData, columnInformation);
91+
logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline);
92+
return result;
10793
}
94+
10895
}
10996
}

src/mlnet/AutoML/IAutoMLEngine.cs

+8-1
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,23 @@
33
// The .NET Foundation licenses this file to you under the MIT license.
44
// See the LICENSE file in the project root for more information.
55

6+
using System.Collections.Generic;
67
using Microsoft.Data.DataView;
78
using Microsoft.ML.Auto;
9+
using Microsoft.ML.CLI.ShellProgressBar;
10+
using Microsoft.ML.Data;
811

912
namespace Microsoft.ML.CLI.CodeGenerator
1013
{
1114
internal interface IAutoMLEngine
1215
{
1316
ColumnInferenceResults InferColumns(MLContext context, ColumnInformation columnInformation);
1417

15-
(Pipeline, ITransformer) ExploreModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation);
18+
IEnumerable<RunResult<BinaryClassificationMetrics>> ExploreBinaryClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, BinaryClassificationMetric optimizationMetric, ProgressBar progressBar);
19+
20+
IEnumerable<RunResult<MultiClassClassifierMetrics>> ExploreMultiClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, MulticlassClassificationMetric optimizationMetric, ProgressBar progressBar);
21+
22+
IEnumerable<RunResult<RegressionMetrics>> ExploreRegressionModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, RegressionMetric optimizationMetric, ProgressBar progressBar);
1623

1724
}
1825
}

src/mlnet/CodeGenerator/CodeGenerationHelper.cs

+81-7
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,17 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections.Generic;
67
using System.IO;
8+
using System.Linq;
9+
using System.Runtime.ExceptionServices;
10+
using System.Threading;
11+
using System.Threading.Tasks;
712
using Microsoft.Data.DataView;
813
using Microsoft.ML.Auto;
914
using Microsoft.ML.CLI.CodeGenerator.CSharp;
1015
using Microsoft.ML.CLI.Data;
16+
using Microsoft.ML.CLI.ShellProgressBar;
1117
using Microsoft.ML.CLI.Utilities;
1218
using Microsoft.ML.Data;
1319
using NLog;
@@ -21,6 +27,7 @@ internal class CodeGenerationHelper
2127
private NewCommandSettings settings;
2228
private static Logger logger = LogManager.GetCurrentClassLogger();
2329
private TaskKind taskKind;
30+
2431
public CodeGenerationHelper(IAutoMLEngine automlEngine, NewCommandSettings settings)
2532
{
2633
this.automlEngine = automlEngine;
@@ -64,11 +71,54 @@ public void GenerateCode()
6471
(IDataView trainData, IDataView validationData) = LoadData(context, textLoaderOptions);
6572

6673
// Explore the models
67-
(Pipeline, ITransformer) result = default;
74+
75+
// The reason why we are doing this way of defining 3 different results is because of the AutoML API
76+
// i.e there is no common class/interface to handle all three tasks together.
77+
78+
IEnumerable<RunResult<BinaryClassificationMetrics>> binaryRunResults = default;
79+
IEnumerable<RunResult<MultiClassClassifierMetrics>> multiRunResults = default;
80+
IEnumerable<RunResult<RegressionMetrics>> regressionRunResults = default;
81+
6882
Console.WriteLine($"{Strings.ExplorePipeline}: {settings.MlTask}");
6983
try
7084
{
71-
result = automlEngine.ExploreModels(context, trainData, validationData, columnInformation);
85+
var options = new ProgressBarOptions
86+
{
87+
ForegroundColor = ConsoleColor.Yellow,
88+
ForegroundColorDone = ConsoleColor.DarkGreen,
89+
BackgroundColor = ConsoleColor.DarkGray,
90+
BackgroundCharacter = '\u2593'
91+
};
92+
var wait = TimeSpan.FromSeconds(settings.MaxExplorationTime);
93+
using (var pbar = new FixedDurationBar(wait, "", options))
94+
{
95+
Task t = default;
96+
switch (taskKind)
97+
{
98+
case TaskKind.BinaryClassification:
99+
t = Task.Run(() => binaryRunResults = automlEngine.ExploreBinaryClassificationModels(context, trainData, validationData, columnInformation, new BinaryExperimentSettings().OptimizingMetric, pbar));
100+
break;
101+
case TaskKind.Regression:
102+
t = Task.Run(() => regressionRunResults = automlEngine.ExploreRegressionModels(context, trainData, validationData, columnInformation, new RegressionExperimentSettings().OptimizingMetric, pbar));
103+
break;
104+
case TaskKind.MulticlassClassification:
105+
t = Task.Run(() => multiRunResults = automlEngine.ExploreMultiClassificationModels(context, trainData, validationData, columnInformation, new MulticlassExperimentSettings().OptimizingMetric, pbar));
106+
break;
107+
default:
108+
logger.Log(LogLevel.Error, Strings.UnsupportedMlTask);
109+
break;
110+
}
111+
112+
if (!pbar.CompletedHandle.WaitOne(wait))
113+
Console.Error.WriteLine($"{nameof(FixedDurationBar)} did not signal {nameof(FixedDurationBar.CompletedHandle)} after {wait}");
114+
115+
if (t.IsCompleted == false)
116+
{
117+
logger.Log(LogLevel.Info, "Waiting for the last iteration to complete ...");
118+
}
119+
t.Wait();
120+
}
121+
72122
}
73123
catch (Exception e)
74124
{
@@ -80,18 +130,42 @@ public void GenerateCode()
80130
}
81131

82132
//Get the best pipeline
83-
Pipeline pipeline = null;
84-
pipeline = result.Item1;
85-
var model = result.Item2;
133+
Pipeline bestPipeline = null;
134+
ITransformer bestModel = null;
135+
136+
switch (taskKind)
137+
{
138+
case TaskKind.BinaryClassification:
139+
var bestBinaryIteration = binaryRunResults.Best();
140+
bestPipeline = bestBinaryIteration.Pipeline;
141+
bestModel = bestBinaryIteration.Model;
142+
ConsolePrinter.ExperimentResultsHeader(LogLevel.Info, settings.MlTask, settings.Dataset.Name, columnInformation.LabelColumn, settings.MaxExplorationTime.ToString(), binaryRunResults.Count());
143+
ConsolePrinter.PrintIterationSummary(binaryRunResults, new BinaryExperimentSettings().OptimizingMetric, 5);
144+
break;
145+
case TaskKind.Regression:
146+
var bestRegressionIteration = regressionRunResults.Best();
147+
bestPipeline = bestRegressionIteration.Pipeline;
148+
bestModel = bestRegressionIteration.Model;
149+
ConsolePrinter.ExperimentResultsHeader(LogLevel.Info, settings.MlTask, settings.Dataset.Name, columnInformation.LabelColumn, settings.MaxExplorationTime.ToString(), regressionRunResults.Count());
150+
ConsolePrinter.PrintIterationSummary(regressionRunResults, new RegressionExperimentSettings().OptimizingMetric, 5);
151+
break;
152+
case TaskKind.MulticlassClassification:
153+
var bestMultiIteration = multiRunResults.Best();
154+
bestPipeline = bestMultiIteration.Pipeline;
155+
bestModel = bestMultiIteration.Model;
156+
ConsolePrinter.ExperimentResultsHeader(LogLevel.Info, settings.MlTask, settings.Dataset.Name, columnInformation.LabelColumn, settings.MaxExplorationTime.ToString(), multiRunResults.Count());
157+
ConsolePrinter.PrintIterationSummary(multiRunResults, new MulticlassExperimentSettings().OptimizingMetric, 5);
158+
break;
159+
}
86160

87161
// Save the model
88162
logger.Log(LogLevel.Info, Strings.SavingBestModel);
89163
var modelprojectDir = Path.Combine(settings.OutputPath.FullName, $"{settings.Name}.Model");
90164
var modelPath = new FileInfo(Path.Combine(modelprojectDir, "MLModel.zip"));
91-
Utils.SaveModel(model, modelPath, context);
165+
Utils.SaveModel(bestModel, modelPath, context);
92166

93167
// Generate the Project
94-
GenerateProject(columnInference, pipeline, columnInformation.LabelColumn, modelPath);
168+
GenerateProject(columnInference, bestPipeline, columnInformation.LabelColumn, modelPath);
95169
}
96170

97171
internal void GenerateProject(ColumnInferenceResults columnInference, Pipeline pipeline, string labelName, FileInfo modelPath)

src/mlnet/NLog.config

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
44

55
<targets>
6-
<target name="logfile" xsi:type="File" fileName="debug_log.txt" />
6+
<target name="logfile" xsi:type="File" fileName="debug_log.txt" layout="${message}" />
77
<target name="logconsole" xsi:type="Console" layout="${message}" />
88
</targets>
99

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
7+
namespace Microsoft.ML.CLI.ShellProgressBar
8+
{
9+
public class ChildProgressBar : ProgressBarBase, IProgressBar
10+
{
11+
private readonly Action _scheduleDraw;
12+
private readonly Action<ProgressBarHeight> _growth;
13+
14+
public DateTime StartDate { get; } = DateTime.Now;
15+
16+
protected override void DisplayProgress() => _scheduleDraw?.Invoke();
17+
18+
internal ChildProgressBar(int maxTicks, string message, Action scheduleDraw, ProgressBarOptions options = null, Action<ProgressBarHeight> growth = null)
19+
: base(maxTicks, message, options)
20+
{
21+
_scheduleDraw = scheduleDraw;
22+
_growth = growth;
23+
_growth?.Invoke(ProgressBarHeight.Increment);
24+
}
25+
26+
private bool _calledDone;
27+
private readonly object _callOnce = new object();
28+
29+
protected override void OnDone()
30+
{
31+
if (_calledDone) return;
32+
lock (_callOnce)
33+
{
34+
if (_calledDone) return;
35+
36+
if (this.EndTime == null)
37+
this.EndTime = DateTime.Now;
38+
39+
if (this.Collapse)
40+
_growth?.Invoke(ProgressBarHeight.Decrement);
41+
42+
_calledDone = true;
43+
}
44+
}
45+
46+
public void Dispose()
47+
{
48+
OnDone();
49+
foreach (var c in this.Children) c.Dispose();
50+
}
51+
}
52+
}

0 commit comments

Comments
 (0)