@@ -12,27 +12,32 @@ namespace Microsoft.ML.Auto
12
12
/// (like <see cref="BinaryClassificationExperiment"/>) inherit from this class.
13
13
/// </summary>
14
14
/// <typeparam name="TMetrics">Metrics type used by task-specific AutoML experiments.</typeparam>
15
- public abstract class ExperimentBase < TMetrics > where TMetrics : class
15
+ /// <typeparam name="TExperimentSettings">Experiment settings type.</typeparam>
16
+ public abstract class ExperimentBase < TMetrics , TExperimentSettings >
17
+ where TMetrics : class
18
+ where TExperimentSettings : ExperimentSettings
16
19
{
17
20
private protected readonly MLContext Context ;
21
+ private protected readonly IMetricsAgent < TMetrics > MetricsAgent ;
22
+ private protected readonly OptimizingMetricInfo OptimizingMetricInfo ;
23
+ private protected readonly TExperimentSettings Settings ;
18
24
19
- private readonly IMetricsAgent < TMetrics > _metricsAgent ;
20
- private readonly OptimizingMetricInfo _optimizingMetricInfo ;
21
- private readonly ExperimentSettings _settings ;
25
+ private readonly AutoMLLogger _logger ;
22
26
private readonly TaskKind _task ;
23
27
private readonly IEnumerable < TrainerName > _trainerWhitelist ;
24
28
25
29
internal ExperimentBase ( MLContext context ,
26
30
IMetricsAgent < TMetrics > metricsAgent ,
27
31
OptimizingMetricInfo optimizingMetricInfo ,
28
- ExperimentSettings settings ,
32
+ TExperimentSettings settings ,
29
33
TaskKind task ,
30
34
IEnumerable < TrainerName > trainerWhitelist )
31
35
{
32
36
Context = context ;
33
- _metricsAgent = metricsAgent ;
34
- _optimizingMetricInfo = optimizingMetricInfo ;
35
- _settings = settings ;
37
+ MetricsAgent = metricsAgent ;
38
+ OptimizingMetricInfo = optimizingMetricInfo ;
39
+ Settings = settings ;
40
+ _logger = new AutoMLLogger ( context ) ;
36
41
_task = task ;
37
42
_trainerWhitelist = trainerWhitelist ;
38
43
}
@@ -53,12 +58,11 @@ internal ExperimentBase(MLContext context,
53
58
/// <see cref="IProgress{T}.Report(T)"/> after each model it produces during the
54
59
/// course of the experiment.
55
60
/// </param>
56
- /// <returns>An enumeration of all the runs in an experiment. See <see cref="RunDetail{TMetrics}"/>
57
- /// for more information on the contents of a run.</returns>
61
+ /// <returns>The experiment result.</returns>
58
62
/// <remarks>
59
63
/// Depending on the size of your data, the AutoML experiment could take a long time to execute.
60
64
/// </remarks>
61
- public IEnumerable < RunDetail < TMetrics > > Execute ( IDataView trainData , string labelColumnName = DefaultColumnNames . Label ,
65
+ public ExperimentResult < TMetrics > Execute ( IDataView trainData , string labelColumnName = DefaultColumnNames . Label ,
62
66
string samplingKeyColumn = null , IEstimator < ITransformer > preFeaturizer = null , IProgress < RunDetail < TMetrics > > progressHandler = null )
63
67
{
64
68
var columnInformation = new ColumnInformation ( )
@@ -83,12 +87,11 @@ public IEnumerable<RunDetail<TMetrics>> Execute(IDataView trainData, string labe
83
87
/// <see cref="IProgress{T}.Report(T)"/> after each model it produces during the
84
88
/// course of the experiment.
85
89
/// </param>
86
- /// <returns>An enumeration of all the runs in an experiment. See <see cref="RunDetail{TMetrics}"/>
87
- /// for more information on the contents of a run.</returns>
90
+ /// <returns>The experiment result.</returns>
88
91
/// <remarks>
89
92
/// Depending on the size of your data, the AutoML experiment could take a long time to execute.
90
93
/// </remarks>
91
- public IEnumerable < RunDetail < TMetrics > > Execute ( IDataView trainData , ColumnInformation columnInformation ,
94
+ public ExperimentResult < TMetrics > Execute ( IDataView trainData , ColumnInformation columnInformation ,
92
95
IEstimator < ITransformer > preFeaturizer = null , IProgress < RunDetail < TMetrics > > progressHandler = null )
93
96
{
94
97
// Cross val threshold for # of dataset rows --
@@ -126,12 +129,11 @@ public IEnumerable<RunDetail<TMetrics>> Execute(IDataView trainData, ColumnInfor
126
129
/// <see cref="IProgress{T}.Report(T)"/> after each model it produces during the
127
130
/// course of the experiment.
128
131
/// </param>
129
- /// <returns>An enumeration of all the runs in an experiment. See <see cref="RunDetail{TMetrics}"/>
130
- /// for more information on the contents of a run.</returns>
132
+ /// <returns>The experiment result.</returns>
131
133
/// <remarks>
132
134
/// Depending on the size of your data, the AutoML experiment could take a long time to execute.
133
135
/// </remarks>
134
- public IEnumerable < RunDetail < TMetrics > > Execute ( IDataView trainData , IDataView validationData , string labelColumnName = DefaultColumnNames . Label , IEstimator < ITransformer > preFeaturizer = null , IProgress < RunDetail < TMetrics > > progressHandler = null )
136
+ public ExperimentResult < TMetrics > Execute ( IDataView trainData , IDataView validationData , string labelColumnName = DefaultColumnNames . Label , IEstimator < ITransformer > preFeaturizer = null , IProgress < RunDetail < TMetrics > > progressHandler = null )
135
137
{
136
138
var columnInformation = new ColumnInformation ( ) { LabelColumnName = labelColumnName } ;
137
139
return Execute ( trainData , validationData , columnInformation , preFeaturizer , progressHandler ) ;
@@ -152,12 +154,11 @@ public IEnumerable<RunDetail<TMetrics>> Execute(IDataView trainData, IDataView v
152
154
/// <see cref="IProgress{T}.Report(T)"/> after each model it produces during the
153
155
/// course of the experiment.
154
156
/// </param>
155
- /// <returns>An enumeration of all the runs in an experiment. See <see cref="RunDetail{TMetrics}"/>
156
- /// for more information on the contents of a run.</returns>
157
+ /// <returns>The experiment result.</returns>
157
158
/// <remarks>
158
159
/// Depending on the size of your data, the AutoML experiment could take a long time to execute.
159
160
/// </remarks>
160
- public IEnumerable < RunDetail < TMetrics > > Execute ( IDataView trainData , IDataView validationData , ColumnInformation columnInformation , IEstimator < ITransformer > preFeaturizer = null , IProgress < RunDetail < TMetrics > > progressHandler = null )
161
+ public ExperimentResult < TMetrics > Execute ( IDataView trainData , IDataView validationData , ColumnInformation columnInformation , IEstimator < ITransformer > preFeaturizer = null , IProgress < RunDetail < TMetrics > > progressHandler = null )
161
162
{
162
163
if ( validationData == null )
163
164
{
@@ -183,12 +184,11 @@ public IEnumerable<RunDetail<TMetrics>> Execute(IDataView trainData, IDataView v
183
184
/// <see cref="IProgress{T}.Report(T)"/> after each model it produces during the
184
185
/// course of the experiment.
185
186
/// </param>
186
- /// <returns>An enumeration of all the runs in an experiment. See <see cref="RunDetail{TMetrics}"/>
187
- /// for more information on the contents of a run.</returns>
187
+ /// <returns>The cross validation experiment result.</returns>
188
188
/// <remarks>
189
189
/// Depending on the size of your data, the AutoML experiment could take a long time to execute.
190
190
/// </remarks>
191
- public IEnumerable < CrossValidationRunDetail < TMetrics > > Execute ( IDataView trainData , uint numberOfCVFolds , ColumnInformation columnInformation = null , IEstimator < ITransformer > preFeaturizer = null , IProgress < CrossValidationRunDetail < TMetrics > > progressHandler = null )
191
+ public CrossValidationExperimentResult < TMetrics > Execute ( IDataView trainData , uint numberOfCVFolds , ColumnInformation columnInformation = null , IEstimator < ITransformer > preFeaturizer = null , IProgress < CrossValidationRunDetail < TMetrics > > progressHandler = null )
192
192
{
193
193
UserInputValidationUtil . ValidateNumberOfCVFoldsArg ( numberOfCVFolds ) ;
194
194
var splitResult = SplitUtil . CrossValSplit ( Context , trainData , numberOfCVFolds , columnInformation ? . SamplingKeyColumnName ) ;
@@ -211,12 +211,11 @@ public IEnumerable<CrossValidationRunDetail<TMetrics>> Execute(IDataView trainDa
211
211
/// <see cref="IProgress{T}.Report(T)"/> after each model it produces during the
212
212
/// course of the experiment.
213
213
/// </param>
214
- /// <returns>An enumeration of all the runs in an experiment. See <see cref="RunDetail{TMetrics}"/>
215
- /// for more information on the contents of a run.</returns>
214
+ /// <returns>The cross validation experiment result.</returns>
216
215
/// <remarks>
217
216
/// Depending on the size of your data, the AutoML experiment could take a long time to execute.
218
217
/// </remarks>
219
- public IEnumerable < CrossValidationRunDetail < TMetrics > > Execute ( IDataView trainData ,
218
+ public CrossValidationExperimentResult < TMetrics > Execute ( IDataView trainData ,
220
219
uint numberOfCVFolds , string labelColumnName = DefaultColumnNames . Label ,
221
220
string samplingKeyColumn = null , IEstimator < ITransformer > preFeaturizer = null ,
222
221
Progress < CrossValidationRunDetail < TMetrics > > progressHandler = null )
@@ -229,7 +228,11 @@ public IEnumerable<CrossValidationRunDetail<TMetrics>> Execute(IDataView trainDa
229
228
return Execute ( trainData , numberOfCVFolds , columnInformation , preFeaturizer , progressHandler ) ;
230
229
}
231
230
232
- private IEnumerable < RunDetail < TMetrics > > ExecuteTrainValidate ( IDataView trainData ,
231
+ private protected abstract CrossValidationRunDetail < TMetrics > GetBestCrossValRun ( IEnumerable < CrossValidationRunDetail < TMetrics > > results ) ;
232
+
233
+ private protected abstract RunDetail < TMetrics > GetBestRun ( IEnumerable < RunDetail < TMetrics > > results ) ;
234
+
235
+ private ExperimentResult < TMetrics > ExecuteTrainValidate ( IDataView trainData ,
233
236
ColumnInformation columnInfo ,
234
237
IDataView validationData ,
235
238
IEstimator < ITransformer > preFeaturizer ,
@@ -247,13 +250,13 @@ private IEnumerable<RunDetail<TMetrics>> ExecuteTrainValidate(IDataView trainDat
247
250
validationData = preprocessorTransform . Transform ( validationData ) ;
248
251
}
249
252
250
- var runner = new TrainValidateRunner < TMetrics > ( Context , trainData , validationData , columnInfo . LabelColumnName , _metricsAgent ,
251
- preFeaturizer , preprocessorTransform , _settings . DebugLogger ) ;
253
+ var runner = new TrainValidateRunner < TMetrics > ( Context , trainData , validationData , columnInfo . LabelColumnName , MetricsAgent ,
254
+ preFeaturizer , preprocessorTransform , _logger ) ;
252
255
var columns = DatasetColumnInfoUtil . GetDatasetColumnInfo ( Context , trainData , columnInfo ) ;
253
256
return Execute ( columnInfo , columns , preFeaturizer , progressHandler , runner ) ;
254
257
}
255
258
256
- private IEnumerable < CrossValidationRunDetail < TMetrics > > ExecuteCrossVal ( IDataView [ ] trainDatasets ,
259
+ private CrossValidationExperimentResult < TMetrics > ExecuteCrossVal ( IDataView [ ] trainDatasets ,
257
260
ColumnInformation columnInfo ,
258
261
IDataView [ ] validationDatasets ,
259
262
IEstimator < ITransformer > preFeaturizer ,
@@ -266,13 +269,21 @@ private IEnumerable<CrossValidationRunDetail<TMetrics>> ExecuteCrossVal(IDataVie
266
269
ITransformer [ ] preprocessorTransforms = null ;
267
270
( trainDatasets , validationDatasets , preprocessorTransforms ) = ApplyPreFeaturizerCrossVal ( trainDatasets , validationDatasets , preFeaturizer ) ;
268
271
269
- var runner = new CrossValRunner < TMetrics > ( Context , trainDatasets , validationDatasets , _metricsAgent , preFeaturizer ,
270
- preprocessorTransforms , columnInfo . LabelColumnName , _settings . DebugLogger ) ;
272
+ var runner = new CrossValRunner < TMetrics > ( Context , trainDatasets , validationDatasets , MetricsAgent , preFeaturizer ,
273
+ preprocessorTransforms , columnInfo . LabelColumnName , _logger ) ;
271
274
var columns = DatasetColumnInfoUtil . GetDatasetColumnInfo ( Context , trainDatasets [ 0 ] , columnInfo ) ;
272
- return Execute ( columnInfo , columns , preFeaturizer , progressHandler , runner ) ;
275
+
276
+ // Execute experiment & get all pipelines run
277
+ var experiment = new Experiment < CrossValidationRunDetail < TMetrics > , TMetrics > ( Context , _task , OptimizingMetricInfo , progressHandler ,
278
+ Settings , MetricsAgent , _trainerWhitelist , columns , runner , _logger ) ;
279
+ var runDetails = experiment . Execute ( ) ;
280
+
281
+ var bestRun = GetBestCrossValRun ( runDetails ) ;
282
+ var experimentResult = new CrossValidationExperimentResult < TMetrics > ( runDetails , bestRun ) ;
283
+ return experimentResult ;
273
284
}
274
285
275
- private IEnumerable < RunDetail < TMetrics > > ExecuteCrossValSummary ( IDataView [ ] trainDatasets ,
286
+ private ExperimentResult < TMetrics > ExecuteCrossValSummary ( IDataView [ ] trainDatasets ,
276
287
ColumnInformation columnInfo ,
277
288
IDataView [ ] validationDatasets ,
278
289
IEstimator < ITransformer > preFeaturizer ,
@@ -285,24 +296,26 @@ private IEnumerable<RunDetail<TMetrics>> ExecuteCrossValSummary(IDataView[] trai
285
296
ITransformer [ ] preprocessorTransforms = null ;
286
297
( trainDatasets , validationDatasets , preprocessorTransforms ) = ApplyPreFeaturizerCrossVal ( trainDatasets , validationDatasets , preFeaturizer ) ;
287
298
288
- var runner = new CrossValSummaryRunner < TMetrics > ( Context , trainDatasets , validationDatasets , _metricsAgent , preFeaturizer ,
289
- preprocessorTransforms , columnInfo . LabelColumnName , _optimizingMetricInfo , _settings . DebugLogger ) ;
299
+ var runner = new CrossValSummaryRunner < TMetrics > ( Context , trainDatasets , validationDatasets , MetricsAgent , preFeaturizer ,
300
+ preprocessorTransforms , columnInfo . LabelColumnName , OptimizingMetricInfo , _logger ) ;
290
301
var columns = DatasetColumnInfoUtil . GetDatasetColumnInfo ( Context , trainDatasets [ 0 ] , columnInfo ) ;
291
302
return Execute ( columnInfo , columns , preFeaturizer , progressHandler , runner ) ;
292
303
}
293
304
294
- private IEnumerable < TRunDetail > Execute < TRunDetail > ( ColumnInformation columnInfo ,
305
+ private ExperimentResult < TMetrics > Execute ( ColumnInformation columnInfo ,
295
306
DatasetColumnInfo [ ] columns ,
296
307
IEstimator < ITransformer > preFeaturizer ,
297
- IProgress < TRunDetail > progressHandler ,
298
- IRunner < TRunDetail > runner )
299
- where TRunDetail : RunDetail
308
+ IProgress < RunDetail < TMetrics > > progressHandler ,
309
+ IRunner < RunDetail < TMetrics > > runner )
300
310
{
301
311
// Execute experiment & get all pipelines run
302
- var experiment = new Experiment < TRunDetail , TMetrics > ( Context , _task , _optimizingMetricInfo , progressHandler ,
303
- _settings , _metricsAgent , _trainerWhitelist , columns , runner ) ;
312
+ var experiment = new Experiment < RunDetail < TMetrics > , TMetrics > ( Context , _task , OptimizingMetricInfo , progressHandler ,
313
+ Settings , MetricsAgent , _trainerWhitelist , columns , runner , _logger ) ;
314
+ var runDetails = experiment . Execute ( ) ;
304
315
305
- return experiment . Execute ( ) ;
316
+ var bestRun = GetBestRun ( runDetails ) ;
317
+ var experimentResult = new ExperimentResult < TMetrics > ( runDetails , bestRun ) ;
318
+ return experimentResult ;
306
319
}
307
320
308
321
private static ( IDataView [ ] trainDatasets , IDataView [ ] validDatasets , ITransformer [ ] preprocessorTransforms )
0 commit comments