Skip to content

Commit 5123aee

Browse files
authored
Trainer estimator cleanup for FastTrees and LightGBM (#1352)
* adding multiclass and ranking extensions for LightGBM. Adding tests, and refactoring catalog and pigsty statics * removing duplicate method from TrainUtils adding groupid to the trainer estimator base refactoring the catalog and static extensions for trees
1 parent 0f793fe commit 5123aee

36 files changed

+840
-486
lines changed

src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs

+35-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ public abstract class TrainerEstimatorBase<TTransformer, TModel> : ITrainerEstim
5050

5151
public abstract PredictionKind PredictionKind { get; }
5252

53-
public TrainerEstimatorBase(IHost host, SchemaShape.Column feature, SchemaShape.Column label, SchemaShape.Column weight = null)
53+
public TrainerEstimatorBase(IHost host,
54+
SchemaShape.Column feature,
55+
SchemaShape.Column label,
56+
SchemaShape.Column weight = null)
5457
{
5558
Contracts.CheckValue(host, nameof(host));
5659
Host = host;
@@ -149,9 +152,39 @@ protected TTransformer TrainTransformer(IDataView trainSet,
149152

150153
protected abstract TTransformer MakeTransformer(TModel model, Schema trainSchema);
151154

152-
private RoleMappedData MakeRoles(IDataView data) =>
155+
protected virtual RoleMappedData MakeRoles(IDataView data) =>
153156
new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, weight: WeightColumn?.Name);
154157

155158
IPredictor ITrainer.Train(TrainContext context) => Train(context);
156159
}
160+
161+
/// <summary>
162+
/// This represents a basic class for 'simple trainer'.
163+
/// A 'simple trainer' accepts one feature column and one label column, also optionally a weight column.
164+
/// It produces a 'prediction transformer'.
165+
/// </summary>
166+
public abstract class TrainerEstimatorBaseWithGroupId<TTransformer, TModel> : TrainerEstimatorBase<TTransformer, TModel>
167+
where TTransformer : ISingleFeaturePredictionTransformer<TModel>
168+
where TModel : IPredictor
169+
{
170+
/// <summary>
171+
/// The optional groupID column that the ranking trainers expects.
172+
/// </summary>
173+
public readonly SchemaShape.Column GroupIdColumn;
174+
175+
public TrainerEstimatorBaseWithGroupId(IHost host,
176+
SchemaShape.Column feature,
177+
SchemaShape.Column label,
178+
SchemaShape.Column weight = null,
179+
SchemaShape.Column groupId = null)
180+
:base(host, feature, label, weight)
181+
{
182+
Host.CheckValueOrNull(groupId);
183+
GroupIdColumn = groupId;
184+
}
185+
186+
protected override RoleMappedData MakeRoles(IDataView data) =>
187+
new RoleMappedData(data, label: LabelColumn?.Name, feature: FeatureColumn.Name, group: GroupIdColumn?.Name, weight: WeightColumn?.Name);
188+
189+
}
157190
}

src/Microsoft.ML.Data/Training/TrainerUtils.cs

