Skip to content

Commit 3d3567c

Browse files
authored
add trainer extension tests, & misc fixes (dotnet#23)
1 parent f7e6376 commit 3d3567c

6 files changed

+84
-25
lines changed

src/AutoML/TrainerExtensions/MultiTrainerExtensions.cs

-6
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<Sweepab
4848

4949
internal class LightGbmMultiExtension : ITrainerExtension
5050
{
51-
private static readonly ITrainerExtension _binaryLearnerCatalogItem = new LightGbmBinaryExtension();
52-
5351
public IEnumerable<SweepableParam> GetHyperparamSweepRanges()
5452
{
5553
return SweepableParams.BuildLightGbmParams();
@@ -80,8 +78,6 @@ public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<Sweepab
8078

8179
internal class SdcaMultiExtension : ITrainerExtension
8280
{
83-
private static readonly ITrainerExtension _binaryLearnerCatalogItem = new SdcaBinaryExtension();
84-
8581
public IEnumerable<SweepableParam> GetHyperparamSweepRanges()
8682
{
8783
return SweepableParams.BuildSdcaParams();
@@ -161,8 +157,6 @@ public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<Sweepab
161157

162158
internal class LogisticRegressionMultiExtension : ITrainerExtension
163159
{
164-
private static readonly ITrainerExtension _binaryLearnerCatalogItem = new LogisticRegressionBinaryExtension();
165-
166160
public IEnumerable<SweepableParam> GetHyperparamSweepRanges()
167161
{
168162
return SweepableParams.BuildLogisticRegressionParams();

src/AutoML/TrainerExtensions/TrainerExtensionCatalog.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public static IEnumerable<ITrainerExtension> GetTrainers(TaskKind task, int maxI
6363
{
6464
return GetBinaryLearners(maxIterations);
6565
}
66-
else if (task == TaskKind.BinaryClassification)
66+
else if (task == TaskKind.MulticlassClassification)
6767
{
6868
return GetMultiLearners(maxIterations);
6969
}

src/AutoML/Utils/SweepableParamAttributes.cs

-8
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,6 @@ public override void SetUsingValueText(string valueText)
7676
RawValue = i;
7777
}
7878

79-
public int IndexOf(object option)
80-
{
81-
for (int i = 0; i < Options.Length; i++)
82-
if (option == Options[i])
83-
return i;
84-
return -1;
85-
}
86-
8779
private static string TranslateOption(object o)
8880
{
8981
switch (o)

src/Test/SweeperTests.cs

+26-10
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,35 @@ namespace Microsoft.ML.Auto.Test
88
public class SweeperTests
99
{
1010
[TestMethod]
11-
public void Smac3ParamsTest()
11+
public void SmacQuickRunTest()
1212
{
1313
var numInitialPopulation = 10;
1414

15+
var floatValueGenerator = new FloatValueGenerator(new FloatParamArguments() { Name = "float", Min = 1, Max = 1000 });
16+
var floatLogValueGenerator = new FloatValueGenerator(new FloatParamArguments() { Name = "floatLog", Min = 1, Max = 1000, LogBase = true });
17+
var longValueGenerator = new LongValueGenerator(new LongParamArguments() { Name = "long", Min = 1, Max = 1000 });
18+
var longLogValueGenerator = new LongValueGenerator(new LongParamArguments() { Name = "longLog", Min = 1, Max = 1000, LogBase = true });
19+
var discreteValueGeneator = new DiscreteValueGenerator(new DiscreteParamArguments() { Name = "discrete", Values = new[] { "200", "400", "600", "800" } });
20+
1521
var sweeper = new SmacSweeper(new SmacSweeper.Arguments()
1622
{
1723
SweptParameters = new IValueGenerator[] {
18-
new FloatValueGenerator(new FloatParamArguments() { Name = "x1", Min = 1, Max = 1000}),
19-
new LongValueGenerator(new LongParamArguments() { Name = "x2", Min = 1, Max = 1000}),
20-
new DiscreteValueGenerator(new DiscreteParamArguments() { Name = "x3", Values = new[] { "200", "400", "600", "800" } }),
24+
floatValueGenerator,
25+
floatLogValueGenerator,
26+
longValueGenerator,
27+
longLogValueGenerator,
28+
discreteValueGeneator
2129
},
2230
NumberInitialPopulation = numInitialPopulation
2331
});
2432

33+
// sanity check grid
34+
Assert.IsNotNull(floatValueGenerator[0].ValueText);
35+
Assert.IsNotNull(floatLogValueGenerator[0].ValueText);
36+
Assert.IsNotNull(longValueGenerator[0].ValueText);
37+
Assert.IsNotNull(longLogValueGenerator[0].ValueText);
38+
Assert.IsNotNull(discreteValueGeneator[0].ValueText);
39+
2540
List<RunResult> results = new List<RunResult>();
2641

2742
RunResult bestResult = null;
@@ -31,12 +46,13 @@ public void Smac3ParamsTest()
3146

3247
foreach (ParameterSet p in pars)
3348
{
34-
float x1 = (p["x1"] as FloatParameterValue).Value;
35-
float x2 = (p["x2"] as LongParameterValue).Value;
36-
float x3 = float.Parse(p["x3"].ValueText);
49+
float x1 = float.Parse(p["float"].ValueText);
50+
float x2 = float.Parse(p["floatLog"].ValueText);
51+
long x3 = long.Parse(p["long"].ValueText);
52+
long x4 = long.Parse(p["longLog"].ValueText);
53+
int x5 = int.Parse(p["discrete"].ValueText);
3754

38-
double metric = -200 * (Math.Abs(100 - x1) +
39-
Math.Abs(300 - x2) + Math.Abs(500 - x3));
55+
double metric = x1 + x2 + x3 + x4 + x5;
4056

4157
RunResult result = new RunResult(p, metric, true);
4258
if (bestResult == null || bestResult.MetricValue < metric)
@@ -53,7 +69,7 @@ public void Smac3ParamsTest()
5369
Console.WriteLine($"Best: {bestResult.MetricValue}");
5470

5571
Assert.IsNotNull(bestResult);
56-
Assert.IsTrue(bestResult.MetricValue != 0);
72+
Assert.IsTrue(bestResult.MetricValue > 0);
5773
}
5874

5975

src/Test/TrainerExtensionsTests.cs

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
using System;
2+
using System.Linq;
3+
using Microsoft.VisualStudio.TestTools.UnitTesting;
4+
5+
namespace Microsoft.ML.Auto.Test
6+
{
7+
[TestClass]
8+
public class TrainerExtensionsTests
9+
{
10+
[TestMethod]
11+
public void TrainerExtensionInstanceTests()
12+
{
13+
var context = new MLContext();
14+
var trainerNames = Enum.GetValues(typeof(TrainerName)).Cast<TrainerName>();
15+
foreach(var trainerName in trainerNames)
16+
{
17+
var extension = TrainerExtensionCatalog.GetTrainerExtension(trainerName);
18+
var instance = extension.CreateInstance(context, null);
19+
Assert.IsNotNull(instance);
20+
var sweepParams = extension.GetHyperparamSweepRanges();
21+
Assert.IsNotNull(sweepParams);
22+
}
23+
}
24+
25+
[TestMethod]
26+
public void GetTrainersByMaxIterations()
27+
{
28+
var tasks = new TaskKind[] { TaskKind.BinaryClassification,
29+
TaskKind.MulticlassClassification, TaskKind.Regression };
30+
31+
foreach(var task in tasks)
32+
{
33+
var trainerSet10 = TrainerExtensionCatalog.GetTrainers(task, 10);
34+
var trainerSet50 = TrainerExtensionCatalog.GetTrainers(task, 50);
35+
var trainerSet100 = TrainerExtensionCatalog.GetTrainers(task, 100);
36+
37+
Assert.IsNotNull(trainerSet10);
38+
Assert.IsNotNull(trainerSet50);
39+
Assert.IsNotNull(trainerSet100);
40+
41+
Assert.IsTrue(trainerSet10.Count() < trainerSet50.Count());
42+
Assert.IsTrue(trainerSet50.Count() < trainerSet100.Count());
43+
}
44+
}
45+
}
46+
}

src/Test/UserInputValidationTests.cs

+11
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,17 @@ public void ValidateAutoFitArgsPurposeOverrideDuplicateCol()
149149
});
150150
}
151151

152+
[TestMethod]
153+
public void ValidateAutoFitArgsPurposeOverrideSuccess()
154+
{
155+
UserInputValidationUtil.ValidateAutoFitArgs(DatasetUtil.GetUciAdultDataView(),
156+
DatasetUtil.UciAdultLabel, DatasetUtil.GetUciAdultDataView(),
157+
null, new List<(string, ColumnPurpose)>()
158+
{
159+
("Workclass", ColumnPurpose.CategoricalFeature)
160+
});
161+
}
162+
152163
[TestMethod]
153164
[ExpectedException(typeof(ArgumentException))]
154165
public void ValidateAutoFitArgsTrainValidColCountMismatch()

0 commit comments

Comments
 (0)