Skip to content

Commit ab6930c

Browse files
authored
fix dataview take util bug, add dataview skip util, add some UTs to increase code coverage (dotnet#21)
* fix dataview take util bug, add dataview skip util, add some UTs to increase code coverage * add accuracy threshold on AutoFit test * add null check to best pipeline on autofit result
1 parent 9f49cf1 commit ab6930c

File tree

5 files changed

+145
-45
lines changed

5 files changed

+145
-45
lines changed

src/AutoML/AutoMlUtils.cs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,16 @@ public static void Assert(bool boolVal, string message = null)
2525

2626
public static IDataView Take(this IDataView data, int count)
2727
{
28-
// REVIEW: This should take an env as a parameter, not create one.
29-
var env = new MLContext();
30-
var take = SkipTakeFilter.Create(env, new SkipTakeFilter.TakeArguments { Count = count }, data);
31-
return new CacheDataView(env, data, Enumerable.Range(0, data.Schema.Count).ToArray());
28+
var context = new MLContext();
29+
var filter = SkipTakeFilter.Create(context, new SkipTakeFilter.TakeArguments { Count = count }, data);
30+
return new CacheDataView(context, filter, Enumerable.Range(0, data.Schema.Count).ToArray());
31+
}
32+
33+
public static IDataView Skip(this IDataView data, int count)
34+
{
35+
var context = new MLContext();
36+
var filter = SkipTakeFilter.Create(context, new SkipTakeFilter.SkipArguments { Count = count }, data);
37+
return new CacheDataView(context, filter, Enumerable.Range(0, data.Schema.Count).ToArray());
3238
}
3339

3440
public static (string, ColumnType, ColumnPurpose, ColumnDimensions)[] GetColumnInfoTuples(MLContext context,

src/Test/AutoFitTests.cs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
3+
namespace Microsoft.ML.Auto.Test
4+
{
5+
[TestClass]
6+
public class AutoFitTests
7+
{
8+
[TestMethod]
9+
public void Hello()
10+
{
11+
var context = new MLContext();
12+
var dataPath = DatasetUtil.DownloadUciAdultDataset();
13+
var columnInference = context.Data.InferColumns(dataPath, DatasetUtil.UciAdultLabel, true);
14+
var textLoader = context.Data.CreateTextReader(columnInference);
15+
var trainData = textLoader.Read(dataPath);
16+
var validationData = trainData.Take(100);
17+
trainData = trainData.Skip(100);
18+
var best = context.BinaryClassification.AutoFit(trainData, DatasetUtil.UciAdultLabel, validationData, settings:
19+
new AutoFitSettings()
20+
{
21+
StoppingCriteria = new ExperimentStoppingCriteria()
22+
{
23+
MaxIterations = 2,
24+
TimeOutInMinutes = 1000000000
25+
}
26+
}, debugLogger: null);
27+
28+
Assert.IsNotNull(best?.BestPipeline?.Model);
29+
Assert.IsTrue(best.BestPipeline.Metrics.Accuracy > 0.80);
30+
}
31+
}
32+
}

src/Test/DatasetUtil.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public static IDataView GetUciAdultDataView()
2222
}
2323

2424
// downloads the UCI Adult dataset from the ML.Net repo
25-
private static string DownloadUciAdultDataset() =>
25+
public static string DownloadUciAdultDataset() =>
2626
DownloadIfNotExists("https://raw.githubusercontent.com/dotnet/machinelearning/f0e639af5ffdc839aae8e65d19b5a9a1f0db634a/test/data/adult.tiny.with-schema.txt", "uciadult.dataset");
2727

2828
private static string DownloadIfNotExists(string baseGitPath, string dataFile)

src/Test/GetNextPipelineTests.cs

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,35 +8,50 @@ namespace Microsoft.ML.Auto.Test
88
[TestClass]
99
public class GetNextPipelineTests
1010
{
11-
[Ignore]
1211
[TestMethod]
1312
public void GetNextPipeline()
1413
{
1514
var context = new MLContext();
1615
var uciAdult = DatasetUtil.GetUciAdultDataView();
1716
var columns = AutoMlUtils.GetColumnInfoTuples(context, uciAdult, DatasetUtil.UciAdultLabel, null);
1817

18+
// get next pipeline
19+
var pipeline = PipelineSuggester.GetNextPipeline(new List<PipelineRunResult>(), columns, TaskKind.BinaryClassification, 5);
20+
21+
// serialize & deserialize pipeline
22+
var serialized = JsonConvert.SerializeObject(pipeline);
23+
Console.WriteLine(serialized);
24+
var deserialized = JsonConvert.DeserializeObject<Pipeline>(serialized);
25+
26+
// run pipeline
27+
var estimator = deserialized.ToEstimator();
28+
var scoredData = estimator.Fit(uciAdult).Transform(uciAdult);
29+
var score = context.BinaryClassification.EvaluateNonCalibrated(scoredData).Accuracy;
30+
var result = new PipelineRunResult(deserialized, score, true);
31+
32+
Assert.IsNotNull(result);
33+
}
34+
35+
[TestMethod]
36+
public void GetNextPipelineMock()
37+
{
38+
var context = new MLContext();
39+
var uciAdult = DatasetUtil.GetUciAdultDataView();
40+
var columns = AutoMlUtils.GetColumnInfoTuples(context, uciAdult, DatasetUtil.UciAdultLabel, null);
41+
1942
// get next pipeline loop
2043
var history = new List<PipelineRunResult>();
21-
var maxIterations = 2;
44+
var maxIterations = 10;
2245
for (var i = 0; i < maxIterations; i++)
2346
{
2447
// get next pipeline
2548
var pipeline = PipelineSuggester.GetNextPipeline(history, columns, TaskKind.BinaryClassification, maxIterations - i);
26-
var serialized = JsonConvert.SerializeObject(pipeline);
27-
Console.WriteLine(serialized);
28-
var deserialized = JsonConvert.DeserializeObject<Pipeline>(serialized);
29-
30-
// run pipeline
31-
var estimator = deserialized.ToEstimator();
32-
var scoredData = estimator.Fit(uciAdult).Transform(uciAdult);
33-
var score = context.BinaryClassification.EvaluateNonCalibrated(scoredData).Accuracy;
34-
var result = new PipelineRunResult(deserialized, score, true);
3549

50+
var result = new PipelineRunResult(pipeline, AutoMlUtils.Random.NextDouble(), true);
3651
history.Add(result);
3752
}
3853

39-
Assert.AreEqual(2, history.Count);
54+
Assert.AreEqual(maxIterations, history.Count);
4055
}
4156
}
4257
}

