Skip to content

Commit 79ad446

Browse files
authored
Null reference exception fix for finding best model when some runs have failed (dotnet#239)
1 parent b5e7e1f commit 79ad446

7 files changed

+130
-27
lines changed

src/Microsoft.ML.Auto/API/BinaryClassificationExperiment.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,7 @@ public static class BinaryExperimentResultExtensions
9898
public static RunResult<BinaryClassificationMetrics> Best(this IEnumerable<RunResult<BinaryClassificationMetrics>> results, BinaryClassificationMetric metric = BinaryClassificationMetric.Accuracy)
9999
{
100100
var metricsAgent = new BinaryMetricsAgent(metric);
101-
double maxScore = results.Select(r => metricsAgent.GetScore(r.Metrics)).Max();
102-
return results.First(r => metricsAgent.GetScore(r.Metrics) == maxScore);
101+
return RunResultUtil.GetBestRunResult(results, metricsAgent);
103102
}
104103
}
105104
}

src/Microsoft.ML.Auto/API/MulticlassClassificationExperiment.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,7 @@ public static class MulticlassExperimentResultExtensions
9696
public static RunResult<MultiClassClassifierMetrics> Best(this IEnumerable<RunResult<MultiClassClassifierMetrics>> results, MulticlassClassificationMetric metric = MulticlassClassificationMetric.AccuracyMicro)
9797
{
9898
var metricsAgent = new MultiMetricsAgent(metric);
99-
double maxScore = results.Select(r => metricsAgent.GetScore(r.Metrics)).Max();
100-
return results.First(r => metricsAgent.GetScore(r.Metrics) == maxScore);
99+
return RunResultUtil.GetBestRunResult(results, metricsAgent);
101100
}
102101
}
103102
}

src/Microsoft.ML.Auto/API/RegressionExperiment.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,7 @@ public static class RegressionExperimentResultExtensions
9393
public static RunResult<RegressionMetrics> Best(this IEnumerable<RunResult<RegressionMetrics>> results, RegressionMetric metric = RegressionMetric.RSquared)
9494
{
9595
var metricsAgent = new RegressionMetricsAgent(metric);
96-
double maxScore = results.Select(r => metricsAgent.GetScore(r.Metrics)).Max();
97-
return results.First(r => metricsAgent.GetScore(r.Metrics) == maxScore);
96+
return RunResultUtil.GetBestRunResult(results, metricsAgent);
9897
}
9998
}
10099
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Generic;
6+
using System.Linq;
7+
8+
namespace Microsoft.ML.Auto
9+
{
10+
internal class RunResultUtil
11+
{
12+
public static RunResult<T> GetBestRunResult<T>(IEnumerable<RunResult<T>> results,
13+
IMetricsAgent<T> metricsAgent)
14+
{
15+
results = results.Where(r => r.Metrics != null);
16+
if (!results.Any()) { return null; }
17+
double maxScore = results.Select(r => metricsAgent.GetScore(r.Metrics)).Max();
18+
return results.First(r => metricsAgent.GetScore(r.Metrics) == maxScore);
19+
}
20+
}
21+
}

src/Test/MetricsAgentsTests.cs

+10-21
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.Reflection;
76
using Microsoft.ML.Data;
87
using Microsoft.VisualStudio.TestTools.UnitTesting;
98

