6
6
using System . Collections . Generic ;
7
7
using System . Diagnostics ;
8
8
using System . Text ;
9
- using System . Threading ;
10
9
using Microsoft . Data . DataView ;
11
10
using Microsoft . ML . Core . Data ;
12
- using Microsoft . ML . Data ;
13
11
14
12
namespace Microsoft . ML . Auto
15
13
{
@@ -23,6 +21,8 @@ internal class AutoFitter<T> where T : class
23
21
private readonly IEstimator < ITransformer > _preFeaturizers ;
24
22
private readonly IProgress < RunResult < T > > _progressCallback ;
25
23
private readonly ExperimentSettings _experimentSettings ;
24
+ private readonly IDataScorer < T > _dataScorer ;
25
+ private readonly IEnumerable < TrainerName > _trainerWhitelist ;
26
26
27
27
private IDataView _trainData ;
28
28
private IDataView _validationData ;
@@ -35,9 +35,11 @@ public AutoFitter(MLContext context,
35
35
ColumnInformation columnInfo ,
36
36
IDataView validationData ,
37
37
IEstimator < ITransformer > preFeaturizers ,
38
- OptimizingMetric metric ,
38
+ OptimizingMetricInfo metricInfo ,
39
39
IProgress < RunResult < T > > progressCallback ,
40
- ExperimentSettings experimentSettings )
40
+ ExperimentSettings experimentSettings ,
41
+ IDataScorer < T > dataScorer ,
42
+ IEnumerable < TrainerName > trainerWhitelist )
41
43
{
42
44
if ( validationData == null )
43
45
{
@@ -49,11 +51,13 @@ public AutoFitter(MLContext context,
49
51
_history = new List < SuggestedPipelineResult < T > > ( ) ;
50
52
_columnInfo = columnInfo ;
51
53
_context = context ;
52
- _optimizingMetricInfo = new OptimizingMetricInfo ( metric ) ;
54
+ _optimizingMetricInfo = metricInfo ;
53
55
_task = task ;
54
56
_preFeaturizers = preFeaturizers ;
55
57
_progressCallback = progressCallback ;
56
- _experimentSettings = experimentSettings ?? new ExperimentSettings ( ) ;
58
+ _experimentSettings = experimentSettings ;
59
+ _dataScorer = dataScorer ;
60
+ _trainerWhitelist = trainerWhitelist ;
57
61
}
58
62
59
63
public List < RunResult < T > > Fit ( )
@@ -81,7 +85,7 @@ public List<RunResult<T>> Fit()
81
85
var getPiplelineStopwatch = Stopwatch . StartNew ( ) ;
82
86
83
87
// get next pipeline
84
- pipeline = PipelineSuggester . GetNextInferredPipeline ( _history , columns , _task , _optimizingMetricInfo . IsMaximizing ) ;
88
+ pipeline = PipelineSuggester . GetNextInferredPipeline ( _history , columns , _task , _optimizingMetricInfo . IsMaximizing , _trainerWhitelist ) ;
85
89
86
90
getPiplelineStopwatch . Stop ( ) ;
87
91
@@ -144,9 +148,9 @@ private SuggestedPipelineResult<T> ProcessPipeline(SuggestedPipeline pipeline)
144
148
{
145
149
var pipelineModel = pipeline . Fit ( _trainData ) ;
146
150
var scoredValidationData = pipelineModel . Transform ( _validationData ) ;
147
- var evaluatedMetrics = GetEvaluatedMetrics ( scoredValidationData ) ;
148
- var score = GetPipelineScore ( evaluatedMetrics ) ;
149
- runResult = new SuggestedPipelineResult < T > ( evaluatedMetrics , pipelineModel , pipeline , score , null ) ;
151
+ var metrics = GetEvaluatedMetrics ( scoredValidationData ) ;
152
+ var score = _dataScorer . GetScore ( metrics ) ;
153
+ runResult = new SuggestedPipelineResult < T > ( metrics , pipelineModel , pipeline , score , null ) ;
150
154
}
151
155
catch ( Exception ex )
152
156
{
@@ -177,26 +181,6 @@ private T GetEvaluatedMetrics(IDataView scoredData)
177
181
}
178
182
}
179
183
180
- private double GetPipelineScore ( object evaluatedMetrics )
181
- {
182
- var type = evaluatedMetrics . GetType ( ) ;
183
- if ( type == typeof ( BinaryClassificationMetrics ) )
184
- {
185
- return ( ( BinaryClassificationMetrics ) evaluatedMetrics ) . Accuracy ;
186
- }
187
- if ( type == typeof ( MultiClassClassifierMetrics ) )
188
- {
189
- return ( ( MultiClassClassifierMetrics ) evaluatedMetrics ) . AccuracyMicro ;
190
- }
191
- if ( type == typeof ( RegressionMetrics ) )
192
- {
193
- return ( ( RegressionMetrics ) evaluatedMetrics ) . RSquared ;
194
- }
195
-
196
- // should not be possible to reach here
197
- throw new InvalidOperationException ( $ "unsupported machine learning task type { _task } ") ;
198
- }
199
-
200
184
private void WriteIterationLog ( SuggestedPipeline pipeline , SuggestedPipelineResult runResult , Stopwatch stopwatch )
201
185
{
202
186
// debug log pipeline result
0 commit comments