Skip to content

Commit ee2ffdb

Browse files
committed
Samples for CustomMapping, IndicateMissingValues, ReplaceMissingValues (dotnet#3216)
1 parent 62a5b34 commit ee2ffdb

10 files changed

+384
-161
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using Microsoft.ML;
4+
5+
namespace Samples.Dynamic
6+
{
7+
public static class CustomMapping
8+
{
9+
public static void Example()
10+
{
11+
// Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
12+
// as well as the source of randomness.
13+
var mlContext = new MLContext();
14+
15+
// Get a small dataset as an IEnumerable and convert it to an IDataView.
16+
var samples = new List<InputData>
17+
{
18+
new InputData { Age = 26 },
19+
new InputData { Age = 35 },
20+
new InputData { Age = 34 },
21+
new InputData { Age = 28 },
22+
};
23+
var data = mlContext.Data.LoadFromEnumerable(samples);
24+
25+
// We define the custom mapping between input and output rows that will be applied by the transformation.
26+
Action<InputData, CustomMappingOutput > mapping =
27+
(input, output) => output.IsUnderThirty = input.Age < 30;
28+
29+
// Custom transformations can be used to transform data directly, or as part of a pipeline of estimators.
30+
// Note: If contractName is null in the CustomMapping estimator, any pipeline of estimators containing it,
31+
// cannot be saved and loaded back.
32+
var pipeline = mlContext.Transforms.CustomMapping(mapping, contractName: null);
33+
34+
// Now we can transform the data and look at the output to confirm the behavior of the estimator.
35+
// This operation doesn't actually evaluate data until we read the data below.
36+
var transformer = pipeline.Fit(data);
37+
var transformedData = transformer.Transform(data);
38+
39+
var dataEnumerable = mlContext.Data.CreateEnumerable<TransformedData>(transformedData, reuseRowObject: true);
40+
Console.WriteLine("Age\t IsUnderThirty");
41+
foreach (var row in dataEnumerable)
42+
Console.WriteLine($"{row.Age}\t {row.IsUnderThirty}");
43+
44+
// Expected output:
45+
// Age IsUnderThirty
46+
// 26 True
47+
// 35 False
48+
// 34 False
49+
// 28 True
50+
}
51+
52+
// Defines only the column to be generated by the custom mapping transformation in addition to the columns already present.
53+
private class CustomMappingOutput
54+
{
55+
public bool IsUnderThirty { get; set; }
56+
}
57+
58+
// Defines the schema of the input data.
59+
private class InputData
60+
{
61+
public float Age { get; set; }
62+
}
63+
64+
// Defines the schema of the transformed data, which includes the new column IsUnderThirty.
65+
private class TransformedData : InputData
66+
{
67+
public bool IsUnderThirty { get; set; }
68+
}
69+
}
70+
}

docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingSample.cs

-80
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using Microsoft.ML;
4+
using Microsoft.ML.Transforms;
5+
6+
namespace Samples.Dynamic
7+
{
8+
public static class CustomMappingSaveAndLoad
9+
{
10+
public static void Example()
11+
{
12+
// Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
13+
// as well as the source of randomness.
14+
var mlContext = new MLContext();
15+
16+
// Get a small dataset as an IEnumerable and convert it to an IDataView.
17+
var samples = new List<InputData>
18+
{
19+
new InputData { Age = 26 },
20+
new InputData { Age = 35 },
21+
new InputData { Age = 34 },
22+
new InputData { Age = 28 },
23+
};
24+
var data = mlContext.Data.LoadFromEnumerable(samples);
25+
26+
// Custom transformations can be used to transform data directly, or as part of a pipeline of estimators.
27+
var pipeline = mlContext.Transforms.CustomMapping(new IsUnderThirtyCustomAction().GetMapping(), contractName: "IsUnderThirty");
28+
var transformer = pipeline.Fit(data);
29+
30+
// To save and load the CustomMapping estimator, the assembly in which the custom action is defined needs to be registered in the
31+
// environment. The following registers the assembly where IsUnderThirtyCustomAction is defined.
32+
mlContext.ComponentCatalog.RegisterAssembly(typeof(IsUnderThirtyCustomAction).Assembly);
33+
34+
// Now the transform pipeline can be saved and loaded through the usual MLContext method.
35+
mlContext.Model.Save(transformer, data.Schema, "customTransform.zip");
36+
var loadedTransform = mlContext.Model.Load("customTransform.zip", out var inputSchema);
37+
38+
// Now we can transform the data and look at the output to confirm the behavior of the estimator.
39+
// This operation doesn't actually evaluate data until we read the data below.
40+
var transformedData = loadedTransform.Transform(data);
41+
42+
var dataEnumerable = mlContext.Data.CreateEnumerable<TransformedData>(transformedData, reuseRowObject: true);
43+
Console.WriteLine("Age\tIsUnderThirty");
44+
foreach (var row in dataEnumerable)
45+
Console.WriteLine($"{row.Age}\t {row.IsUnderThirty}");
46+
47+
// Expected output:
48+
// Age IsUnderThirty
49+
// 26 True
50+
// 35 False
51+
// 34 False
52+
// 28 True
53+
}
54+
55+
// The custom action needs to implement the abstract class CustomMappingFactory, and needs to have attribute
56+
// CustomMappingFactoryAttribute with argument equal to the contractName used to define the CustomMapping estimator
57+
// which uses the action.
58+
[CustomMappingFactoryAttribute("IsUnderThirty")]
59+
private class IsUnderThirtyCustomAction : CustomMappingFactory<InputData, CustomMappingOutput>
60+
{
61+
// We define the custom mapping between input and output rows that will be applied by the transformation.
62+
public static void CustomAction(InputData input, CustomMappingOutput output)
63+
=> output.IsUnderThirty = input.Age < 30;
64+
65+
public override Action<InputData, CustomMappingOutput> GetMapping()
66+
=> CustomAction;
67+
}
68+
69+
// Defines only the column to be generated by the custom mapping transformation in addition to the columns already present.
70+
private class CustomMappingOutput
71+
{
72+
public bool IsUnderThirty { get; set; }
73+
}
74+
75+
// Defines the schema of the input data.
76+
private class InputData
77+
{
78+
public float Age { get; set; }
79+
}
80+
81+
// Defines the schema of the transformed data, which includes the new column IsUnderThirty.
82+
private class TransformedData : InputData
83+
{
84+
public bool IsUnderThirty { get; set; }
85+
}
86+
}
87+
}

docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/IndicateMissingValues.cs

+11-27
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,24 @@ namespace Microsoft.ML.Samples.Dynamic
77
{
88
public static class IndicateMissingValues
99
{
10-
1110
public static void Example()
1211
{
1312
// Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
1413
// as well as the source of randomness.
1514
var mlContext = new MLContext();
1615

16+
// Get a small dataset as an IEnumerable and convert it to an IDataView.
1717
var samples = new List<DataPoint>()
1818
{
19-
new DataPoint(){ Label = 3, Features = new float[3] {1, 1, 0} },
20-
new DataPoint(){ Label = 32, Features = new float[3] {0, float.NaN, 1} },
21-
new DataPoint(){ Label = float.NaN, Features = new float[3] {-1, float.NaN, -3} },
19+
new DataPoint(){ Features = new float[3] {1, 1, 0} },
20+
new DataPoint(){ Features = new float[3] {0, float.NaN, 1} },
21+
new DataPoint(){ Features = new float[3] {-1, float.NaN, -3} },
2222
};
23-
// Convert training data to IDataView, the general data type used in ML.NET.
2423
var data = mlContext.Data.LoadFromEnumerable(samples);
2524

26-
// IndicateMissingValues is used to create a boolean containing
27-
// 'true' where the value in the input column is NaN. This value can be used
28-
// to replace missing values with other values.
29-
IEstimator<ITransformer> pipeline = mlContext.Transforms.IndicateMissingValues("MissingIndicator", "Features");
25+
// IndicateMissingValues is used to create a boolean containing 'true' where the value in the
26+
// input column is missing. For floats and doubles, missing values are represented as NaN.
27+
var pipeline = mlContext.Transforms.IndicateMissingValues("MissingIndicator", "Features");
3028

3129
// Now we can transform the data and look at the output to confirm the behavior of the estimator.
3230
// This operation doesn't actually evaluate data until we read the data below.
@@ -36,32 +34,18 @@ public static void Example()
3634
// We can extract the newly created column as an IEnumerable of SampleDataTransformed, the class we define below.
3735
var rowEnumerable = mlContext.Data.CreateEnumerable<SampleDataTransformed>(transformedData, reuseRowObject: false);
3836

39-
// a small printing utility
40-
Func<object[], string> vectorPrinter = (object[] vector) =>
41-
{
42-
string preview = "[";
43-
foreach (var slot in vector)
44-
preview += $"{slot} ";
45-
return preview += "]";
46-
47-
};
48-
4937
// And finally, we can write out the rows of the dataset, looking at the columns of interest.
5038
foreach (var row in rowEnumerable)
51-
{
52-
Console.WriteLine($"Label: {row.Label} Features: {vectorPrinter(row.Features.Cast<object>().ToArray())} MissingIndicator: {vectorPrinter(row.MissingIndicator.Cast<object>().ToArray())}");
53-
}
39+
Console.WriteLine($"Features: [{string.Join(", ", row.Features)}]\t MissingIndicator: [{string.Join(", ", row.MissingIndicator)}]");
5440

5541
// Expected output:
56-
//
57-
// Label: 3 Features: [1 1 0] MissingIndicator: [False False False]
58-
// Label: 32 Features: [0 NaN 1] MissingIndicator: [False True False]
59-
// Label: NaN Features: [-1 NaN -3 ] MissingIndicator: [False True False]
42+
// Features: [1, 1, 0] MissingIndicator: [False, False, False]
43+
// Features: [0, NaN, 1] MissingIndicator: [False, True, False]
44+
// Features: [-1, NaN, -3] MissingIndicator: [False, True, False]
6045
}
6146

6247
private class DataPoint
6348
{
64-
public float Label { get; set; }
6549
[VectorType(3)]
6650
public float[] Features { get; set; }
6751
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using Microsoft.ML;
4+
using Microsoft.ML.Data;
5+
6+
namespace Samples.Dynamic
7+
{
8+
public static class IndicateMissingValuesMultiColumn
9+
{
10+
public static void Example()
11+
{
12+
// Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
13+
// as well as the source of randomness.
14+
var mlContext = new MLContext();
15+
16+
// Get a small dataset as an IEnumerable and convert it to an IDataView.
17+
var samples = new List<DataPoint>()
18+
{
19+
new DataPoint(){ Features1 = new float[3] {1, 1, 0}, Features2 = new float[2] {1, 1} },
20+
new DataPoint(){ Features1 = new float[3] {0, float.NaN, 1}, Features2 = new float[2] {float.NaN, 1} },
21+
new DataPoint(){ Features1 = new float[3] {-1, float.NaN, -3}, Features2 = new float[2] {1, float.PositiveInfinity} },
22+
};
23+
var data = mlContext.Data.LoadFromEnumerable(samples);
24+
25+
// IndicateMissingValues is used to create a boolean containing 'true' where the value in the
26+
// input column is missing. For floats and doubles, missing values are NaN.
27+
// We can use an array of InputOutputColumnPair to apply the MissingValueIndicatorEstimator
28+
// to multiple columns in one pass over the data.
29+
var pipeline = mlContext.Transforms.IndicateMissingValues(new[] {
30+
new InputOutputColumnPair("MissingIndicator1", "Features1"),
31+
new InputOutputColumnPair("MissingIndicator2", "Features2")
32+
});
33+
34+
// Now we can transform the data and look at the output to confirm the behavior of the estimator.
35+
// This operation doesn't actually evaluate data until we read the data below.
36+
var tansformer = pipeline.Fit(data);
37+
var transformedData = tansformer.Transform(data);
38+
39+
// We can extract the newly created column as an IEnumerable of SampleDataTransformed, the class we define below.
40+
var rowEnumerable = mlContext.Data.CreateEnumerable<SampleDataTransformed>(transformedData, reuseRowObject: false);
41+
42+
// And finally, we can write out the rows of the dataset, looking at the columns of interest.
43+
foreach (var row in rowEnumerable)
44+
Console.WriteLine($"Features1: [{string.Join(", ", row.Features1)}]\t MissingIndicator1: [{string.Join(", ", row.MissingIndicator1)}]\t " +
45+
$"Features2: [{string.Join(", ", row.Features2)}]\t MissingIndicator2: [{string.Join(", ", row.MissingIndicator2)}]");
46+
47+
// Expected output:
48+
// Features1: [1, 1, 0] MissingIndicator1: [False, False, False] Features2: [1, 1] MissingIndicator2: [False, False]
49+
// Features1: [0, NaN, 1] MissingIndicator1: [False, True, False] Features2: [NaN, 1] MissingIndicator2: [True, False]
50+
// Features1: [-1, NaN, -3] MissingIndicator1: [False, True, False] Features2: [1, ∞] MissingIndicator2: [False, False]
51+
}
52+
53+
private class DataPoint
54+
{
55+
[VectorType(3)]
56+
public float[] Features1 { get; set; }
57+
[VectorType(2)]
58+
public float[] Features2 { get; set; }
59+
}
60+
61+
private sealed class SampleDataTransformed : DataPoint
62+
{
63+
public bool[] MissingIndicator1 { get; set; }
64+
public bool[] MissingIndicator2 { get; set; }
65+
66+
}
67+
}
68+
}

0 commit comments

Comments
 (0)