Skip to content

Commit fb5f418

Browse files
daholsteDmitry-A
authored andcommitted
[AutoML] Rev AutoML public API; add required native references to AutoML projects (dotnet#3364)
1 parent 5d9e058 commit fb5f418

28 files changed

+322
-237
lines changed

src/Microsoft.ML.Auto/API/BinaryClassificationExperiment.cs

+5-27
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ public enum BinaryClassificationTrainer
132132
/// <summary>
133133
/// AutoML experiment on binary classification datasets.
134134
/// </summary>
135-
public sealed class BinaryClassificationExperiment : ExperimentBase<BinaryClassificationMetrics>
135+
public sealed class BinaryClassificationExperiment : ExperimentBase<BinaryClassificationMetrics, BinaryExperimentSettings>
136136
{
137137
internal BinaryClassificationExperiment(MLContext context, BinaryExperimentSettings settings)
138138
: base(context,
@@ -143,37 +143,15 @@ internal BinaryClassificationExperiment(MLContext context, BinaryExperimentSetti
143143
TrainerExtensionUtil.GetTrainerNames(settings.Trainers))
144144
{
145145
}
146-
}
147146

148-
/// <summary>
149-
/// Extension methods that operate over binary experiment run results.
150-
/// </summary>
151-
public static class BinaryExperimentResultExtensions
152-
{
153-
/// <summary>
154-
/// Select the best run from an enumeration of experiment runs.
155-
/// </summary>
156-
/// <param name="results">Enumeration of AutoML experiment run results.</param>
157-
/// <param name="metric">Metric to consider when selecting the best run.</param>
158-
/// <returns>The best experiment run.</returns>
159-
public static RunDetail<BinaryClassificationMetrics> Best(this IEnumerable<RunDetail<BinaryClassificationMetrics>> results, BinaryClassificationMetric metric = BinaryClassificationMetric.Accuracy)
147+
private protected override RunDetail<BinaryClassificationMetrics> GetBestRun(IEnumerable<RunDetail<BinaryClassificationMetrics>> results)
160148
{
161-
var metricsAgent = new BinaryMetricsAgent(null, metric);
162-
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
163-
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
149+
return BestResultUtil.GetBestRun(results, MetricsAgent, OptimizingMetricInfo.IsMaximizing);
164150
}
165151

166-
/// <summary>
167-
/// Select the best run from an enumeration of experiment cross validation runs.
168-
/// </summary>
169-
/// <param name="results">Enumeration of AutoML experiment cross validation run results.</param>
170-
/// <param name="metric">Metric to consider when selecting the best run.</param>
171-
/// <returns>The best experiment run.</returns>
172-
public static CrossValidationRunDetail<BinaryClassificationMetrics> Best(this IEnumerable<CrossValidationRunDetail<BinaryClassificationMetrics>> results, BinaryClassificationMetric metric = BinaryClassificationMetric.Accuracy)
152+
private protected override CrossValidationRunDetail<BinaryClassificationMetrics> GetBestCrossValRun(IEnumerable<CrossValidationRunDetail<BinaryClassificationMetrics>> results)
173153
{
174-
var metricsAgent = new BinaryMetricsAgent(null, metric);
175-
var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
176-
return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
154+
return BestResultUtil.GetBestRun(results, MetricsAgent, OptimizingMetricInfo.IsMaximizing);
177155
}
178156
}
179157
}

src/Microsoft.ML.Auto/API/ColumnInference.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public sealed class ColumnInferenceResults
2828
/// <remarks>
2929
/// <para>Contains the inferred purposes of each column. See <see cref="Auto.ColumnInformation"/> for more details.</para>
3030
/// <para>This can be fed to the AutoML API when running an experiment.
31-
/// See <typeref cref="ExperimentBase{TMetrics}.Execute(IDataView, ColumnInformation, IEstimator{ITransformer}, System.IProgress{RunDetail{TMetrics}})" />
31+
/// See <typeref cref="ExperimentBase{TMetrics, TExperimentSettings}.Execute(IDataView, ColumnInformation, IEstimator{ITransformer}, System.IProgress{RunDetail{TMetrics}})" />
3232
/// for example.</para>
3333
/// </remarks>
3434
public ColumnInformation ColumnInformation { get; internal set; } = new ColumnInformation();
@@ -42,7 +42,7 @@ public sealed class ColumnInferenceResults
4242
/// it enumerates the dataset columns that AutoML should treat as categorical,
4343
/// the columns AutoML should ignore, which column is the label, etc.</para>
4444
/// <para><see cref="ColumnInformation"/> can be fed to the AutoML API when running an experiment.
45-
/// See <typeref cref="ExperimentBase{TMetrics}.Execute(IDataView, ColumnInformation, IEstimator{ITransformer}, System.IProgress{RunDetail{TMetrics}})" />
45+
/// See <typeref cref="ExperimentBase{TMetrics, TExperimentSettings}.Execute(IDataView, ColumnInformation, IEstimator{ITransformer}, System.IProgress{RunDetail{TMetrics}})" />
4646
/// for example.</para>
4747
/// </remarks>
4848
public sealed class ColumnInformation

src/Microsoft.ML.Auto/API/ExperimentBase.cs

+56-43
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,32 @@ namespace Microsoft.ML.Auto
1212
/// (like <see cref="BinaryClassificationExperiment"/>) inherit from this class.
1313
/// </summary>
1414
/// <typeparam name="TMetrics">Metrics type used by task-specific AutoML experiments.</typeparam>
15-
public abstract class ExperimentBase<TMetrics> where TMetrics : class
15+
/// <typeparam name="TExperimentSettings">Experiment settings type.</typeparam>
16+
public abstract class ExperimentBase<TMetrics, TExperimentSettings>
17+
where TMetrics : class
18+
where TExperimentSettings : ExperimentSettings
1619
{
1720
private protected readonly MLContext Context;
21+
private protected readonly IMetricsAgent<TMetrics> MetricsAgent;
22+
private protected readonly OptimizingMetricInfo OptimizingMetricInfo;
23+
private protected readonly TExperimentSettings Settings;
1824

19-
private readonly IMetricsAgent<TMetrics> _metricsAgent;
20-
private readonly OptimizingMetricInfo _optimizingMetricInfo;
21-
private readonly ExperimentSettings _settings;
25+
private readonly AutoMLLogger _logger;
2226
private readonly TaskKind _task;
2327
private readonly IEnumerable<TrainerName> _trainerWhitelist;
2428

2529
internal ExperimentBase(MLContext context,
2630
IMetricsAgent<TMetrics> metricsAgent,
2731
OptimizingMetricInfo optimizingMetricInfo,
28-
ExperimentSettings settings,
32+
TExperimentSettings settings,
2933
TaskKind task,
3034
IEnumerable<TrainerName> trainerWhitelist)
3135
{
3236
Context = context;
33-
_metricsAgent = metricsAgent;
34-
_optimizingMetricInfo = optimizingMetricInfo;
35-
_settings = settings;
37+
MetricsAgent = metricsAgent;
38+
OptimizingMetricInfo = optimizingMetricInfo;
39+
Settings = settings;
40+
_logger = new AutoMLLogger(context);
3641
_task = task;
3742
_trainerWhitelist = trainerWhitelist;
3843
}
@@ -53,12 +58,11 @@ internal ExperimentBase(MLContext context,
5358
/// <see cref="IProgress{T}.Report(T)"/> after each model it produces during the
5459
/// course of the experiment.
5560
/// </param>
56-
/// <returns>An enumeration of all the runs in an experiment. See <see cref="RunDetail{TMetrics}"/>
57-
/// for more information on the contents of a run.</returns>
61+
/// <returns>The experiment result.</returns>
5862
/// <remarks>
5963
/// Depending on the size of your data, the AutoML experiment could take a long time to execute.
6064
/// </remarks>
61-
public IEnumerable<RunDetail<TMetrics>> Execute(IDataView trainData, string labelColumnName = DefaultColumnNames.Label,
65+
public ExperimentResult<TMetrics> Execute(IDataView trainData, string labelColumnName = DefaultColumnNames.Label,
6266
string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<TMetrics>> progressHandler = null)
6367
{
6468
var columnInformation = new ColumnInformation()
@@ -83,12 +87,11 @@ public IEnumerable<RunDetail<TMetrics>> Execute(IDataView trainData, string labe
8387
/// <see cref="IProgress{T}.Report(T)"/> after each model it produces during the
8488
/// course of the experiment.
8589
/// </param>
86-
/// <returns>An enumeration of all the runs in an experiment. See <see cref="RunDetail{TMetrics}"/>
87-
/// for more information on the contents of a run.</returns>
90+
/// <returns>The experiment result.</returns>
8891
/// <remarks>
8992
/// Depending on the size of your data, the AutoML experiment could take a long time to execute.
9093
/// </remarks>
91-
public IEnumerable<RunDetail<TMetrics>> Execute(IDataView trainData, ColumnInformation columnInformation,
94+
public ExperimentResult<TMetrics> Execute(IDataView trainData, ColumnInformation columnInformation,
9295
IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<TMetrics>> progressHandler = null)
9396
{
9497
// Cross val threshold for # of dataset rows --
@@ -126,12 +129,11 @@ public IEnumerable<RunDetail<TMetrics>> Execute(IDataView trainData, ColumnInfor
126129
/// <see cref="IProgress{T}.Report(T)"/> after each model it produces during the
127130
/// course of the experiment.
128131
/// </param>
129-
/// <returns>An enumeration of all the runs in an experiment. See <see cref="RunDetail{TMetrics}"/>
130-
/// for more information on the contents of a run.</returns>
132+
/// <returns>The experiment result.</returns>
131133
/// <remarks>
132134
/// Depending on the size of your data, the AutoML experiment could take a long time to execute.
133135
/// </remarks>
134-
public IEnumerable<RunDetail<TMetrics>> Execute(IDataView trainData, IDataView validationData, string labelColumnName = DefaultColumnNames.Label, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<TMetrics>> progressHandler = null)
136+
public ExperimentResult<TMetrics> Execute(IDataView trainData, IDataView validationData, string labelColumnName = DefaultColumnNames.Label, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<TMetrics>> progressHandler = null)
135137
{
136138
var columnInformation = new ColumnInformation() { LabelColumnName = labelColumnName };
137139
return Execute(trainData, validationData, columnInformation, preFeaturizer, progressHandler);
@@ -152,12 +154,11 @@ public IEnumerable<RunDetail<TMetrics>> Execute(IDataView trainData, IDataView v
152154
/// <see cref="IProgress{T}.Report(T)"/> after each model it produces during the
153155
/// course of the experiment.
154156
/// </param>
155-
/// <returns>An enumeration of all the runs in an experiment. See <see cref="RunDetail{TMetrics}"/>
156-
/// for more information on the contents of a run.</returns>
157+
/// <returns>The experiment result.</returns>
157158
/// <remarks>
158159
/// Depending on the size of your data, the AutoML experiment could take a long time to execute.
159160
/// </remarks>
160-
public IEnumerable<RunDetail<TMetrics>> Execute(IDataView trainData, IDataView validationData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<TMetrics>> progressHandler = null)
161+
public ExperimentResult<TMetrics> Execute(IDataView trainData, IDataView validationData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<TMetrics>> progressHandler = null)
161162
{
162163
if (validationData == null)
163164
{
@@ -183,12 +184,11 @@ public IEnumerable<RunDetail<TMetrics>> Execute(IDataView trainData, IDataView v
183184
/// <see cref="IProgress{T}.Report(T)"/> after each model it produces during the
184185
/// course of the experiment.
185186
/// </param>
186-
/// <returns>An enumeration of all the runs in an experiment. See <see cref="RunDetail{TMetrics}"/>
187-
/// for more information on the contents of a run.</returns>
187+
/// <returns>The cross validation experiment result.</returns>
188188
/// <remarks>
189189
/// Depending on the size of your data, the AutoML experiment could take a long time to execute.
190190
/// </remarks>
191-
public IEnumerable<CrossValidationRunDetail<TMetrics>> Execute(IDataView trainData, uint numberOfCVFolds, ColumnInformation columnInformation = null, IEstimator<ITransformer> preFeaturizer = null, IProgress<CrossValidationRunDetail<TMetrics>> progressHandler = null)
191+
public CrossValidationExperimentResult<TMetrics> Execute(IDataView trainData, uint numberOfCVFolds, ColumnInformation columnInformation = null, IEstimator<ITransformer> preFeaturizer = null, IProgress<CrossValidationRunDetail<TMetrics>> progressHandler = null)
192192
{
193193
UserInputValidationUtil.ValidateNumberOfCVFoldsArg(numberOfCVFolds);
194194
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numberOfCVFolds, columnInformation?.SamplingKeyColumnName);
@@ -211,12 +211,11 @@ public IEnumerable<CrossValidationRunDetail<TMetrics>> Execute(IDataView trainDa
211211
/// <see cref="IProgress{T}.Report(T)"/> after each model it produces during the
212212
/// course of the experiment.
213213
/// </param>
214-
/// <returns>An enumeration of all the runs in an experiment. See <see cref="RunDetail{TMetrics}"/>
215-
/// for more information on the contents of a run.</returns>
214+
/// <returns>The cross validation experiment result.</returns>
216215
/// <remarks>
217216
/// Depending on the size of your data, the AutoML experiment could take a long time to execute.
218217
/// </remarks>
219-
public IEnumerable<CrossValidationRunDetail<TMetrics>> Execute(IDataView trainData,
218+
public CrossValidationExperimentResult<TMetrics> Execute(IDataView trainData,
220219
uint numberOfCVFolds, string labelColumnName = DefaultColumnNames.Label,
221220
string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizer = null,
222221
Progress<CrossValidationRunDetail<TMetrics>> progressHandler = null)
@@ -229,7 +228,11 @@ public IEnumerable<CrossValidationRunDetail<TMetrics>> Execute(IDataView trainDa
229228
return Execute(trainData, numberOfCVFolds, columnInformation, preFeaturizer, progressHandler);
230229
}
231230

232-
private IEnumerable<RunDetail<TMetrics>> ExecuteTrainValidate(IDataView trainData,
231+
private protected abstract CrossValidationRunDetail<TMetrics> GetBestCrossValRun(IEnumerable<CrossValidationRunDetail<TMetrics>> results);
232+
233+
private protected abstract RunDetail<TMetrics> GetBestRun(IEnumerable<RunDetail<TMetrics>> results);
234+
235+
private ExperimentResult<TMetrics> ExecuteTrainValidate(IDataView trainData,
233236
ColumnInformation columnInfo,
234237
IDataView validationData,
235238
IEstimator<ITransformer> preFeaturizer,
@@ -247,13 +250,13 @@ private IEnumerable<RunDetail<TMetrics>> ExecuteTrainValidate(IDataView trainDat
247250
validationData = preprocessorTransform.Transform(validationData);
248251
}
249252

250-
var runner = new TrainValidateRunner<TMetrics>(Context, trainData, validationData, columnInfo.LabelColumnName, _metricsAgent,
251-
preFeaturizer, preprocessorTransform, _settings.DebugLogger);
253+
var runner = new TrainValidateRunner<TMetrics>(Context, trainData, validationData, columnInfo.LabelColumnName, MetricsAgent,
254+
preFeaturizer, preprocessorTransform, _logger);
252255
var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainData, columnInfo);
253256
return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner);
254257
}
255258

256-
private IEnumerable<CrossValidationRunDetail<TMetrics>> ExecuteCrossVal(IDataView[] trainDatasets,
259+
private CrossValidationExperimentResult<TMetrics> ExecuteCrossVal(IDataView[] trainDatasets,
257260
ColumnInformation columnInfo,
258261
IDataView[] validationDatasets,
259262
IEstimator<ITransformer> preFeaturizer,
@@ -266,13 +269,21 @@ private IEnumerable<CrossValidationRunDetail<TMetrics>> ExecuteCrossVal(IDataVie
266269
ITransformer[] preprocessorTransforms = null;
267270
(trainDatasets, validationDatasets, preprocessorTransforms) = ApplyPreFeaturizerCrossVal(trainDatasets, validationDatasets, preFeaturizer);
268271

269-
var runner = new CrossValRunner<TMetrics>(Context, trainDatasets, validationDatasets, _metricsAgent, preFeaturizer,
270-
preprocessorTransforms, columnInfo.LabelColumnName, _settings.DebugLogger);
272+
var runner = new CrossValRunner<TMetrics>(Context, trainDatasets, validationDatasets, MetricsAgent, preFeaturizer,
273+
preprocessorTransforms, columnInfo.LabelColumnName, _logger);
271274
var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainDatasets[0], columnInfo);
272-
return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner);
275+
276+
// Execute experiment & get all pipelines run
277+
var experiment = new Experiment<CrossValidationRunDetail<TMetrics>, TMetrics>(Context, _task, OptimizingMetricInfo, progressHandler,
278+
Settings, MetricsAgent, _trainerWhitelist, columns, runner, _logger);
279+
var runDetails = experiment.Execute();
280+
281+
var bestRun = GetBestCrossValRun(runDetails);
282+
var experimentResult = new CrossValidationExperimentResult<TMetrics>(runDetails, bestRun);
283+
return experimentResult;
273284
}
274285

275-
private IEnumerable<RunDetail<TMetrics>> ExecuteCrossValSummary(IDataView[] trainDatasets,
286+
private ExperimentResult<TMetrics> ExecuteCrossValSummary(IDataView[] trainDatasets,
276287
ColumnInformation columnInfo,
277288
IDataView[] validationDatasets,
278289
IEstimator<ITransformer> preFeaturizer,
@@ -285,24 +296,26 @@ private IEnumerable<RunDetail<TMetrics>> ExecuteCrossValSummary(IDataView[] trai
285296
ITransformer[] preprocessorTransforms = null;
286297
(trainDatasets, validationDatasets, preprocessorTransforms) = ApplyPreFeaturizerCrossVal(trainDatasets, validationDatasets, preFeaturizer);
287298

288-
var runner = new CrossValSummaryRunner<TMetrics>(Context, trainDatasets, validationDatasets, _metricsAgent, preFeaturizer,
289-
preprocessorTransforms, columnInfo.LabelColumnName, _optimizingMetricInfo, _settings.DebugLogger);
299+
var runner = new CrossValSummaryRunner<TMetrics>(Context, trainDatasets, validationDatasets, MetricsAgent, preFeaturizer,
300+
preprocessorTransforms, columnInfo.LabelColumnName, OptimizingMetricInfo, _logger);
290301
var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainDatasets[0], columnInfo);
291302
return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner);
292303
}
293304

294-
private IEnumerable<TRunDetail> Execute<TRunDetail>(ColumnInformation columnInfo,
305+
private ExperimentResult<TMetrics> Execute(ColumnInformation columnInfo,
295306
DatasetColumnInfo[] columns,
296307
IEstimator<ITransformer> preFeaturizer,
297-
IProgress<TRunDetail> progressHandler,
298-
IRunner<TRunDetail> runner)
299-
where TRunDetail : RunDetail
308+
IProgress<RunDetail<TMetrics>> progressHandler,
309+
IRunner<RunDetail<TMetrics>> runner)
300310
{
301311
// Execute experiment & get all pipelines run
302-
var experiment = new Experiment<TRunDetail, TMetrics>(Context, _task, _optimizingMetricInfo, progressHandler,
303-
_settings, _metricsAgent, _trainerWhitelist, columns, runner);
312+
var experiment = new Experiment<RunDetail<TMetrics>, TMetrics>(Context, _task, OptimizingMetricInfo, progressHandler,
313+
Settings, MetricsAgent, _trainerWhitelist, columns, runner, _logger);
314+
var runDetails = experiment.Execute();
304315

305-
return experiment.Execute();
316+
var bestRun = GetBestRun(runDetails);
317+
var experimentResult = new ExperimentResult<TMetrics>(runDetails, bestRun);
318+
return experimentResult;
306319
}
307320

308321
private static (IDataView[] trainDatasets, IDataView[] validDatasets, ITransformer[] preprocessorTransforms)

0 commit comments

Comments
 (0)