Skip to content

Commit 468a1b7

Browse files
authored
add basic autofit regression test (dotnet#28)
1 parent fb6390b commit 468a1b7

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

src/Test/AutoFitTests.cs

+24
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,29 @@ public void AutoFitMultiTest()
5252
Assert.IsNotNull(best?.BestPipeline?.Model);
5353
Assert.IsTrue(best.BestPipeline.Metrics.AccuracyMicro > 0.80);
5454
}
55+
56+
[TestMethod]
57+
public void AutoFitRegressionTest()
58+
{
59+
var context = new MLContext();
60+
var dataPath = DatasetUtil.DownloadMlNetGeneratedRegressionDataset();
61+
var columnInference = context.Data.InferColumns(dataPath, DatasetUtil.MlNetGeneratedRegressionLabel, true);
62+
var textLoader = context.Data.CreateTextReader(columnInference);
63+
var trainData = textLoader.Read(dataPath);
64+
var validationData = trainData.Take(20);
65+
trainData = trainData.Skip(20);
66+
var best = context.Regression.AutoFit(trainData, DatasetUtil.MlNetGeneratedRegressionLabel, validationData, settings:
67+
new AutoFitSettings()
68+
{
69+
StoppingCriteria = new ExperimentStoppingCriteria()
70+
{
71+
MaxIterations = 1,
72+
TimeOutInMinutes = 1000000000
73+
}
74+
}, debugLogger: null);
75+
76+
Assert.IsNotNull(best?.BestPipeline?.Model);
77+
Assert.IsTrue(best.BestPipeline.Metrics.RSquared > 0.9);
78+
}
5579
}
5680
}

src/Test/DatasetUtil.cs

+4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ internal static class DatasetUtil
99
{
1010
public const string UciAdultLabel = DefaultColumnNames.Label;
1111
public const string TrivialDatasetLabel = DefaultColumnNames.Label;
12+
public const string MlNetGeneratedRegressionLabel = "target";
1213

1314
private static IDataView _uciAdultDataView;
1415

@@ -29,6 +30,9 @@ public static string DownloadUciAdultDataset() =>
2930
public static string DownloadTrivialDataset() =>
3031
DownloadIfNotExists("https://raw.githubusercontent.com/dotnet/machinelearning/eae76959e6714af44caa212e102a5f06f0110e72/test/data/trivial-train.tsv", "trivial.dataset");
3132

33+
public static string DownloadMlNetGeneratedRegressionDataset() =>
34+
DownloadIfNotExists("https://raw.githubusercontent.com/dotnet/machinelearning/e78971ea6fd736038b4c355b840e5cbabae8cb55/test/data/generated_regression_dataset.csv", "mlnet_generated_regression.dataset");
35+
3236
private static string DownloadIfNotExists(string baseGitPath, string dataFile)
3337
{
3438
// if file doesn't already exist, download it

0 commit comments

Comments
 (0)