src/Test/SweeperTests.cs

Lines changed: 75 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,65 @@
11
using System;
22
using System.Collections.Generic;
3-
using System.IO;
4-
using Microsoft.ML;
53
using Microsoft.VisualStudio.TestTools.UnitTesting;
64

75
namespace Microsoft.ML.Auto.Test
86
{
97
[TestClass]
108
public class SweeperTests
119
{
12-
[Ignore]
1310
[TestMethod]
14-
public void Smac2ParamsTest()
11+
public void Smac3ParamsTest()
1512
{
13+
var numInitialPopulation = 10;
14+
1615
var sweeper = new SmacSweeper(new SmacSweeper.Arguments()
1716
{
1817
SweptParameters = new INumericValueGenerator[] {
19-
new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5}),
20-
new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true })
18+
new FloatValueGenerator(new FloatParamArguments() { Name = "x1", Min = 1, Max = 1000}),
19+
new FloatValueGenerator(new FloatParamArguments() { Name = "x2", Min = 1, Max = 1000}),
20+
new FloatValueGenerator(new FloatParamArguments() { Name = "x3", Min = 1, Max = 1000}),
2121
},
22+
NumberInitialPopulation = numInitialPopulation
2223
});
2324

24-
Random rand = new Random(0);
2525
List<RunResult> results = new List<RunResult>();
2626

