Skip to content

Commit 184e661

Browse files
smac - ignore fail trial during initialize (#6738)
1 parent 8f8905e commit 184e661

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

src/Microsoft.ML.AutoML/Tuner/SmacTuner.cs

+10-1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ public Parameter Propose(TrialSettings settings)
111111
}
112112
}
113113

114+
// test purpose
115+
internal Queue<Parameter> Candidates => _candidates;
116+
117+
// test purpose
118+
internal List<TrialResult> Histories => _histories;
119+
114120
private FastForestRegressionModelParameters FitModel(IEnumerable<TrialResult> history)
115121
{
116122
Single[] losses = new Single[history.Count()];
@@ -357,7 +363,10 @@ private double ComputeEI(double bestLoss, double[] forestStatistics)
357363

358364
public void Update(TrialResult result)
359365
{
360-
_histories.Add(result);
366+
if (!double.IsNaN(result.Loss) && !double.IsInfinity(result.Loss))
367+
{
368+
_histories.Add(result);
369+
}
361370
}
362371
}
363372
}

test/Microsoft.ML.AutoML.Tests/TunerTests.cs

+33
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,39 @@ public void tuner_e2e_test()
6666
}
6767
}
6868

69+
[Fact]
70+
public void Smac_should_ignore_fail_trials_during_initialize()
71+
{
72+
// fix for https://github.com/dotnet/machinelearning-modelbuilder/issues/2721
73+
var context = new MLContext(1);
74+
var searchSpace = new SearchSpace<LbfgsOption>();
75+
var tuner = new SmacTuner(context, searchSpace, seed: 1);
76+
for (int i = 0; i != 1000; ++i)
77+
{
78+
var trialSettings = new TrialSettings()
79+
{
80+
TrialId = i,
81+
};
82+
83+
var param = tuner.Propose(trialSettings);
84+
trialSettings.Parameter = param;
85+
var option = param.AsType<LbfgsOption>();
86+
87+
option.L1Regularization.Should().BeInRange(0.03125f, 32768.0f);
88+
option.L2Regularization.Should().BeInRange(0.03125f, 32768.0f);
89+
90+
tuner.Update(new TrialResult()
91+
{
92+
DurationInMilliseconds = i * 1000,
93+
Loss = double.NaN,
94+
TrialSettings = trialSettings,
95+
});
96+
}
97+
98+
tuner.Candidates.Count.Should().Be(0);
99+
tuner.Histories.Count.Should().Be(0);
100+
}
101+
69102
[Fact]
70103
public void CFO_should_be_recoverd_if_history_provided()
71104
{

0 commit comments

Comments
 (0)