Skip to content

Commit afd3220

Browse files
george-microsofteerhardt
authored andcommitted
Fix SupportedMetric.ByName() method (dotnet#280)
* Fix for SupportedMetric.ByName() method. Include new unit test for function. * Fix for SupportedMetric.ByName() method. Include new unit test for function. * Fix for SupportedMetric.ByName() method. Include new unit test for function. * Removed unnecessary field filter, per review comment.
1 parent 2617e2b commit afd3220

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

src/Microsoft.ML.PipelineInference/AutoInference.cs

+5-4
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,13 @@ private SupportedMetric(string name, bool isMaximizing)
9898
public static SupportedMetric ByName(string name)
9999
{
100100
var fields =
101-
typeof(SupportedMetric).GetMembers(BindingFlags.Static | BindingFlags.Public)
102-
.Where(s => s.MemberType == MemberTypes.Field);
101+
typeof(SupportedMetric).GetFields(BindingFlags.Static | BindingFlags.Public);
102+
103103
foreach (var field in fields)
104104
{
105-
if (name.Equals(field.Name, StringComparison.OrdinalIgnoreCase))
106-
return (SupportedMetric)typeof(SupportedMetric).GetField(field.Name).GetValue(null);
105+
var metric = (SupportedMetric)field.GetValue(Auc);
106+
if (name.Equals(metric.Name, StringComparison.OrdinalIgnoreCase))
107+
return metric;
107108
}
108109
throw new NotSupportedException($"Metric '{name}' not supported.");
109110
}

test/Microsoft.ML.Predictor.Tests/TestAutoInference.cs

+22
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System.Linq;
6+
using System.Collections.Generic;
67
using Newtonsoft.Json.Linq;
78
using Microsoft.ML.Runtime.Data;
89
using Microsoft.ML.Runtime.EntryPoints;
@@ -411,6 +412,27 @@ public void TestPipelineNodeCloning()
411412
}
412413
}
413414

415+
[Fact]
416+
public void TestSupportedMetricsByName()
417+
{
418+
var names = new List<string>()
419+
{
420+
AutoInference.SupportedMetric.AccuracyMacro.Name,
421+
AutoInference.SupportedMetric.AccuracyMicro.Name,
422+
AutoInference.SupportedMetric.Auc.Name,
423+
AutoInference.SupportedMetric.AuPrc.Name,
424+
AutoInference.SupportedMetric.Dbi.Name,
425+
AutoInference.SupportedMetric.F1.Name,
426+
AutoInference.SupportedMetric.LogLossReduction.Name
427+
};
428+
429+
foreach (var name in names)
430+
{
431+
var metric = AutoInference.SupportedMetric.ByName(name);
432+
Assert.Equal(metric.Name, name);
433+
}
434+
}
435+
414436
[Fact]
415437
public void TestHyperparameterFreezing()
416438
{

0 commit comments

Comments
 (0)