@@ -14,32 +14,32 @@ internal static class PipelineSuggester
14
14
{
15
15
private const int TopKTrainers = 3 ;
16
16
17
- public static Pipeline GetNextPipeline ( IEnumerable < PipelineScore > history ,
17
+ public static Pipeline GetNextPipeline ( MLContext context ,
18
+ IEnumerable < PipelineScore > history ,
18
19
( string , ColumnType , ColumnPurpose , ColumnDimensions ) [ ] columns ,
19
20
TaskKind task ,
20
21
bool isMaximizingMetric = true )
21
22
{
22
- var inferredHistory = history . Select ( r => SuggestedPipelineResult . FromPipelineRunResult ( r ) ) ;
23
- var nextInferredPipeline = GetNextInferredPipeline ( inferredHistory , columns , task , isMaximizingMetric ) ;
23
+ var inferredHistory = history . Select ( r => SuggestedPipelineResult . FromPipelineRunResult ( context , r ) ) ;
24
+ var nextInferredPipeline = GetNextInferredPipeline ( context , inferredHistory , columns , task , isMaximizingMetric ) ;
24
25
return nextInferredPipeline ? . ToPipeline ( ) ;
25
26
}
26
27
27
- public static SuggestedPipeline GetNextInferredPipeline ( IEnumerable < SuggestedPipelineResult > history ,
28
+ public static SuggestedPipeline GetNextInferredPipeline ( MLContext context ,
29
+ IEnumerable < SuggestedPipelineResult > history ,
28
30
( string , ColumnType , ColumnPurpose , ColumnDimensions ) [ ] columns ,
29
31
TaskKind task ,
30
32
bool isMaximizingMetric ,
31
33
IEnumerable < TrainerName > trainerWhitelist = null )
32
34
{
33
- var context = new MLContext ( ) ;
34
-
35
35
var availableTrainers = RecipeInference . AllowedTrainers ( context , task , trainerWhitelist ) ;
36
36
var transforms = CalculateTransforms ( context , columns , task ) ;
37
37
//var transforms = TransformInferenceApi.InferTransforms(context, columns, task);
38
38
39
39
// if we haven't run all pipelines once
40
40
if ( history . Count ( ) < availableTrainers . Count ( ) )
41
41
{
42
- return GetNextFirstStagePipeline ( history , availableTrainers , transforms ) ;
42
+ return GetNextFirstStagePipeline ( context , history , availableTrainers , transforms ) ;
43
43
}
44
44
45
45
// get top trainers from stage 1 runs
@@ -63,14 +63,14 @@ public static SuggestedPipeline GetNextInferredPipeline(IEnumerable<SuggestedPip
63
63
do
64
64
{
65
65
// sample new hyperparameters for the learner
66
- if ( ! SampleHyperparameters ( newTrainer , history , isMaximizingMetric ) )
66
+ if ( ! SampleHyperparameters ( context , newTrainer , history , isMaximizingMetric ) )
67
67
{
68
68
// if unable to sample new hyperparameters for the learner
69
69
// (ie SMAC returned 0 suggestions), break
70
70
break ;
71
71
}
72
72
73
- var suggestedPipeline = new SuggestedPipeline ( transforms , newTrainer ) ;
73
+ var suggestedPipeline = new SuggestedPipeline ( transforms , newTrainer , context ) ;
74
74
75
75
// make sure we have not seen pipeline before
76
76
if ( ! visitedPipelines . Contains ( suggestedPipeline ) )
@@ -113,12 +113,13 @@ private static IEnumerable<SuggestedTrainer> OrderTrainersByNumTrials(IEnumerabl
113
113
. Select ( x => x . First ( ) . Pipeline . Trainer ) ;
114
114
}
115
115
116
- private static SuggestedPipeline GetNextFirstStagePipeline ( IEnumerable < SuggestedPipelineResult > history ,
116
+ private static SuggestedPipeline GetNextFirstStagePipeline ( MLContext context ,
117
+ IEnumerable < SuggestedPipelineResult > history ,
117
118
IEnumerable < SuggestedTrainer > availableTrainers ,
118
119
IEnumerable < SuggestedTransform > transforms )
119
120
{
120
121
var trainer = availableTrainers . ElementAt ( history . Count ( ) ) ;
121
- return new SuggestedPipeline ( transforms , trainer ) ;
122
+ return new SuggestedPipeline ( transforms , trainer , context ) ;
122
123
}
123
124
124
125
private static IValueGenerator [ ] ConvertToValueGenerators ( IEnumerable < SweepableParam > hps )
@@ -184,10 +185,10 @@ private static IValueGenerator[] ConvertToValueGenerators(IEnumerable<SweepableP
184
185
/// Samples new hyperparameters for the trainer, and sets them.
185
186
/// Returns true if success (new hyperparams were suggested and set). Else, returns false.
186
187
/// </summary>
187
- private static bool SampleHyperparameters ( SuggestedTrainer trainer , IEnumerable < SuggestedPipelineResult > history , bool isMaximizingMetric )
188
+ private static bool SampleHyperparameters ( MLContext context , SuggestedTrainer trainer , IEnumerable < SuggestedPipelineResult > history , bool isMaximizingMetric )
188
189
{
189
190
var sps = ConvertToValueGenerators ( trainer . SweepParams ) ;
190
- var sweeper = new SmacSweeper (
191
+ var sweeper = new SmacSweeper ( context ,
191
192
new SmacSweeper . Arguments
192
193
{
193
194
SweptParameters = sps
0 commit comments