Skip to content

Commit ff62d40

Browse files
authored
Expose advanced options for the NormalizingEstimator (#3052)
* adding 5 apis * moving 5 APIs to Experimental nuget, and adding tests * review comments (fixZero, updated tests) * review comments (fix order of arguments)
1 parent 75fc055 commit ff62d40

File tree

3 files changed

+207
-5
lines changed

3 files changed

+207
-5
lines changed

src/Microsoft.ML.Experimental/MLContextExtensions.cs

+103
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using Microsoft.ML.Data;
6+
using Microsoft.ML.Transforms;
7+
58
namespace Microsoft.ML.Experimental
69
{
710
public static class MLContextExtensions
@@ -11,5 +14,105 @@ public static class MLContextExtensions
1114
/// </summary>
1215
/// <param name="ctx"><see cref="MLContext"/> reference.</param>
1316
public static void CancelExecution(this MLContext ctx) => ctx.CancelExecution();
17+
18+
/// <summary>
19+
/// Normalize (rescale) the column according to the <see cref="NormalizingEstimator.NormalizationMode.MinMax"/> mode.
20+
/// It normalizes the data based on the observed minimum and maximum values of the data.
21+
/// </summary>
22+
/// <param name="catalog">The transform catalog</param>
23+
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
24+
/// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
25+
/// <param name="maximumExampleCount">Maximum number of examples used to train the normalizer.</param>
26+
/// <param name="fixZero">Whether to map zero to zero, preserving sparsity.</param>
27+
public static NormalizingEstimator NormalizeMinMax(this TransformsCatalog catalog,
28+
string outputColumnName, string inputColumnName = null,
29+
long maximumExampleCount = NormalizingEstimator.Defaults.MaximumExampleCount,
30+
bool fixZero = NormalizingEstimator.Defaults.EnsureZeroUntouched)
31+
{
32+
var columnOptions = new NormalizingEstimator.MinMaxColumnOptions(outputColumnName, inputColumnName, maximumExampleCount, fixZero);
33+
return new NormalizingEstimator(CatalogUtils.GetEnvironment(catalog), columnOptions);
34+
}
35+
36+
/// <summary>
37+
/// Normalize (rescale) the column according to the <see cref="NormalizingEstimator.NormalizationMode.MeanVariance"/> mode.
38+
/// It normalizes the data based on the computed mean and variance of the data.
39+
/// </summary>
40+
/// <param name="catalog">The transform catalog</param>
41+
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
42+
/// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
43+
/// <param name="maximumExampleCount">Maximum number of examples used to train the normalizer.</param>
44+
/// <param name="fixZero">Whether to map zero to zero, preserving sparsity.</param>
45+
/// <param name="useCdf">Whether to use CDF as the output.</param>
46+
public static NormalizingEstimator NormalizeMeanVariance(this TransformsCatalog catalog,
47+
string outputColumnName, string inputColumnName = null,
48+
long maximumExampleCount = NormalizingEstimator.Defaults.MaximumExampleCount,
49+
bool fixZero = NormalizingEstimator.Defaults.EnsureZeroUntouched,
50+
bool useCdf = NormalizingEstimator.Defaults.MeanVarCdf)
51+
{
52+
var columnOptions = new NormalizingEstimator.MeanVarianceColumnOptions(outputColumnName, inputColumnName, maximumExampleCount, fixZero, useCdf);
53+
return new NormalizingEstimator(CatalogUtils.GetEnvironment(catalog), columnOptions);
54+
}
55+
56+
/// <summary>
57+
/// Normalize (rescale) the column according to the <see cref="NormalizingEstimator.NormalizationMode.LogMeanVariance"/> mode.
58+
/// It normalizes the data based on the computed mean and variance of the logarithm of the data.
59+
/// </summary>
60+
/// <param name="catalog">The transform catalog</param>
61+
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
62+
/// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
63+
/// <param name="maximumExampleCount">Maximum number of examples used to train the normalizer.</param>
64+
/// <param name="useCdf">Whether to use CDF as the output.</param>
65+
public static NormalizingEstimator NormalizeLogMeanVariance(this TransformsCatalog catalog,
66+
string outputColumnName, string inputColumnName = null,
67+
long maximumExampleCount = NormalizingEstimator.Defaults.MaximumExampleCount,
68+
bool useCdf = NormalizingEstimator.Defaults.LogMeanVarCdf)
69+
{
70+
var columnOptions = new NormalizingEstimator.LogMeanVarianceColumnOptions(outputColumnName, inputColumnName, maximumExampleCount, useCdf);
71+
return new NormalizingEstimator(CatalogUtils.GetEnvironment(catalog), columnOptions);
72+
}
73+
74+
/// <summary>
75+
/// Normalize (rescale) the column according to the <see cref="NormalizingEstimator.NormalizationMode.Binning"/> mode.
76+
/// The values are assigned into bins with equal density.
77+
/// </summary>
78+
/// <param name="catalog">The transform catalog</param>
79+
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
80+
/// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
81+
/// <param name="maximumExampleCount">Maximum number of examples used to train the normalizer.</param>
82+
/// <param name="fixZero">Whether to map zero to zero, preserving sparsity.</param>
83+
/// <param name="maximumBinCount">Maximum number of bins (power of 2 recommended).</param>
84+
public static NormalizingEstimator NormalizeBinning(this TransformsCatalog catalog,
85+
string outputColumnName, string inputColumnName = null,
86+
long maximumExampleCount = NormalizingEstimator.Defaults.MaximumExampleCount,
87+
bool fixZero = NormalizingEstimator.Defaults.EnsureZeroUntouched,
88+
int maximumBinCount = NormalizingEstimator.Defaults.MaximumBinCount)
89+
{
90+
var columnOptions = new NormalizingEstimator.BinningColumnOptions(outputColumnName, inputColumnName, maximumExampleCount, fixZero, maximumBinCount);
91+
return new NormalizingEstimator(CatalogUtils.GetEnvironment(catalog), columnOptions);
92+
}
93+
94+
/// <summary>
95+
/// Normalize (rescale) the column according to the <see cref="NormalizingEstimator.NormalizationMode.SupervisedBinning"/> mode.
96+
/// The values are assigned into bins based on correlation with the <paramref name="labelColumnName"/> column.
97+
/// </summary>
98+
/// <param name="catalog">The transform catalog</param>
99+
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
100+
/// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
101+
/// <param name="labelColumnName">Name of the label column for supervised binning.</param>
102+
/// <param name="maximumExampleCount">Maximum number of examples used to train the normalizer.</param>
103+
/// <param name="fixZero">Whether to map zero to zero, preserving sparsity.</param>
104+
/// <param name="maximumBinCount">Maximum number of bins (power of 2 recommended).</param>
105+
/// <param name="mininimumExamplesPerBin">Minimum number of examples per bin.</param>
106+
public static NormalizingEstimator NormalizeSupervisedBinning(this TransformsCatalog catalog,
107+
string outputColumnName, string inputColumnName = null,
108+
string labelColumnName = DefaultColumnNames.Label,
109+
long maximumExampleCount = NormalizingEstimator.Defaults.MaximumExampleCount,
110+
bool fixZero = NormalizingEstimator.Defaults.EnsureZeroUntouched,
111+
int maximumBinCount = NormalizingEstimator.Defaults.MaximumBinCount,
112+
int mininimumExamplesPerBin = NormalizingEstimator.Defaults.MininimumBinSize)
113+
{
114+
var columnOptions = new NormalizingEstimator.SupervisedBinningColumOptions(outputColumnName, inputColumnName, labelColumnName, maximumExampleCount, fixZero, maximumBinCount, mininimumExamplesPerBin);
115+
return new NormalizingEstimator(CatalogUtils.GetEnvironment(catalog), columnOptions);
116+
}
14117
}
15118
}

