Skip to content

Modify API for advanced settings. (FastTree, RandomForest) #2047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 11, 2019
21 changes: 21 additions & 0 deletions src/Microsoft.ML.Data/EntryPoints/InputBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,27 @@ public enum CachingOptions
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInput))]
public abstract class LearnerInputBase
{
/// <summary>
/// The data to be used for training.
/// </summary>
[Argument(ArgumentType.Required, ShortName = "data", HelpText = "The data to be used for training", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public IDataView TrainingData;

/// <summary>
/// Column to use for features.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for features", ShortName = "feat", SortOrder = 2, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string FeatureColumn = DefaultColumnNames.Features;

/// <summary>
/// Normalize option for the feature column.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Normalize option for the feature column", ShortName = "norm", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public NormalizeOption NormalizeFeatures = NormalizeOption.Auto;

/// <summary>
/// Whether learner should cache input training data.
/// </summary>
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether learner should cache input training data", ShortName = "cache", SortOrder = 6, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public CachingOptions Caching = CachingOptions.Auto;
}
Expand All @@ -54,6 +66,9 @@ public abstract class LearnerInputBase
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithLabel))]
public abstract class LearnerInputBaseWithLabel : LearnerInputBase
{
/// <summary>
/// Column to use for labels.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 3, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string LabelColumn = DefaultColumnNames.Label;
}
Expand All @@ -65,6 +80,9 @@ public abstract class LearnerInputBaseWithLabel : LearnerInputBase
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithWeight))]
public abstract class LearnerInputBaseWithWeight : LearnerInputBaseWithLabel
{
/// <summary>
/// Column to use for example weight.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public Optional<string> WeightColumn = Optional<string>.Implicit(DefaultColumnNames.Weight);
}
Expand Down Expand Up @@ -95,6 +113,9 @@ public abstract class EvaluateInputBase
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithGroupId))]
public abstract class LearnerInputBaseWithGroupId : LearnerInputBaseWithWeight
{
/// <summary>
/// Column to use for example groupId.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example groupId", ShortName = "groupId", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public Optional<string> GroupIdColumn = Optional<string>.Implicit(DefaultColumnNames.GroupId);
}
Expand Down
12 changes: 3 additions & 9 deletions src/Microsoft.ML.FastTree/BoostingFastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,10 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env,
int numLeaves,
int numTrees,
int minDatapointsInLeaves,
double learningRate,
Action<TArgs> advancedSettings)
: base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves, advancedSettings)
double learningRate)
: base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves)
{

if (Args.LearningRates != learningRate)
{
using (var ch = Host.Start($"Setting learning rate to: {learningRate} as supplied in the direct arguments."))
Args.LearningRates = learningRate;
}
Args.LearningRates = learningRate;
}

protected override void CheckArgs(IChannel ch)
Expand Down
8 changes: 2 additions & 6 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env,
string groupIdColumn,
int numLeaves,
int numTrees,
int minDatapointsInLeaves,
Action<TArgs> advancedSettings)
int minDatapointsInLeaves)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn))
{
Args = new TArgs();
Expand All @@ -126,9 +125,6 @@ private protected FastTreeTrainerBase(IHostEnvironment env,
Args.NumTrees = numTrees;
Args.MinDocumentsInLeafs = minDatapointsInLeaves;

//apply the advanced args, if the user supplied any
advancedSettings?.Invoke(Args);

Args.LabelColumn = label.Name;
Args.FeatureColumn = featureColumn;

Expand All @@ -152,7 +148,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env,
}

/// <summary>
/// Legacy constructor that is used when invoking the classes deriving from this, through maml.
/// Constructor that is used when invoking the classes deriving from this, through maml.
/// </summary>
private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn, args.WeightColumn.IsExplicit))
Expand Down
Loading