+11-62
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,14 @@ public static SchemaShape.Column MakeR4ScalarLabel(string labelColumn)
362362
/// <summary>
363363
/// The <see cref="SchemaShape.Column"/> for the label column for regression tasks.
364364
/// </summary>
365-
/// <param name="labelColumn">name of the weight column</param>
366-
public static SchemaShape.Column MakeU4ScalarLabel(string labelColumn)
367-
=> new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
365+
/// <param name="columnName">name of the weight column</param>
366+
public static SchemaShape.Column MakeU4ScalarColumn(string columnName)
367+
{
368+
if (columnName == null)
369+
return null;
370+
371+
return new SchemaShape.Column(columnName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
372+
}
368373

369374
/// <summary>
370375
/// The <see cref="SchemaShape.Column"/> for the feature column.
@@ -377,69 +382,13 @@ public static SchemaShape.Column MakeR4VecFeature(string featureColumn)
377382
/// The <see cref="SchemaShape.Column"/> for the weight column.
378383
/// </summary>
379384
/// <param name="weightColumn">name of the weight column</param>
380-
public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn)
385+
/// <param name="isExplicit">whether the column is implicitly, or explicitly defined</param>
386+
public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn, bool isExplicit = true)
381387
{
382-
if (weightColumn == null)
388+
if (weightColumn == null || !isExplicit)
383389
return null;
384390
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
385391
}
386-
387-
private static void CheckArgColName(IHostEnvironment host, string defaultColName, string argValue)
388-
{
389-
if (argValue != defaultColName)
390-
throw host.Except($"Don't supply a value for the {defaultColName} column in the arguments, as it will be ignored. Specify them in the loader, or constructor instead instead.");
391-
}
392-
393-
/// <summary>
394-
/// Check that the label, feature, weights, groupId column names are not supplied in the args of the constructor, through the advancedSettings parameter,
395-
/// for cases when the public constructor is called.
396-
/// The recommendation is to set the column names directly.
397-
/// </summary>
398-
public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithGroupId args)
399-
{
400-
// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
401-
CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn);
402-
CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn);
403-
CheckArgColName(host, DefaultColumnNames.Weight, args.WeightColumn);
404-
405-
if (args.GroupIdColumn != null)
406-
CheckArgColName(host, DefaultColumnNames.GroupId, args.GroupIdColumn);
407-
}
408-
409-
/// <summary>
410-
/// Check that the label, feature, and weights column names are not supplied in the args of the constructor, through the advancedSettings parameter,
411-
/// for cases when the public constructor is called.
412-
/// The recommendation is to set the column names directly.
413-
/// </summary>
414-
public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithWeight args)
415-
{
416-
// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
417-
CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn);
418-
CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn);
419-
CheckArgColName(host, DefaultColumnNames.Weight, args.WeightColumn);
420-
}
421-
422-
/// <summary>
423-
/// Check that the label and feature column names are not supplied in the args of the constructor, through the advancedSettings parameter,
424-
/// for cases when the public constructor is called.
425-
/// The recommendation is to set the column names directly.
426-
/// </summary>
427-
public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithLabel args)
428-
{
429-
// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
430-
CheckArgColName(host, DefaultColumnNames.Label, args.LabelColumn);
431-
CheckArgColName(host, DefaultColumnNames.Features, args.FeatureColumn);
432-
}
433-
434-
/// <summary>
435-
/// If, after applying the advancedArgs delegate, the args are different that the default value
436-
/// and are also different than the value supplied directly to the xtension method, warn the user.
437-
/// </summary>
438-
public static void CheckArgsAndAdvancedSettingMismatch<T>(IChannel channel, T methodParam, T defaultVal, T setting, string argName)
439-
{
440-
if (!setting.Equals(defaultVal) && !setting.Equals(methodParam))
441-
channel.Warning($"The value supplied to advanced settings , is different than the value supplied directly. Using value {setting} for {argName}");
442-
}
443392
}
444393

445394
/// <summary>

src/Microsoft.ML.FastTree/BoostingFastTree.cs

+17-3
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,24 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaSh
2121
{
2222
}
2323

24-
protected BoostingFastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn,
25-
string weightColumn = null, string groupIdColumn = null, Action<TArgs> advancedSettings = null)
26-
: base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings)
24+
protected BoostingFastTreeTrainerBase(IHostEnvironment env,
25+
SchemaShape.Column label,
26+
string featureColumn,
27+
string weightColumn,
28+
string groupIdColumn,
29+
int numLeaves,
30+
int numTrees,
31+
int minDocumentsInLeafs,
32+
double learningRate,
33+
Action<TArgs> advancedSettings)
34+
: base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDocumentsInLeafs, advancedSettings)
2735
{
36+
37+
if (Args.LearningRates != learningRate)
38+
{
39+
using (var ch = Host.Start($"Setting learning rate to: {learningRate} as supplied in the direct arguments."))
40+
Args.LearningRates = learningRate;
41+
}
2842
}
2943

3044
protected override void CheckArgs(IChannel ch)

src/Microsoft.ML.FastTree/FastTree.cs

+21-36
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using Microsoft.ML.Runtime.CommandLine;
88
using Microsoft.ML.Runtime.Data;
99
using Microsoft.ML.Runtime.Data.Conversion;
10+
using Microsoft.ML.Runtime.EntryPoints;
1011
using Microsoft.ML.Runtime.Internal.Calibration;
1112
using Microsoft.ML.Runtime.Internal.Internallearn;
1213
using Microsoft.ML.Runtime.Internal.Utilities;
@@ -45,7 +46,7 @@ internal static class FastTreeShared
4546
}
4647

