diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs index 57a7c9120f..5583c66df0 100644 --- a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs +++ b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs @@ -69,6 +69,16 @@ public abstract class LearnerInputBaseWithWeight : LearnerInputBaseWithLabel public Optional WeightColumn = Optional.Implicit(DefaultColumnNames.Weight); } + /// + /// The base class for all unsupervised learner inputs that support a weight column. + /// + [TlcModule.EntryPointKind(typeof(CommonInputs.IUnsupervisedTrainerWithWeight))] + public abstract class UnsupervisedLearnerInputBaseWithWeight : LearnerInputBase + { + [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] + public Optional WeightColumn = Optional.Implicit(DefaultColumnNames.Weight); + } + /// /// The base class for all evaluators inputs. /// @@ -224,6 +234,14 @@ public interface ITrainerInputWithLabel : ITrainerInput string LabelColumn { get; } } + /// + /// Interface that all API trainer input classes will implement. + /// + public interface IUnsupervisedTrainerWithWeight : ITrainerInput + { + Optional WeightColumn { get; } + } + /// /// Interface that all API trainer input classes will implement. /// diff --git a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs index b833278890..1f09ec850f 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs @@ -44,7 +44,7 @@ public enum InitAlgorithm KMeansParallel = 2 } - public class Arguments : LearnerInputBase + public class Arguments : UnsupervisedLearnerInputBaseWithWeight { [Argument(ArgumentType.AtMostOnce, HelpText = "The number of clusters", SortOrder = 50)] [TGUI(SuggestedSweeps = "5,10,20,40")] @@ -162,7 +162,7 @@ private void TrainCore(IChannel ch, RoleMappedData data) long missingFeatureCount; long totalTrainingInstances; - var cursorFactory = new FeatureFloatVectorCursor.Factory(data, CursOpt.Features | CursOpt.Id); + var cursorFactory = new FeatureFloatVectorCursor.Factory(data, CursOpt.Features | CursOpt.Id | CursOpt.Weight); // REVIEW: It would be nice to extract these out into subcomponents in the future. We should // revisit and even consider breaking these all into individual KMeans-flavored trainers, they // all produce a valid set of output centroids with various trade-offs in runtime (with perhaps @@ -234,7 +234,8 @@ public static CommonOutputs.ClusteringOutput TrainKMeans(IHostEnvironment env, A EntryPointUtils.CheckInputArgs(host, input); return LearnerEntryPointsUtils.Train(host, input, - () => new KMeansPlusPlusTrainer(host, input)); + () => new KMeansPlusPlusTrainer(host, input), + getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn)); } } diff --git a/src/Microsoft.ML.PCA/PcaTrainer.cs b/src/Microsoft.ML.PCA/PcaTrainer.cs index e46cc120fb..1d34db4b21 100644 --- a/src/Microsoft.ML.PCA/PcaTrainer.cs +++ b/src/Microsoft.ML.PCA/PcaTrainer.cs @@ -49,7 +49,7 @@ public sealed class RandomizedPcaTrainer : TrainerBase WeightColumn = Optional.Implicit(DefaultColumnNames.Weight); } private int _dimension; diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 010d4a0afa..7e06c77821 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -9117,6 +9117,18 @@ "IsNullable": false, "Default": "Features" }, + { + "Name": "WeightColumn", + "Type": "String", + "Desc": "Column to use for example weight", + "Aliases": [ + "weight" + ], + "Required": false, + "SortOrder": 4.0, + "IsNullable": false, + "Default": "Weight" + }, { "Name": "NormalizeFeatures", "Type": { @@ -9253,6 +9265,7 @@ } ], "InputKind": [ + "IUnsupervisedTrainerWithWeight", "ITrainerInput" ], "OutputKind": [ @@ -10875,6 +10888,7 @@ } ], "InputKind": [ + "IUnsupervisedTrainerWithWeight", "ITrainerInput" ], "OutputKind": [ @@ -22577,6 +22591,23 @@ "Type": "TransformModel" } ] + }, + { + "Kind": "IUnsupervisedTrainerWithWeight", + "Settings": [ + { + "Name": "WeightColumn", + "Type": "String" + }, + { + "Name": "TrainingData", + "Type": "DataView" + }, + { + "Name": "FeatureColumn", + "Type": "String" + } + ] } ] } \ No newline at end of file diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 8ef51b9de2..5249d59ace 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -1918,7 +1918,7 @@ public void EntryPointTrainTestMacroNoTransformInput() [Fact] public void EntryPointKMeans() { - TestEntryPointRoutine("Train-Tiny-28x28.txt", "Trainers.KMeansPlusPlusClusterer"); + TestEntryPointRoutine("Train-Tiny-28x28.txt", "Trainers.KMeansPlusPlusClusterer", "col=Weight:R4:0 col=Features:R4:1-784", ",'InitAlgorithm':'KMeansPlusPlus'"); } [Fact]