17
17
using Microsoft . ML . Trainers . KMeans ;
18
18
using Microsoft . ML . Training ;
19
19
20
- [ assembly: LoadableClass ( KMeansPlusPlusTrainer . Summary , typeof ( KMeansPlusPlusTrainer ) , typeof ( KMeansPlusPlusTrainer . Arguments ) ,
20
+ [ assembly: LoadableClass ( KMeansPlusPlusTrainer . Summary , typeof ( KMeansPlusPlusTrainer ) , typeof ( KMeansPlusPlusTrainer . Options ) ,
21
21
new [ ] { typeof ( SignatureClusteringTrainer ) , typeof ( SignatureTrainer ) } ,
22
22
KMeansPlusPlusTrainer . UserNameValue ,
23
23
KMeansPlusPlusTrainer . LoadNameValue ,
@@ -30,7 +30,7 @@ namespace Microsoft.ML.Trainers.KMeans
30
30
/// <include file='./doc.xml' path='doc/members/member[@name="KMeans++"]/*' />
31
31
public class KMeansPlusPlusTrainer : TrainerEstimatorBase < ClusteringPredictionTransformer < KMeansModelParameters > , KMeansModelParameters >
32
32
{
33
- public const string LoadNameValue = "KMeansPlusPlus" ;
33
+ internal const string LoadNameValue = "KMeansPlusPlus" ;
34
34
internal const string UserNameValue = "KMeans++ Clustering" ;
35
35
internal const string ShortName = "KM" ;
36
36
internal const string Summary = "K-means is a popular clustering algorithm. With K-means, the data is clustered into a specified "
@@ -45,34 +45,54 @@ public enum InitAlgorithm
45
45
}
46
46
47
47
[ BestFriend ]
48
- internal static class Defaults {
48
+ internal static class Defaults
49
+ {
49
50
/// <value>The number of clusters.</value>
50
- public const int K = 5 ;
51
+ public const int ClustersCount = 5 ;
51
52
}
52
53
53
- public class Arguments : UnsupervisedLearnerInputBaseWithWeight
54
+ public class Options : UnsupervisedLearnerInputBaseWithWeight
54
55
{
55
- [ Argument ( ArgumentType . AtMostOnce , HelpText = "The number of clusters" , SortOrder = 50 ) ]
56
+ /// <summary>
57
+ /// The number of clusters.
58
+ /// </summary>
59
+ [ Argument ( ArgumentType . AtMostOnce , HelpText = "The number of clusters" , SortOrder = 50 , Name = "K" ) ]
56
60
[ TGUI ( SuggestedSweeps = "5,10,20,40" ) ]
57
61
[ TlcModule . SweepableDiscreteParam ( "K" , new object [ ] { 5 , 10 , 20 , 40 } ) ]
58
- public int K = Defaults . K ;
62
+ public int ClustersCount = Defaults . ClustersCount ;
59
63
64
+ /// <summary>
65
+ /// Cluster initialization algorithm.
66
+ /// </summary>
60
67
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Cluster initialization algorithm" , ShortName = "init" ) ]
61
68
public InitAlgorithm InitAlgorithm = InitAlgorithm . KMeansParallel ;
62
69
70
+ /// <summary>
71
+ /// Tolerance parameter for trainer convergence. Low = slower, more accurate.
72
+ /// </summary>
63
73
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Tolerance parameter for trainer convergence. Low = slower, more accurate" ,
64
- ShortName = "ot" ) ]
74
+ Name = "OptTol" , ShortName = "ot" ) ]
65
75
[ TGUI ( Label = "Optimization Tolerance" , Description = "Threshold for trainer convergence" ) ]
66
- public float OptTol = ( float ) 1e-7 ;
76
+ public float OptimizationTolerance = ( float ) 1e-7 ;
67
77
78
+ /// <summary>
79
+ /// Maximum number of iterations.
80
+ /// </summary>
68
81
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Maximum number of iterations." , ShortName = "maxiter" ) ]
69
82
[ TGUI ( Label = "Max Number of Iterations" ) ]
70
83
public int MaxIterations = 1000 ;
71
84
72
- [ Argument ( ArgumentType . AtMostOnce , HelpText = "Memory budget (in MBs) to use for KMeans acceleration" , ShortName = "accelMemBudgetMb" ) ]
85
+ /// <summary>
86
+ /// Memory budget (in MBs) to use for KMeans acceleration.
87
+ /// </summary>
88
+ [ Argument ( ArgumentType . AtMostOnce , HelpText = "Memory budget (in MBs) to use for KMeans acceleration" ,
89
+ Name = "AccelMemBudgetMb" , ShortName = "accelMemBudgetMb" ) ]
73
90
[ TGUI ( Label = "Memory Budget (in MBs) for KMeans Acceleration" ) ]
74
- public int AccelMemBudgetMb = 4 * 1024 ; // by default, use at most 4 GB
91
+ public int AccelerationMemoryBudgetMb = 4 * 1024 ; // by default, use at most 4 GB
75
92
93
+ /// <summary>
94
+ /// Degree of lock-free parallelism. Defaults to automatic. Determinism not guaranteed.
95
+ /// </summary>
76
96
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Degree of lock-free parallelism. Defaults to automatic. Determinism not guaranteed." , ShortName = "nt,t,threads" , SortOrder = 50 ) ]
77
97
[ TGUI ( Label = "Number of threads" ) ]
78
98
public int ? NumThreads ;
@@ -95,58 +115,31 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight
95
115
/// Initializes a new instance of <see cref="KMeansPlusPlusTrainer"/>
96
116
/// </summary>
97
117
/// <param name="env">The <see cref="IHostEnvironment"/> to use.</param>
98
- /// <param name="featureColumn">The name of the feature column.</param>
99
- /// <param name="weights">The name for the optional column containing the example weights.</param>
100
- /// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
101
- /// <param name="clustersCount">The number of clusters.</param>
102
- public KMeansPlusPlusTrainer ( IHostEnvironment env ,
103
- string featureColumn = DefaultColumnNames . Features ,
104
- int clustersCount = Defaults . K ,
105
- string weights = null ,
106
- Action < Arguments > advancedSettings = null )
107
- : this ( env , new Arguments
108
- {
109
- FeatureColumn = featureColumn ,
110
- WeightColumn = weights ,
111
- K = clustersCount
112
- } , advancedSettings )
113
- {
114
- }
115
-
116
- internal KMeansPlusPlusTrainer ( IHostEnvironment env , Arguments args )
117
- : this ( env , args , null )
118
+ /// <param name="options">The advanced options of the algorithm.</param>
119
+ internal KMeansPlusPlusTrainer ( IHostEnvironment env , Options options )
120
+ : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( LoadNameValue ) , TrainerUtils . MakeR4VecFeature ( options . FeatureColumn ) , default , TrainerUtils . MakeR4ScalarWeightColumn ( options . WeightColumn ) )
118
121
{
122
+ Host . CheckValue ( options , nameof ( options ) ) ;
123
+ Host . CheckUserArg ( options . ClustersCount > 0 , nameof ( options . ClustersCount ) , "Must be positive" ) ;
119
124
120
- }
121
-
122
- private KMeansPlusPlusTrainer ( IHostEnvironment env , Arguments args , Action < Arguments > advancedSettings = null )
123
- : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( LoadNameValue ) , TrainerUtils . MakeR4VecFeature ( args . FeatureColumn ) , default , TrainerUtils . MakeR4ScalarWeightColumn ( args . WeightColumn ) )
124
- {
125
- Host . CheckValue ( args , nameof ( args ) ) ;
126
-
127
- // override with the advanced settings.
128
- advancedSettings ? . Invoke ( args ) ;
129
-
130
- Host . CheckUserArg ( args . K > 0 , nameof ( args . K ) , "Must be positive" ) ;
131
-
132
- _featureColumn = args . FeatureColumn ;
125
+ _featureColumn = options . FeatureColumn ;
133
126
134
- _k = args . K ;
127
+ _k = options . ClustersCount ;
135
128
136
- Host . CheckUserArg ( args . MaxIterations > 0 , nameof ( args . MaxIterations ) , "Must be positive" ) ;
137
- _maxIterations = args . MaxIterations ;
129
+ Host . CheckUserArg ( options . MaxIterations > 0 , nameof ( options . MaxIterations ) , "Must be positive" ) ;
130
+ _maxIterations = options . MaxIterations ;
138
131
139
- Host . CheckUserArg ( args . OptTol > 0 , nameof ( args . OptTol ) , "Tolerance must be positive" ) ;
140
- _convergenceThreshold = args . OptTol ;
132
+ Host . CheckUserArg ( options . OptimizationTolerance > 0 , nameof ( options . OptimizationTolerance ) , "Tolerance must be positive" ) ;
133
+ _convergenceThreshold = options . OptimizationTolerance ;
141
134
142
- Host . CheckUserArg ( args . AccelMemBudgetMb > 0 , nameof ( args . AccelMemBudgetMb ) , "Must be positive" ) ;
143
- _accelMemBudgetMb = args . AccelMemBudgetMb ;
135
+ Host . CheckUserArg ( options . AccelerationMemoryBudgetMb > 0 , nameof ( options . AccelerationMemoryBudgetMb ) , "Must be positive" ) ;
136
+ _accelMemBudgetMb = options . AccelerationMemoryBudgetMb ;
144
137
145
- _initAlgorithm = args . InitAlgorithm ;
138
+ _initAlgorithm = options . InitAlgorithm ;
146
139
147
- Host . CheckUserArg ( ! args . NumThreads . HasValue || args . NumThreads > 0 , nameof ( args . NumThreads ) ,
140
+ Host . CheckUserArg ( ! options . NumThreads . HasValue || options . NumThreads > 0 , nameof ( options . NumThreads ) ,
148
141
"Must be either null or a positive integer." ) ;
149
- _numThreads = ComputeNumThreads ( Host , args . NumThreads ) ;
142
+ _numThreads = ComputeNumThreads ( Host , options . NumThreads ) ;
150
143
Info = new TrainerInfo ( ) ;
151
144
}
152
145
@@ -247,14 +240,14 @@ private static int ComputeNumThreads(IHost host, int? argNumThreads)
247
240
ShortName = ShortName ,
248
241
XmlInclude = new [ ] { @"<include file='../Microsoft.ML.KMeansClustering/doc.xml' path='doc/members/member[@name=""KMeans++""]/*' />" ,
249
242
@"<include file='../Microsoft.ML.KMeansClustering/doc.xml' path='doc/members/example[@name=""KMeans++""]/*' />" } ) ]
250
- public static CommonOutputs . ClusteringOutput TrainKMeans ( IHostEnvironment env , Arguments input )
243
+ public static CommonOutputs . ClusteringOutput TrainKMeans ( IHostEnvironment env , Options input )
251
244
{
252
245
Contracts . CheckValue ( env , nameof ( env ) ) ;
253
246
var host = env . Register ( "TrainKMeans" ) ;
254
247
host . CheckValue ( input , nameof ( input ) ) ;
255
248
EntryPointUtils . CheckInputArgs ( host , input ) ;
256
249
257
- return LearnerEntryPointsUtils . Train < Arguments , CommonOutputs . ClusteringOutput > ( host , input ,
250
+ return LearnerEntryPointsUtils . Train < Options , CommonOutputs . ClusteringOutput > ( host , input ,
258
251
( ) => new KMeansPlusPlusTrainer ( host , input ) ,
259
252
getWeight : ( ) => LearnerEntryPointsUtils . FindColumn ( host , input . TrainingData . Schema , input . WeightColumn ) ) ;
260
253
}
@@ -749,10 +742,10 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl
749
742
host . CheckValue ( ch , nameof ( ch ) ) ;
750
743
ch . CheckValue ( cursorFactory , nameof ( cursorFactory ) ) ;
751
744
ch . CheckValue ( centroids , nameof ( centroids ) ) ;
752
- ch . CheckUserArg ( numThreads > 0 , nameof ( KMeansPlusPlusTrainer . Arguments . NumThreads ) , "Must be positive" ) ;
753
- ch . CheckUserArg ( k > 0 , nameof ( KMeansPlusPlusTrainer . Arguments . K ) , "Must be positive" ) ;
745
+ ch . CheckUserArg ( numThreads > 0 , nameof ( KMeansPlusPlusTrainer . Options . NumThreads ) , "Must be positive" ) ;
746
+ ch . CheckUserArg ( k > 0 , nameof ( KMeansPlusPlusTrainer . Options . ClustersCount ) , "Must be positive" ) ;
754
747
ch . CheckParam ( dimensionality > 0 , nameof ( dimensionality ) , "Must be positive" ) ;
755
- ch . CheckUserArg ( accelMemBudgetMb >= 0 , nameof ( KMeansPlusPlusTrainer . Arguments . AccelMemBudgetMb ) , "Must be non-negative" ) ;
748
+ ch . CheckUserArg ( accelMemBudgetMb >= 0 , nameof ( KMeansPlusPlusTrainer . Options . AccelerationMemoryBudgetMb ) , "Must be non-negative" ) ;
756
749
757
750
int numRounds ;
758
751
int numSamplesPerRound ;
0 commit comments