diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs index 64130ed80e..483ae08de4 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformWrapper.cs @@ -110,24 +110,24 @@ public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx) /// /// Estimator for trained wrapped transformers. /// - internal abstract class TrainedWrapperEstimatorBase : IEstimator + public abstract class TrainedWrapperEstimatorBase : IEstimator { - private readonly IHost _host; + protected readonly IHost Host; protected TrainedWrapperEstimatorBase(IHost host) { Contracts.CheckValue(host, nameof(host)); - _host = host; + Host = host; } public abstract TransformWrapper Fit(IDataView input); public SchemaShape GetOutputSchema(SchemaShape inputSchema) { - _host.CheckValue(inputSchema, nameof(inputSchema)); + Host.CheckValue(inputSchema, nameof(inputSchema)); - var fakeSchema = new FakeSchema(_host, inputSchema); - var transformer = Fit(new EmptyDataView(_host, fakeSchema)); + var fakeSchema = new FakeSchema(Host, inputSchema); + var transformer = Fit(new EmptyDataView(Host, fakeSchema)); return SchemaShape.Create(transformer.GetOutputSchema(fakeSchema)); } } diff --git a/src/Microsoft.ML.Transforms/GcnTransform.cs b/src/Microsoft.ML.Transforms/GcnTransform.cs index 67a6bbd951..267f680cb2 100644 --- a/src/Microsoft.ML.Transforms/GcnTransform.cs +++ b/src/Microsoft.ML.Transforms/GcnTransform.cs @@ -311,7 +311,7 @@ public LpNormNormalizerTransform(IHostEnvironment env, GcnArguments args, IDataV /// Input . This is the output from previous transform or loader. /// Name of the output column. /// Name of the column to be transformed. If this is null '' will be used. - /// /// The norm to use to normalize each sample. + /// The norm to use to normalize each sample. /// Subtract mean from each value before normalizing. public static IDataTransform CreateLpNormNormalizer(IHostEnvironment env, IDataView input, diff --git a/src/Microsoft.ML.Transforms/Microsoft.ML.Transforms.csproj b/src/Microsoft.ML.Transforms/Microsoft.ML.Transforms.csproj index 5da86f946e..11d96a6cfc 100644 --- a/src/Microsoft.ML.Transforms/Microsoft.ML.Transforms.csproj +++ b/src/Microsoft.ML.Transforms/Microsoft.ML.Transforms.csproj @@ -48,6 +48,7 @@ + diff --git a/src/Microsoft.ML.Transforms/WhiteningTransform.cs b/src/Microsoft.ML.Transforms/WhiteningTransform.cs index 5d1b3188b7..99681bfd11 100644 --- a/src/Microsoft.ML.Transforms/WhiteningTransform.cs +++ b/src/Microsoft.ML.Transforms/WhiteningTransform.cs @@ -34,16 +34,7 @@ public enum WhiteningKind Zca } - /// - /// Implements PCA (Principal Component Analysis) and ZCA (Zero phase Component Analysis) whitening. - /// The whitening process consists of 2 steps: - /// 1. Decorrelation of the input data. Input data is assumed to have zero mean. - /// 2. Rescale decorrelated features to have unit variance. - /// That is, PCA whitening is essentially just a PCA + rescale. - /// ZCA whitening tries to make resulting data to look more like input data by rotating it back to the - /// original input space. - /// More information: http://ufldl.stanford.edu/wiki/index.php/Whitening - /// + /// public sealed class WhiteningTransform : OneToOneTransformBase { private static class Defaults diff --git a/src/Microsoft.ML.Transforms/WrappedGcnTransformers.cs b/src/Microsoft.ML.Transforms/WrappedGcnTransformers.cs new file mode 100644 index 0000000000..1eaf5afd1d --- /dev/null +++ b/src/Microsoft.ML.Transforms/WrappedGcnTransformers.cs @@ -0,0 +1,221 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using System.Collections.Generic; +using System.Linq; +using static Microsoft.ML.Runtime.Data.LpNormNormalizerTransform; + +namespace Microsoft.ML.Transforms +{ + /// + public sealed class LpNormalizer : TrivialWrapperEstimator + { + /// + /// The environment. + /// The column containing text to tokenize. + /// The column containing output tokens. Null means is replaced. + /// Type of norm to use to normalize each sample. + /// Subtract mean from each value before normalizing. + public LpNormalizer(IHostEnvironment env, string inputColumn, string outputColumn = null, NormalizerKind normKind = NormalizerKind.L2Norm, bool subMean = false) + : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, normKind, subMean) + { + } + + /// + /// The environment. + /// Pairs of columns to run the tokenization on. + /// Type of norm to use to normalize each sample. + /// Subtract mean from each value before normalizing. + public LpNormalizer(IHostEnvironment env, (string input, string output)[] columns, NormalizerKind normKind = NormalizerKind.L2Norm, bool subMean = false) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LpNormalizer)), MakeTransformer(env, columns, normKind, subMean)) + { + } + + private static TransformWrapper MakeTransformer(IHostEnvironment env, (string input, string output)[] columns, NormalizerKind normKind, bool subMean) + { + Contracts.AssertValue(env); + env.CheckNonEmpty(columns, nameof(columns)); + foreach (var (input, output) in columns) + { + env.CheckValue(input, nameof(input)); + env.CheckValue(output, nameof(input)); + } + + var args = new LpNormNormalizerTransform.Arguments + { + Column = columns.Select(x => new LpNormNormalizerTransform.Column { Source = x.input, Name = x.output }).ToArray(), + SubMean = subMean, + NormKind = normKind + }; + + // Create a valid instance of data. + var schema = new SimpleSchema(env, columns.Select(x => new KeyValuePair(x.input, new VectorType(NumberType.R4))).ToArray()); + var emptyData = new EmptyDataView(env, schema); + + return new TransformWrapper(env, new LpNormNormalizerTransform(env, args, emptyData)); + } + } + + /// + public sealed class GlobalContrastNormalizer : TrivialWrapperEstimator + { + /// + /// The environment. + /// The column containing text to tokenize. + /// The column containing output tokens. Null means is replaced. + /// Subtract mean from each value before normalizing. + /// Normalize by standard deviation rather than L2 norm. + /// Scale features by this value. + public GlobalContrastNormalizer(IHostEnvironment env, string inputColumn, string outputColumn = null, bool subMean = true, bool useStdDev = false, float scale = 1) + : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, subMean, useStdDev , scale) + { + } + + /// + /// The environment. + /// Pairs of columns to run the tokenization on. + /// Subtract mean from each value before normalizing. + /// Normalize by standard deviation rather than L2 norm. + /// Scale features by this value. + public GlobalContrastNormalizer(IHostEnvironment env, (string input, string output)[] columns, bool subMean = true, bool useStdDev = false, float scale = 1) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(GlobalContrastNormalizer)), MakeTransformer(env, columns, subMean, useStdDev, scale)) + { + } + + private static TransformWrapper MakeTransformer(IHostEnvironment env, (string input, string output)[] columns, bool subMean, bool useStdDev, float scale) + { + Contracts.AssertValue(env); + env.CheckNonEmpty(columns, nameof(columns)); + foreach (var (input, output) in columns) + { + env.CheckValue(input, nameof(input)); + env.CheckValue(output, nameof(input)); + } + + var args = new LpNormNormalizerTransform.GcnArguments + { + Column = columns.Select(x => new LpNormNormalizerTransform.GcnColumn { Source = x.input, Name = x.output }).ToArray(), + SubMean = subMean, + UseStdDev = useStdDev, + Scale = scale + }; + + // Create a valid instance of data. + var schema = new SimpleSchema(env, columns.Select(x => new KeyValuePair(x.input, new VectorType(NumberType.R4))).ToArray()); + var emptyData = new EmptyDataView(env, schema); + + return new TransformWrapper(env, new LpNormNormalizerTransform(env, args, emptyData)); + } + } + + /// + /// Extensions for statically typed LpNormalizer estimator. + /// + public static class LpNormNormalizerExtensions + { + private sealed class OutPipelineColumn : Vector + { + public readonly Vector Input; + + public OutPipelineColumn(Vector input, NormalizerKind normKind, bool subMean) + : base(new Reconciler(normKind, subMean), input) + { + Input = input; + } + } + + private sealed class Reconciler : EstimatorReconciler + { + private readonly NormalizerKind _normKind; + private readonly bool _subMean; + + public Reconciler(NormalizerKind normKind, bool subMean) + { + _normKind = normKind; + _subMean = subMean; + } + + public override IEstimator Reconcile(IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) + { + Contracts.Assert(toOutput.Length == 1); + + var pairs = new List<(string input, string output)>(); + foreach (var outCol in toOutput) + pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol])); + + return new LpNormalizer(env, pairs.ToArray(), _normKind, _subMean); + } + } + + /// + /// The column to apply to. + /// Type of norm to use to normalize each sample. + /// Subtract mean from each value before normalizing. + public static Vector LpNormalize(this Vector input, NormalizerKind normKind = NormalizerKind.L2Norm, bool subMean = false) => new OutPipelineColumn(input, normKind, subMean); + } + + /// + /// Extensions for statically typed GcNormalizer estimator. + /// + public static class GcNormalizerExtensions + { + private sealed class OutPipelineColumn : Vector + { + public readonly Vector Input; + + public OutPipelineColumn(Vector input, bool subMean, bool useStdDev, float scale) + : base(new Reconciler(subMean, useStdDev, scale), input) + { + Input = input; + } + } + + private sealed class Reconciler : EstimatorReconciler + { + private readonly bool _subMean; + private readonly bool _useStdDev; + private readonly float _scale; + + public Reconciler(bool subMean, bool useStdDev, float scale) + { + _subMean = subMean; + _useStdDev = useStdDev; + _scale = scale; + } + + public override IEstimator Reconcile(IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) + { + Contracts.Assert(toOutput.Length == 1); + + var pairs = new List<(string input, string output)>(); + foreach (var outCol in toOutput) + pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol])); + + return new GlobalContrastNormalizer(env, pairs.ToArray(), _subMean, _useStdDev, _scale); + } + } + + /// + /// The column to apply to. + /// Subtract mean from each value before normalizing. + /// Normalize by standard deviation rather than L2 norm. + /// Scale features by this value. + public static Vector GlobalContrastNormalize(this Vector input, + bool subMean = true, + bool useStdDev = false, + float scale = 1) => new OutPipelineColumn(input, subMean, useStdDev, scale); + } +} diff --git a/src/Microsoft.ML.Transforms/WrappedWhiteningTransformer.cs b/src/Microsoft.ML.Transforms/WrappedWhiteningTransformer.cs new file mode 100644 index 0000000000..9ef15276d2 --- /dev/null +++ b/src/Microsoft.ML.Transforms/WrappedWhiteningTransformer.cs @@ -0,0 +1,165 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Internal.Utilities; +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.ML.Transforms +{ + /// + public sealed class Whitening : TrainedWrapperEstimatorBase + { + private readonly (string input, string output)[] _columns; + private readonly WhiteningKind _kind; + private readonly float _eps; + private readonly int _maxRows; + private readonly bool _saveInverse; + private readonly int _pcaNum; + + /// + /// The environment. + /// The column containing text to tokenize. + /// The column containing output tokens. Null means is replaced. + /// Whitening kind (PCA/ZCA). + /// Scaling regularizer. + /// Max number of rows. + /// Whether to save inverse (recovery) matrix. + /// PCA components to retain. + public Whitening(IHostEnvironment env, + string inputColumn, + string outputColumn = null, + WhiteningKind kind = WhiteningKind.Zca, + float eps = (float)1e-5, + int maxRows = 100 * 1000, + bool saveInverse = false, + int pcaNum = 0) + : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, kind, eps, maxRows, saveInverse, pcaNum) + { + } + + /// + /// The environment. + /// Pairs of columns to run the tokenization on. + /// Whitening kind (PCA/ZCA). + /// Scaling regularizer. + /// Max number of rows. + /// Whether to save inverse (recovery) matrix. + /// PCA components to retain. + public Whitening(IHostEnvironment env, (string input, string output)[] columns, + WhiteningKind kind = WhiteningKind.Zca, + float eps = (float)1e-5, + int maxRows = 100 * 1000, + bool saveInverse = false, + int pcaNum = 0) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(LpNormalizer))) + { + foreach (var (input, output) in columns) + { + Host.CheckUserArg(Utils.Size(input) > 0, nameof(input)); + Host.CheckValue(output, nameof(input)); + } + + _columns = columns; + _kind = kind; + _eps = eps; + _maxRows = maxRows; + _saveInverse = saveInverse; + _pcaNum = pcaNum; + } + + public override TransformWrapper Fit(IDataView input) + { + var args = new WhiteningTransform.Arguments + { + Column = _columns.Select(x => new WhiteningTransform.Column { Source = x.input, Name = x.output }).ToArray(), + Kind = _kind, + Eps = _eps, + MaxRows = _maxRows, + SaveInverse = _saveInverse, + PcaNum = _pcaNum + }; + + return new TransformWrapper(Host, new WhiteningTransform(Host, args, input)); + } + } + + /// + /// Extensions for statically typed Whitening estimator. + /// + public static class WhiteningExtensions + { + private sealed class OutPipelineColumn : Vector + { + public readonly Vector Input; + + public OutPipelineColumn(Vector input, WhiteningKind kind, float eps, int maxRows, bool saveInverse, int pcaNum) + : base(new Reconciler(kind, eps, maxRows, saveInverse, pcaNum), input) + { + Input = input; + } + } + + private sealed class Reconciler : EstimatorReconciler + { + private readonly WhiteningKind _kind; + private readonly float _eps; + private readonly int _maxRows; + private readonly bool _saveInverse; + private readonly int _pcaNum; + + public Reconciler(WhiteningKind kind, float eps, int maxRows, bool saveInverse, int pcaNum) + { + _kind = kind; + _eps = eps; + _maxRows = maxRows; + _saveInverse = saveInverse; + _pcaNum = pcaNum; + } + + public override IEstimator Reconcile(IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) + { + Contracts.Assert(toOutput.Length == 1); + + var pairs = new List<(string input, string output)>(); + foreach (var outCol in toOutput) + pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol])); + + return new Whitening(env, pairs.ToArray(), _kind, _eps, _maxRows, _saveInverse, _pcaNum); + } + } + + /// + /// The column to apply to. + /// Scaling regularizer. + /// Max number of rows. + /// Whether to save inverse (recovery) matrix. + /// PCA components to retain. + public static Vector PcaWhitening(this Vector input, + float eps = (float)1e-5, + int maxRows = 100 * 1000, + bool saveInverse = false, + int pcaNum = 0) => new OutPipelineColumn(input, WhiteningKind.Pca, eps, maxRows, saveInverse, pcaNum); + + /// + /// The column to apply to. + /// Scaling regularizer. + /// Max number of rows. + /// Whether to save inverse (recovery) matrix. + /// PCA components to retain. + public static Vector ZcaWhitening(this Vector input, + float eps = (float)1e-5, + int maxRows = 100 * 1000, + bool saveInverse = false, + int pcaNum = 0) => new OutPipelineColumn(input, WhiteningKind.Zca, eps, maxRows, saveInverse, pcaNum); + } +} diff --git a/src/Microsoft.ML.Transforms/doc.xml b/src/Microsoft.ML.Transforms/doc.xml index 92788e52c0..79cc10f135 100644 --- a/src/Microsoft.ML.Transforms/doc.xml +++ b/src/Microsoft.ML.Transforms/doc.xml @@ -217,7 +217,7 @@ Scaling inputs to unit norms is a common operation for text classification or clustering. For more information see: - + pipeline.Add(new LpNormalizer("FeatureCol") @@ -341,6 +341,19 @@ + + + + Implements PCA (Principal Component Analysis) and ZCA (Zero phase Component Analysis) whitening. + The whitening process consists of 2 steps: + 1. Decorrelation of the input data. Input data is assumed to have zero mean. + 2. Rescale decorrelated features to have unit variance. + That is, PCA whitening is essentially just a PCA + rescale. + ZCA whitening tries to make resulting data to look more like input data by rotating it back to the + original input space. + More information: http://ufldl.stanford.edu/wiki/index.php/Whitening + + diff --git a/test/BaselineOutput/SingleDebug/Text/lpnorm_gcnorm_whitened.tsv b/test/BaselineOutput/SingleDebug/Text/lpnorm_gcnorm_whitened.tsv new file mode 100644 index 0000000000..fad6d5db4d --- /dev/null +++ b/test/BaselineOutput/SingleDebug/Text/lpnorm_gcnorm_whitened.tsv @@ -0,0 +1,10 @@ +#@ TextLoader{ +#@ sep=tab +#@ col=lpnorm:R4:0-10 +#@ col=gcnorm:R4:11-21 +#@ col=whitened:R4:22-32 +#@ } +-0.686319232 0.192169383 -0.152238086 0.03493989 0.346903175 0.09483684 -0.132272437 -0.124785319 -0.5315855 -0.0973325446 0.114802495 -0.626524031 0.289601743 -0.0695612058 0.125636056 0.4509648 0.188099176 -0.0487401523 -0.04093227 -0.465160966 -0.0123033375 0.208920211 -2.604605 0.829638362 -0.5992434 0.19860521 1.33247662 0.369197041 -0.5760094 -0.5490271 -1.94509208 -0.393351972 0.507488966 +-0.20306389 -0.1231699 -0.039946992 0.183090389 -0.3328916 0.279628932 -0.0066578323 0.432759076 -0.0798939839 -0.1664458 -0.7057302 -0.137441739 -0.055349838 0.0301625486 0.259335726 -0.270841062 0.3585301 0.0643675 0.5158729 -0.0108833946 -0.09981628 -0.653936446 -0.5923902 -0.324390084 -0.114805378 0.6855182 -1.055579 0.8767955 -0.0392023772 1.21807373 -0.160801888 -0.47570774 -2.22817 +-0.268398017 -0.28734377 0.571529865 0.006315247 -0.246294647 -0.445224941 -0.344181 -0.20524554 0.284186125 -0.116832078 -0.06946772 -0.176903129 -0.19703348 0.715542555 0.114987023 -0.153417692 -0.3647864 -0.257424533 -0.109801926 0.410232216 -0.0158602837 0.0344656035 -0.9132714 -0.911281645 1.814283 0.07471426 -0.8969923 -1.44387519 -1.19571114 -0.6542767 0.887983143 -0.4604767 -0.17543222 +0.117021732 0.438831449 -0.100304335 0.125380427 -0.413755417 0.0794076 0.133739114 -0.397038 -0.497342378 -0.2632989 0.313451052 0.160775661 0.485780418 -0.0587080531 0.169217348 -0.3752711 0.122788094 0.17765902 -0.358387738 -0.459687948 -0.223320842 0.3591552 0.236966148 1.004758 -0.233154371 0.3862052 -1.02724624 0.240614042 0.299898773 -1.03102541 -1.13852251 -0.6675951 0.766793966 diff --git a/test/BaselineOutput/SingleRelease/Text/lpnorm_gcnorm_whitened.tsv b/test/BaselineOutput/SingleRelease/Text/lpnorm_gcnorm_whitened.tsv new file mode 100644 index 0000000000..fad6d5db4d --- /dev/null +++ b/test/BaselineOutput/SingleRelease/Text/lpnorm_gcnorm_whitened.tsv @@ -0,0 +1,10 @@ +#@ TextLoader{ +#@ sep=tab +#@ col=lpnorm:R4:0-10 +#@ col=gcnorm:R4:11-21 +#@ col=whitened:R4:22-32 +#@ } +-0.686319232 0.192169383 -0.152238086 0.03493989 0.346903175 0.09483684 -0.132272437 -0.124785319 -0.5315855 -0.0973325446 0.114802495 -0.626524031 0.289601743 -0.0695612058 0.125636056 0.4509648 0.188099176 -0.0487401523 -0.04093227 -0.465160966 -0.0123033375 0.208920211 -2.604605 0.829638362 -0.5992434 0.19860521 1.33247662 0.369197041 -0.5760094 -0.5490271 -1.94509208 -0.393351972 0.507488966 +-0.20306389 -0.1231699 -0.039946992 0.183090389 -0.3328916 0.279628932 -0.0066578323 0.432759076 -0.0798939839 -0.1664458 -0.7057302 -0.137441739 -0.055349838 0.0301625486 0.259335726 -0.270841062 0.3585301 0.0643675 0.5158729 -0.0108833946 -0.09981628 -0.653936446 -0.5923902 -0.324390084 -0.114805378 0.6855182 -1.055579 0.8767955 -0.0392023772 1.21807373 -0.160801888 -0.47570774 -2.22817 +-0.268398017 -0.28734377 0.571529865 0.006315247 -0.246294647 -0.445224941 -0.344181 -0.20524554 0.284186125 -0.116832078 -0.06946772 -0.176903129 -0.19703348 0.715542555 0.114987023 -0.153417692 -0.3647864 -0.257424533 -0.109801926 0.410232216 -0.0158602837 0.0344656035 -0.9132714 -0.911281645 1.814283 0.07471426 -0.8969923 -1.44387519 -1.19571114 -0.6542767 0.887983143 -0.4604767 -0.17543222 +0.117021732 0.438831449 -0.100304335 0.125380427 -0.413755417 0.0794076 0.133739114 -0.397038 -0.497342378 -0.2632989 0.313451052 0.160775661 0.485780418 -0.0587080531 0.169217348 -0.3752711 0.122788094 0.17765902 -0.358387738 -0.459687948 -0.223320842 0.3591552 0.236966148 1.004758 -0.233154371 0.3862052 -1.02724624 0.240614042 0.299898773 -1.03102541 -1.13852251 -0.6675951 0.766793966 diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index 8d7fb18cd4..aad42add36 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.TestFramework; +using Microsoft.ML.Transforms; using Microsoft.ML.Transforms.Text; using System; using System.Collections.Generic; @@ -431,5 +432,44 @@ public void Tokenize() Assert.True(type.IsVector && !type.IsKnownSizeVector && type.ItemType.IsKey); Assert.True(type.ItemType.AsKey.RawKind == DataKind.U2); } + + + [Fact] + public void LpGcNormAndWhitening() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath("generated_regression_dataset.csv"); + var dataSource = new MultiFileSource(dataPath); + + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)), + separator: ';', hasHeader: true); + var data = reader.Read(dataSource); + + var est = reader.MakeNewEstimator() + .Append(r => (r.label, + lpnorm: r.features.LpNormalize(), + gcnorm: r.features.GlobalContrastNormalize(), + zcawhitened: r.features.ZcaWhitening(), + pcswhitened: r.features.PcaWhitening())); + var tdata = est.Fit(data).Transform(data); + var schema = tdata.AsDynamic.Schema; + + Assert.True(schema.TryGetColumnIndex("lpnorm", out int lpnormCol)); + var type = schema.GetColumnType(lpnormCol); + Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); + + Assert.True(schema.TryGetColumnIndex("gcnorm", out int gcnormCol)); + type = schema.GetColumnType(gcnormCol); + Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); + + Assert.True(schema.TryGetColumnIndex("zcawhitened", out int zcawhitenedCol)); + type = schema.GetColumnType(zcawhitenedCol); + Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); + + Assert.True(schema.TryGetColumnIndex("pcswhitened", out int pcswhitenedCol)); + type = schema.GetColumnType(pcswhitenedCol); + Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); + } } } diff --git a/test/Microsoft.ML.Tests/Transformers/NormalizerTests.cs b/test/Microsoft.ML.Tests/Transformers/NormalizerTests.cs index eefae183e2..8f6cd18428 100644 --- a/test/Microsoft.ML.Tests/Transformers/NormalizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/NormalizerTests.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.Transforms; using System.IO; using Xunit; using Xunit.Abstractions; @@ -103,5 +104,40 @@ public void SimpleConstructors() Done(); } + + [Fact] + public void LpGcNormAndWhiteningWorkout() + { + var env = new ConsoleEnvironment(seed: 0); + string dataSource = GetDataPath("generated_regression_dataset.csv"); + var data = TextLoader.CreateReader(env, + c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)), + separator: ';', hasHeader: true) + .Read(new MultiFileSource(dataSource)); + + var invalidData = TextLoader.CreateReader(env, + c => (label: c.LoadFloat(11), features: c.LoadText(0, 10)), + separator: ';', hasHeader: true) + .Read(new MultiFileSource(dataSource)); + + var est = new LpNormalizer(env, "features", "lpnorm") + .Append(new GlobalContrastNormalizer(env, "features", "gcnorm")) + .Append(new Whitening(env, "features", "whitened")); + TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); + + var outputPath = GetOutputPath("Text", "lpnorm_gcnorm_whitened.tsv"); + using (var ch = Env.Start("save")) + { + var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true, OutputHeader = false }); + IDataView savedData = TakeFilter.Create(Env, est.Fit(data.AsDynamic).Transform(data.AsDynamic), 4); + savedData = new ChooseColumnsTransform(Env, savedData, "lpnorm", "gcnorm", "whitened"); + + using (var fs = File.Create(outputPath)) + DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); + } + + CheckEquality("Text", "lpnorm_gcnorm_whitened.tsv"); + Done(); + } } }