@@ -41,25 +41,33 @@ public static InferredPipeline GetNextInferredPipeline(IEnumerable<InferredPipel
41
41
return GetNextFirstStagePipeline ( history , availableTrainers , transforms ) ;
42
42
}
43
43
44
- // get next trainer
44
+ // get top trainers from stage 1 runs
45
45
var topTrainers = GetTopTrainers ( history , availableTrainers , isMaximizingMetric ) ;
46
- var nextTrainerIndex = ( history . Count ( ) - availableTrainers . Count ( ) ) % topTrainers . Count ( ) ;
47
- var trainer = topTrainers . ElementAt ( nextTrainerIndex ) . Clone ( ) ;
48
-
49
- // make sure we have not seen pipeline before.
50
- // repeat until passes or runs out of chances.
51
- var visitedPipelines = new HashSet < InferredPipeline > ( history . Select ( h => h . Pipeline ) ) ;
52
- const int maxNumberAttempts = 10 ;
53
- var count = 0 ;
54
- do
46
+
47
+ // sort top trainers by # of times they've been run, from lowest to highest
48
+ var orderedTopTrainers = OrderTrainersByNumTrials ( history , topTrainers ) ;
49
+
50
+ // iterate over top trainers (from least run to most run),
51
+ // to find next pipeline
52
+ foreach ( var trainer in orderedTopTrainers )
55
53
{
56
- SampleHyperparameters ( trainer , history , isMaximizingMetric ) ;
57
- var pipeline = new InferredPipeline ( transforms , trainer ) ;
58
- if ( ! visitedPipelines . Contains ( pipeline ) )
54
+ var newTrainer = trainer . Clone ( ) ;
55
+
56
+ // make sure we have not seen pipeline before.
57
+ // repeat until passes or runs out of chances
58
+ var visitedPipelines = new HashSet < InferredPipeline > ( history . Select ( h => h . Pipeline ) ) ;
59
+ const int maxNumberAttempts = 10 ;
60
+ var count = 0 ;
61
+ do
59
62
{
60
- return pipeline ;
61
- }
62
- } while ( ++ count <= maxNumberAttempts ) ;
63
+ SampleHyperparameters ( newTrainer , history , isMaximizingMetric ) ;
64
+ var pipeline = new InferredPipeline ( transforms , newTrainer ) ;
65
+ if ( ! visitedPipelines . Contains ( pipeline ) )
66
+ {
67
+ return pipeline ;
68
+ }
69
+ } while ( ++ count <= maxNumberAttempts ) ;
70
+ }
63
71
64
72
return null ;
65
73
}
@@ -84,6 +92,16 @@ private static IEnumerable<SuggestedTrainer> GetTopTrainers(IEnumerable<Inferred
84
92
return topTrainers ;
85
93
}
86
94
95
+ private static IEnumerable < SuggestedTrainer > OrderTrainersByNumTrials ( IEnumerable < InferredPipelineRunResult > history ,
96
+ IEnumerable < SuggestedTrainer > selectedTrainers )
97
+ {
98
+ var selectedTrainerNames = new HashSet < TrainerName > ( selectedTrainers . Select ( t => t . TrainerName ) ) ;
99
+ return history . Where ( h => selectedTrainerNames . Contains ( h . Pipeline . Trainer . TrainerName ) )
100
+ . GroupBy ( h => h . Pipeline . Trainer . TrainerName )
101
+ . OrderBy ( x => x . Count ( ) )
102
+ . Select ( x => x . First ( ) . Pipeline . Trainer ) ;
103
+ }
104
+
87
105
private static InferredPipeline GetNextFirstStagePipeline ( IEnumerable < InferredPipelineRunResult > history ,
88
106
IEnumerable < SuggestedTrainer > availableTrainers ,
89
107
IEnumerable < SuggestedTransform > transforms )
0 commit comments