Skip to content

Commit 2f79d34

Browse files
daholsteDmitry-A
authored andcommitted
Add cross-validation (CV), and auto-CV for small datasets; push common API experiment methods into base class (dotnet#287)
1 parent 9023263 commit 2f79d34

File tree

57 files changed

+1413
-763
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1413
-763
lines changed

src/Microsoft.ML.Auto/API/BinaryClassificationExperiment.cs

+18-60
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ public sealed class BinaryExperimentSettings : ExperimentSettings
1414
public BinaryClassificationMetric OptimizingMetric { get; set; } = BinaryClassificationMetric.Accuracy;
1515
public ICollection<BinaryClassificationTrainer> Trainers { get; } =
1616
Enum.GetValues(typeof(BinaryClassificationTrainer)).OfType<BinaryClassificationTrainer>().ToList();
17-
public IProgress<RunResult<BinaryClassificationMetrics>> ProgressHandler { get; set; }
1817
}
1918

2019
public enum BinaryClassificationMetric
@@ -42,74 +41,33 @@ public enum BinaryClassificationTrainer
4241
SymbolicSgdLogisticRegression,
4342
}
4443

45-
public sealed class BinaryClassificationExperiment
44+
public sealed class BinaryClassificationExperiment : ExperimentBase<BinaryClassificationMetrics>
4645
{
47-
private readonly MLContext _context;
48-
private readonly BinaryExperimentSettings _settings;
49-
5046
internal BinaryClassificationExperiment(MLContext context, BinaryExperimentSettings settings)
47+
: base(context,
48+
new BinaryMetricsAgent(context, settings.OptimizingMetric),
49+
new OptimizingMetricInfo(settings.OptimizingMetric),
50+
settings,
51+
TaskKind.BinaryClassification,
52+
TrainerExtensionUtil.GetTrainerNames(settings.Trainers))
5153
{
52-
_context = context;
53-
_settings = settings;
54-
}
55-
56-
public IEnumerable<RunResult<BinaryClassificationMetrics>> Execute(IDataView trainData, string labelColumn = DefaultColumnNames.Label,
57-
string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizers = null)
58-
{
59-
var columnInformation = new ColumnInformation()
60-
{
61-
LabelColumn = labelColumn,
62-
SamplingKeyColumn = samplingKeyColumn
63-
};
64-
return Execute(_context, trainData, columnInformation, null, preFeaturizers);
65-
}
66-
67-
public IEnumerable<RunResult<BinaryClassificationMetrics>> Execute(IDataView trainData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizers = null)
68-
{
69-
return Execute(_context, trainData, columnInformation, null, preFeaturizers);
70-
}
71-
72-
public IEnumerable<RunResult<BinaryClassificationMetrics>> Execute(IDataView trainData, IDataView validationData, string labelColumn = DefaultColumnNames.Label, IEstimator<ITransformer> preFeaturizers = null)
73-
{
74-
var columnInformation = new ColumnInformation() { LabelColumn = labelColumn };
75-
return Execute(_context, trainData, columnInformation, validationData, preFeaturizers);
76-
}
77-
78-
public IEnumerable<RunResult<BinaryClassificationMetrics>> Execute(IDataView trainData, IDataView validationData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizers = null)
79-
{
80-
return Execute(_context, trainData, columnInformation, validationData, preFeaturizers);
81-
}
82-
83-
internal IEnumerable<RunResult<BinaryClassificationMetrics>> Execute(IDataView trainData, uint numberOfCVFolds, ColumnInformation columnInformation = null, IEstimator<ITransformer> preFeaturizers = null)
84-
{
85-
throw new NotImplementedException();
86-
}
87-
88-
internal IEnumerable<RunResult<BinaryClassificationMetrics>> Execute(MLContext context,
89-
IDataView trainData,
90-
ColumnInformation columnInfo,
91-
IDataView validationData = null,
92-
IEstimator<ITransformer> preFeaturizers = null)
93-
{
94-
columnInfo = columnInfo ?? new ColumnInformation();
95-
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, columnInfo, validationData);
96-
97-
// run autofit & get all pipelines run in that process
98-
var experiment = new Experiment<BinaryClassificationMetrics>(context, TaskKind.BinaryClassification, trainData, columnInfo,
99-
validationData, preFeaturizers, new OptimizingMetricInfo(_settings.OptimizingMetric), _settings.ProgressHandler,
100-
_settings, new BinaryMetricsAgent(_settings.OptimizingMetric),
101-
TrainerExtensionUtil.GetTrainerNames(_settings.Trainers));
102-
103-
return experiment.Execute();
10454
}
10555
}
10656

