Skip to content

Commit 6609cd9

Browse files
authored
fix bug where if one pipeline hyperparam optimization converges, run terminates (dotnet#36)
1 parent bf42ba5 commit 6609cd9

File tree

1 file changed

+34
-16
lines changed

1 file changed

+34
-16
lines changed

src/AutoML/PipelineSuggesters/PipelineSuggester.cs

+34-16
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,33 @@ public static InferredPipeline GetNextInferredPipeline(IEnumerable<InferredPipel
4141
return GetNextFirstStagePipeline(history, availableTrainers, transforms);
4242
}
4343

44-
// get next trainer
44+
// get top trainers from stage 1 runs
4545
var topTrainers = GetTopTrainers(history, availableTrainers, isMaximizingMetric);
46-
var nextTrainerIndex = (history.Count() - availableTrainers.Count()) % topTrainers.Count();
47-
var trainer = topTrainers.ElementAt(nextTrainerIndex).Clone();
48-
49-
// make sure we have not seen pipeline before.
50-
// repeat until passes or runs out of chances.
51-
var visitedPipelines = new HashSet<InferredPipeline>(history.Select(h => h.Pipeline));
52-
const int maxNumberAttempts = 10;
53-
var count = 0;
54-
do
46+
47+
// sort top trainers by # of times they've been run, from lowest to highest
48+
var orderedTopTrainers = OrderTrainersByNumTrials(history, topTrainers);
49+
50+
// iterate over top trainers (from least run to most run),
51+
// to find next pipeline
52+
foreach(var trainer in orderedTopTrainers)
5553
{
56-
SampleHyperparameters(trainer, history, isMaximizingMetric);
57-
var pipeline = new InferredPipeline(transforms, trainer);
58-
if(!visitedPipelines.Contains(pipeline))
54+
var newTrainer = trainer.Clone();
55+
56+
// make sure we have not seen pipeline before.
57+
// repeat until passes or runs out of chances
58+
var visitedPipelines = new HashSet<InferredPipeline>(history.Select(h => h.Pipeline));
59+
const int maxNumberAttempts = 10;
60+
var count = 0;
61+
do
5962
{
60-
return pipeline;
61-
}
62-
} while (++count <= maxNumberAttempts);
63+
SampleHyperparameters(newTrainer, history, isMaximizingMetric);
64+
var pipeline = new InferredPipeline(transforms, newTrainer);
65+
if (!visitedPipelines.Contains(pipeline))
66+
{
67+
return pipeline;
68+
}
69+
} while (++count <= maxNumberAttempts);
70+
}
6371

6472
return null;
6573
}
@@ -84,6 +92,16 @@ private static IEnumerable<SuggestedTrainer> GetTopTrainers(IEnumerable<Inferred
8492
return topTrainers;
8593
}
8694

95+
private static IEnumerable<SuggestedTrainer> OrderTrainersByNumTrials(IEnumerable<InferredPipelineRunResult> history,
96+
IEnumerable<SuggestedTrainer> selectedTrainers)
97+
{
98+
var selectedTrainerNames = new HashSet<TrainerName>(selectedTrainers.Select(t => t.TrainerName));
99+
return history.Where(h => selectedTrainerNames.Contains(h.Pipeline.Trainer.TrainerName))
100+
.GroupBy(h => h.Pipeline.Trainer.TrainerName)
101+
.OrderBy(x => x.Count())
102+
.Select(x => x.First().Pipeline.Trainer);
103+
}
104+
87105
private static InferredPipeline GetNextFirstStagePipeline(IEnumerable<InferredPipelineRunResult> history,
88106
IEnumerable<SuggestedTrainer> availableTrainers,
89107
IEnumerable<SuggestedTransform> transforms)

0 commit comments

Comments
 (0)