Skip to content

Commit d9208bf

Browse files
authored
Cleaned and fixed public API surface for KMeans, NaiveBayes, OLS. (#2819)
1 parent c2fabb8 commit d9208bf

File tree

17 files changed

+148
-106
lines changed

17 files changed

+148
-106
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PFIHelper.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ private class BinaryOutputRow
4848
private readonly static Action<ContinuousInputRow, BinaryOutputRow> GreaterThanAverage = (input, output)
4949
=> output.AboveAverage = input.MedianHomeValue > 22.6;
5050

51-
public static float[] GetLinearModelWeights(OlsLinearRegressionModelParameters linearModel)
51+
public static float[] GetLinearModelWeights(OrdinaryLeastSquaresRegressionModelParameters linearModel)
5252
{
5353
return linearModel.Weights.ToArray();
5454
}

docs/samples/Microsoft.ML.Samples/Dynamic/KMeans.cs renamed to docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Clustering/KMeans.cs

+5-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ public static void Example()
99
{
1010
// Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
1111
// as well as the source of randomness.
12-
var ml = new MLContext();
12+
var ml = new MLContext(seed: 1, conc: 1);
1313

1414
// Get a small dataset as an IEnumerable and convert it to an IDataView.
1515
var data = SamplesUtils.DatasetUtils.GetInfertData();
@@ -39,7 +39,10 @@ public static void Example()
3939
modelParams.GetClusterCentroids(ref centroids, out k);
4040

4141
var centroid = centroids[0].GetValues();
42-
Console.WriteLine("The coordinates of centroid 0 are: " + string.Join(", ", centroid.ToArray()));
42+
Console.WriteLine($"The coordinates of centroid 0 are: ({string.Join(", ", centroid.ToArray())})");
43+
44+
// Expected Output:
45+
// The coordinates of centroid 0 are: (26, 6, 1)
4346
}
4447
}
4548
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
using System;
2+
using Microsoft.ML.Data;
3+
using Microsoft.ML.Trainers;
4+
5+
namespace Microsoft.ML.Samples.Dynamic
6+
{
7+
public class KMeansWithOptions
8+
{
9+
public static void Example()
10+
{
11+
// Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging,
12+
// as well as the source of randomness.
13+
var ml = new MLContext(seed: 1, conc: 1);
14+
15+
// Get a small dataset as an IEnumerable and convert it to an IDataView.
16+
var data = SamplesUtils.DatasetUtils.GetInfertData();
17+
var trainData = ml.Data.LoadFromEnumerable(data);
18+
19+
// Preview of the data.
20+
//
21+
// Age Case Education Induced Parity PooledStratum RowNum ...
22+
// 26 1 0-5yrs 1 6 3 1 ...
23+
// 42 1 0-5yrs 1 1 1 2 ...
24+
// 39 1 0-5yrs 2 6 4 3 ...
25+
// 34 1 0-5yrs 2 4 2 4 ...
26+
// 35 1 6-11yrs 1 3 32 5 ...
27+
28+
// A pipeline for concatenating the age, parity and induced columns together in the Features column and training a KMeans model on them.
29+
string outputColumnName = "Features";
30+
var pipeline = ml.Transforms.Concatenate(outputColumnName, new[] { "Age", "Parity", "Induced" })
31+
.Append(ml.Clustering.Trainers.KMeans(
32+
new KMeansPlusPlusTrainer.Options
33+
{
34+
FeatureColumnName = outputColumnName,
35+
NumberOfClusters = 2,
36+
NumberOfIterations = 100,
37+
OptimizationTolerance = 1e-6f
38+
}
39+
));
40+
41+
var model = pipeline.Fit(trainData);
42+
43+
// Get cluster centroids and the number of clusters k from KMeansModelParameters.
44+
VBuffer<float>[] centroids = default;
45+
int k;
46+
47+
var modelParams = model.LastTransformer.Model;
48+
modelParams.GetClusterCentroids(ref centroids, out k);
49+
50+
var centroid = centroids[0].GetValues();
51+
Console.WriteLine("The coordinates of centroid 0 are: " + string.Join(", ", centroid.ToArray()));
52+
53+
// Expected Output:
54+
// The coordinates of centroid 0 are: (26, 6, 1)
55+
}
56+
}
57+
}

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/OrdinaryLeastSquaresWithOptions.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ public static void Example()
4646
// as data is already processed in a form consumable by the trainer
4747
var pipeline = mlContext.Regression.Trainers.OrdinaryLeastSquares(new OrdinaryLeastSquaresRegressionTrainer.Options()
4848
{
49-
L2Weight = 0.1f,
50-
PerParameterSignificance = false
49+
L2Regularization = 0.1f,
50+
CalculateStatistics = false
5151
});
5252
var model = pipeline.Fit(split.TrainSet);
5353

