Skip to content

Commit db7acd8

Browse files
daholsteDmitry-A
authored andcommitted
Make optimizing metric customizable and add trainer whitelist functionality (dotnet#172)
1 parent f2b13f5 commit db7acd8

17 files changed

+414
-123
lines changed

src/Microsoft.ML.Auto/API/BinaryClassificationExperiment.cs

+21-5
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,33 @@ namespace Microsoft.ML.Auto
1414
public class BinaryExperimentSettings : ExperimentSettings
1515
{
1616
public IProgress<RunResult<BinaryClassificationMetrics>> ProgressCallback;
17-
public BinaryClassificationMetric OptimizingMetric;
17+
public BinaryClassificationMetric OptimizingMetric = BinaryClassificationMetric.Accuracy;
1818
public BinaryClassificationTrainer[] WhitelistedTrainers;
1919
}
2020

2121
public enum BinaryClassificationMetric
2222
{
23-
Accuracy
23+
Accuracy,
24+
Auc,
25+
Auprc,
26+
F1Score,
27+
PositivePrecision,
28+
PositiveRecall,
29+
NegativePrecision,
30+
NegativeRecall,
2431
}
2532

2633
public enum BinaryClassificationTrainer
2734
{
28-
LightGbm
35+
AveragedPerceptron,
36+
FastForest,
37+
FastTree,
38+
LightGbm,
39+
LinearSupportVectorMachines,
40+
LogisticRegression,
41+
StochasticDualCoordinateAscent,
42+
StochasticGradientDescent,
43+
SymbolicStochasticGradientDescent,
2944
}
3045

3146
public class BinaryClassificationExperiment
@@ -65,8 +80,9 @@ internal IEnumerable<RunResult<BinaryClassificationMetrics>> Execute(MLContext c
6580

6681
// run autofit & get all pipelines run in that process
6782
var autoFitter = new AutoFitter<BinaryClassificationMetrics>(context, TaskKind.BinaryClassification, trainData, columnInfo,
68-
validationData, preFeaturizers, OptimizingMetric.Accuracy, _settings?.ProgressCallback,
69-
_settings);
83+
validationData, preFeaturizers, new OptimizingMetricInfo(_settings.OptimizingMetric), _settings.ProgressCallback,
84+
_settings, new BinaryDataScorer(_settings.OptimizingMetric),
85+
TrainerExtensionUtil.GetTrainerNames(_settings.WhitelistedTrainers));
7086

7187
return autoFitter.Fit();
7288
}

src/Microsoft.ML.Auto/API/MulticlassClassificationExperiment.cs

+19-5
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,31 @@ namespace Microsoft.ML.Auto
1414
public class MulticlassExperimentSettings : ExperimentSettings
1515
{
1616
public IProgress<RunResult<MultiClassClassifierMetrics>> ProgressCallback;
17-
public MulticlassClassificationMetric OptimizingMetric;
17+
public MulticlassClassificationMetric OptimizingMetric = MulticlassClassificationMetric.AccuracyMicro;
1818
public MulticlassClassificationTrainer[] WhitelistedTrainers;
1919
}
2020

2121
public enum MulticlassClassificationMetric
2222
{
23-
Accuracy
23+
AccuracyMicro,
24+
AccuracyMacro,
25+
LogLoss,
26+
LogLossReduction,
27+
TopKAccuracy,
2428
}
2529

2630
public enum MulticlassClassificationTrainer
2731
{
28-
LightGbm
32+
AveragedPerceptronOVA,
33+
FastForestOVA,
34+
FastTreeOVA,
35+
LightGbm,
36+
LinearSupportVectorMachinesOVA,
37+
LogisticRegression,
38+
LogisticRegressionOVA,
39+
StochasticDualCoordinateAscent,
40+
StochasticGradientDescentOVA,
41+
SymbolicStochasticGradientDescentOVA,
2942
}
3043

3144
public class MulticlassClassificationExperiment
@@ -65,8 +78,9 @@ internal IEnumerable<RunResult<MultiClassClassifierMetrics>> Execute(MLContext c
6578

6679
// run autofit & get all pipelines run in that process
6780
var autoFitter = new AutoFitter<MultiClassClassifierMetrics>(context, TaskKind.MulticlassClassification, trainData,
68-
columnInfo, validationData, preFeaturizers, OptimizingMetric.Accuracy,
69-
_settings?.ProgressCallback, _settings);
81+
columnInfo, validationData, preFeaturizers, new OptimizingMetricInfo(_settings.OptimizingMetric),
82+
_settings.ProgressCallback, _settings, new MultiDataScorer(_settings.OptimizingMetric),
83+
TrainerExtensionUtil.GetTrainerNames(_settings.WhitelistedTrainers));
7084

7185
return autoFitter.Fit();
7286
}

src/Microsoft.ML.Auto/API/RegressionExperiment.cs

+12-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace Microsoft.ML.Auto
1414
public class RegressionExperimentSettings : ExperimentSettings
1515
{
1616
public IProgress<RunResult<RegressionMetrics>> ProgressCallback;
17-
public RegressionMetric OptimizingMetric;
17+
public RegressionMetric OptimizingMetric = RegressionMetric.RSquared;
1818
public RegressionTrainer[] WhitelistedTrainers;
1919
}
2020