test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
<ProjectReference Include="..\..\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
1010
<ProjectReference Include="..\..\src\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.csproj" />
1111
<ProjectReference Include="..\..\src\Microsoft.ML.EntryPoints\Microsoft.ML.EntryPoints.csproj" />
12+
<ProjectReference Include="..\..\src\Microsoft.ML.Experimental\Microsoft.ML.Experimental.csproj" />
1213
<ProjectReference Include="..\..\src\Microsoft.ML.LightGbm\Microsoft.ML.LightGbm.csproj" />
1314
<ProjectReference Include="..\..\src\Microsoft.ML.Mkl.Components\Microsoft.ML.Mkl.Components.csproj" />
1415
<ProjectReference Include="..\..\src\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj" />
@@ -38,7 +39,7 @@
3839
<NativeAssemblyReference Include="SymSgdNative" />
3940
<NativeAssemblyReference Include="MklProxyNative" />
4041
<NativeAssemblyReference Include="MklImports" />
41-
<NativeAssemblyReference Condition="'$(OS)' == 'Windows_NT'" Include="libiomp5md"/>
42+
<NativeAssemblyReference Condition="'$(OS)' == 'Windows_NT'" Include="libiomp5md" />
4243
</ItemGroup>
4344

4445
<!-- TensorFlow is 64-bit only -->

