@@ -12,7 +12,18 @@ internal static class PipelineSuggester
12
12
{
13
13
private const int TopKTrainers = 3 ;
14
14
15
- public static InferredPipeline GetNextPipeline ( IEnumerable < PipelineRunResult > history ,
15
+ public static Pipeline GetNextPipeline ( IEnumerable < PipelineRunResult > history ,
16
+ IEnumerable < SuggestedTransform > transforms ,
17
+ IEnumerable < SuggestedTrainer > availableTrainers ,
18
+ bool isMaximizingMetric = true )
19
+ {
20
+ var inferredHistory = history . Select ( r => InferredPipelineRunResult . FromPipelineRunResult ( r ) ) ;
21
+ var nextInferredPipeline = GetNextInferredPipeline ( inferredHistory ,
22
+ transforms , availableTrainers , isMaximizingMetric ) ;
23
+ return nextInferredPipeline . ToPipeline ( ) ;
24
+ }
25
+
26
+ public static InferredPipeline GetNextInferredPipeline ( IEnumerable < InferredPipelineRunResult > history ,
16
27
IEnumerable < SuggestedTransform > transforms ,
17
28
IEnumerable < SuggestedTrainer > availableTrainers ,
18
29
bool isMaximizingMetric = true )
@@ -49,15 +60,15 @@ public static InferredPipeline GetNextPipeline(IEnumerable<PipelineRunResult> hi
49
60
/// <summary>
50
61
/// Get top trainers from first stage
51
62
/// </summary>
52
- private static IEnumerable < SuggestedTrainer > GetTopTrainers ( IEnumerable < PipelineRunResult > history ,
63
+ private static IEnumerable < SuggestedTrainer > GetTopTrainers ( IEnumerable < InferredPipelineRunResult > history ,
53
64
IEnumerable < SuggestedTrainer > availableTrainers ,
54
65
bool isMaximizingMetric )
55
66
{
56
67
// narrow history to first stage runs
57
68
history = history . Take ( availableTrainers . Count ( ) ) ;
58
69
59
70
history = history . GroupBy ( r => r . Pipeline . Trainer . TrainerName ) . Select ( g => g . First ( ) ) ;
60
- IEnumerable < PipelineRunResult > sortedHistory = history . OrderBy ( r => r . Score ) ;
71
+ IEnumerable < InferredPipelineRunResult > sortedHistory = history . OrderBy ( r => r . Score ) ;
61
72
if ( isMaximizingMetric )
62
73
{
63
74
sortedHistory = sortedHistory . Reverse ( ) ;
@@ -66,7 +77,7 @@ private static IEnumerable<SuggestedTrainer> GetTopTrainers(IEnumerable<Pipeline
66
77
return topTrainers ;
67
78
}
68
79
69
- private static InferredPipeline GetNextFirstStagePipeline ( IEnumerable < PipelineRunResult > history ,
80
+ private static InferredPipeline GetNextFirstStagePipeline ( IEnumerable < InferredPipelineRunResult > history ,
70
81
IEnumerable < SuggestedTrainer > availableTrainers ,
71
82
IEnumerable < SuggestedTransform > transforms )
72
83
{
@@ -133,7 +144,7 @@ private static IValueGenerator[] ConvertToValueGenerators(IEnumerable<SweepableP
133
144
return results ;
134
145
}
135
146
136
- private static void SampleHyperparameters ( SuggestedTrainer trainer , IEnumerable < PipelineRunResult > history , bool isMaximizingMetric )
147
+ private static void SampleHyperparameters ( SuggestedTrainer trainer , IEnumerable < InferredPipelineRunResult > history , bool isMaximizingMetric )
137
148
{
138
149
var sps = ConvertToValueGenerators ( trainer . SweepParams ) ;
139
150
var sweeper = new SmacSweeper (
@@ -142,7 +153,7 @@ private static void SampleHyperparameters(SuggestedTrainer trainer, IEnumerable<
142
153
SweptParameters = sps
143
154
} ) ;
144
155
145
- IEnumerable < PipelineRunResult > historyToUse = history
156
+ IEnumerable < InferredPipelineRunResult > historyToUse = history
146
157
. Where ( r => r . RunSucceded && r . Pipeline . Trainer . TrainerName == trainer . TrainerName && r . Pipeline . Trainer . HyperParamSet != null ) ;
147
158
148
159
// get new set of hyperparameter values
0 commit comments