Skip to content

Commit 1b18db5

Browse files
author
Shahab Moradi
committed
Addressed PR comments dotnet#2
1 parent db8d690 commit 1b18db5

File tree

2 files changed

+32
-30
lines changed

2 files changed

+32
-30
lines changed

src/Microsoft.ML.PCA/PcaTransform.cs

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
using Microsoft.ML.Runtime.Numeric;
1818
using Microsoft.ML.StaticPipe;
1919
using Microsoft.ML.StaticPipe.Runtime;
20+
using Microsoft.ML.Transforms;
2021

2122
[assembly: LoadableClass(PcaTransform.Summary, typeof(IDataTransform), typeof(PcaTransform), typeof(PcaTransform.Arguments), typeof(SignatureDataTransform),
2223
PcaTransform.UserName, PcaTransform.LoaderSignature, PcaTransform.ShortName)]
@@ -32,39 +33,30 @@
3233

3334
[assembly: LoadableClass(typeof(void), typeof(PcaTransform), null, typeof(SignatureEntryPointModule), PcaTransform.LoaderSignature)]
3435

35-
namespace Microsoft.ML.Runtime.Data
36+
namespace Microsoft.ML.Transforms
3637
{
3738
/// <include file='doc.xml' path='doc/members/member[@name="PCA"]/*' />
3839
public sealed class PcaTransform : OneToOneTransformerBase
3940
{
40-
internal static class Defaults
41-
{
42-
public const string WeightColumn = null;
43-
public const int Rank = 20;
44-
public const int Oversampling = 20;
45-
public const bool Center = true;
46-
public const int Seed = 0;
47-
}
48-
4941
public sealed class Arguments : TransformInputBase
5042
{
5143
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
5244
public Column[] Column;
5345

5446
[Argument(ArgumentType.Multiple, HelpText = "The name of the weight column", ShortName = "weight", Purpose = SpecialPurpose.ColumnName)]
55-
public string WeightColumn = Defaults.WeightColumn;
47+
public string WeightColumn = PcaEstimator.Defaults.WeightColumn;
5648

5749
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of components in the PCA", ShortName = "k")]
58-
public int Rank = Defaults.Rank;
50+
public int Rank = PcaEstimator.Defaults.Rank;
5951

6052
[Argument(ArgumentType.AtMostOnce, HelpText = "Oversampling parameter for randomized PCA training", ShortName = "over")]
61-
public int Oversampling = Defaults.Oversampling;
53+
public int Oversampling = PcaEstimator.Defaults.Oversampling;
6254

6355
[Argument(ArgumentType.AtMostOnce, HelpText = "If enabled, data is centered to be zero mean")]
64-
public bool Center = Defaults.Center;
56+
public bool Center = PcaEstimator.Defaults.Center;
6557

6658
[Argument(ArgumentType.AtMostOnce, HelpText = "The seed for random number generation")]
67-
public int Seed = Defaults.Seed;
59+
public int Seed = PcaEstimator.Defaults.Seed;
6860
}
6961

7062
public class Column : OneToOneColumn
@@ -121,10 +113,10 @@ public sealed class ColumnInfo
121113
/// </summary>
122114
public ColumnInfo(string input,
123115
string output,
124-
string weightColumn = Defaults.WeightColumn,
125-
int rank = Defaults.Rank,
126-
int overSampling = Defaults.Oversampling,
127-
bool center = Defaults.Center,
116+
string weightColumn = PcaEstimator.Defaults.WeightColumn,
117+
int rank = PcaEstimator.Defaults.Rank,
118+
int overSampling = PcaEstimator.Defaults.Oversampling,
119+
bool center = PcaEstimator.Defaults.Center,
128120
int? seed = null)
129121
{
130122
Input = input;
@@ -134,6 +126,7 @@ public ColumnInfo(string input,
134126
Oversampling = overSampling;
135127
Center = center;
136128
Seed = seed;
129+
Contracts.CheckUserArg(Oversampling >= 0, nameof(Oversampling), "Oversampling must be non-negative.");
137130
}
138131

139132
// The following functions and properties are all internal and used for simplifying the
@@ -312,7 +305,6 @@ public PcaTransform(IHostEnvironment env, IDataView input, ColumnInfo[] columns)
312305
var col = columns[i];
313306
col.SetSchema(input.Schema);
314307
ValidatePcaInput(Host, col.Input, col.InputType);
315-
Host.CheckUserArg(col.Oversampling >= 0, nameof(col.Oversampling), "Oversampling must be non-negative");
316308
_transformInfos[i] = new TransformInfo(col.Rank, col.InputType.ValueCount);
317309
}
318310

@@ -614,8 +606,8 @@ internal static void ValidatePcaInput(IHost host, string name, ColumnType type)
614606
throw host.Except($"Pca transform can only be applied to vector columns. Column ${name} is of size ${type.VectorSize}");
615607

616608
var itemType = type.ItemType;
617-
if (!itemType.IsNumber)
618-
throw host.Except($"Pca transform can only be applied to vector of numeric items. Column ${name} contains type ${itemType}");
609+
if (itemType.RawKind != DataKind.R4)
610+
throw host.Except($"Pca transform can only be applied to vector of float items. Column ${name} contains type ${itemType}");
619611
}
620612

621613
private sealed class Mapper : MapperBase
@@ -707,6 +699,15 @@ public static CommonOutputs.TransformOutput Calculate(IHostEnvironment env, Argu
707699

708700
public sealed class PcaEstimator : IEstimator<PcaTransform>
709701
{
702+
internal static class Defaults
703+
{
704+
public const string WeightColumn = null;
705+
public const int Rank = 20;
706+
public const int Oversampling = 20;
707+
public const bool Center = true;
708+
public const int Seed = 0;
709+
}
710+
710711
private readonly IHost _host;
711712
private readonly PcaTransform.ColumnInfo[] _columns;
712713

@@ -721,8 +722,8 @@ public sealed class PcaEstimator : IEstimator<PcaTransform>
721722
/// <param name="center">If enabled, data is centered to be zero mean.</param>
722723
/// <param name="seed">The seed for random number generation</param>
723724
public PcaEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null,
724-
string weightColumn = PcaTransform.Defaults.WeightColumn, int rank = PcaTransform.Defaults.Rank,
725-
int overSampling = PcaTransform.Defaults.Oversampling, bool center = PcaTransform.Defaults.Center,
725+
string weightColumn = Defaults.WeightColumn, int rank = Defaults.Rank,
726+
int overSampling = Defaults.Oversampling, bool center = Defaults.Center,
726727
int? seed = null)
727728
: this(env, new PcaTransform.ColumnInfo(inputColumn, outputColumn ?? inputColumn, weightColumn, rank, overSampling, center, seed))
728729
{
@@ -746,7 +747,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
746747
if (!inputSchema.TryFindColumn(colInfo.Input, out var col))
747748
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
748749

749-
if (!(col.Kind == SchemaShape.Column.VectorKind.Vector && col.ItemType.IsNumber))
750+
if (col.Kind != SchemaShape.Column.VectorKind.Vector || col.ItemType.RawKind != DataKind.R4)
750751
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input);
751752

752753
result[colInfo.Output] = new SchemaShape.Column(colInfo.Output,
@@ -808,10 +809,10 @@ public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
808809
/// <param name="seed">The seed for random number generation</param>
809810
/// <returns>Vector containing the principal components.</returns>
810811
public static Vector<float> ToPrincipalComponents(this Vector<float> input,
811-
string weightColumn = PcaTransform.Defaults.WeightColumn,
812-
int rank = PcaTransform.Defaults.Rank,
813-
int overSampling = PcaTransform.Defaults.Oversampling,
814-
bool center = PcaTransform.Defaults.Center,
812+
string weightColumn = PcaEstimator.Defaults.WeightColumn,
813+
int rank = PcaEstimator.Defaults.Rank,
814+
int overSampling = PcaEstimator.Defaults.Oversampling,
815+
bool center = PcaEstimator.Defaults.Center,
815816
int? seed = null) => new OutPipelineColumn(input, weightColumn, rank, overSampling, center, seed);
816817
}
817818
}

test/Microsoft.ML.Tests/Transformers/PcaTests.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System.IO;
56
using Microsoft.ML.Runtime.Data;
67
using Microsoft.ML.Runtime.Data.IO;
78
using Microsoft.ML.Runtime.RunTests;
8-
using System.IO;
9+
using Microsoft.ML.Transforms;
910
using Xunit;
1011
using Xunit.Abstractions;
1112

0 commit comments

Comments
 (0)