Skip to content

Commit b87ae02

Browse files
zeahmedshauheen
authored andcommitted
Converted PcaTransform into Transformer using TransformerWrapper. (#1017)
1 parent d42963c commit b87ae02

File tree

7 files changed

+230
-5
lines changed

7 files changed

+230
-5
lines changed

src/Microsoft.ML.PCA/PcaTransform.cs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,34 @@ namespace Microsoft.ML.Runtime.Data
2929
/// <include file='doc.xml' path='doc/members/member[@name="PCA"]/*' />
3030
public sealed class PcaTransform : OneToOneTransformBase
3131
{
32+
internal static class Defaults
33+
{
34+
public const string WeightColumn = null;
35+
public const int Rank = 20;
36+
public const int Oversampling = 20;
37+
public const bool Center = true;
38+
public const int Seed = 0;
39+
}
40+
3241
public sealed class Arguments : TransformInputBase
3342
{
3443
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
3544
public Column[] Column;
3645

3746
[Argument(ArgumentType.Multiple, HelpText = "The name of the weight column", ShortName = "weight", Purpose = SpecialPurpose.ColumnName)]
38-
public string WeightColumn;
47+
public string WeightColumn = Defaults.WeightColumn;
3948

4049
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of components in the PCA", ShortName = "k")]
41-
public int Rank = 20;
50+
public int Rank = Defaults.Rank;
4251

4352
[Argument(ArgumentType.AtMostOnce, HelpText = "Oversampling parameter for randomized PCA training", ShortName = "over")]
44-
public int Oversampling = 20;
53+
public int Oversampling = Defaults.Oversampling;
4554

4655
[Argument(ArgumentType.AtMostOnce, HelpText = "If enabled, data is centered to be zero mean")]
47-
public bool Center = true;
56+
public bool Center = Defaults.Center;
4857

4958
[Argument(ArgumentType.AtMostOnce, HelpText = "The seed for random number generation")]
50-
public int Seed = 0;
59+
public int Seed = Defaults.Seed;
5160
}
5261

5362
public class Column : OneToOneColumn
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Core.Data;
6+
using Microsoft.ML.Runtime.Internal.Utilities;
7+
using Microsoft.ML.StaticPipe;
8+
using Microsoft.ML.StaticPipe.Runtime;
9+
using System;
10+
using System.Collections.Generic;
11+
using System.Linq;
12+
13+
namespace Microsoft.ML.Runtime.Data
14+
{
15+
/// <include file='doc.xml' path='doc/members/member[@name="PCA"]/*' />
16+
public sealed class PcaEstimator : TrainedWrapperEstimatorBase
17+
{
18+
private readonly PcaTransform.Arguments _args;
19+
20+
/// <include file='doc.xml' path='doc/members/member[@name="PCA"]/*' />
21+
/// <param name="env">The environment.</param>
22+
/// <param name="inputColumn">Input column to apply PCA on.</param>
23+
/// <param name="outputColumn">Output column. Null means <paramref name="inputColumn"/> is replaced.</param>
24+
/// <param name="rank">The number of components in the PCA.</param>
25+
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
26+
public PcaEstimator(IHostEnvironment env,
27+
string inputColumn,
28+
string outputColumn = null,
29+
int rank = PcaTransform.Defaults.Rank,
30+
Action<PcaTransform.Arguments> advancedSettings = null)
31+
: this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, rank, advancedSettings)
32+
{
33+
}
34+
35+
/// <include file='doc.xml' path='doc/members/member[@name="PCA"]/*' />
36+
/// <param name="env">The environment.</param>
37+
/// <param name="columns">Pairs of columns to run the PCA on.</param>
38+
/// <param name="rank">The number of components in the PCA.</param>
39+
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
40+
public PcaEstimator(IHostEnvironment env, (string input, string output)[] columns,
41+
int rank = PcaTransform.Defaults.Rank,
42+
Action<PcaTransform.Arguments> advancedSettings = null)
43+
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(PcaEstimator)))
44+
{
45+
foreach (var (input, output) in columns)
46+
{
47+
Host.CheckUserArg(Utils.Size(input) > 0, nameof(input));
48+
Host.CheckValue(output, nameof(input));
49+
}
50+
51+
_args = new PcaTransform.Arguments();
52+
_args.Column = columns.Select(x => new PcaTransform.Column { Source = x.input, Name = x.output }).ToArray();
53+
_args.Rank = rank;
54+
55+
advancedSettings?.Invoke(_args);
56+
}
57+
58+
public override TransformWrapper Fit(IDataView input)
59+
{
60+
return new TransformWrapper(Host, new PcaTransform(Host, _args, input));
61+
}
62+
}
63+
64+
/// <summary>
65+
/// Extensions for statically typed <see cref="PcaEstimator"/>.
66+
/// </summary>
67+
public static class PcaEstimatorExtensions
68+
{
69+
private sealed class OutPipelineColumn : Vector<float>
70+
{
71+
public readonly Vector<float> Input;
72+
73+
public OutPipelineColumn(Vector<float> input, int rank, Action<PcaTransform.Arguments> advancedSettings)
74+
: base(new Reconciler(null, rank, advancedSettings), input)
75+
{
76+
Input = input;
77+
}
78+
}
79+
80+
private sealed class Reconciler : EstimatorReconciler
81+
{
82+
private readonly int _rank;
83+
private readonly Action<PcaTransform.Arguments> _advancedSettings;
84+
85+
public Reconciler(PipelineColumn weightColumn, int rank, Action<PcaTransform.Arguments> advancedSettings)
86+
{
87+
_rank = rank;
88+
_advancedSettings = advancedSettings;
89+
}
90+
91+
public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
92+
PipelineColumn[] toOutput,
93+
IReadOnlyDictionary<PipelineColumn, string> inputNames,
94+
IReadOnlyDictionary<PipelineColumn, string> outputNames,
95+
IReadOnlyCollection<string> usedNames)
96+
{
97+
Contracts.Assert(toOutput.Length == 1);
98+
99+
var pairs = new List<(string input, string output)>();
100+
foreach (var outCol in toOutput)
101+
pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol]));
102+
103+
return new PcaEstimator(env, pairs.ToArray(), _rank, _advancedSettings);
104+
}
105+
}
106+
107+
/// <include file='doc.xml' path='doc/members/member[@name="Whitening"]/*'/>
108+
/// <param name="input">The column to apply PCA to.</param>
109+
/// <param name="rank">The number of components in the PCA.</param>
110+
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
111+
public static Vector<float> ToPrincipalComponents(this Vector<float> input,
112+
int rank = PcaTransform.Defaults.Rank,
113+
Action<PcaTransform.Arguments> advancedSettings = null) => new OutPipelineColumn(input, rank, advancedSettings);
114+
}
115+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#@ TextLoader{
2+
#@ sep=tab
3+
#@ col=pca:R4:0-4
4+
#@ }
5+
2.085487 0.09400085 2.58366132 -1.721405 -0.732070744
6+
0.9069792 0.7748574 0.6097196 1.07868779 0.453838825
7+
-0.167718172 -0.92723 -0.19140324 0.243479848 -1.060547
8+
0.548309 0.5576686 -0.587472439 -1.38610959 0.9422219
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#@ TextLoader{
2+
#@ sep=tab
3+
#@ col=pca:R4:0-4
4+
#@ }
5+
2.085487 0.09400085 2.58366132 -1.721405 -0.732070744
6+
0.9069792 0.7748574 0.6097196 1.07868779 0.453838825
7+
-0.167718172 -0.92723 -0.19140324 0.243479848 -1.060547
8+
0.548309 0.5576686 -0.587472439 -1.38610959 0.9422219

