-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Use Timer and ctx.CancelExecution() to fix AutoML max-time experiment bug #5445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d0f7054
4fa26f8
48a6267
f324030
ee70024
36bf24e
d5d23de
33cf5a6
bfc93e9
c69a19f
299b05b
2e2d441
ce747fb
7635500
94a80de
abe1d7f
9585a50
1ab662f
bc9e578
71ebf23
2d8d06f
490d8c1
b0de1d3
0918afa
0922aed
ef4b34f
b4b49ce
6502fc8
28e2f2e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -7,6 +7,8 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||
using System.Diagnostics; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
using System.IO; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
using System.Linq; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
using System.Threading; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
using Microsoft.ML.Data; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
using Microsoft.ML.Runtime; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
namespace Microsoft.ML.AutoML | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -25,6 +27,11 @@ internal class Experiment<TRunDetail, TMetrics> where TRunDetail : RunDetail | |||||||||||||||||||||||||||||||||||||||||||||||||||||
private readonly IRunner<TRunDetail> _runner; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
private readonly IList<SuggestedPipelineRunDetail> _history; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
private readonly IChannel _logger; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
private Timer _maxExperimentTimeTimer; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
private Timer _mainContextCanceledTimer; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
private bool _experimentTimerExpired; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
private MLContext _currentModelMLContext; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
private Random _newContextSeedGenerator; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
public Experiment(MLContext context, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
TaskKind task, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -49,60 +56,125 @@ public Experiment(MLContext context, | |||||||||||||||||||||||||||||||||||||||||||||||||||||
_datasetColumnInfo = datasetColumnInfo; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_runner = runner; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_logger = logger; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_experimentTimerExpired = false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
private void MaxExperimentTimeExpiredEvent(object state) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// If at least one model was run, end experiment immediately. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Else, wait for first model to run before experiment is concluded. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_experimentTimerExpired = true; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (_history.Any(r => r.RunSucceeded)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_logger.Warning("Allocated time for Experiment of {0} seconds has elapsed with {1} models run. Ending experiment...", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_experimentSettings.MaxExperimentTimeInSeconds, _history.Count()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_currentModelMLContext.CancelExecution(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
private void MainContextCanceledEvent(object state) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// If the main MLContext is canceled, cancel the ongoing model training and MLContext. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if ((_context.Model.GetEnvironment() as ICancelable).IsCanceled) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_logger.Warning("Main MLContext has been canceled. Ending experiment..."); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Stop timer to prevent restarting and prevent continuous calls to | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// MainContextCanceledEvent | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_mainContextCanceledTimer.Change(Timeout.Infinite, Timeout.Infinite); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_currentModelMLContext.CancelExecution(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
public IList<TRunDetail> Execute() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
var stopwatch = Stopwatch.StartNew(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
var iterationResults = new List<TRunDetail>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Create a timer for the max duration of experiment. When given time has | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// elapsed, MaxExperimentTimeExpiredEvent is called to interrupt training | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// of current model. Timer is not used if no experiment time is given, or | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// is not a positive number. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (_experimentSettings.MaxExperimentTimeInSeconds > 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_maxExperimentTimeTimer = new Timer( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
new TimerCallback(MaxExperimentTimeExpiredEvent), null, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_experimentSettings.MaxExperimentTimeInSeconds * 1000, Timeout.Infinite | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
mstfbl marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// If given max duration of experiment is 0, only 1 model will be trained. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// _experimentSettings.MaxExperimentTimeInSeconds is of type uint, it is | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// either 0 or >0. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_experimentTimerExpired = true; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Add second timer to check for the cancelation signal from the main MLContext | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
mstfbl marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// to the active child MLContext. This timer will propagate the cancelation | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// signal from the main to the child MLContexs if the main MLContext is | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// canceled. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_mainContextCanceledTimer = new Timer(new TimerCallback(MainContextCanceledEvent), null, 1000, 1000); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Pseudo random number generator to result in deterministic runs with the provided main MLContext's seed and to | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// maintain variability between training iterations. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
int? mainContextSeed = ((ISeededEnvironment)_context.Model.GetEnvironment()).Seed; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_newContextSeedGenerator = (mainContextSeed.HasValue) ? RandomUtils.Create(mainContextSeed.Value) : null; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
do | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
var iterationStopwatch = Stopwatch.StartNew(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// get next pipeline | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
var getPipelineStopwatch = Stopwatch.StartNew(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
var pipeline = PipelineSuggester.GetNextInferredPipeline(_context, _history, _datasetColumnInfo, _task, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_optimizingMetricInfo.IsMaximizing, _experimentSettings.CacheBeforeTrainer, _logger, _trainerAllowList); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
var pipelineInferenceTimeInSeconds = getPipelineStopwatch.Elapsed.TotalSeconds; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// break if no candidates returned, means no valid pipeline available | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (pipeline == null) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// evaluate pipeline | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_logger.Trace($"Evaluating pipeline {pipeline.ToString()}"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
(SuggestedPipelineRunDetail suggestedPipelineRunDetail, TRunDetail runDetail) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
= _runner.Run(pipeline, _modelDirectory, _history.Count + 1); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
_history.Add(suggestedPipelineRunDetail); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
WriteIterationLog(pipeline, suggestedPipelineRunDetail, iterationStopwatch); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
runDetail.RuntimeInSeconds = iterationStopwatch.Elapsed.TotalSeconds; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
runDetail.PipelineInferenceTimeInSeconds = getPipelineStopwatch.Elapsed.TotalSeconds; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
ReportProgress(runDetail); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
iterationResults.Add(runDetail); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// if model is perfect, break | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (_metricsAgent.IsModelPerfect(suggestedPipelineRunDetail.Score)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
try | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
var iterationStopwatch = Stopwatch.StartNew(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// get next pipeline | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
var getPipelineStopwatch = Stopwatch.StartNew(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// A new MLContext is needed per model run. When max experiment time is reached, each used | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// context is canceled to stop further model training. The cancellation of the main MLContext | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// a user has instantiated is not desirable, thus additional MLContexts are used. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_currentModelMLContext = _newContextSeedGenerator == null ? new MLContext() : new MLContext(_newContextSeedGenerator.Next()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this stop user from getting ongoing training log? As the context that is used for training will be different from the context where AutoML experiment is created, and is unavailable externally. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @LittleLittleCloud: Good point. Have you seen issues? We can always duplicate the logger. Or attach a logger to the new context, and when called, have it pass the message to the original context. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, in Image classification we subscribe to the log channel to show training progress. It's quite important since image-classification training is way more time consuming than other scenarios and we need to show something to let users know the training progress |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
var pipeline = PipelineSuggester.GetNextInferredPipeline(_currentModelMLContext, _history, _datasetColumnInfo, _task, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_optimizingMetricInfo.IsMaximizing, _experimentSettings.CacheBeforeTrainer, _logger, _trainerAllowList); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// break if no candidates returned, means no valid pipeline available | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (pipeline == null) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// evaluate pipeline | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_logger.Trace($"Evaluating pipeline {pipeline.ToString()}"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
(SuggestedPipelineRunDetail suggestedPipelineRunDetail, TRunDetail runDetail) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
= _runner.Run(pipeline, _modelDirectory, _history.Count + 1); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
_history.Add(suggestedPipelineRunDetail); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
WriteIterationLog(pipeline, suggestedPipelineRunDetail, iterationStopwatch); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
runDetail.RuntimeInSeconds = iterationStopwatch.Elapsed.TotalSeconds; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
runDetail.PipelineInferenceTimeInSeconds = getPipelineStopwatch.Elapsed.TotalSeconds; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
ReportProgress(runDetail); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
iterationResults.Add(runDetail); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// if model is perfect, break | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (_metricsAgent.IsModelPerfect(suggestedPipelineRunDetail.Score)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// If after third run, all runs have failed so far, throw exception | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (_history.Count() == 3 && _history.All(r => !r.RunSucceeded)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
throw new InvalidOperationException($"Training failed with the exception: {_history.Last().Exception}"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
// If after third run, all runs have failed so far, throw exception | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (_history.Count() == 3 && _history.All(r => !r.RunSucceeded)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
catch (OperationCanceledException e) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might make sense to catch the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I figured it's better to catch |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
throw new InvalidOperationException($"Training failed with the exception: {_history.Last().Exception}"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// This exception is thrown when the IHost/MLContext of the trainer is canceled due to | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// reaching maximum experiment time. Simply catch this exception and return finished | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
// iteration results. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
_logger.Warning("OperationCanceledException has been caught after maximum experiment time" + | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"was reached, and the running MLContext was stopped. Details: {0}", e.Message); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return iterationResults; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
} while (_history.Count < _experimentSettings.MaxModels && | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
!_experimentSettings.CancellationToken.IsCancellationRequested && | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
stopwatch.Elapsed.TotalSeconds < _experimentSettings.MaxExperimentTimeInSeconds); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
!_experimentTimerExpired); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return iterationResults; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.