27-
int count = 0;
28-
while (true)
27+
RunResult bestResult = null;
28+
for (var i = 0; i < numInitialPopulation + 1; i++)
2929
{
3030
ParameterSet[] pars = sweeper.ProposeSweeps(1, results);
31-
if(pars == null)
32-
{
33-
break;
34-
}
31+
3532
foreach (ParameterSet p in pars)
3633
{
37-
float foo = 0;
38-
long bar = 0;
34+
float x1 = (p["x1"] as FloatParameterValue).Value;
35+
float x2 = (p["x2"] as FloatParameterValue).Value;
36+
float x3 = (p["x3"] as FloatParameterValue).Value;
3937

40-
foo = (p["foo"] as FloatParameterValue).Value;
41-
bar = (p["bar"] as LongParameterValue).Value;
38+
double metric = -200 * (Math.Abs(100 - x1) +
39+
Math.Abs(300 - x2) + Math.Abs(500 - x3));
4240

43-
double metric = ((5 - Math.Abs(4 - foo)) * 200) + (1001 - Math.Abs(33 - bar)) + rand.Next(1, 20);
44-
results.Add(new RunResult(p, metric, true));
45-
count++;
46-
Console.WriteLine("{0}--{1}--{2}--{3}", count, foo, bar, metric);
41+
RunResult result = new RunResult(p, metric, true);
42+
if (bestResult == null || bestResult.MetricValue < metric)
43+
{
44+
bestResult = result;
45+
}
46+
results.Add(result);
47+
48+
Console.WriteLine($"{metric}\t{x1},{x2}");
4749
}
50+
4851
}
52+
53+
Console.WriteLine($"Best: {bestResult.MetricValue}");
54+
55+
Assert.IsNotNull(bestResult);
56+
Assert.IsTrue(bestResult.MetricValue != 0);
4957
}
5058

59+
5160
[Ignore]
5261
[TestMethod]
53-
public void Smac4ParamsTest()
62+
public void Smac4ParamsConvergenceTest()
5463
{
5564
var sweeper = new SmacSweeper(new SmacSweeper.Arguments()
5665
{
@@ -61,15 +70,14 @@ public void Smac4ParamsTest()
6170
new FloatValueGenerator(new FloatParamArguments() { Name = "x4", Min = 1, Max = 1000}),
6271
},
6372
});
64-
65-
Random rand = new Random(0);
73+
6674
List<RunResult> results = new List<RunResult>();
6775

6876
RunResult bestResult = null;
6977
for (var i = 0; i < 300; i++)
7078
{
7179
ParameterSet[] pars = sweeper.ProposeSweeps(1, results);
72-
80+
7381
// if run converged, break
7482
if (pars == null)
7583
{
@@ -82,14 +90,14 @@ public void Smac4ParamsTest()
8290
float x2 = (p["x2"] as FloatParameterValue).Value;
8391
float x3 = (p["x3"] as FloatParameterValue).Value;
8492
float x4 = (p["x4"] as FloatParameterValue).Value;
85-
93+
8694
double metric = -200 * (Math.Abs(100 - x1) +
8795
Math.Abs(300 - x2) +
8896
Math.Abs(500 - x3) +
89-
Math.Abs(700 - x4) );
97+
Math.Abs(700 - x4));
9098

9199
RunResult result = new RunResult(p, metric, true);
92-
if(bestResult == null || bestResult.MetricValue < metric)
100+
if (bestResult == null || bestResult.MetricValue < metric)
93101
{
94102
bestResult = result;
95103
}
@@ -102,5 +110,44 @@ public void Smac4ParamsTest()
102110

103111
Console.WriteLine($"Best: {bestResult.MetricValue}");
104112
}
113+
114+
[Ignore]
115+
[TestMethod]
116+
public void Smac2ParamsConvergenceTest()
117+
{
118+
var sweeper = new SmacSweeper(new SmacSweeper.Arguments()
119+
{
120+
SweptParameters = new INumericValueGenerator[] {
121+
new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5}),
122+
new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true })
123+
},
124+
});
125+
126+
Random rand = new Random(0);
127+
List<RunResult> results = new List<RunResult>();
128+
129+
int count = 0;
130+
while (true)
131+
{
132+
ParameterSet[] pars = sweeper.ProposeSweeps(1, results);
133+
if(pars == null)
134+
{
135+
break;
136+
}
137+
foreach (ParameterSet p in pars)
138+
{
139+
float foo = 0;
140+
long bar = 0;
141+
142+
foo = (p["foo"] as FloatParameterValue).Value;
143+
bar = (p["bar"] as LongParameterValue).Value;
144+
145+
double metric = ((5 - Math.Abs(4 - foo)) * 200) + (1001 - Math.Abs(33 - bar)) + rand.Next(1, 20);
146+
results.Add(new RunResult(p, metric, true));
147+
count++;
148+
Console.WriteLine("{0}--{1}--{2}--{3}", count, foo, bar, metric);
149+
}
150+
}
151+
}
105152
}
106153
}

0 commit comments

Comments
 (0)