@@ -28,7 +28,14 @@ public enum RegressionMetric
2828

2929
public enum RegressionTrainer
3030
{
31-
LightGbm
31+
FastForest,
32+
FastTree,
33+
FastTreeTweedie,
34+
LightGbm,
35+
OnlineGradientDescent,
36+
OrdinaryLeastSquares,
37+
PoissonRegression,
38+
StochasticDualCoordinateAscent,
3239
}
3340

3441
public class RegressionExperiment
@@ -68,8 +75,9 @@ internal IEnumerable<RunResult<RegressionMetrics>> Execute(MLContext context,
6875

6976
// run autofit & get all pipelines run in that process
7077
var autoFitter = new AutoFitter<RegressionMetrics>(context, TaskKind.Regression, trainData, columnInfo,
71-
validationData, preFeaturizers, OptimizingMetric.RSquared, _settings?.ProgressCallback,
72-
_settings);
78+
validationData, preFeaturizers, new OptimizingMetricInfo(_settings.OptimizingMetric),
79+
_settings.ProgressCallback, _settings, new RegressionDataScorer(_settings.OptimizingMetric),
80+
TrainerExtensionUtil.GetTrainerNames(_settings.WhitelistedTrainers));
7381

7482
return autoFitter.Fit();
7583
}

src/Microsoft.ML.Auto/AutoFitter/AutoFitter.cs

+14-30
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66
using System.Collections.Generic;
77
using System.Diagnostics;
88
using System.Text;
9-
using System.Threading;
109
using Microsoft.Data.DataView;
1110
using Microsoft.ML.Core.Data;
12-
using Microsoft.ML.Data;
1311

