Skip to content

Commit a927479

Browse files
authored
Modify API for advanced settings. (FastTree, RandomForest) (#2047)
* Changes for FastTree & related learners * Removing defaults from some of the newly added APIs * Argument -> Options * Pass objects as arguments instead of delegate * review comments - 1 * review comments - 2. updating comments, help summary etc * review comments - 3. Rename Options objects as options (instead of args or advancedSettings used so far) * making the constructors internal
1 parent d26510f commit a927479

File tree

20 files changed

+634
-250
lines changed

20 files changed

+634
-250
lines changed

src/Microsoft.ML.Data/EntryPoints/InputBase.cs

+21
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,27 @@ public enum CachingOptions
3535
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInput))]
3636
public abstract class LearnerInputBase
3737
{
38+
/// <summary>
39+
/// The data to be used for training.
40+
/// </summary>
3841
[Argument(ArgumentType.Required, ShortName = "data", HelpText = "The data to be used for training", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
3942
public IDataView TrainingData;
4043

44+
/// <summary>
45+
/// Column to use for features.
46+
/// </summary>
4147
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for features", ShortName = "feat", SortOrder = 2, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
4248
public string FeatureColumn = DefaultColumnNames.Features;
4349

50+
/// <summary>
51+
/// Normalize option for the feature column.
52+
/// </summary>
4453
[Argument(ArgumentType.AtMostOnce, HelpText = "Normalize option for the feature column", ShortName = "norm", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
4554
public NormalizeOption NormalizeFeatures = NormalizeOption.Auto;
4655

56+
/// <summary>
57+
/// Whether learner should cache input training data.
58+
/// </summary>
4759
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether learner should cache input training data", ShortName = "cache", SortOrder = 6, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
4860
public CachingOptions Caching = CachingOptions.Auto;
4961
}
@@ -54,6 +66,9 @@ public abstract class LearnerInputBase
5466
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithLabel))]
5567
public abstract class LearnerInputBaseWithLabel : LearnerInputBase
5668
{
69+
/// <summary>
70+
/// Column to use for labels.
71+
/// </summary>
5772
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 3, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
5873
public string LabelColumn = DefaultColumnNames.Label;
5974
}
@@ -65,6 +80,9 @@ public abstract class LearnerInputBaseWithLabel : LearnerInputBase
6580
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithWeight))]
6681
public abstract class LearnerInputBaseWithWeight : LearnerInputBaseWithLabel
6782
{
83+
/// <summary>
84+
/// Column to use for example weight.
85+
/// </summary>
6886
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
6987
public Optional<string> WeightColumn = Optional<string>.Implicit(DefaultColumnNames.Weight);
7088
}
@@ -95,6 +113,9 @@ public abstract class EvaluateInputBase
95113
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithGroupId))]
96114
public abstract class LearnerInputBaseWithGroupId : LearnerInputBaseWithWeight
97115
{
116+
/// <summary>
117+
/// Column to use for example groupId.
118+
/// </summary>
98119
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example groupId", ShortName = "groupId", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
99120
public Optional<string> GroupIdColumn = Optional<string>.Implicit(DefaultColumnNames.GroupId);
100121
}

src/Microsoft.ML.FastTree/BoostingFastTree.cs

+3-9
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,10 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env,
2828
int numLeaves,
2929
int numTrees,
3030
int minDatapointsInLeaves,
31-
double learningRate,
32-
Action<TArgs> advancedSettings)
33-
: base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves, advancedSettings)
31+
double learningRate)
32+
: base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves)
3433
{
35-
36-
if (Args.LearningRates != learningRate)
37-
{
38-
using (var ch = Host.Start($"Setting learning rate to: {learningRate} as supplied in the direct arguments."))
39-
Args.LearningRates = learningRate;
40-
}
34+
Args.LearningRates = learningRate;
4135
}
4236

4337
protected override void CheckArgs(IChannel ch)

src/Microsoft.ML.FastTree/FastTree.cs

+2-6
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env,
114114
string groupIdColumn,
115115
int numLeaves,
116116
int numTrees,
117-
int minDatapointsInLeaves,
118-
Action<TArgs> advancedSettings)
117+
int minDatapointsInLeaves)
119118
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn))
120119
{
121120
Args = new TArgs();
@@ -126,9 +125,6 @@ private protected FastTreeTrainerBase(IHostEnvironment env,
126125
Args.NumTrees = numTrees;
127126
Args.MinDocumentsInLeafs = minDatapointsInLeaves;
128127

129-
//apply the advanced args, if the user supplied any
130-
advancedSettings?.Invoke(Args);
131-
132128
Args.LabelColumn = label.Name;
133129
Args.FeatureColumn = featureColumn;
134130

@@ -152,7 +148,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env,
152148
}
153149

154150
/// <summary>
155-
/// Legacy constructor that is used when invoking the classes deriving from this, through maml.
151+
/// Constructor that is used when invoking the classes deriving from this, through maml.
156152
/// </summary>
157153
private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label)
158154
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))

0 commit comments

Comments
 (0)