17
17
using Microsoft . ML . Runtime . Numeric ;
18
18
using Microsoft . ML . StaticPipe ;
19
19
using Microsoft . ML . StaticPipe . Runtime ;
20
+ using Microsoft . ML . Transforms ;
20
21
21
22
[ assembly: LoadableClass ( PcaTransform . Summary , typeof ( IDataTransform ) , typeof ( PcaTransform ) , typeof ( PcaTransform . Arguments ) , typeof ( SignatureDataTransform ) ,
22
23
PcaTransform . UserName , PcaTransform . LoaderSignature , PcaTransform . ShortName ) ]
32
33
33
34
[ assembly: LoadableClass ( typeof ( void ) , typeof ( PcaTransform ) , null , typeof ( SignatureEntryPointModule ) , PcaTransform . LoaderSignature ) ]
34
35
35
- namespace Microsoft . ML . Runtime . Data
36
+ namespace Microsoft . ML . Transforms
36
37
{
37
38
/// <include file='doc.xml' path='doc/members/member[@name="PCA"]/*' />
38
39
public sealed class PcaTransform : OneToOneTransformerBase
39
40
{
40
- internal static class Defaults
41
- {
42
- public const string WeightColumn = null ;
43
- public const int Rank = 20 ;
44
- public const int Oversampling = 20 ;
45
- public const bool Center = true ;
46
- public const int Seed = 0 ;
47
- }
48
-
49
41
public sealed class Arguments : TransformInputBase
50
42
{
51
43
[ Argument ( ArgumentType . Multiple | ArgumentType . Required , HelpText = "New column definition(s) (optional form: name:src)" , ShortName = "col" , SortOrder = 1 ) ]
52
44
public Column [ ] Column ;
53
45
54
46
[ Argument ( ArgumentType . Multiple , HelpText = "The name of the weight column" , ShortName = "weight" , Purpose = SpecialPurpose . ColumnName ) ]
55
- public string WeightColumn = Defaults . WeightColumn ;
47
+ public string WeightColumn = PcaEstimator . Defaults . WeightColumn ;
56
48
57
49
[ Argument ( ArgumentType . AtMostOnce , HelpText = "The number of components in the PCA" , ShortName = "k" ) ]
58
- public int Rank = Defaults . Rank ;
50
+ public int Rank = PcaEstimator . Defaults . Rank ;
59
51
60
52
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Oversampling parameter for randomized PCA training" , ShortName = "over" ) ]
61
- public int Oversampling = Defaults . Oversampling ;
53
+ public int Oversampling = PcaEstimator . Defaults . Oversampling ;
62
54
63
55
[ Argument ( ArgumentType . AtMostOnce , HelpText = "If enabled, data is centered to be zero mean" ) ]
64
- public bool Center = Defaults . Center ;
56
+ public bool Center = PcaEstimator . Defaults . Center ;
65
57
66
58
[ Argument ( ArgumentType . AtMostOnce , HelpText = "The seed for random number generation" ) ]
67
- public int Seed = Defaults . Seed ;
59
+ public int Seed = PcaEstimator . Defaults . Seed ;
68
60
}
69
61
70
62
public class Column : OneToOneColumn
@@ -121,10 +113,10 @@ public sealed class ColumnInfo
121
113
/// </summary>
122
114
public ColumnInfo ( string input ,
123
115
string output ,
124
- string weightColumn = Defaults . WeightColumn ,
125
- int rank = Defaults . Rank ,
126
- int overSampling = Defaults . Oversampling ,
127
- bool center = Defaults . Center ,
116
+ string weightColumn = PcaEstimator . Defaults . WeightColumn ,
117
+ int rank = PcaEstimator . Defaults . Rank ,
118
+ int overSampling = PcaEstimator . Defaults . Oversampling ,
119
+ bool center = PcaEstimator . Defaults . Center ,
128
120
int ? seed = null )
129
121
{
130
122
Input = input ;
@@ -134,6 +126,7 @@ public ColumnInfo(string input,
134
126
Oversampling = overSampling ;
135
127
Center = center ;
136
128
Seed = seed ;
129
+ Contracts . CheckUserArg ( Oversampling >= 0 , nameof ( Oversampling ) , "Oversampling must be non-negative." ) ;
137
130
}
138
131
139
132
// The following functions and properties are all internal and used for simplifying the
@@ -312,7 +305,6 @@ public PcaTransform(IHostEnvironment env, IDataView input, ColumnInfo[] columns)
312
305
var col = columns [ i ] ;
313
306
col . SetSchema ( input . Schema ) ;
314
307
ValidatePcaInput ( Host , col . Input , col . InputType ) ;
315
- Host . CheckUserArg ( col . Oversampling >= 0 , nameof ( col . Oversampling ) , "Oversampling must be non-negative" ) ;
316
308
_transformInfos [ i ] = new TransformInfo ( col . Rank , col . InputType . ValueCount ) ;
317
309
}
318
310
@@ -614,8 +606,8 @@ internal static void ValidatePcaInput(IHost host, string name, ColumnType type)
614
606
throw host . Except ( $ "Pca transform can only be applied to vector columns. Column ${ name } is of size ${ type . VectorSize } ") ;
615
607
616
608
var itemType = type . ItemType ;
617
- if ( ! itemType . IsNumber )
618
- throw host . Except ( $ "Pca transform can only be applied to vector of numeric items. Column ${ name } contains type ${ itemType } ") ;
609
+ if ( itemType . RawKind != DataKind . R4 )
610
+ throw host . Except ( $ "Pca transform can only be applied to vector of float items. Column ${ name } contains type ${ itemType } ") ;
619
611
}
620
612
621
613
private sealed class Mapper : MapperBase
@@ -707,6 +699,15 @@ public static CommonOutputs.TransformOutput Calculate(IHostEnvironment env, Argu
707
699
708
700
public sealed class PcaEstimator : IEstimator < PcaTransform >
709
701
{
702
+ internal static class Defaults
703
+ {
704
+ public const string WeightColumn = null ;
705
+ public const int Rank = 20 ;
706
+ public const int Oversampling = 20 ;
707
+ public const bool Center = true ;
708
+ public const int Seed = 0 ;
709
+ }
710
+
710
711
private readonly IHost _host ;
711
712
private readonly PcaTransform . ColumnInfo [ ] _columns ;
712
713
@@ -721,8 +722,8 @@ public sealed class PcaEstimator : IEstimator<PcaTransform>
721
722
/// <param name="center">If enabled, data is centered to be zero mean.</param>
722
723
/// <param name="seed">The seed for random number generation</param>
723
724
public PcaEstimator ( IHostEnvironment env , string inputColumn , string outputColumn = null ,
724
- string weightColumn = PcaTransform . Defaults . WeightColumn , int rank = PcaTransform . Defaults . Rank ,
725
- int overSampling = PcaTransform . Defaults . Oversampling , bool center = PcaTransform . Defaults . Center ,
725
+ string weightColumn = Defaults . WeightColumn , int rank = Defaults . Rank ,
726
+ int overSampling = Defaults . Oversampling , bool center = Defaults . Center ,
726
727
int ? seed = null )
727
728
: this ( env , new PcaTransform . ColumnInfo ( inputColumn , outputColumn ?? inputColumn , weightColumn , rank , overSampling , center , seed ) )
728
729
{
@@ -746,7 +747,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
746
747
if ( ! inputSchema . TryFindColumn ( colInfo . Input , out var col ) )
747
748
throw _host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , colInfo . Input ) ;
748
749
749
- if ( ! ( col . Kind == SchemaShape . Column . VectorKind . Vector && col . ItemType . IsNumber ) )
750
+ if ( col . Kind != SchemaShape . Column . VectorKind . Vector || col . ItemType . RawKind != DataKind . R4 )
750
751
throw _host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , colInfo . Input ) ;
751
752
752
753
result [ colInfo . Output ] = new SchemaShape . Column ( colInfo . Output ,
@@ -808,10 +809,10 @@ public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
808
809
/// <param name="seed">The seed for random number generation</param>
809
810
/// <returns>Vector containing the principal components.</returns>
810
811
public static Vector < float > ToPrincipalComponents ( this Vector < float > input ,
811
- string weightColumn = PcaTransform . Defaults . WeightColumn ,
812
- int rank = PcaTransform . Defaults . Rank ,
813
- int overSampling = PcaTransform . Defaults . Oversampling ,
814
- bool center = PcaTransform . Defaults . Center ,
812
+ string weightColumn = PcaEstimator . Defaults . WeightColumn ,
813
+ int rank = PcaEstimator . Defaults . Rank ,
814
+ int overSampling = PcaEstimator . Defaults . Oversampling ,
815
+ bool center = PcaEstimator . Defaults . Center ,
815
816
int ? seed = null ) => new OutPipelineColumn ( input , weightColumn , rank , overSampling , center , seed ) ;
816
817
}
817
818
}
0 commit comments