10757
public static class BinaryExperimentResultExtensions
10858
{
109-
public static RunResult<BinaryClassificationMetrics> Best(this IEnumerable<RunResult<BinaryClassificationMetrics>> results, BinaryClassificationMetric metric = BinaryClassificationMetric.Accuracy)
59+
public static RunDetails<BinaryClassificationMetrics> Best(this IEnumerable<RunDetails<BinaryClassificationMetrics>> results, BinaryClassificationMetric metric = BinaryClassificationMetric.Accuracy)
60+
{
61+
var metricsAgent = new BinaryMetricsAgent(null, metric);
62+
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
63+
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
64+
}
65+
66+
public static CrossValidationRunDetails<BinaryClassificationMetrics> Best(this IEnumerable<CrossValidationRunDetails<BinaryClassificationMetrics>> results, BinaryClassificationMetric metric = BinaryClassificationMetric.Accuracy)
11067
{
111-
var metricsAgent = new BinaryMetricsAgent(metric);
112-
return RunResultUtil.GetBestRunResult(results, metricsAgent);
68+
var metricsAgent = new BinaryMetricsAgent(null, metric);
69+
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
70+
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
11371
}
11472
}
11573
}
+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
// Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
8+
namespace Microsoft.ML.Auto
9+
{
10+
public abstract class ExperimentBase<TMetrics> where TMetrics : class
11+
{
12+
protected readonly MLContext Context;
13+
14+
private readonly IMetricsAgent<TMetrics> _metricsAgent;
15+
private readonly OptimizingMetricInfo _optimizingMetricInfo;
16+
private readonly ExperimentSettings _settings;
17+
private readonly TaskKind _task;
18+
private readonly IEnumerable<TrainerName> _trainerWhitelist;
19+
20+
internal ExperimentBase(MLContext context,
21+
IMetricsAgent<TMetrics> metricsAgent,
22+
OptimizingMetricInfo optimizingMetricInfo,
23+
ExperimentSettings settings,
24+
TaskKind task,
25+
IEnumerable<TrainerName> trainerWhitelist)
26+
{
27+
Context = context;
28+
_metricsAgent = metricsAgent;
29+
_optimizingMetricInfo = optimizingMetricInfo;
30+
_settings = settings;
31+
_task = task;
32+
_trainerWhitelist = trainerWhitelist;
33+
}
34+
35+
public IEnumerable<RunDetails<TMetrics>> Execute(IDataView trainData, string labelColumn = DefaultColumnNames.Label,
36+
string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizers = null, IProgress<RunDetails<TMetrics>> progressHandler = null)
37+
{
38+
var columnInformation = new ColumnInformation()
39+
{
40+
LabelColumn = labelColumn,
41+
SamplingKeyColumn = samplingKeyColumn
42+
};
43+
return Execute(trainData, columnInformation, preFeaturizers, progressHandler);
44+
}
45+
46+
public IEnumerable<RunDetails<TMetrics>> Execute(IDataView trainData, ColumnInformation columnInformation,
47+
IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetails<TMetrics>> progressHandler = null)
48+
{
49+
// Cross val threshold for # of dataset rows --
50+
// If dataset has < threshold # of rows, use cross val.
51+
// Else, use run experiment using train-validate split.
52+
const int crossValRowCountThreshold = 15000;
53+
54+
var rowCount = DatasetDimensionsUtil.CountRows(trainData, crossValRowCountThreshold);
55+
56+
if (rowCount < crossValRowCountThreshold)
57+
{
58+
const int numCrossValFolds = 10;
59+
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numCrossValFolds, columnInformation?.SamplingKeyColumn);
60+
return ExecuteCrossValSummary(splitResult.trainDatasets, columnInformation, splitResult.validationDatasets, preFeaturizer, progressHandler);
61+
}
62+
else
63+
{
64+
var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, columnInformation?.SamplingKeyColumn);
65+
return ExecuteTrainValidate(splitResult.trainData, columnInformation, splitResult.validationData, preFeaturizer, progressHandler);
66+
}
67+
}
68+
69+
public IEnumerable<RunDetails<TMetrics>> Execute(IDataView trainData, IDataView validationData, string labelColumn = DefaultColumnNames.Label, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetails<TMetrics>> progressHandler = null)
70+
{
71+
var columnInformation = new ColumnInformation() { LabelColumn = labelColumn };
72+
return Execute(trainData, validationData, columnInformation, preFeaturizer, progressHandler);
73+
}
74+
75+
public IEnumerable<RunDetails<TMetrics>> Execute(IDataView trainData, IDataView validationData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetails<TMetrics>> progressHandler = null)
76+
{
77+
if (validationData == null)
78+
{
79+
var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, columnInformation?.SamplingKeyColumn);
80+
trainData = splitResult.trainData;
81+
validationData = splitResult.validationData;
82+
}
83+
return ExecuteTrainValidate(trainData, columnInformation, validationData, preFeaturizer, progressHandler);
84+
}
85+
86+
public IEnumerable<CrossValidationRunDetails<TMetrics>> Execute(IDataView trainData, uint numberOfCVFolds, ColumnInformation columnInformation = null, IEstimator<ITransformer> preFeaturizers = null, IProgress<CrossValidationRunDetails<TMetrics>> progressHandler = null)
87+
{
88+
UserInputValidationUtil.ValidateNumberOfCVFoldsArg(numberOfCVFolds);
89+
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numberOfCVFolds, columnInformation?.SamplingKeyColumn);
90+
return ExecuteCrossVal(splitResult.trainDatasets, columnInformation, splitResult.validationDatasets, preFeaturizers, progressHandler);
91+
}
92+
93+
public IEnumerable<CrossValidationRunDetails<TMetrics>> Execute(IDataView trainData,
94+
uint numberOfCVFolds, string labelColumn = DefaultColumnNames.Label,
95+
string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizers = null,
96+
Progress<CrossValidationRunDetails<TMetrics>> progressHandler = null)
97+
{
98+
var columnInformation = new ColumnInformation()
99+
{
100+
LabelColumn = labelColumn,
101+
SamplingKeyColumn = samplingKeyColumn
102+
};
103+
return Execute(trainData, numberOfCVFolds, columnInformation, preFeaturizers, progressHandler);
104+
}
105+
106+
private IEnumerable<RunDetails<TMetrics>> ExecuteTrainValidate(IDataView trainData,
107+
ColumnInformation columnInfo,
108+
IDataView validationData,
109+
IEstimator<ITransformer> preFeaturizer,
110+
IProgress<RunDetails<TMetrics>> progressHandler)
111+
{
112+
columnInfo = columnInfo ?? new ColumnInformation();
113+
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, columnInfo, validationData);
114+
var runner = new TrainValidateRunner<TMetrics>(Context, trainData, validationData, columnInfo.LabelColumn, _metricsAgent,
115+
preFeaturizer, _settings.DebugLogger);
116+
var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainData, columnInfo);
117+
return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner);
118+
}
119+
120+
private IEnumerable<CrossValidationRunDetails<TMetrics>> ExecuteCrossVal(IDataView[] trainDatasets,
121+
ColumnInformation columnInfo,
122+
IDataView[] validationDatasets,
123+
IEstimator<ITransformer> preFeaturizer,
124+
IProgress<CrossValidationRunDetails<TMetrics>> progressHandler)
125+
{
126+
columnInfo = columnInfo ?? new ColumnInformation();
127+
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0]);
128+
var runner = new CrossValRunner<TMetrics>(Context, trainDatasets, validationDatasets, _metricsAgent, preFeaturizer,
129+
columnInfo.LabelColumn, _settings.DebugLogger);
130+
var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainDatasets[0], columnInfo);
131+
return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner);
132+
}
133+
134+
private IEnumerable<RunDetails<TMetrics>> ExecuteCrossValSummary(IDataView[] trainDatasets,
135+
ColumnInformation columnInfo,
136+
IDataView[] validationDatasets,
137+
IEstimator<ITransformer> preFeaturizer,
138+
IProgress<RunDetails<TMetrics>> progressHandler)
139+
{
140+
columnInfo = columnInfo ?? new ColumnInformation();
141+
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0]);
142+
var runner = new CrossValSummaryRunner<TMetrics>(Context, trainDatasets, validationDatasets, _metricsAgent, preFeaturizer,
143+
columnInfo.LabelColumn, _optimizingMetricInfo, _settings.DebugLogger);
144+
var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainDatasets[0], columnInfo);
145+
return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner);
146+
}
147+
148+
private IEnumerable<TRunDetails> Execute<TRunDetails>(ColumnInformation columnInfo,
149+
DatasetColumnInfo[] columns,
150+
IEstimator<ITransformer> preFeaturizer,
151+
IProgress<TRunDetails> progressHandler,
152+
IRunner<TRunDetails> runner)
153+
where TRunDetails : RunDetails
154+
{
155+
// Execute experiment & get all pipelines run
156+
var experiment = new Experiment<TRunDetails, TMetrics>(Context, _task, _optimizingMetricInfo, progressHandler,
157+
_settings, _metricsAgent, _trainerWhitelist, columns, runner);
158+
159+
return experiment.Execute();
160+
}
161+
}
162+
}

src/Microsoft.ML.Auto/API/ExperimentSettings.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ public class ExperimentSettings
1818
/// (Please note: for an experiment with high runtime operating on a large dataset, opting to keep models in
1919
/// memory could cause a system to run out of memory.)
2020
/// </summary>
21-
public DirectoryInfo ModelDirectory { get; set; } = new DirectoryInfo(Path.Combine(Path.GetTempPath(), "Microsoft.ML.Auto"));
21+
public DirectoryInfo CacheDirectory { get; set; } = new DirectoryInfo(Path.Combine(Path.GetTempPath(), "Microsoft.ML.Auto"));
2222

2323
/// <summary>
2424
/// This setting controls whether or not an AutoML experiment will make use of ML.NET-provided caching.
2525
/// If set to true, caching will be forced on for all pipelines. If set to false, caching will be forced off.
2626
/// If set to null (default value), AutoML will decide whether to enable caching for each model.
2727
/// </summary>
28-
public bool? EnableCaching = null;
28+
public bool? CacheBeforeTrainer = null;
2929

3030
internal int MaxModels = int.MaxValue;
3131
internal IDebugLogger DebugLogger;

0 commit comments

Comments
 (0)