test/Microsoft.ML.StaticPipelineTesting/Microsoft.ML.StaticPipelineTesting.csproj

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
<ProjectReference Include="..\..\src\Microsoft.ML.LightGBM\Microsoft.ML.LightGBM.csproj" />
99
<ProjectReference Include="..\..\src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
1010
<ProjectReference Include="..\Microsoft.ML.TestFramework\Microsoft.ML.TestFramework.csproj" />
11+
<ProjectReference Include="..\..\src\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj" />
12+
1113
<ProjectReference Include="..\..\src\Microsoft.ML.Analyzer\Microsoft.ML.Analyzer.csproj">
1214
<ReferenceOutputAssembly>false</ReferenceOutputAssembly>
1315
<OutputItemType>Analyzer</OutputItemType>

test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,5 +645,28 @@ public void TrainTestSplit()
645645
Assert.True(testLabels.Count() > 0);
646646
Assert.False(trainLabels.Intersect(testLabels).Any());
647647
}
648+
649+
[Fact]
650+
public void PrincipalComponentAnalysis()
651+
{
652+
var env = new ConsoleEnvironment(seed: 0);
653+
var dataPath = GetDataPath("generated_regression_dataset.csv");
654+
var dataSource = new MultiFileSource(dataPath);
655+
656+
var reader = TextLoader.CreateReader(env,
657+
c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)),
658+
separator: ';', hasHeader: true);
659+
var data = reader.Read(dataSource);
660+
661+
var est = reader.MakeNewEstimator()
662+
.Append(r => (r.label,
663+
pca: r.features.ToPrincipalComponents(rank: 5)));
664+
var tdata = est.Fit(data).Transform(data);
665+
var schema = tdata.AsDynamic.Schema;
666+
667+
Assert.True(schema.TryGetColumnIndex("pca", out int pcaCol));
668+
var type = schema.GetColumnType(pcaCol);
669+
Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber);
670+
}
648671
}
649672
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime.Data;
6+
using Microsoft.ML.Runtime.Data.IO;
7+
using Microsoft.ML.Runtime.RunTests;
8+
using Microsoft.ML.Transforms;
9+
using System.IO;
10+
using Xunit;
11+
using Xunit.Abstractions;
12+
13+
namespace Microsoft.ML.Tests.Transformers
14+
{
15+
public sealed class PcaTests : TestDataPipeBase
16+
{
17+
public PcaTests(ITestOutputHelper helper)
18+
: base(helper)
19+
{
20+
}
21+
22+
[Fact]
23+
public void PcaWorkout()
24+
{
25+
var env = new ConsoleEnvironment(seed: 1, conc: 1);
26+
string dataSource = GetDataPath("generated_regression_dataset.csv");
27+
var data = TextLoader.CreateReader(env,
28+
c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)),
29+
separator: ';', hasHeader: true)
30+
.Read(new MultiFileSource(dataSource));
31+
32+
var invalidData = TextLoader.CreateReader(env,
33+
c => (label: c.LoadFloat(11), features: c.LoadText(0, 10)),
34+
separator: ';', hasHeader: true)
35+
.Read(new MultiFileSource(dataSource));
36+
37+
var est = new PcaEstimator(env, "features", "pca", rank: 5, advancedSettings: s => {
38+
s.Seed = 1;
39+
});
40+
41+
// The following call fails because of the following issue
42+
// https://github.com/dotnet/machinelearning/issues/969
43+
// TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic);
44+
45+
var outputPath = GetOutputPath("PCA", "pca.tsv");
46+
using (var ch = env.Start("save"))
47+
{
48+
var saver = new TextSaver(env, new TextSaver.Arguments { Silent = true, OutputHeader = false });
49+
IDataView savedData = TakeFilter.Create(env, est.Fit(data.AsDynamic).Transform(data.AsDynamic), 4);
50+
savedData = new ChooseColumnsTransform(env, savedData, "pca");
51+
52+
using (var fs = File.Create(outputPath))
53+
DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true);
54+
}
55+
56+
CheckEquality("PCA", "pca.tsv");
57+
Done();
58+
}
59+
}
60+
}

0 commit comments

Comments
 (0)