Skip to content

Commit e52d7f0

Browse files
authored
Exit if perfect model produced (dotnet#220)
1 parent 95dc1fc commit e52d7f0

File tree

7 files changed

+272
-10
lines changed

7 files changed

+272
-10
lines changed

src/Microsoft.ML.Auto/Experiment/Experiment.cs

+7
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ public List<RunResult<T>> Execute()
114114
var iterationResult = runResult.ToIterationResult();
115115
ReportProgress(iterationResult);
116116
iterationResults.Add(iterationResult);
117+
118+
// if model is perfect, break
119+
if (_metricsAgent.IsModelPerfect(iterationResult.Metrics))
120+
{
121+
break;
122+
}
123+
117124
} while (_history.Count < _experimentSettings.MaxModels &&
118125
!_experimentSettings.CancellationToken.IsCancellationRequested &&
119126
stopwatch.Elapsed.TotalSeconds < _experimentSettings.MaxInferenceTimeInSeconds);

src/Microsoft.ML.Auto/Experiment/MetricsAgents/BinaryMetricsAgent.cs

+31-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
65
using Microsoft.ML.Data;
76

87
namespace Microsoft.ML.Auto
@@ -36,10 +35,39 @@ public double GetScore(BinaryClassificationMetrics metrics)
3635
return metrics.PositivePrecision;
3736
case BinaryClassificationMetric.PositiveRecall:
3837
return metrics.PositiveRecall;
38+
default:
39+
throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric);
40+
}
41+
}
42+
43+
public bool IsModelPerfect(BinaryClassificationMetrics metrics)
44+
{
45+
if (metrics == null)
46+
{
47+
return false;
3948
}
4049

41-
// never expected to reach here
42-
throw new NotSupportedException($"{_optimizingMetric} is not a supported sweep metric");
50+
switch (_optimizingMetric)
51+
{
52+
case BinaryClassificationMetric.Accuracy:
53+
return metrics.Accuracy == 1;
54+
case BinaryClassificationMetric.Auc:
55+
return metrics.Auc == 1;
56+
case BinaryClassificationMetric.Auprc:
57+
return metrics.Auprc == 1;
58+
case BinaryClassificationMetric.F1Score:
59+
return metrics.F1Score == 1;
60+
case BinaryClassificationMetric.NegativePrecision:
61+
return metrics.NegativePrecision == 1;
62+
case BinaryClassificationMetric.NegativeRecall:
63+
return metrics.NegativeRecall == 1;
64+
case BinaryClassificationMetric.PositivePrecision:
65+
return metrics.PositivePrecision == 1;
66+
case BinaryClassificationMetric.PositiveRecall:
67+
return metrics.PositiveRecall == 1;
68+
default:
69+
throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric);
70+
}
4371
}
4472
}
4573
}

src/Microsoft.ML.Auto/Experiment/MetricsAgents/IMetricsAgent.cs

+2
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,7 @@ namespace Microsoft.ML.Auto
77
internal interface IMetricsAgent<T>
88
{
99
double GetScore(T metrics);
10+
11+
bool IsModelPerfect(T metrics);
1012
}
1113
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
7+
namespace Microsoft.ML.Auto
8+
{
9+
internal static class MetricsAgentUtil
10+
{
11+
public static NotSupportedException BuildMetricNotSupportedException<T>(T optimizingMetric)
12+
{
13+
return new NotSupportedException($"{optimizingMetric} is not a supported sweep metric");
14+
}
15+
}
16+
}

src/Microsoft.ML.Auto/Experiment/MetricsAgents/MultiMetricsAgent.cs

+26-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
65
using Microsoft.ML.Data;
76

87
namespace Microsoft.ML.Auto
@@ -30,10 +29,33 @@ public double GetScore(MultiClassClassifierMetrics metrics)
3029
return metrics.LogLossReduction;
3130
case MulticlassClassificationMetric.TopKAccuracy:
3231
return metrics.TopKAccuracy;
32+
default:
33+
throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric);
3334
}
35+
}
3436

