@@ -16,22 +16,22 @@ internal class AutoFitter
16
16
private readonly IDebugLogger _debugLogger ;
17
17
private readonly IList < InferredPipelineRunResult > _history ;
18
18
private readonly string _label ;
19
- private readonly MLContext _mlContext ;
19
+ private readonly MLContext _context ;
20
20
private readonly OptimizingMetricInfo _optimizingMetricInfo ;
21
21
private readonly IDictionary < string , ColumnPurpose > _purposeOverrides ;
22
22
private readonly AutoFitSettings _settings ;
23
23
private readonly IDataView _trainData ;
24
24
private readonly TaskKind _task ;
25
25
private readonly IDataView _validationData ;
26
26
27
- public AutoFitter ( MLContext mlContext , OptimizingMetricInfo metricInfo , AutoFitSettings settings ,
27
+ public AutoFitter ( MLContext context , OptimizingMetricInfo metricInfo , AutoFitSettings settings ,
28
28
TaskKind task , string label , IDataView trainData , IDataView validationData ,
29
29
IDictionary < string , ColumnPurpose > purposeOverrides , IDebugLogger debugLogger )
30
30
{
31
31
_debugLogger = debugLogger ;
32
32
_history = new List < InferredPipelineRunResult > ( ) ;
33
33
_label = label ;
34
- _mlContext = mlContext ;
34
+ _context = context ;
35
35
_optimizingMetricInfo = metricInfo ;
36
36
_settings = settings ?? new AutoFitSettings ( ) ;
37
37
_purposeOverrides = purposeOverrides ;
@@ -49,13 +49,13 @@ public InferredPipelineRunResult[] Fit()
49
49
private void IteratePipelinesAndFit ( )
50
50
{
51
51
var stopwatch = Stopwatch . StartNew ( ) ;
52
- var transforms = TransformInferenceApi . InferTransforms ( _mlContext , _trainData , _label , _purposeOverrides ) ;
53
- var availableTrainers = RecipeInference . AllowedTrainers ( _mlContext , _task , _settings . StoppingCriteria . MaxIterations ) ;
52
+ var columns = AutoMlUtils . GetColumnInfoTuples ( _context , _trainData , _label , _purposeOverrides ) ;
54
53
55
54
do
56
55
{
57
56
// get next pipeline
58
- var pipeline = PipelineSuggester . GetNextInferredPipeline ( _history , transforms , availableTrainers , _optimizingMetricInfo . IsMaximizing ) ;
57
+ var iterationsRemaining = _settings . StoppingCriteria . MaxIterations - _history . Count ;
58
+ var pipeline = PipelineSuggester . GetNextInferredPipeline ( _history , columns , _task , iterationsRemaining , _optimizingMetricInfo . IsMaximizing ) ;
59
59
60
60
// break if no candidates returned, means no valid pipeline available
61
61
if ( pipeline == null )
@@ -113,11 +113,11 @@ private object GetEvaluatedMetrics(IDataView scoredData)
113
113
switch ( _task )
114
114
{
115
115
case TaskKind . BinaryClassification :
116
- return _mlContext . BinaryClassification . EvaluateNonCalibrated ( scoredData ) ;
116
+ return _context . BinaryClassification . EvaluateNonCalibrated ( scoredData ) ;
117
117
case TaskKind . MulticlassClassification :
118
- return _mlContext . MulticlassClassification . Evaluate ( scoredData ) ;
118
+ return _context . MulticlassClassification . Evaluate ( scoredData ) ;
119
119
case TaskKind . Regression :
120
- return _mlContext . Regression . Evaluate ( scoredData ) ;
120
+ return _context . Regression . Evaluate ( scoredData ) ;
121
121
// should not be possible to reach here
122
122
default :
123
123
throw new InvalidOperationException ( $ "unsupported machine learning task type { _task } ") ;
0 commit comments