Skip to content

Commit a3b7fce

Browse files
Ivanidzo4kaeerhardt
authored andcommitted
introduce IUnsupervisedLearningWithWeights (dotnet#236)
* introduce IUnsupervisedLearningWithWeights * add test to check KMeans don't need label and can handle presence of weight column. also extract real weight value from cursor.
1 parent 3b5fbfb commit a3b7fce

File tree

5 files changed

+55
-8
lines changed

5 files changed

+55
-8
lines changed

src/Microsoft.ML.Data/EntryPoints/InputBase.cs

+18
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,16 @@ public abstract class LearnerInputBaseWithWeight : LearnerInputBaseWithLabel
6969
public Optional<string> WeightColumn = Optional<string>.Implicit(DefaultColumnNames.Weight);
7070
}
7171

72+
/// <summary>
73+
/// The base class for all unsupervised learner inputs that support a weight column.
74+
/// </summary>
75+
[TlcModule.EntryPointKind(typeof(CommonInputs.IUnsupervisedTrainerWithWeight))]
76+
public abstract class UnsupervisedLearnerInputBaseWithWeight : LearnerInputBase
77+
{
78+
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
79+
public Optional<string> WeightColumn = Optional<string>.Implicit(DefaultColumnNames.Weight);
80+
}
81+
7282
/// <summary>
7383
/// The base class for all evaluators inputs.
7484
/// </summary>
@@ -224,6 +234,14 @@ public interface ITrainerInputWithLabel : ITrainerInput
224234
string LabelColumn { get; }
225235
}
226236

237+
/// <summary>
238+
/// Interface that all API trainer input classes will implement.
239+
/// </summary>
240+
public interface IUnsupervisedTrainerWithWeight : ITrainerInput
241+
{
242+
Optional<string> WeightColumn { get; }
243+
}
244+
227245
/// <summary>
228246
/// Interface that all API trainer input classes will implement.
229247
/// </summary>

src/Microsoft.ML.KMeansClustering/KMeansPlusPlusTrainer.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public enum InitAlgorithm
4444
KMeansParallel = 2
4545
}
4646

47-
public class Arguments : LearnerInputBase
47+
public class Arguments : UnsupervisedLearnerInputBaseWithWeight
4848
{
4949
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of clusters", SortOrder = 50)]
5050
[TGUI(SuggestedSweeps = "5,10,20,40")]
@@ -162,7 +162,7 @@ private void TrainCore(IChannel ch, RoleMappedData data)
162162
long missingFeatureCount;
163163
long totalTrainingInstances;
164164

165-
var cursorFactory = new FeatureFloatVectorCursor.Factory(data, CursOpt.Features | CursOpt.Id);
165+
var cursorFactory = new FeatureFloatVectorCursor.Factory(data, CursOpt.Features | CursOpt.Id | CursOpt.Weight);
166166
// REVIEW: It would be nice to extract these out into subcomponents in the future. We should
167167
// revisit and even consider breaking these all into individual KMeans-flavored trainers, they
168168
// all produce a valid set of output centroids with various trade-offs in runtime (with perhaps
@@ -234,7 +234,8 @@ public static CommonOutputs.ClusteringOutput TrainKMeans(IHostEnvironment env, A
234234
EntryPointUtils.CheckInputArgs(host, input);
235235

236236
return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.ClusteringOutput>(host, input,
237-
() => new KMeansPlusPlusTrainer(host, input));
237+
() => new KMeansPlusPlusTrainer(host, input),
238+
getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn));
238239
}
239240
}
240241

src/Microsoft.ML.PCA/PcaTrainer.cs

+1-4
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public sealed class RandomizedPcaTrainer : TrainerBase<RoleMappedData, PcaPredic
4949
internal const string Summary = "This algorithm trains an approximate PCA using Randomized SVD algorithm. "
5050
+ "This PCA can be made into Kernel PCA by using Random Fourier Features transform.";
5151

52-
public class Arguments : LearnerInputBase
52+
public class Arguments : UnsupervisedLearnerInputBaseWithWeight
5353
{
5454
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of components in the PCA", ShortName = "k", SortOrder = 50)]
5555
[TGUI(SuggestedSweeps = "10,20,40,80")]
@@ -67,9 +67,6 @@ public class Arguments : LearnerInputBase
6767

6868
[Argument(ArgumentType.AtMostOnce, HelpText = "The seed for random number generation", ShortName = "seed")]
6969
public int? Seed;
70-
71-
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
72-
public Optional<string> WeightColumn = Optional<string>.Implicit(DefaultColumnNames.Weight);
7370
}
7471

7572
private int _dimension;

test/BaselineOutput/Common/EntryPoints/core_manifest.json

+31
Original file line numberDiff line numberDiff line change
@@ -9117,6 +9117,18 @@
91179117
"IsNullable": false,
91189118
"Default": "Features"
91199119
},
9120+
{
9121+
"Name": "WeightColumn",
9122+
"Type": "String",
9123+
"Desc": "Column to use for example weight",
9124+
"Aliases": [
9125+
"weight"
9126+
],
9127+
"Required": false,
9128+
"SortOrder": 4.0,
9129+
"IsNullable": false,
9130+
"Default": "Weight"
9131+
},
91209132
{
91219133
"Name": "NormalizeFeatures",
91229134
"Type": {
@@ -9253,6 +9265,7 @@
92539265
}
92549266
],
92559267
"InputKind": [
9268+
"IUnsupervisedTrainerWithWeight",
92569269
"ITrainerInput"
92579270
],
92589271
"OutputKind": [
@@ -10875,6 +10888,7 @@
1087510888
}
1087610889
],
1087710890
"InputKind": [
10891+
"IUnsupervisedTrainerWithWeight",
1087810892
"ITrainerInput"
1087910893
],
1088010894
"OutputKind": [
@@ -22577,6 +22591,23 @@
2257722591
"Type": "TransformModel"
2257822592
}
2257922593
]
22594+
},
22595+
{
22596+
"Kind": "IUnsupervisedTrainerWithWeight",
22597+
"Settings": [
22598+
{
22599+
"Name": "WeightColumn",
22600+
"Type": "String"
22601+
},
22602+
{
22603+
"Name": "TrainingData",
22604+
"Type": "DataView"
22605+
},
22606+
{
22607+
"Name": "FeatureColumn",
22608+
"Type": "String"
22609+
}
22610+
]
2258022611
}
2258122612
]
2258222613
}

test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1918,7 +1918,7 @@ public void EntryPointTrainTestMacroNoTransformInput()
19181918
[Fact]
19191919
public void EntryPointKMeans()
19201920
{
1921-
TestEntryPointRoutine("Train-Tiny-28x28.txt", "Trainers.KMeansPlusPlusClusterer");
1921+
TestEntryPointRoutine("Train-Tiny-28x28.txt", "Trainers.KMeansPlusPlusClusterer", "col=Weight:R4:0 col=Features:R4:1-784", ",'InitAlgorithm':'KMeansPlusPlus'");
19221922
}
19231923

19241924
[Fact]

0 commit comments

Comments
 (0)