35-
// never expected to reach here
36-
throw new NotSupportedException($"{_optimizingMetric} is not a supported sweep metric");
37+
public bool IsModelPerfect(MultiClassClassifierMetrics metrics)
38+
{
39+
if (metrics == null)
40+
{
41+
return false;
42+
}
43+
44+
switch (_optimizingMetric)
45+
{
46+
case MulticlassClassificationMetric.AccuracyMacro:
47+
return metrics.AccuracyMacro == 1;
48+
case MulticlassClassificationMetric.AccuracyMicro:
49+
return metrics.AccuracyMicro == 1;
50+
case MulticlassClassificationMetric.LogLoss:
51+
return metrics.LogLoss == 0;
52+
case MulticlassClassificationMetric.LogLossReduction:
53+
return metrics.LogLossReduction == 1;
54+
case MulticlassClassificationMetric.TopKAccuracy:
55+
return metrics.TopKAccuracy == 1;
56+
default:
57+
throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric);
58+
}
3759
}
3860
}
39-
}
61+
}

src/Microsoft.ML.Auto/Experiment/MetricsAgents/RegressionMetricsAgent.cs

+23-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
65
using Microsoft.ML.Data;
76

87
namespace Microsoft.ML.Auto
@@ -28,10 +27,31 @@ public double GetScore(RegressionMetrics metrics)
2827
return metrics.Rms;
2928
case RegressionMetric.RSquared:
3029
return metrics.RSquared;
30+
default:
31+
throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric);
32+
}
33+
}
34+
35+
public bool IsModelPerfect(RegressionMetrics metrics)
36+
{
37+
if (metrics == null)
38+
{
39+
return false;
3140
}
3241

33-
// never expected to reach here
34-
throw new NotSupportedException($"{_optimizingMetric} is not a supported sweep metric");
42+
switch (_optimizingMetric)
43+
{
44+
case RegressionMetric.L1:
45+
return metrics.L1 == 0;
46+
case RegressionMetric.L2:
47+
return metrics.L2 == 0;
48+
case RegressionMetric.Rms:
49+
return metrics.Rms == 0;
50+
case RegressionMetric.RSquared:
51+
return metrics.RSquared == 1;
52+
default:
53+
throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric);
54+
}
3555
}
3656
}
3757
}

src/Test/MetricsAgentsTests.cs