test/Microsoft.ML.Tests/Transformers/NormalizerTests.cs

+102-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using System.IO;
88
using Microsoft.ML.Data;
99
using Microsoft.ML.Data.IO;
10+
using Microsoft.ML.Experimental;
1011
using Microsoft.ML.Model;
1112
using Microsoft.ML.RunTests;
1213
using Microsoft.ML.StaticPipe;
@@ -242,10 +243,10 @@ public void SimpleConstructorsAndExtensions()
242243
CheckSameValues(data1, data4);
243244
CheckSameValues(data1, data5);
244245

245-
// Tests for SupervisedBinning
246-
var est6 = new NormalizingEstimator(Env, NormalizingEstimator.NormalizationMode.SupervisedBinning, ("float4", "float4"));
247-
var est7 = new NormalizingEstimator(Env, new NormalizingEstimator.SupervisedBinningColumOptions("float4"));
248-
var est8 = ML.Transforms.Normalize(NormalizingEstimator.NormalizationMode.SupervisedBinning, ("float4", "float4"));
246+
// Tests for MeanVariance
247+
var est6 = new NormalizingEstimator(Env, NormalizingEstimator.NormalizationMode.MeanVariance, ("float4", "float4"));
248+
var est7 = new NormalizingEstimator(Env, new NormalizingEstimator.MeanVarianceColumnOptions("float4"));
249+
var est8 = ML.Transforms.Normalize(NormalizingEstimator.NormalizationMode.MeanVariance, ("float4", "float4"));
249250

250251
var data6 = est6.Fit(data).Transform(data);
251252
var data7 = est7.Fit(data).Transform(data);
@@ -255,6 +256,103 @@ public void SimpleConstructorsAndExtensions()
255256
CheckSameValues(data6, data7);
256257
CheckSameValues(data6, data8);
257258

