Skip to content

Commit 7b9a6a9

Browse files
authored
Merge pull request #1 from dotnet/master
Added sample for WithOnFitDelegate (dotnet#3738)
2 parents 0337ab4 + 3fb7256 commit 7b9a6a9

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Collections.Immutable;
4+
using System.Linq;
5+
using Microsoft.ML;
6+
using Microsoft.ML.Data;
7+
using Microsoft.ML.Transforms;
8+
using static Microsoft.ML.Transforms.NormalizingTransformer;
9+
10+
namespace Samples.Dynamic
11+
{
12+
public class WithOnFitDelegate
13+
{
14+
public static void Example()
15+
{
16+
// Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
17+
// as well as the source of randomness.
18+
var mlContext = new MLContext();
19+
var samples = new List<DataPoint>()
20+
{
21+
new DataPoint(){ Features = new float[4] { 8, 1, 3, 0}, Label = true },
22+
new DataPoint(){ Features = new float[4] { 6, 2, 2, 0}, Label = true },
23+
new DataPoint(){ Features = new float[4] { 4, 0, 1, 0}, Label = false },
24+
new DataPoint(){ Features = new float[4] { 2,-1,-1, 1}, Label = false }
25+
};
26+
// Convert training data to IDataView, the general data type used in ML.NET.
27+
var data = mlContext.Data.LoadFromEnumerable(samples);
28+
29+
// Create a pipeline to normalize the features and train a binary classifier.
30+
// We use WithOnFitDelegate for the intermediate binning normalization step,
31+
// so that we can inspect the properties of the normalizer after fitting.
32+
NormalizingTransformer binningTransformer = null;
33+
var pipeline =
34+
mlContext.Transforms.NormalizeBinning("Features", maximumBinCount: 3)
35+
.WithOnFitDelegate(fittedTransformer => binningTransformer = fittedTransformer)
36+
.Append(mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression());
37+
38+
Console.WriteLine(binningTransformer == null);
39+
// Expected Output:
40+
// True
41+
42+
var model = pipeline.Fit(data);
43+
44+
// During fitting binningTransformer will get assigned a new value
45+
Console.WriteLine(binningTransformer == null);
46+
// Expected Output:
47+
// False
48+
49+
// Inspect some of the properties of the binning transformer
50+
var binningParam = binningTransformer.GetNormalizerModelParameters(0) as
51+
BinNormalizerModelParameters<ImmutableArray<float>>;
52+
53+
for (int i = 0; i < binningParam.UpperBounds.Length; i++)
54+
{
55+
var upperBounds = string.Join(", ", binningParam.UpperBounds[i]);
56+
Console.WriteLine(
57+
$"Bin {i}: Density = {binningParam.Density[i]}, " +
58+
$"Upper-bounds = {upperBounds}");
59+
}
60+
// Expected output:
61+
// Bin 0: Density = 2, Upper-bounds = 3, 7, Infinity
62+
// Bin 1: Density = 2, Upper-bounds = -0.5, 1.5, Infinity
63+
// Bin 2: Density = 2, Upper-bounds = 0, 2.5, Infinity
64+
// Bin 3: Density = 1, Upper-bounds = 0.5, Infinity
65+
}
66+
67+
private class DataPoint
68+
{
69+
[VectorType(4)]
70+
public float[] Features { get; set; }
71+
public bool Label { get; set; }
72+
}
73+
}
74+
}
75+

src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs

+7
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,13 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
135135
/// <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> is called. Because <see cref="IEstimator{TTransformer}.Fit(IDataView)"/>
136136
/// may be called multiple times, this delegate may also be called multiple times.</param>
137137
/// <returns>A wrapping estimator that calls the indicated delegate whenever fit is called</returns>
138+
/// <example>
139+
/// <format type="text/markdown">
140+
/// <![CDATA[
141+
/// [!code-csharp[OnFit](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/WithOnFitDelegate.cs)]
142+
/// ]]>
143+
/// </format>
144+
/// </example>
138145
public static IEstimator<TTransformer> WithOnFitDelegate<TTransformer>(this IEstimator<TTransformer> estimator, Action<TTransformer> onFit)
139146
where TTransformer : class, ITransformer
140147
{

0 commit comments

Comments
 (0)