Skip to content

Commit b89ce70

Browse files
authored
KMeans and Implicit weight cleanup (#2158)
* KMeans and Implicit weight cleanup
1 parent 0b72301 commit b89ce70

File tree

15 files changed

+192
-91
lines changed

15 files changed

+192
-91
lines changed

src/Microsoft.ML.Data/EntryPoints/InputBase.cs

+3
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ public abstract class LearnerInputBaseWithWeight : LearnerInputBaseWithLabel
9393
[TlcModule.EntryPointKind(typeof(CommonInputs.IUnsupervisedTrainerWithWeight))]
9494
public abstract class UnsupervisedLearnerInputBaseWithWeight : LearnerInputBase
9595
{
96+
/// <summary>
97+
/// Column to use for example weight.
98+
/// </summary>
9699
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
97100
public Optional<string> WeightColumn = Optional<string>.Implicit(DefaultColumnNames.Weight);
98101
}

src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs

+7
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,13 @@ SchemaShape IEstimator<CalibratorTransformer<TICalibrator>>.GetOutputSchema(Sche
122122
return new SchemaShape(outColumns.Values);
123123
}
124124

125+
/// <summary>
126+
/// Fits the scored <see cref="IDataView"/> creating a <see cref="CalibratorTransformer{TICalibrator}"/> that can transform the data by adding a
127+
/// <see cref="DefaultColumnNames.Probability"/> column containing the calibrated <see cref="DefaultColumnNames.Score"/>.
128+
/// </summary>
129+
/// <param name="input"></param>
130+
/// <returns>A trained <see cref="CalibratorTransformer{TICalibrator}"/> that will transform the data by adding the
131+
/// <see cref="DefaultColumnNames.Probability"/> column.</returns>
125132
public CalibratorTransformer<TICalibrator> Fit(IDataView input)
126133
{
127134
TICalibrator calibrator = null;

src/Microsoft.ML.Data/Training/TrainerUtils.cs

+14-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Collections.Generic;
77
using Microsoft.ML.Core.Data;
88
using Microsoft.ML.Data;
9+
using Microsoft.ML.EntryPoints;
910
using Microsoft.ML.Internal.Utilities;
1011

1112
namespace Microsoft.ML.Training
@@ -385,10 +386,20 @@ public static SchemaShape.Column MakeR4VecFeature(string featureColumn)
385386
/// The <see cref="SchemaShape.Column"/> for the weight column.
386387
/// </summary>
387388
/// <param name="weightColumn">name of the weight column</param>
388-
/// <param name="isExplicit">whether the column is implicitly, or explicitly defined</param>
389-
public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, bool isExplicit = true)
389+
public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn)
390390
{
391-
if (weightColumn == null || !isExplicit)
391+
if (weightColumn == null)
392+
return default;
393+
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
394+
}
395+
396+
/// <summary>
397+
/// The <see cref="SchemaShape.Column"/> for the weight column.
398+
/// </summary>
399+
/// <param name="weightColumn">name of the weight column</param>
400+
public static SchemaShape.Column MakeR4ScalarWeightColumn(Optional<string> weightColumn)
401+
{
402+
if (weightColumn == null || weightColumn.Value == null || !weightColumn.IsExplicit)
392403
return default;
393404
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
394405
}