259+
// Tests for LogMeanVariance
260+
var est9 = new NormalizingEstimator(Env, NormalizingEstimator.NormalizationMode.LogMeanVariance, ("float4", "float4"));
261+
var est10 = new NormalizingEstimator(Env, new NormalizingEstimator.LogMeanVarianceColumnOptions("float4"));
262+
var est11 = ML.Transforms.Normalize(NormalizingEstimator.NormalizationMode.LogMeanVariance, ("float4", "float4"));
263+
264+
var data9 = est9.Fit(data).Transform(data);
265+
var data10 = est10.Fit(data).Transform(data);
266+
var data11 = est11.Fit(data).Transform(data);
267+
CheckSameSchemas(data9.Schema, data10.Schema);
268+
CheckSameSchemas(data9.Schema, data11.Schema);
269+
CheckSameValues(data9, data10);
270+
CheckSameValues(data9, data11);
271+
272+
// Tests for Binning
273+
var est12 = new NormalizingEstimator(Env, NormalizingEstimator.NormalizationMode.Binning, ("float4", "float4"));
274+
var est13 = new NormalizingEstimator(Env, new NormalizingEstimator.BinningColumnOptions("float4"));
275+
var est14 = ML.Transforms.Normalize(NormalizingEstimator.NormalizationMode.Binning, ("float4", "float4"));
276+
277+
var data12 = est12.Fit(data).Transform(data);
278+
var data13 = est13.Fit(data).Transform(data);
279+
var data14 = est14.Fit(data).Transform(data);
280+
CheckSameSchemas(data12.Schema, data13.Schema);
281+
CheckSameSchemas(data12.Schema, data14.Schema);
282+
CheckSameValues(data12, data13);
283+
CheckSameValues(data12, data14);
284+
285+
// Tests for SupervisedBinning
286+
var est15 = new NormalizingEstimator(Env, NormalizingEstimator.NormalizationMode.SupervisedBinning, ("float4", "float4"));
287+
var est16 = new NormalizingEstimator(Env, new NormalizingEstimator.SupervisedBinningColumOptions("float4"));
288+
var est17 = ML.Transforms.Normalize(NormalizingEstimator.NormalizationMode.SupervisedBinning, ("float4", "float4"));
289+
290+
var data15 = est15.Fit(data).Transform(data);
291+
var data16 = est16.Fit(data).Transform(data);
292+
var data17 = est17.Fit(data).Transform(data);
293+
CheckSameSchemas(data15.Schema, data16.Schema);
294+
CheckSameSchemas(data15.Schema, data17.Schema);
295+
CheckSameValues(data15, data16);
296+
CheckSameValues(data15, data17);
297+
298+
Done();
299+
}
300+
301+
[Fact]
302+
public void NormalizerExperimentalExtensions()
303+
{
304+
string dataPath = GetDataPath(TestDatasets.iris.trainFilename);
305+
306+
var loader = new TextLoader(Env, new TextLoader.Options
307+
{
308+
Columns = new[] {
309+
new TextLoader.Column("Label", DataKind.Single, 0),
310+
new TextLoader.Column("float4", DataKind.Single, new[]{new TextLoader.Range(1, 4) }),
311+
}
312+
});
313+
314+
var data = loader.Load(dataPath);
315+
316+
// Normalizer Extensions
317+
var est1 = ML.Transforms.Normalize(NormalizingEstimator.NormalizationMode.MinMax, ("float4", "float4"));
318+
var est2 = ML.Transforms.Normalize(NormalizingEstimator.NormalizationMode.MeanVariance, ("float4", "float4"));
319+
var est3 = ML.Transforms.Normalize(NormalizingEstimator.NormalizationMode.LogMeanVariance, ("float4", "float4"));
320+
var est4 = ML.Transforms.Normalize(NormalizingEstimator.NormalizationMode.Binning, ("float4", "float4"));
321+
var est5 = ML.Transforms.Normalize(NormalizingEstimator.NormalizationMode.SupervisedBinning, ("float4", "float4"));
322+
323+
// Normalizer Extensions (Experimental)
324+
var est6 = ML.Transforms.NormalizeMinMax("float4", "float4");
325+
var est7 = ML.Transforms.NormalizeMeanVariance("float4", "float4");
326+
var est8 = ML.Transforms.NormalizeLogMeanVariance("float4", "float4");
327+
var est9 = ML.Transforms.NormalizeBinning("float4", "float4");
328+
var est10 = ML.Transforms.NormalizeSupervisedBinning("float4", "float4");
329+
330+
// Fit and Transpose
331+
var data1 = est1.Fit(data).Transform(data);
332+
var data2 = est2.Fit(data).Transform(data);
333+
var data3 = est3.Fit(data).Transform(data);
334+
var data4 = est4.Fit(data).Transform(data);
335+
var data5 = est5.Fit(data).Transform(data);
336+
var data6 = est6.Fit(data).Transform(data);
337+
var data7 = est7.Fit(data).Transform(data);
338+
var data8 = est8.Fit(data).Transform(data);
339+
var data9 = est9.Fit(data).Transform(data);
340+
var data10 = est10.Fit(data).Transform(data);
341+
342+
// Schema Checks
343+
CheckSameSchemas(data1.Schema, data6.Schema);
344+
CheckSameSchemas(data2.Schema, data7.Schema);
345+
CheckSameSchemas(data3.Schema, data8.Schema);
346+
CheckSameSchemas(data4.Schema, data9.Schema);
347+
CheckSameSchemas(data5.Schema, data10.Schema);
348+
349+
// Value Checks
350+
CheckSameValues(data1, data6);
351+
CheckSameValues(data2, data7);
352+
CheckSameValues(data3, data8);
353+
CheckSameValues(data4, data9);
354+
CheckSameValues(data5, data10);
355+
258356
Done();
259357
}
260358

0 commit comments

Comments
 (0)