+167
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Reflection;
7+
using Microsoft.ML.Data;
8+
using Microsoft.VisualStudio.TestTools.UnitTesting;
9+
10+
namespace Microsoft.ML.Auto.Test
11+
{
12+
[TestClass]
13+
public class MetricsAgentsTests
14+
{
15+
[TestMethod]
16+
public void BinaryMetricsGetScoreTest()
17+
{
18+
var metrics = CreateInstance<BinaryClassificationMetrics>(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8);
19+
Assert.AreEqual(0.1, GetScore(metrics, BinaryClassificationMetric.Auc));
20+
Assert.AreEqual(0.2, GetScore(metrics, BinaryClassificationMetric.Accuracy));
21+
Assert.AreEqual(0.3, GetScore(metrics, BinaryClassificationMetric.PositivePrecision));
22+
Assert.AreEqual(0.4, GetScore(metrics, BinaryClassificationMetric.PositiveRecall));
23+
Assert.AreEqual(0.5, GetScore(metrics, BinaryClassificationMetric.NegativePrecision));
24+
Assert.AreEqual(0.6, GetScore(metrics, BinaryClassificationMetric.NegativeRecall));
25+
Assert.AreEqual(0.7, GetScore(metrics, BinaryClassificationMetric.F1Score));
26+
Assert.AreEqual(0.8, GetScore(metrics, BinaryClassificationMetric.Auprc));
27+
}
28+
29+
[TestMethod]
30+
public void BinaryMetricsNonPerfectTest()
31+
{
32+
var metrics = CreateInstance<BinaryClassificationMetrics>(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8);
33+
Assert.AreEqual(false, IsPerfectModel(metrics, BinaryClassificationMetric.Accuracy));
34+
Assert.AreEqual(false, IsPerfectModel(metrics, BinaryClassificationMetric.Auc));
35+
Assert.AreEqual(false, IsPerfectModel(metrics, BinaryClassificationMetric.Auprc));
36+
Assert.AreEqual(false, IsPerfectModel(metrics, BinaryClassificationMetric.F1Score));
37+
Assert.AreEqual(false, IsPerfectModel(metrics, BinaryClassificationMetric.NegativePrecision));
38+
Assert.AreEqual(false, IsPerfectModel(metrics, BinaryClassificationMetric.NegativeRecall));
39+
Assert.AreEqual(false, IsPerfectModel(metrics, BinaryClassificationMetric.PositivePrecision));
40+
Assert.AreEqual(false, IsPerfectModel(metrics, BinaryClassificationMetric.PositiveRecall));
41+
}
42+
43+
[TestMethod]
44+
public void BinaryMetricsPerfectTest()
45+
{
46+
var metrics = CreateInstance<BinaryClassificationMetrics>(1, 1, 1, 1, 1, 1, 1, 1);
47+
Assert.AreEqual(true, IsPerfectModel(metrics, BinaryClassificationMetric.Accuracy));
48+
Assert.AreEqual(true, IsPerfectModel(metrics, BinaryClassificationMetric.Auc));
49+
Assert.AreEqual(true, IsPerfectModel(metrics, BinaryClassificationMetric.Auprc));
50+
Assert.AreEqual(true, IsPerfectModel(metrics, BinaryClassificationMetric.F1Score));
51+
Assert.AreEqual(true, IsPerfectModel(metrics, BinaryClassificationMetric.NegativePrecision));
52+
Assert.AreEqual(true, IsPerfectModel(metrics, BinaryClassificationMetric.NegativeRecall));
53+
Assert.AreEqual(true, IsPerfectModel(metrics, BinaryClassificationMetric.PositivePrecision));
54+
Assert.AreEqual(true, IsPerfectModel(metrics, BinaryClassificationMetric.PositiveRecall));
55+
}
56+
57+
[TestMethod]
58+
public void MulticlassMetricsGetScoreTest()
59+
{
60+
var metrics = CreateInstance<MultiClassClassifierMetrics>(0.1, 0.2, 0.3, 0.4, 0, 0.5, new double[] {});
61+
Assert.AreEqual(0.1, GetScore(metrics, MulticlassClassificationMetric.AccuracyMicro));
62+
Assert.AreEqual(0.2, GetScore(metrics, MulticlassClassificationMetric.AccuracyMacro));
63+
Assert.AreEqual(0.3, GetScore(metrics, MulticlassClassificationMetric.LogLoss));
64+
Assert.AreEqual(0.4, GetScore(metrics, MulticlassClassificationMetric.LogLossReduction));
65+
Assert.AreEqual(0.5, GetScore(metrics, MulticlassClassificationMetric.TopKAccuracy));
66+
}
67+
68+
[TestMethod]
69+
public void MulticlassMetricsNonPerfectTest()
70+
{
71+
var metrics = CreateInstance<MultiClassClassifierMetrics>(0.1, 0.2, 0.3, 0.4, 0, 0.5, new double[] { });
72+
Assert.AreEqual(false, IsPerfectModel(metrics, MulticlassClassificationMetric.AccuracyMacro));
73+
Assert.AreEqual(false, IsPerfectModel(metrics, MulticlassClassificationMetric.AccuracyMicro));
74+
Assert.AreEqual(false, IsPerfectModel(metrics, MulticlassClassificationMetric.LogLoss));
75+
Assert.AreEqual(false, IsPerfectModel(metrics, MulticlassClassificationMetric.LogLossReduction));
76+
Assert.AreEqual(false, IsPerfectModel(metrics, MulticlassClassificationMetric.TopKAccuracy));
77+
}
78+
79+
[TestMethod]
80+
public void MulticlassMetricsPerfectTest()
81+
{
82+
var metrics = CreateInstance<MultiClassClassifierMetrics>(1, 1, 0, 1, 0, 1, new double[] { });
83+
Assert.AreEqual(true, IsPerfectModel(metrics, MulticlassClassificationMetric.AccuracyMicro));
84+
Assert.AreEqual(true, IsPerfectModel(metrics, MulticlassClassificationMetric.AccuracyMacro));
85+
Assert.AreEqual(true, IsPerfectModel(metrics, MulticlassClassificationMetric.LogLoss));
86+
Assert.AreEqual(true, IsPerfectModel(metrics, MulticlassClassificationMetric.LogLossReduction));
87+
Assert.AreEqual(true, IsPerfectModel(metrics, MulticlassClassificationMetric.TopKAccuracy));
88+
}
89+
90+
[TestMethod]
91+
public void RegressionMetricsGetScoreTest()
92+
{
93+
var metrics = CreateInstance<RegressionMetrics>(0.2, 0.3, 0.4, 0.5, 0.6);
94+
Assert.AreEqual(0.2, GetScore(metrics, RegressionMetric.L1));
95+
Assert.AreEqual(0.3, GetScore(metrics, RegressionMetric.L2));
96+
Assert.AreEqual(0.4, GetScore(metrics, RegressionMetric.Rms));
97+
Assert.AreEqual(0.6, GetScore(metrics, RegressionMetric.RSquared));
98+
}
99+
100+
[TestMethod]
101+
public void RegressionMetricsNonPerfectTest()
102+
{
103+
var metrics = CreateInstance<RegressionMetrics>(0.2, 0.3, 0.4, 0.5, 0.6);
104+
Assert.AreEqual(false, IsPerfectModel(metrics, RegressionMetric.L1));
105+
Assert.AreEqual(false, IsPerfectModel(metrics, RegressionMetric.L2));
106+
Assert.AreEqual(false, IsPerfectModel(metrics, RegressionMetric.Rms));
107+
Assert.AreEqual(false, IsPerfectModel(metrics, RegressionMetric.RSquared));
108+
}
109+
110+
[TestMethod]
111+
public void RegressionMetricsPerfectTest()
112+
{
113+
var metrics = CreateInstance<RegressionMetrics>(0, 0, 0, 0, 1);
114+
Assert.AreEqual(true, IsPerfectModel(metrics, RegressionMetric.L1));
115+
Assert.AreEqual(true, IsPerfectModel(metrics, RegressionMetric.L2));
116+
Assert.AreEqual(true, IsPerfectModel(metrics, RegressionMetric.Rms));
117+
Assert.AreEqual(true, IsPerfectModel(metrics, RegressionMetric.RSquared));
118+
}
119+
120+
[TestMethod]
121+
[ExpectedException(typeof(NotSupportedException))]
122+
public void ThrowNotSupportedMetricException()
123+
{
124+
throw MetricsAgentUtil.BuildMetricNotSupportedException(BinaryClassificationMetric.Accuracy);
125+
}
126+
127+
private static T CreateInstance<T>(params object[] args)
128+
{
129+
var type = typeof(T);
130+
var instance = type.Assembly.CreateInstance(
131+
type.FullName, false,
132+
BindingFlags.Instance | BindingFlags.NonPublic,
133+
null, args, null, null);
134+
return (T)instance;
135+
}
136+
137+
private static double GetScore(BinaryClassificationMetrics metrics, BinaryClassificationMetric metric)
138+
{
139+
return new BinaryMetricsAgent(metric).GetScore(metrics);
140+
}
141+
142+
private static double GetScore(MultiClassClassifierMetrics metrics, MulticlassClassificationMetric metric)
143+
{
144+
return new MultiMetricsAgent(metric).GetScore(metrics);
145+
}
146+
147+
private static double GetScore(RegressionMetrics metrics, RegressionMetric metric)
148+
{
149+
return new RegressionMetricsAgent(metric).GetScore(metrics);
150+
}
151+
152+
private static bool IsPerfectModel(BinaryClassificationMetrics metrics, BinaryClassificationMetric metric)
153+
{
154+
return new BinaryMetricsAgent(metric).IsModelPerfect(metrics);
155+
}
156+
157+
private static bool IsPerfectModel(MultiClassClassifierMetrics metrics, MulticlassClassificationMetric metric)
158+
{
159+
return new MultiMetricsAgent(metric).IsModelPerfect(metrics);
160+
}
161+
162+
private static bool IsPerfectModel(RegressionMetrics metrics, RegressionMetric metric)
163+
{
164+
return new RegressionMetricsAgent(metric).IsModelPerfect(metrics);
165+
}
166+
}
167+
}

0 commit comments

Comments
 (0)