Skip to content

Commit d8bc4ed

Browse files
authored
[AutoML] Fix error handling in CLI. (#3618)
* fix error handling * renaming variables
1 parent 2c1cdc1 commit d8bc4ed

File tree

2 files changed

+63
-14
lines changed

2 files changed

+63
-14
lines changed

src/mlnet/CodeGenerator/CodeGenerationHelper.cs

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,17 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.Collections.Generic;
76
using System.Diagnostics;
87
using System.IO;
98
using System.Linq;
10-
using System.Runtime.ExceptionServices;
119
using System.Threading;
12-
using System.Threading.Tasks;
1310
using Microsoft.ML.AutoML;
1411
using Microsoft.ML.CLI.CodeGenerator.CSharp;
1512
using Microsoft.ML.CLI.Data;
1613
using Microsoft.ML.CLI.ShellProgressBar;
1714
using Microsoft.ML.CLI.Utilities;
1815
using Microsoft.ML.Data;
1916
using NLog;
20-
using NLog.Targets;
2117

2218
namespace Microsoft.ML.CLI.CodeGenerator
2319
{
@@ -81,7 +77,7 @@ public void GenerateCode()
8177
// i.e there is no common class/interface to handle all three tasks together.
8278

8379
ExperimentResult<BinaryClassificationMetrics> binaryExperimentResult = default;
84-
ExperimentResult<MulticlassClassificationMetrics> multiExperimentResult = default;
80+
ExperimentResult<MulticlassClassificationMetrics> multiclassExperimentResult = default;
8581
ExperimentResult<RegressionMetrics> regressionExperimentResult = default;
8682
if (verboseLevel > LogLevel.Trace)
8783
{
@@ -111,20 +107,22 @@ public void GenerateCode()
111107

112108
if (verboseLevel > LogLevel.Trace && !Console.IsOutputRedirected)
113109
{
110+
Exception ex = null;
114111
using (var pbar = new FixedDurationBar(wait, "", options))
115112
{
116113
pbar.Message = Strings.WaitingForFirstIteration;
117114
Thread t = default;
118115
switch (taskKind)
119116
{
117+
// TODO: It may be a good idea to convert the below Threads to Tasks or get rid of this progress bar all together and use an existing one in opensource.
120118
case TaskKind.BinaryClassification:
121-
t = new Thread(() => binaryExperimentResult = automlEngine.ExploreBinaryClassificationModels(context, trainData, validationData, columnInformation, new BinaryExperimentSettings().OptimizingMetric, pbar));
119+
t = new Thread(() => SafeExecute(() => automlEngine.ExploreBinaryClassificationModels(context, trainData, validationData, columnInformation, new BinaryExperimentSettings().OptimizingMetric, pbar), out ex, out binaryExperimentResult, pbar));
122120
break;
123121
case TaskKind.Regression:
124-
t = new Thread(() => regressionExperimentResult = automlEngine.ExploreRegressionModels(context, trainData, validationData, columnInformation, new RegressionExperimentSettings().OptimizingMetric, pbar));
122+
t = new Thread(() => SafeExecute(() => automlEngine.ExploreRegressionModels(context, trainData, validationData, columnInformation, new RegressionExperimentSettings().OptimizingMetric, pbar), out ex, out regressionExperimentResult, pbar));
125123
break;
126124
case TaskKind.MulticlassClassification:
127-
t = new Thread(() => multiExperimentResult = automlEngine.ExploreMultiClassificationModels(context, trainData, validationData, columnInformation, new MulticlassExperimentSettings().OptimizingMetric, pbar));
125+
t = new Thread(() => SafeExecute(() => automlEngine.ExploreMultiClassificationModels(context, trainData, validationData, columnInformation, new MulticlassExperimentSettings().OptimizingMetric, pbar), out ex, out multiclassExperimentResult, pbar));
128126
break;
129127
default:
130128
logger.Log(LogLevel.Error, Strings.UnsupportedMlTask);
@@ -147,6 +145,10 @@ public void GenerateCode()
147145
pbar.Message = originalMessage;
148146
}
149147
}
148+
if (ex != null)
149+
{
150+
throw ex;
151+
}
150152
}
151153
}
152154
else
@@ -160,7 +162,7 @@ public void GenerateCode()
160162
regressionExperimentResult = automlEngine.ExploreRegressionModels(context, trainData, validationData, columnInformation, new RegressionExperimentSettings().OptimizingMetric);
161163
break;
162164
case TaskKind.MulticlassClassification:
163-
multiExperimentResult = automlEngine.ExploreMultiClassificationModels(context, trainData, validationData, columnInformation, new MulticlassExperimentSettings().OptimizingMetric);
165+
multiclassExperimentResult = automlEngine.ExploreMultiClassificationModels(context, trainData, validationData, columnInformation, new MulticlassExperimentSettings().OptimizingMetric);
164166
break;
165167
default:
166168
logger.Log(LogLevel.Error, Strings.UnsupportedMlTask);
@@ -204,11 +206,11 @@ public void GenerateCode()
204206
ConsolePrinter.PrintIterationSummary(regressionExperimentResult.RunDetails, new RegressionExperimentSettings().OptimizingMetric, 5);
205207
break;
206208
case TaskKind.MulticlassClassification:
207-
var bestMultiIteration = multiExperimentResult.BestRun;
208-
bestPipeline = bestMultiIteration.Pipeline;
209-
bestModel = bestMultiIteration.Model;
210-
ConsolePrinter.ExperimentResultsHeader(LogLevel.Info, settings.MlTask, settings.Dataset.Name, columnInformation.LabelColumnName, elapsedTime.ToString("F2"), multiExperimentResult.RunDetails.Count());
211-
ConsolePrinter.PrintIterationSummary(multiExperimentResult.RunDetails, new MulticlassExperimentSettings().OptimizingMetric, 5);
209+
var bestMulticlassIteration = multiclassExperimentResult.BestRun;
210+
bestPipeline = bestMulticlassIteration.Pipeline;
211+
bestModel = bestMulticlassIteration.Model;
212+
ConsolePrinter.ExperimentResultsHeader(LogLevel.Info, settings.MlTask, settings.Dataset.Name, columnInformation.LabelColumnName, elapsedTime.ToString("F2"), multiclassExperimentResult.RunDetails.Count());
213+
ConsolePrinter.PrintIterationSummary(multiclassExperimentResult.RunDetails, new MulticlassExperimentSettings().OptimizingMetric, 5);
212214
break;
213215
}
214216
}
@@ -278,5 +280,51 @@ private void ConsumeAutoMLSDKLogs(MLContext context)
278280
}
279281
};
280282
}
283+
284+
private void SafeExecute(Func<ExperimentResult<BinaryClassificationMetrics>> p, out Exception ex, out ExperimentResult<BinaryClassificationMetrics> binaryExperimentResult, FixedDurationBar pbar)
285+
{
286+
try
287+
{
288+
binaryExperimentResult = p.Invoke();
289+
ex = null;
290+
}
291+
catch (Exception e)
292+
{
293+
ex = e;
294+
binaryExperimentResult = null;
295+
return;
296+
}
297+
}
298+
299+
private void SafeExecute(Func<ExperimentResult<RegressionMetrics>> p, out Exception ex, out ExperimentResult<RegressionMetrics> regressionExperimentResult, FixedDurationBar pbar)
300+
{
301+
try
302+
{
303+
regressionExperimentResult = p.Invoke();
304+
ex = null;
305+
}
306+
catch (Exception e)
307+
{
308+
ex = e;
309+
regressionExperimentResult = null;
310+
return;
311+
}
312+
}
313+
314+
private void SafeExecute(Func<ExperimentResult<MulticlassClassificationMetrics>> p, out Exception ex, out ExperimentResult<MulticlassClassificationMetrics> multiClassExperimentResult, FixedDurationBar pbar)
315+
{
316+
try
317+
{
318+
multiClassExperimentResult = p.Invoke();
319+
ex = null;
320+
}
321+
catch (Exception e)
322+
{
323+
ex = e;
324+
multiClassExperimentResult = null;
325+
pbar.Dispose(); // or ((ManualResetEvent)pbar.CompletedHandle).Set();
326+
return;
327+
}
328+
}
281329
}
282330
}

src/mlnet/ProgressBar/ProgressBar.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ public void Dispose()
356356
_timer?.Dispose();
357357
_timer = null;
358358
foreach (var c in this.Children) c.Dispose();
359+
OnDone();
359360
}
360361
}
361362
}

0 commit comments

Comments
 (0)