1412
namespace Microsoft.ML.Auto
1513
{
@@ -23,6 +21,8 @@ internal class AutoFitter<T> where T : class
2321
private readonly IEstimator<ITransformer> _preFeaturizers;
2422
private readonly IProgress<RunResult<T>> _progressCallback;
2523
private readonly ExperimentSettings _experimentSettings;
24+
private readonly IDataScorer<T> _dataScorer;
25+
private readonly IEnumerable<TrainerName> _trainerWhitelist;
2626

2727
private IDataView _trainData;
2828
private IDataView _validationData;
@@ -35,9 +35,11 @@ public AutoFitter(MLContext context,
3535
ColumnInformation columnInfo,
3636
IDataView validationData,
3737
IEstimator<ITransformer> preFeaturizers,
38-
OptimizingMetric metric,
38+
OptimizingMetricInfo metricInfo,
3939
IProgress<RunResult<T>> progressCallback,
40-
ExperimentSettings experimentSettings)
40+
ExperimentSettings experimentSettings,
41+
IDataScorer<T> dataScorer,
42+
IEnumerable<TrainerName> trainerWhitelist)
4143
{
4244
if (validationData == null)
4345
{
@@ -49,11 +51,13 @@ public AutoFitter(MLContext context,
4951
_history = new List<SuggestedPipelineResult<T>>();
5052
_columnInfo = columnInfo;
5153
_context = context;
52-
_optimizingMetricInfo = new OptimizingMetricInfo(metric);
54+
_optimizingMetricInfo = metricInfo;
5355
_task = task;
5456
_preFeaturizers = preFeaturizers;
5557
_progressCallback = progressCallback;
56-
_experimentSettings = experimentSettings ?? new ExperimentSettings();
58+
_experimentSettings = experimentSettings;
59+
_dataScorer = dataScorer;
60+
_trainerWhitelist = trainerWhitelist;
5761
}
5862

5963
public List<RunResult<T>> Fit()
@@ -81,7 +85,7 @@ public List<RunResult<T>> Fit()
8185
var getPiplelineStopwatch = Stopwatch.StartNew();
8286

8387
// get next pipeline
84-
pipeline = PipelineSuggester.GetNextInferredPipeline(_history, columns, _task, _optimizingMetricInfo.IsMaximizing);
88+
pipeline = PipelineSuggester.GetNextInferredPipeline(_history, columns, _task, _optimizingMetricInfo.IsMaximizing, _trainerWhitelist);
8589

8690
getPiplelineStopwatch.Stop();
8791

@@ -144,9 +148,9 @@ private SuggestedPipelineResult<T> ProcessPipeline(SuggestedPipeline pipeline)
144148
{
145149
var pipelineModel = pipeline.Fit(_trainData);
146150
var scoredValidationData = pipelineModel.Transform(_validationData);
147-
var evaluatedMetrics = GetEvaluatedMetrics(scoredValidationData);
148-
var score = GetPipelineScore(evaluatedMetrics);
149-
runResult = new SuggestedPipelineResult<T>(evaluatedMetrics, pipelineModel, pipeline, score, null);
151+
var metrics = GetEvaluatedMetrics(scoredValidationData);
152+
var score = _dataScorer.GetScore(metrics);
153+
runResult = new SuggestedPipelineResult<T>(metrics, pipelineModel, pipeline, score, null);
150154
}
151155
catch(Exception ex)
152156
{
@@ -177,26 +181,6 @@ private T GetEvaluatedMetrics(IDataView scoredData)
177181
}
178182
}
179183

180-
private double GetPipelineScore(object evaluatedMetrics)
181-
{
182-
var type = evaluatedMetrics.GetType();
183-
if(type == typeof(BinaryClassificationMetrics))
184-
{
185-
return ((BinaryClassificationMetrics)evaluatedMetrics).Accuracy;
186-
}
187-
if (type == typeof(MultiClassClassifierMetrics))
188-
{
189-
return ((MultiClassClassifierMetrics)evaluatedMetrics).AccuracyMicro;
190-
}
191-
if (type == typeof(RegressionMetrics))
192-
{
193-
return ((RegressionMetrics)evaluatedMetrics).RSquared;
194-
}
195-
196-
// should not be possible to reach here
197-
throw new InvalidOperationException($"unsupported machine learning task type {_task}");
198-
}
199-
200184
private void WriteIterationLog(SuggestedPipeline pipeline, SuggestedPipelineResult runResult, Stopwatch stopwatch)
201185
{
202186
// debug log pipeline result
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using Microsoft.ML.Data;
7+
8+
namespace Microsoft.ML.Auto
9+
{
10+
internal class BinaryDataScorer : IDataScorer<BinaryClassificationMetrics>
11+
{
12+
private readonly BinaryClassificationMetric _metric;
13+
14+
public BinaryDataScorer(BinaryClassificationMetric metric)
15+
{
16+
this._metric = metric;
17+
}
18+
19+
public double GetScore(BinaryClassificationMetrics metrics)
20+
{
21+
switch(_metric)
22+
{
23+
case BinaryClassificationMetric.Accuracy:
24+
return metrics.Accuracy;
25+
case BinaryClassificationMetric.Auc:
26+
return metrics.Auc;
27+
case BinaryClassificationMetric.Auprc:
28+
return metrics.Auprc;
29+
case BinaryClassificationMetric.F1Score:
30+
return metrics.F1Score;
31+
case BinaryClassificationMetric.NegativePrecision:
32+
return metrics.NegativePrecision;
33+
case BinaryClassificationMetric.NegativeRecall:
34+
return metrics.NegativeRecall;
35+
case BinaryClassificationMetric.PositivePrecision:
36+
return metrics.PositivePrecision;
37+
case BinaryClassificationMetric.PositiveRecall:
38+
return metrics.PositiveRecall;
39+
}
40+
41+
// never expected to reach here
42+
throw new NotSupportedException($"{_metric} is not a supported sweep metric");
43+
}
44+
}
45+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
namespace Microsoft.ML.Auto
6+
{
7+
internal interface IDataScorer<T>
8+
{
9+
double GetScore(T metrics);
10+
}
11+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using Microsoft.ML.Data;
7+
8+
namespace Microsoft.ML.Auto
9+
{
10+
internal class MultiDataScorer : IDataScorer<MultiClassClassifierMetrics>
11+
{
12+
private readonly MulticlassClassificationMetric _metric;
13+
14+
public MultiDataScorer(MulticlassClassificationMetric metric)
15+
{
16+
this._metric = metric;
17+
}
18+
19+
public double GetScore(MultiClassClassifierMetrics metrics)
20+
{
21+
switch (_metric)
22+
{
23+
case MulticlassClassificationMetric.AccuracyMacro:
24+
return metrics.AccuracyMacro;
25+
case MulticlassClassificationMetric.AccuracyMicro:
26+
return metrics.AccuracyMicro;
27+
case MulticlassClassificationMetric.LogLoss:
28+
return metrics.LogLoss;
29+
case MulticlassClassificationMetric.LogLossReduction:
30+
return metrics.LogLossReduction;
31+
case MulticlassClassificationMetric.TopKAccuracy:
32+
return metrics.TopKAccuracy;
33+
}
34+
35+
// never expected to reach here
36+
throw new NotSupportedException($"{_metric} is not a supported sweep metric");
37+
}
38+
}
39+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using Microsoft.ML.Data;
7+
8+
namespace Microsoft.ML.Auto
9+
{
10+
internal class RegressionDataScorer : IDataScorer<RegressionMetrics>
11+
{
12+
private readonly RegressionMetric _metric;
13+
14+
public RegressionDataScorer(RegressionMetric metric)
15+
{
16+
this._metric = metric;
17+
}
18+
19+
public double GetScore(RegressionMetrics metrics)
20+
{
21+
switch(_metric)
22+
{
23+
case RegressionMetric.L1:
24+
return metrics.L1;
25+
case RegressionMetric.L2:
26+
return metrics.L2;
27+
case RegressionMetric.Rms:
28+
return metrics.Rms;
29+
case RegressionMetric.RSquared:
30+
return metrics.RSquared;
31+
}
32+
33+
// never expected to reach here
34+
throw new NotSupportedException($"{_metric} is not a supported sweep metric");
35+
}
36+
}
37+
}

0 commit comments

Comments
 (0)