src/Microsoft.ML.FastTree/FastTree.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env,
151151
/// Constructor that is used when invoking the classes deriving from this, through maml.
152152
/// </summary>
153153
private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label)
154-
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
154+
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
155155
{
156156
Host.CheckValue(args, nameof(args));
157157
Args = args;

src/Microsoft.ML.FastTree/GamTrainer.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ private protected GamTrainerBase(IHostEnvironment env,
192192

193193
private protected GamTrainerBase(IHostEnvironment env, TArgs args, string name, SchemaShape.Column label)
194194
: base(Contracts.CheckRef(env, nameof(env)).Register(name), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
195-
label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
195+
label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
196196
{
197197
Contracts.CheckValue(env, nameof(env));
198198
Host.CheckValue(args, nameof(args));

src/Microsoft.ML.HalLearners/OlsLinearRegression.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public OlsLinearRegressionTrainer(IHostEnvironment env,
8787
/// </summary>
8888
internal OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args)
8989
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
90-
TrainerUtils.MakeR4ScalarColumn(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
90+
TrainerUtils.MakeR4ScalarColumn(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
9191
{
9292
Host.CheckValue(args, nameof(args));
9393
Host.CheckUserArg(args.L2Weight >= 0, nameof(args.L2Weight), "L2 regularization term cannot be negative");
@@ -106,7 +106,7 @@ private static Arguments ArgsInit(string featureColumn,
106106
advancedSettings?.Invoke(args);
107107
args.FeatureColumn = featureColumn;
108108
args.LabelColumn = labelColumn;
109-
args.WeightColumn = weightColumn;
109+
args.WeightColumn = weightColumn != null ? Optional<string>.Explicit(weightColumn) : Optional<string>.Implicit(DefaultColumnNames.Weight);
110110
return args;
111111
}
112112

src/Microsoft.ML.KMeansClustering/KMeansCatalog.cs

+33-7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using Microsoft.ML.Data;
7+
using Microsoft.ML.EntryPoints;
78
using Microsoft.ML.Trainers.KMeans;
89

910
namespace Microsoft.ML
@@ -16,20 +17,45 @@ public static class KMeansClusteringExtensions
1617
/// <summary>
1718
/// Train a KMeans++ clustering algorithm.
1819
/// </summary>
19-
/// <param name="ctx">The regression context trainer object.</param>
20-
/// <param name="features">The features, or independent variables.</param>
20+
/// <param name="ctx">The clustering context trainer object.</param>
21+
/// <param name="featureColumn">The features, or independent variables.</param>
2122
/// <param name="weights">The optional example weights.</param>
2223
/// <param name="clustersCount">The number of clusters to use for KMeans.</param>
23-
/// <param name="advancedSettings">Algorithm advanced settings.</param>
24+
/// <example>
25+
/// <format type="text/markdown">
26+
/// <![CDATA[
27+
/// [!code-csharp[KMeans](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/KMeans.cs)]
28+
/// ]]></format>
29+
/// </example>
2430
public static KMeansPlusPlusTrainer KMeans(this ClusteringContext.ClusteringTrainers ctx,
25-
string features,
31+
string featureColumn = DefaultColumnNames.Features,
2632
string weights = null,
27-
int clustersCount = KMeansPlusPlusTrainer.Defaults.K,
28-
Action<KMeansPlusPlusTrainer.Arguments> advancedSettings = null)
33+
int clustersCount = KMeansPlusPlusTrainer.Defaults.ClustersCount)
2934
{
3035
Contracts.CheckValue(ctx, nameof(ctx));
3136
var env = CatalogUtils.GetEnvironment(ctx);
32-
return new KMeansPlusPlusTrainer(env, features, clustersCount, weights, advancedSettings);
37+
38+
var options = new KMeansPlusPlusTrainer.Options
39+
{
40+
FeatureColumn = featureColumn,
41+
WeightColumn = weights != null ? Optional<string>.Explicit(weights) : Optional<string>.Implicit(DefaultColumnNames.Weight),
42+
ClustersCount = clustersCount
43+
};
44+
return new KMeansPlusPlusTrainer(env, options);
45+
}
46+
47+
/// <summary>
48+
/// Train a KMeans++ clustering algorithm.
49+
/// </summary>
50+
/// <param name="ctx">The clustering context trainer object.</param>
51+
/// <param name="options">Algorithm advanced options.</param>
52+
public static KMeansPlusPlusTrainer KMeans(this ClusteringContext.ClusteringTrainers ctx, KMeansPlusPlusTrainer.Options options)
53+
{
54+
Contracts.CheckValue(ctx, nameof(ctx));
55+
Contracts.CheckValue(options, nameof(options));
56+
57+
var env = CatalogUtils.GetEnvironment(ctx);
58+
return new KMeansPlusPlusTrainer(env, options);
3359
}
3460
}
3561
}

src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs

+52-59
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
using Microsoft.ML.Trainers.KMeans;
1818
using Microsoft.ML.Training;
1919

20-
[assembly: LoadableClass(KMeansPlusPlusTrainer.Summary, typeof(KMeansPlusPlusTrainer), typeof(KMeansPlusPlusTrainer.Arguments),
20+
[assembly: LoadableClass(KMeansPlusPlusTrainer.Summary, typeof(KMeansPlusPlusTrainer), typeof(KMeansPlusPlusTrainer.Options),
2121
new[] { typeof(SignatureClusteringTrainer), typeof(SignatureTrainer) },
2222
KMeansPlusPlusTrainer.UserNameValue,
2323
KMeansPlusPlusTrainer.LoadNameValue,
@@ -30,7 +30,7 @@ namespace Microsoft.ML.Trainers.KMeans
3030
/// <include file='./doc.xml' path='doc/members/member[@name="KMeans++"]/*' />
3131
public class KMeansPlusPlusTrainer : TrainerEstimatorBase<ClusteringPredictionTransformer<KMeansModelParameters>, KMeansModelParameters>
3232
{
33-
public const string LoadNameValue = "KMeansPlusPlus";
33+
internal const string LoadNameValue = "KMeansPlusPlus";
3434
internal const string UserNameValue = "KMeans++ Clustering";
3535
internal const string ShortName = "KM";
3636
internal const string Summary = "K-means is a popular clustering algorithm. With K-means, the data is clustered into a specified "
@@ -45,34 +45,54 @@ public enum InitAlgorithm
4545
}
4646

4747
[BestFriend]
48-
internal static class Defaults{
48+
internal static class Defaults
49+
{
4950
/// <value>The number of clusters.</value>
50-
public const int K = 5;
51+
public const int ClustersCount = 5;
5152
}
5253

53-
public class Arguments : UnsupervisedLearnerInputBaseWithWeight
54+
public class Options : UnsupervisedLearnerInputBaseWithWeight
5455
{
55-
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of clusters", SortOrder = 50)]
56+
/// <summary>
57+
/// The number of clusters.
58+
/// </summary>
59+
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of clusters", SortOrder = 50, Name = "K")]
5660
[TGUI(SuggestedSweeps = "5,10,20,40")]
5761
[TlcModule.SweepableDiscreteParam("K", new object[] { 5, 10, 20, 40 })]
58-
public int K = Defaults.K;
62+
public int ClustersCount = Defaults.ClustersCount;
5963

64+
/// <summary>
65+
/// Cluster initialization algorithm.
66+
/// </summary>
6067
[Argument(ArgumentType.AtMostOnce, HelpText = "Cluster initialization algorithm", ShortName = "init")]
6168
public InitAlgorithm InitAlgorithm = InitAlgorithm.KMeansParallel;
6269

70+
/// <summary>
71+
/// Tolerance parameter for trainer convergence. Low = slower, more accurate.
72+
/// </summary>
6373
[Argument(ArgumentType.AtMostOnce, HelpText = "Tolerance parameter for trainer convergence. Low = slower, more accurate",
64-
ShortName = "ot")]
74+
Name = "OptTol", ShortName = "ot")]
6575
[TGUI(Label = "Optimization Tolerance", Description = "Threshold for trainer convergence")]
66-
public float OptTol = (float)1e-7;
76+
public float OptimizationTolerance = (float)1e-7;
6777

78+
/// <summary>
79+
/// Maximum number of iterations.
80+
/// </summary>
6881
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum number of iterations.", ShortName = "maxiter")]
6982
[TGUI(Label = "Max Number of Iterations")]
7083
public int MaxIterations = 1000;
7184

72-
[Argument(ArgumentType.AtMostOnce, HelpText = "Memory budget (in MBs) to use for KMeans acceleration", ShortName = "accelMemBudgetMb")]
85+
/// <summary>
86+
/// Memory budget (in MBs) to use for KMeans acceleration.
87+
/// </summary>
88+
[Argument(ArgumentType.AtMostOnce, HelpText = "Memory budget (in MBs) to use for KMeans acceleration",
89+
Name = "AccelMemBudgetMb", ShortName = "accelMemBudgetMb")]
7390
[TGUI(Label = "Memory Budget (in MBs) for KMeans Acceleration")]
74-
public int AccelMemBudgetMb = 4 * 1024; // by default, use at most 4 GB
91+
public int AccelerationMemoryBudgetMb = 4 * 1024; // by default, use at most 4 GB
7592

93+
/// <summary>
94+
/// Degree of lock-free parallelism. Defaults to automatic. Determinism not guaranteed.
95+
/// </summary>
7696
[Argument(ArgumentType.AtMostOnce, HelpText = "Degree of lock-free parallelism. Defaults to automatic. Determinism not guaranteed.", ShortName = "nt,t,threads", SortOrder = 50)]
7797
[TGUI(Label = "Number of threads")]
7898
public int? NumThreads;
@@ -95,58 +115,31 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight
95115
/// Initializes a new instance of <see cref="KMeansPlusPlusTrainer"/>
96116
/// </summary>
97117
/// <param name="env">The <see cref="IHostEnvironment"/> to use.</param>
98-
/// <param name="featureColumn">The name of the feature column.</param>
99-
/// <param name="weights">The name for the optional column containing the example weights.</param>
100-
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
101-
/// <param name="clustersCount">The number of clusters.</param>
102-
public KMeansPlusPlusTrainer(IHostEnvironment env,
103-
string featureColumn = DefaultColumnNames.Features,
104-
int clustersCount = Defaults.K,
105-
string weights = null,
106-
Action<Arguments> advancedSettings = null)
107-
: this(env, new Arguments
108-
{
109-
FeatureColumn = featureColumn,
110-
WeightColumn = weights,
111-
K = clustersCount
112-
}, advancedSettings)
113-
{
114-
}
115-
116-
internal KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args)
117-
: this(env, args, null)
118+
/// <param name="options">The advanced options of the algorithm.</param>
119+
internal KMeansPlusPlusTrainer(IHostEnvironment env, Options options)
120+
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(options.FeatureColumn), default, TrainerUtils.MakeR4ScalarWeightColumn(options.WeightColumn))
118121
{
122+
Host.CheckValue(options, nameof(options));
123+
Host.CheckUserArg(options.ClustersCount > 0, nameof(options.ClustersCount), "Must be positive");
119124

120-
}
121-
122-
private KMeansPlusPlusTrainer(IHostEnvironment env, Arguments args, Action<Arguments> advancedSettings = null)
123-
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), default, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
124-
{
125-
Host.CheckValue(args, nameof(args));
126-
127-
// override with the advanced settings.
128-
advancedSettings?.Invoke(args);
129-
130-
Host.CheckUserArg(args.K > 0, nameof(args.K), "Must be positive");
131-
132-
_featureColumn = args.FeatureColumn;
125+
_featureColumn = options.FeatureColumn;
133126

134-
_k = args.K;
127+
_k = options.ClustersCount;
135128

136-
Host.CheckUserArg(args.MaxIterations > 0, nameof(args.MaxIterations), "Must be positive");
137-
_maxIterations = args.MaxIterations;
129+
Host.CheckUserArg(options.MaxIterations > 0, nameof(options.MaxIterations), "Must be positive");
130+
_maxIterations = options.MaxIterations;
138131

139-
Host.CheckUserArg(args.OptTol > 0, nameof(args.OptTol), "Tolerance must be positive");
140-
_convergenceThreshold = args.OptTol;
132+
Host.CheckUserArg(options.OptimizationTolerance > 0, nameof(options.OptimizationTolerance), "Tolerance must be positive");
133+
_convergenceThreshold = options.OptimizationTolerance;
141134

142-
Host.CheckUserArg(args.AccelMemBudgetMb > 0, nameof(args.AccelMemBudgetMb), "Must be positive");
143-
_accelMemBudgetMb = args.AccelMemBudgetMb;
135+
Host.CheckUserArg(options.AccelerationMemoryBudgetMb > 0, nameof(options.AccelerationMemoryBudgetMb), "Must be positive");
136+
_accelMemBudgetMb = options.AccelerationMemoryBudgetMb;
144137

145-
_initAlgorithm = args.InitAlgorithm;
138+
_initAlgorithm = options.InitAlgorithm;
146139

147-
Host.CheckUserArg(!args.NumThreads.HasValue || args.NumThreads > 0, nameof(args.NumThreads),
140+
Host.CheckUserArg(!options.NumThreads.HasValue || options.NumThreads > 0, nameof(options.NumThreads),
148141
"Must be either null or a positive integer.");
149-
_numThreads = ComputeNumThreads(Host, args.NumThreads);
142+
_numThreads = ComputeNumThreads(Host, options.NumThreads);
150143
Info = new TrainerInfo();
151144
}
152145

@@ -247,14 +240,14 @@ private static int ComputeNumThreads(IHost host, int? argNumThreads)
247240
ShortName = ShortName,
248241
XmlInclude = new[] { @"<include file='../Microsoft.ML.KMeansClustering/doc.xml' path='doc/members/member[@name=""KMeans++""]/*' />",
249242
@"<include file='../Microsoft.ML.KMeansClustering/doc.xml' path='doc/members/example[@name=""KMeans++""]/*' />"})]
250-
public static CommonOutputs.ClusteringOutput TrainKMeans(IHostEnvironment env, Arguments input)
243+
public static CommonOutputs.ClusteringOutput TrainKMeans(IHostEnvironment env, Options input)
251244
{
252245
Contracts.CheckValue(env, nameof(env));
253246
var host = env.Register("TrainKMeans");
254247
host.CheckValue(input, nameof(input));
255248
EntryPointUtils.CheckInputArgs(host, input);
256249

257-
return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.ClusteringOutput>(host, input,
250+
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.ClusteringOutput>(host, input,
258251
() => new KMeansPlusPlusTrainer(host, input),
259252
getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn));
260253
}
@@ -749,10 +742,10 @@ public static void Initialize(IHost host, int numThreads, IChannel ch, FeatureFl
749742
host.CheckValue(ch, nameof(ch));
750743
ch.CheckValue(cursorFactory, nameof(cursorFactory));
751744
ch.CheckValue(centroids, nameof(centroids));
752-
ch.CheckUserArg(numThreads > 0, nameof(KMeansPlusPlusTrainer.Arguments.NumThreads), "Must be positive");
753-
ch.CheckUserArg(k > 0, nameof(KMeansPlusPlusTrainer.Arguments.K), "Must be positive");
745+
ch.CheckUserArg(numThreads > 0, nameof(KMeansPlusPlusTrainer.Options.NumThreads), "Must be positive");
746+
ch.CheckUserArg(k > 0, nameof(KMeansPlusPlusTrainer.Options.ClustersCount), "Must be positive");
754747
ch.CheckParam(dimensionality > 0, nameof(dimensionality), "Must be positive");
755-
ch.CheckUserArg(accelMemBudgetMb >= 0, nameof(KMeansPlusPlusTrainer.Arguments.AccelMemBudgetMb), "Must be non-negative");
748+
ch.CheckUserArg(accelMemBudgetMb >= 0, nameof(KMeansPlusPlusTrainer.Options.AccelerationMemoryBudgetMb), "Must be non-negative");
756749

757750
int numRounds;
758751
int numSamplesPerRound;

0 commit comments

Comments
 (0)