4748
public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
48-
TrainerEstimatorBase<TTransformer, TModel>
49+
TrainerEstimatorBaseWithGroupId<TTransformer, TModel>
4950
where TTransformer: ISingleFeaturePredictionTransformer<TModel>
5051
where TArgs : TreeArgs, new()
5152
where TModel : IPredictorProducing<Float>
@@ -92,26 +93,36 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
9293
/// <summary>
9394
/// Constructor to use when instantiating the classes deriving from here through the API.
9495
/// </summary>
95-
private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn,
96-
string weightColumn = null, string groupIdColumn = null, Action<TArgs> advancedSettings = null)
97-
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
96+
private protected FastTreeTrainerBase(IHostEnvironment env,
97+
SchemaShape.Column label,
98+
string featureColumn,
99+
string weightColumn,
100+
string groupIdColumn,
101+
int numLeaves,
102+
int numTrees,
103+
int minDocumentsInLeafs,
104+
Action<TArgs> advancedSettings)
105+
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn))
98106
{
99107
Args = new TArgs();
100108

109+
// set up the directly provided values
110+
// override with the directly provided values.
111+
Args.NumLeaves = numLeaves;
112+
Args.NumTrees = numTrees;
113+
Args.MinDocumentsInLeafs = minDocumentsInLeafs;
114+
101115
//apply the advanced args, if the user supplied any
102116
advancedSettings?.Invoke(Args);
103117

104-
// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
105-
TrainerUtils.CheckArgsHaveDefaultColNames(Host, Args);
106-
107118
Args.LabelColumn = label.Name;
108119
Args.FeatureColumn = featureColumn;
109120

110121
if (weightColumn != null)
111-
Args.WeightColumn = weightColumn;
122+
Args.WeightColumn = Optional<string>.Explicit(weightColumn); ;
112123

113124
if (groupIdColumn != null)
114-
Args.GroupIdColumn = groupIdColumn;
125+
Args.GroupIdColumn = Optional<string>.Explicit(groupIdColumn); ;
115126

116127
// The discretization step renders this trainer non-parametric, and therefore it does not need normalization.
117128
// Also since it builds its own internal discretized columnar structures, it cannot benefit from caching.
@@ -128,7 +139,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column l
128139
/// Legacy constructor that is used when invoking the classes deriving from this, through maml.
129140
/// </summary>
130141
private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label)
131-
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
142+
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
132143
{
133144
Host.CheckValue(args, nameof(args));
134145
Args = args;
@@ -159,32 +170,6 @@ protected virtual Float GetMaxLabel()
159170
return Float.PositiveInfinity;
160171
}
161172

162-
/// <summary>
163-
/// If, after applying the advancedSettings delegate, the args are different that the default value
164-
/// and are also different than the value supplied directly to the xtension method, warn the user
165-
/// about which value is being used.
166-
/// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune.
167-
/// This list should follow the one in the constructor, and the extension methods on the <see cref="TrainContextBase"/>.
168-
/// REVIEW: we should somehow mark the arguments that are set apart in those two places. Currently they stand out by their sort order annotation.
169-
/// </summary>
170-
protected void CheckArgsAndAdvancedSettingMismatch(int numLeaves,
171-
int numTrees,
172-
int minDocumentsInLeafs,
173-
double learningRate,
174-
BoostedTreeArgs snapshot,
175-
BoostedTreeArgs currentArgs)
176-
{
177-
using (var ch = Host.Start("Comparing advanced settings with the directly provided values."))
178-
{
179-
180-
// Check that the user didn't supply different parameters in the args, from what it specified directly.
181-
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numLeaves, snapshot.NumLeaves, currentArgs.NumLeaves, nameof(numLeaves));
182-
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numTrees, snapshot.NumTrees, currentArgs.NumTrees, nameof(numTrees));
183-
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, minDocumentsInLeafs, snapshot.MinDocumentsInLeafs, currentArgs.MinDocumentsInLeafs, nameof(minDocumentsInLeafs));
184-
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, learningRate, snapshot.LearningRates, currentArgs.LearningRates, nameof(learningRate));
185-
}
186-
}
187-
188173
private void Initialize(IHostEnvironment env)
189174
{
190175
int numThreads = Args.NumThreads ?? Environment.ProcessorCount;

0 commit comments

Comments
 (0)