src/Microsoft.ML.KMeansClustering/KMeansCatalog.cs

+11-5
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace Microsoft.ML
1515
public static class KMeansClusteringExtensions
1616
{
1717
/// <summary>
18-
/// Train a KMeans++ clustering algorithm.
18+
/// Train a KMeans++ clustering algorithm using <see cref="KMeansPlusPlusTrainer"/>.
1919
/// </summary>
2020
/// <param name="catalog">The clustering catalog trainer object.</param>
2121
/// <param name="featureColumnName">The name of the feature column.</param>
@@ -24,13 +24,13 @@ public static class KMeansClusteringExtensions
2424
/// <example>
2525
/// <format type="text/markdown">
2626
/// <![CDATA[
27-
/// [!code-csharp[KMeans](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/KMeans.cs)]
27+
/// [!code-csharp[KMeans](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Clustering/KMeans.cs)]
2828
/// ]]></format>
2929
/// </example>
3030
public static KMeansPlusPlusTrainer KMeans(this ClusteringCatalog.ClusteringTrainers catalog,
3131
string featureColumnName = DefaultColumnNames.Features,
3232
string exampleWeightColumnName = null,
33-
int clustersCount = KMeansPlusPlusTrainer.Defaults.ClustersCount)
33+
int clustersCount = KMeansPlusPlusTrainer.Defaults.NumberOfClusters)
3434
{
3535
Contracts.CheckValue(catalog, nameof(catalog));
3636
var env = CatalogUtils.GetEnvironment(catalog);
@@ -39,16 +39,22 @@ public static KMeansPlusPlusTrainer KMeans(this ClusteringCatalog.ClusteringTrai
3939
{
4040
FeatureColumnName = featureColumnName,
4141
ExampleWeightColumnName = exampleWeightColumnName,
42-
ClustersCount = clustersCount
42+
NumberOfClusters = clustersCount
4343
};
4444
return new KMeansPlusPlusTrainer(env, options);
4545
}
4646

4747
/// <summary>
48-
/// Train a KMeans++ clustering algorithm.
48+
/// Train a KMeans++ clustering algorithm using <see cref="KMeansPlusPlusTrainer"/>.
4949
/// </summary>
5050
/// <param name="catalog">The clustering catalog trainer object.</param>
5151
/// <param name="options">Algorithm advanced options.</param>
52+
/// <example>
53+
/// <format type="text/markdown">
54+
/// <![CDATA[
55+
/// [!code-csharp[KMeans](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Clustering/KMeansWithOptions.cs)]
56+
/// ]]></format>
57+
/// </example>
5258
public static KMeansPlusPlusTrainer KMeans(this ClusteringCatalog.ClusteringTrainers catalog, KMeansPlusPlusTrainer.Options options)
5359
{
5460
Contracts.CheckValue(catalog, nameof(catalog));

src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs

+19-19
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,18 @@ public class KMeansPlusPlusTrainer : TrainerEstimatorBase<ClusteringPredictionTr
3636
+ "number of clusters in order to minimize the within-cluster sum of squares. K-means++ improves upon K-means by using a better "
3737
+ "method for choosing the initial cluster centers.";
3838

39-
public enum InitAlgorithm
39+
public enum InitializationAlgorithm
4040
{
4141
KMeansPlusPlus = 0,
4242
Random = 1,
43-
KMeansParallel = 2
43+
KMeansYinyang = 2
4444
}
4545

4646
[BestFriend]
4747
internal static class Defaults
4848
{
4949
/// <value>The number of clusters.</value>
50-
public const int ClustersCount = 5;
50+
public const int NumberOfClusters = 5;
5151
}
5252

5353
public sealed class Options : UnsupervisedTrainerInputBaseWithWeight
@@ -58,13 +58,13 @@ public sealed class Options : UnsupervisedTrainerInputBaseWithWeight
5858
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of clusters", SortOrder = 50, Name = "K")]
5959
[TGUI(SuggestedSweeps = "5,10,20,40")]
6060
[TlcModule.SweepableDiscreteParam("K", new object[] { 5, 10, 20, 40 })]
61-
public int ClustersCount = Defaults.ClustersCount;
61+
public int NumberOfClusters = Defaults.NumberOfClusters;
6262

6363
/// <summary>
6464
/// Cluster initialization algorithm.
6565
/// </summary>
6666
[Argument(ArgumentType.AtMostOnce, HelpText = "Cluster initialization algorithm", ShortName = "init")]
67-
public InitAlgorithm InitAlgorithm = InitAlgorithm.KMeansParallel;
67+
public InitializationAlgorithm InitializationAlgorithm = InitializationAlgorithm.KMeansYinyang;
6868

6969
/// <summary>
7070
/// Tolerance parameter for trainer convergence. Low = slower, more accurate.
@@ -79,7 +79,7 @@ public sealed class Options : UnsupervisedTrainerInputBaseWithWeight
7979
/// </summary>
8080
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of iterations.", ShortName = "maxiter")]
8181
[TGUI(Label = "Max Number of Iterations")]
82-
public int MaxIterations = 1000;
82+
public int NumberOfIterations = 1000;
8383

8484
/// <summary>
8585
/// Memory budget (in MBs) to use for KMeans acceleration.
@@ -94,7 +94,7 @@ public sealed class Options : UnsupervisedTrainerInputBaseWithWeight
9494
/// </summary>
9595
[Argument(ArgumentType.AtMostOnce, HelpText = "Degree of lock-free parallelism. Defaults to automatic. Determinism not guaranteed.", ShortName = "nt,t,threads", SortOrder = 50)]
9696
[TGUI(Label = "Number of threads")]
97-
public int? NumThreads;
97+
public int? NumberOfThreads;
9898
}
9999

100100
private readonly int _k;
@@ -103,7 +103,7 @@ public sealed class Options : UnsupervisedTrainerInputBaseWithWeight
103103
private readonly float _convergenceThreshold; // convergence thresholds
104104

105105
private readonly long _accelMemBudgetMb;
106-
private readonly InitAlgorithm _initAlgorithm;
106+
private readonly InitializationAlgorithm _initAlgorithm;
107107
private readonly int _numThreads;
108108
private readonly string _featureColumn;
109109

@@ -119,26 +119,26 @@ internal KMeansPlusPlusTrainer(IHostEnvironment env, Options options)
119119
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(options.FeatureColumnName), default, TrainerUtils.MakeR4ScalarWeightColumn(options.ExampleWeightColumnName))
120120
{
121121
Host.CheckValue(options, nameof(options));
122-
Host.CheckUserArg(options.ClustersCount > 0, nameof(options.ClustersCount), "Must be positive");
122+
Host.CheckUserArg(options.NumberOfClusters > 0, nameof(options.NumberOfClusters), "Must be positive");
123123

124124
_featureColumn = options.FeatureColumnName;
125125

126-
_k = options.ClustersCount;
126+
_k = options.NumberOfClusters;
127127

128-
Host.CheckUserArg(options.MaxIterations > 0, nameof(options.MaxIterations), "Must be positive");
129-
_maxIterations = options.MaxIterations;
128+
Host.CheckUserArg(options.NumberOfIterations > 0, nameof(options.NumberOfIterations), "Must be positive");
129+
_maxIterations = options.NumberOfIterations;
130130

131131
Host.CheckUserArg(options.OptimizationTolerance > 0, nameof(options.OptimizationTolerance), "Tolerance must be positive");
132132
_convergenceThreshold = options.OptimizationTolerance;
133133

134134
Host.CheckUserArg(options.AccelerationMemoryBudgetMb > 0, nameof(options.AccelerationMemoryBudgetMb), "Must be positive");
135135
_accelMemBudgetMb = options.AccelerationMemoryBudgetMb;
136136

137-
_initAlgorithm = options.InitAlgorithm;
137+
_initAlgorithm = options.InitializationAlgorithm;
138138

139-
Host.CheckUserArg(!options.NumThreads.HasValue || options.NumThreads > 0, nameof(options.NumThreads),
139+
Host.CheckUserArg(!options.NumberOfThreads.HasValue || options.NumberOfThreads > 0, nameof(options.NumberOfThreads),
140140
"Must be either null or a positive integer.");
141-
_numThreads = ComputeNumThreads(Host, options.NumThreads);
141+
_numThreads = ComputeNumThreads(Host, options.NumberOfThreads);
142142
Info = new TrainerInfo();
143143
}
144144

@@ -184,12 +184,12 @@ private KMeansModelParameters TrainCore(IChannel ch, RoleMappedData data, int di
184184
// all produce a valid set of output centroids with various trade-offs in runtime (with perhaps
185185
// random initialization creating a set that's not terribly useful.) They could also be extended to
186186
// pay attention to their incoming set of centroids and incrementally train.
187-
if (_initAlgorithm == InitAlgorithm.KMeansPlusPlus)
187+
if (_initAlgorithm == InitializationAlgorithm.KMeansPlusPlus)
188188
{
189189
KMeansPlusPlusInit.Initialize(Host, _numThreads, ch, cursorFactory, _k, dimensionality,
190190
centroids, out missingFeatureCount, out totalTrainingInstances);
191191
}
192-
else if (_initAlgorithm == InitAlgorithm.Random)
192+
else if (_initAlgorithm == InitializationAlgorithm.Random)
193193
{
194194
KMeansRandomInit.Initialize(Host, _numThreads, ch, cursorFactory, _k,
195195
centroids, out missingFeatureCount, out totalTrainingInstances);
@@ -743,8 +743,8 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl
743743
host.CheckValue(ch, nameof(ch));
744744
ch.CheckValue(cursorFactory, nameof(cursorFactory));
745745
ch.CheckValue(centroids, nameof(centroids));
746-
ch.CheckUserArg(numThreads > 0, nameof(KMeansPlusPlusTrainer.Options.NumThreads), "Must be positive");
747-
ch.CheckUserArg(k > 0, nameof(KMeansPlusPlusTrainer.Options.ClustersCount), "Must be positive");
746+
ch.CheckUserArg(numThreads > 0, nameof(KMeansPlusPlusTrainer.Options.NumberOfThreads), "Must be positive");
747+
ch.CheckUserArg(k > 0, nameof(KMeansPlusPlusTrainer.Options.NumberOfClusters), "Must be positive");
748748
ch.CheckParam(dimensionality > 0, nameof(dimensionality), "Must be positive");
749749
ch.CheckUserArg(accelMemBudgetMb >= 0, nameof(KMeansPlusPlusTrainer.Options.AccelerationMemoryBudgetMb), "Must be non-negative");
750750

0 commit comments

Comments
 (0)