@@ -15,7 +14,7 @@ public class MetricsAgentsTests
1514
[TestMethod]
1615
public void BinaryMetricsGetScoreTest()
1716
{
18-
var metrics = CreateInstance<BinaryClassificationMetrics>(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8);
17+
var metrics = MetricsUtil.CreateBinaryClassificationMetrics(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8);
1918
Assert.AreEqual(0.1, GetScore(metrics, BinaryClassificationMetric.Auc));
2019
Assert.AreEqual(0.2, GetScore(metrics, BinaryClassificationMetric.Accuracy));
2120
Assert.AreEqual(0.3, GetScore(metrics, BinaryClassificationMetric.PositivePrecision));
@@ -29,7 +28,7 @@ public void BinaryMetricsGetScoreTest()
2928
[TestMethod]
3029
public void BinaryMetricsNonPerfectTest()
3130
{
32-
var metrics = CreateInstance<BinaryClassificationMetrics>(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8);
31+
var metrics = MetricsUtil.CreateBinaryClassificationMetrics(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8);
3332
Assert.AreEqual(false, IsPerfectModel(metrics, BinaryClassificationMetric.Accuracy));
3433
Assert.AreEqual(false, IsPerfectModel(metrics, BinaryClassificationMetric.Auc));
3534
Assert.AreEqual(false, IsPerfectModel(metrics, BinaryClassificationMetric.Auprc));
@@ -43,7 +42,7 @@ public void BinaryMetricsNonPerfectTest()
4342
[TestMethod]
4443
public void BinaryMetricsPerfectTest()
4544
{
46-
var metrics = CreateInstance<BinaryClassificationMetrics>(1, 1, 1, 1, 1, 1, 1, 1);
45+
var metrics = MetricsUtil.CreateBinaryClassificationMetrics(1, 1, 1, 1, 1, 1, 1, 1);
4746
Assert.AreEqual(true, IsPerfectModel(metrics, BinaryClassificationMetric.Accuracy));
4847
Assert.AreEqual(true, IsPerfectModel(metrics, BinaryClassificationMetric.Auc));
4948
Assert.AreEqual(true, IsPerfectModel(metrics, BinaryClassificationMetric.Auprc));
@@ -57,7 +56,7 @@ public void BinaryMetricsPerfectTest()
5756
[TestMethod]
5857
public void MulticlassMetricsGetScoreTest()
5958
{
60-
var metrics = CreateInstance<MultiClassClassifierMetrics>(0.1, 0.2, 0.3, 0.4, 0, 0.5, new double[] {});
59+
var metrics = MetricsUtil.CreateMulticlassClassificationMetrics(0.1, 0.2, 0.3, 0.4, 0, 0.5, new double[] {});
6160
Assert.AreEqual(0.1, GetScore(metrics, MulticlassClassificationMetric.AccuracyMicro));
6261
Assert.AreEqual(0.2, GetScore(metrics, MulticlassClassificationMetric.AccuracyMacro));
6362
Assert.AreEqual(0.3, GetScore(metrics, MulticlassClassificationMetric.LogLoss));
@@ -68,7 +67,7 @@ public void MulticlassMetricsGetScoreTest()
6867
[TestMethod]
6968
public void MulticlassMetricsNonPerfectTest()
7069
{
71-
var metrics = CreateInstance<MultiClassClassifierMetrics>(0.1, 0.2, 0.3, 0.4, 0, 0.5, new double[] { });
70+
var metrics = MetricsUtil.CreateMulticlassClassificationMetrics(0.1, 0.2, 0.3, 0.4, 0, 0.5, new double[] { });
7271
Assert.AreEqual(false, IsPerfectModel(metrics, MulticlassClassificationMetric.AccuracyMacro));
7372
Assert.AreEqual(false, IsPerfectModel(metrics, MulticlassClassificationMetric.AccuracyMicro));
7473
Assert.AreEqual(false, IsPerfectModel(metrics, MulticlassClassificationMetric.LogLoss));
@@ -79,7 +78,7 @@ public void MulticlassMetricsNonPerfectTest()
7978
[TestMethod]
8079
public void MulticlassMetricsPerfectTest()
8180
{
82-
var metrics = CreateInstance<MultiClassClassifierMetrics>(1, 1, 0, 1, 0, 1, new double[] { });
81+
var metrics = MetricsUtil.CreateMulticlassClassificationMetrics(1, 1, 0, 1, 0, 1, new double[] { });
8382
Assert.AreEqual(true, IsPerfectModel(metrics, MulticlassClassificationMetric.AccuracyMicro));
8483
Assert.AreEqual(true, IsPerfectModel(metrics, MulticlassClassificationMetric.AccuracyMacro));
8584
Assert.AreEqual(true, IsPerfectModel(metrics, MulticlassClassificationMetric.LogLoss));
@@ -90,7 +89,7 @@ public void MulticlassMetricsPerfectTest()
9089
[TestMethod]
9190
public void RegressionMetricsGetScoreTest()
9291
{
93-
var metrics = CreateInstance<RegressionMetrics>(0.2, 0.3, 0.4, 0.5, 0.6);
92+
var metrics = MetricsUtil.CreateRegressionMetrics(0.2, 0.3, 0.4, 0.5, 0.6);
9493
Assert.AreEqual(0.2, GetScore(metrics, RegressionMetric.L1));
9594
Assert.AreEqual(0.3, GetScore(metrics, RegressionMetric.L2));
9695
Assert.AreEqual(0.4, GetScore(metrics, RegressionMetric.Rms));
@@ -100,7 +99,7 @@ public void RegressionMetricsGetScoreTest()
10099
[TestMethod]
101100
public void RegressionMetricsNonPerfectTest()
102101
{
103-
var metrics = CreateInstance<RegressionMetrics>(0.2, 0.3, 0.4, 0.5, 0.6);
102+
var metrics = MetricsUtil.CreateRegressionMetrics(0.2, 0.3, 0.4, 0.5, 0.6);
104103
Assert.AreEqual(false, IsPerfectModel(metrics, RegressionMetric.L1));
105104
Assert.AreEqual(false, IsPerfectModel(metrics, RegressionMetric.L2));
106105
Assert.AreEqual(false, IsPerfectModel(metrics, RegressionMetric.Rms));
@@ -110,7 +109,7 @@ public void RegressionMetricsNonPerfectTest()
110109
[TestMethod]
111110
public void RegressionMetricsPerfectTest()
112111
{
113-
var metrics = CreateInstance<RegressionMetrics>(0, 0, 0, 0, 1);
112+
var metrics = MetricsUtil.CreateRegressionMetrics(0, 0, 0, 0, 1);
114113
Assert.AreEqual(true, IsPerfectModel(metrics, RegressionMetric.L1));
115114
Assert.AreEqual(true, IsPerfectModel(metrics, RegressionMetric.L2));
116115
Assert.AreEqual(true, IsPerfectModel(metrics, RegressionMetric.Rms));
@@ -122,17 +121,7 @@ public void RegressionMetricsPerfectTest()
122121
public void ThrowNotSupportedMetricException()
123122
{
124123
throw MetricsAgentUtil.BuildMetricNotSupportedException(BinaryClassificationMetric.Accuracy);
125-
}
126-
127-
private static T CreateInstance<T>(params object[] args)
128-
{
129-
var type = typeof(T);
130-
var instance = type.Assembly.CreateInstance(
131-
type.FullName, false,
132-
BindingFlags.Instance | BindingFlags.NonPublic,
133-
null, args, null, null);
134-
return (T)instance;
135-
}
124+
}
136125

137126
private static double GetScore(BinaryClassificationMetrics metrics, BinaryClassificationMetric metric)
138127
{

src/Test/MetricsUtil.cs

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Reflection;
6+
using Microsoft.ML.Data;
7+
8+
namespace Microsoft.ML.Auto.Test
9+
{
10+
internal static class MetricsUtil
11+
{
12+
public static BinaryClassificationMetrics CreateBinaryClassificationMetrics(
13+
double auc, double accuracy, double positivePrecision,
14+
double positiveRecall, double negativePrecision,
15+
double negativeRecall, double f1Score, double auprc)
16+
{
17+
return CreateInstance<BinaryClassificationMetrics>(auc, accuracy,
18+
positivePrecision, positiveRecall, negativePrecision,
19+
negativeRecall, f1Score, auprc);
20+
}
21+
22+
public static MultiClassClassifierMetrics CreateMulticlassClassificationMetrics(
23+
double accuracyMicro, double accuracyMacro, double logLoss,
24+
double logLossReduction, int topK, double topKAccuracy,
25+
double[] perClassLogLoss)
26+
{
27+
return CreateInstance<MultiClassClassifierMetrics>(accuracyMicro,
28+
accuracyMacro, logLoss, logLossReduction, topK,
29+
topKAccuracy, perClassLogLoss);
30+
}
31+
32+
public static RegressionMetrics CreateRegressionMetrics(double l1,
33+
double l2, double rms, double lossFn, double rSquared)
34+
{
35+
return CreateInstance<RegressionMetrics>(l1, l2,
36+
rms, lossFn, rSquared);
37+
}
38+
39+
private static T CreateInstance<T>(params object[] args)
40+
{
41+
var type = typeof(T);
42+
var instance = type.Assembly.CreateInstance(
43+
type.FullName, false,
44+
BindingFlags.Instance | BindingFlags.NonPublic,
45+
null, args, null, null);
46+
return (T)instance;
47+
}
48+
}
49+
}

src/Test/RunResultTests.cs

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Generic;
6+
using Microsoft.ML.Data;
7+
using Microsoft.VisualStudio.TestTools.UnitTesting;
8+
9+
namespace Microsoft.ML.Auto.Test
10+
{
11+
[TestClass]
12+
public class RunResultTests
13+
{
14+
[TestMethod]
15+
public void FindBestResultWithSomeNullMetrics()
16+
{
17+
var metrics1 = MetricsUtil.CreateRegressionMetrics(0.2, 0.2, 0.2, 0.2, 0.2);
18+
var metrics2 = MetricsUtil.CreateRegressionMetrics(0.3, 0.3, 0.3, 0.3, 0.3);
19+
var metrics3 = MetricsUtil.CreateRegressionMetrics(0.1, 0.1, 0.1, 0.1, 0.1);
20+
21+
var runResults = new List<RunResult<RegressionMetrics>>()
22+
{
23+
new RunResult<RegressionMetrics>(null, null, null, null, 0, 0),
24+
new RunResult<RegressionMetrics>(null, metrics1, null, null, 0, 0),
25+
new RunResult<RegressionMetrics>(null, metrics2, null, null, 0, 0),
26+
new RunResult<RegressionMetrics>(null, metrics3, null, null, 0, 0),
27+
};
28+
29+
var metricsAgent = new RegressionMetricsAgent(RegressionMetric.RSquared);
30+
var bestResult = RunResultUtil.GetBestRunResult(runResults, metricsAgent);
31+
Assert.AreEqual(0.3, bestResult.Metrics.RSquared);
32+
}
33+
34+
[TestMethod]
35+
public void FindBestResultWithAllNullMetrics()
36+
{
37+
var runResults = new List<RunResult<RegressionMetrics>>()
38+
{
39+
new RunResult<RegressionMetrics>(null, null, null, null, 0, 0),
40+
};
41+
42+
var metricsAgent = new RegressionMetricsAgent(RegressionMetric.RSquared);
43+
var bestResult = RunResultUtil.GetBestRunResult(runResults, metricsAgent);
44+
Assert.AreEqual(null, bestResult);
45+
}
46+
}
47+
}

0 commit comments

Comments
 (0)