-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Adding functional tests for all training and evaluation tasks #2646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
2e4d4b0
0b02fb1
54dfc2f
acd2f1a
e9d5bad
7012ea6
94c23d2
12d3dfe
42af510
51ea456
89e44cb
051ba01
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
|
||
using System; | ||
using Microsoft.Data.DataView; | ||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.Functional.Tests.Datasets | ||
{ | ||
/// <summary> | ||
/// A class for the Iris test dataset. | ||
/// </summary> | ||
internal sealed class Iris | ||
{ | ||
[LoadColumn(0)] | ||
public float Label { get; set; } | ||
|
||
[LoadColumn(1)] | ||
public float SepalLength { get; set; } | ||
|
||
[LoadColumn(2)] | ||
public float SepalWidth { get; set; } | ||
|
||
[LoadColumn(4)] | ||
public float PetalLength { get; set; } | ||
|
||
[LoadColumn(5)] | ||
public float PetalWidth { get; set; } | ||
|
||
/// <summary> | ||
/// The list of columns commonly used as features. | ||
/// </summary> | ||
public static readonly string[] Features = new string[] { "SepalLength", "SepalWidth", "PetalLength", "PetalWidth" }; | ||
|
||
public static IDataView LoadAsRankingProblem(MLContext mlContext, string filePath, bool hasHeader, char separatorChar, int seed = 1) | ||
{ | ||
// Load the Iris data. | ||
var data = mlContext.Data.ReadFromTextFile<Iris>(filePath, hasHeader: hasHeader, separatorChar: separatorChar); | ||
|
||
// Create a function that generates a random groupId. | ||
var rng = new Random(seed); | ||
Action<Iris, IrisWithGroup> generateGroupId = (input, output) => | ||
{ | ||
output.Label = input.Label; | ||
// The standard set used in tests has 150 rows | ||
output.GroupId = (ushort)rng.Next(0, 30); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
any reason why it's ushort? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
output.PetalLength = input.PetalLength; | ||
output.PetalWidth = input.PetalWidth; | ||
output.SepalLength = input.SepalLength; | ||
output.SepalWidth = input.SepalWidth; | ||
}; | ||
|
||
// Describe a pipeline that generates a groupId and converts it to a key. | ||
var pipeline = mlContext.Transforms.CustomMapping(generateGroupId, null) | ||
.Append(mlContext.Transforms.Conversion.MapValueToKey("GroupId")); | ||
|
||
// Transform the data | ||
var transformedData = pipeline.Fit(data).Transform(data); | ||
|
||
return transformedData; | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// A class for the Iris dataset with a GroupId column. | ||
/// </summary> | ||
internal sealed class IrisWithGroup | ||
{ | ||
public float Label { get; set; } | ||
public ushort GroupId { get; set; } | ||
public float SepalLength { get; set; } | ||
public float SepalWidth { get; set; } | ||
public float PetalLength { get; set; } | ||
public float PetalWidth { get; set; } | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.Functional.Tests.Datasets | ||
{ | ||
/// <summary> | ||
/// A class for reading in the MNIST One Class test dataset. | ||
/// </summary> | ||
internal sealed class MnistOneClass | ||
{ | ||
[LoadColumn(0)] | ||
public float Label { get; set; } | ||
|
||
[LoadColumn(1, 784), VectorType(784)] | ||
public float[] Features { get; set; } | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.Functional.Tests.Datasets | ||
{ | ||
/// <summary> | ||
/// A class for reading in the Sentiment test dataset. | ||
/// </summary> | ||
internal sealed class TweetSentiment | ||
{ | ||
[LoadColumn(0), ColumnName("Label")] | ||
public bool Sentiment { get; set; } | ||
|
||
[LoadColumn(1)] | ||
public string SentimentText { get; set; } | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
|
||
using System; | ||
using Microsoft.Data.DataView; | ||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.Functional.Tests.Datasets | ||
{ | ||
/// <summary> | ||
/// A class containing one property per <see cref="DataKind"/>. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't get this sentence. What DataKind has to do with MF? Your data is row+col+value triplets for matrix. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
/// </summary> | ||
/// <remarks> | ||
/// This class has annotations for automatic deserialization from a file, and contains helper methods | ||
/// for reading from a file and for generating a random dataset as an IEnumerable. | ||
/// </remarks> | ||
internal sealed class TrivialMatrixFactorization | ||
{ | ||
[LoadColumn(0)] | ||
public float Label { get; set; } | ||
|
||
[LoadColumn(1)] | ||
public uint MatrixColumnIndex { get; set; } | ||
|
||
[LoadColumn(2)] | ||
public uint MatrixRowIndex { get; set; } | ||
|
||
public static IDataView LoadAndFeaturizeFromTextFile(MLContext mlContext, string filePath, bool hasHeader, char separatorChar) | ||
{ | ||
// Load the data from a textfile. | ||
var data = mlContext.Data.ReadFromTextFile<TrivialMatrixFactorization>(filePath, hasHeader: hasHeader, separatorChar: separatorChar); | ||
|
||
// Describe a pipeline to translate the uints to keys. | ||
var pipeline = mlContext.Transforms.Conversion.MapValueToKey("MatrixColumnIndex") | ||
.Append(mlContext.Transforms.Conversion.MapValueToKey("MatrixRowIndex")); | ||
|
||
// Transform the data. | ||
var transformedData = pipeline.Fit(data).Transform(data); | ||
|
||
return transformedData; | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we added
CalibratedBinaryClassificationMetrics
recently. Strangely you don't have them in your PR. #ClosedThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. I hadn't noticed that.
In reply to: 259454652 [](ancestors = 259454652)