Skip to content

Adding the extension methods for FastTree classification and regression #1009

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 25, 2018
31 changes: 31 additions & 0 deletions src/Microsoft.ML.Data/Training/TrainerUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using System;
using System.Collections.Generic;
Expand Down Expand Up @@ -382,6 +383,36 @@ public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn)
return null;
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}

/// <summary>
/// Check that the label, feature, weights, groupId column names are not supplied in the args of the constructor, through the advancedSettings parameter,
/// for cases when the public constructor is called.
/// The recommendation is to set the column names directly.
/// </summary>
public static void CheckArgsHaveDefaultColNames(IHostEnvironment host, LearnerInputBaseWithGroupId args)
{
Action<string, string> checkArgColName = (defaultColName, argValue) =>
{
if (argValue != defaultColName)
throw host.Except($"Don't supply a value for the {defaultColName} column in the arguments, as it will be ignored. Specify them in the loader, or constructor instead instead.");
};

// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
checkArgColName(DefaultColumnNames.Label, args.LabelColumn);
checkArgColName(DefaultColumnNames.Features, args.FeatureColumn);
checkArgColName(DefaultColumnNames.Weight, args.WeightColumn);

if(args.GroupIdColumn != null)
checkArgColName(DefaultColumnNames.GroupId, args.GroupIdColumn);
}

public static void CheckArgsAndAdvancedSettingMismatch<T>(IChannel channel, T methodParam, T defaultVal, T setting, string argName)
{
// if, after applying the advancedArgs delegate, the args are different that the default value
// and are also different than the value supplied directly to the xtension method, warn the user.
if (!setting.Equals(defaultVal) && !setting.Equals(methodParam))
channel.Warning($"The value supplied to advanced settings , is different than the value supplied directly. Using value {setting} for {argName}");
}
}

/// <summary>
Expand Down
55 changes: 33 additions & 22 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,16 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
/// </summary>
private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column label, string featureColumn,
string weightColumn = null, string groupIdColumn = null, Action<TArgs> advancedSettings = null)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), MakeFeatureColumn(featureColumn), label, MakeWeightColumn(weightColumn))
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
{
Args = new TArgs();

//apply the advanced args, if the user supplied any
advancedSettings?.Invoke(Args);

// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
TrainerUtils.CheckArgsHaveDefaultColNames(Host, Args);

Args.LabelColumn = label.Name;
Args.FeatureColumn = featureColumn;

Expand All @@ -123,7 +127,7 @@ private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column l
/// Legacy constructor that is used when invoking the classes deriving from this, through maml.
/// </summary>
private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaShape.Column label)
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), MakeFeatureColumn(args.FeatureColumn), label, MakeWeightColumn(args.WeightColumn))
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
{
Host.CheckValue(args, nameof(args));
Args = args;
Expand Down Expand Up @@ -154,16 +158,32 @@ protected virtual Float GetMaxLabel()
return Float.PositiveInfinity;
}

private static SchemaShape.Column MakeWeightColumn(string weightColumn)
/// <summary>
/// If, after applying the advancedSettings delegate, the args are different that the default value
/// and are also different than the value supplied directly to the xtension method, warn the user
/// about which value is being used.
/// The parameters that appear here, numTrees, minDocumentsInLeafs, numLeaves, learningRate are the ones the users are most likely to tune.
/// This list should follow the one in the constructor, and the extension methods on the <see cref="TrainContextBase"/>.
/// REVIEW: we should somehow mark the arguments that are set apart in those two places. Currently they stand out by their sort order annotation.
/// </summary>
protected void CheckArgsAndAdvancedSettingMismatch(int numLeaves,
int numTrees,
int minDocumentsInLeafs,
double learningRate,
BoostedTreeArgs snapshot,
BoostedTreeArgs currentArgs)
{
if (weightColumn == null)
return null;
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}
using (var ch = Host.Start("Comparing advanced settings with the directly provided values."))
{

private static SchemaShape.Column MakeFeatureColumn(string featureColumn)
{
return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
// Check that the user didn't supply different parameters in the args, from what it specified directly.
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numLeaves, snapshot.NumLeaves, currentArgs.NumLeaves, nameof(numLeaves));
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numTrees, snapshot.NumTrees, currentArgs.NumTrees, nameof(numTrees));
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, minDocumentsInLeafs, snapshot.MinDocumentsInLeafs, currentArgs.MinDocumentsInLeafs, nameof(minDocumentsInLeafs));
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, learningRate, snapshot.LearningRates, currentArgs.LearningRates, nameof(learningRate));

ch.Done();
}
}

private void Initialize(IHostEnvironment env)
Expand Down Expand Up @@ -244,10 +264,7 @@ protected virtual bool ShouldStop(IChannel ch, ref IEarlyStoppingCriterion early
bestIteration = Ensemble.NumTrees;
return false;
}
protected virtual int GetBestIteration(IChannel ch)
{
return Ensemble.NumTrees;
}
protected virtual int GetBestIteration(IChannel ch) => Ensemble.NumTrees;

protected virtual void InitializeThreads(int numThreads)
{
Expand Down Expand Up @@ -307,21 +324,15 @@ protected virtual void CheckArgs(IChannel ch)
/// it to print specific test graph header.
/// </summary>
/// <returns> string representation of test graph header </returns>
protected virtual string GetTestGraphHeader()
{
return string.Empty;
}
protected virtual string GetTestGraphHeader() => string.Empty;

/// <summary>
/// A virtual method that is used to print a single line of test graph.
/// Applications that need printing test graph are supposed to override
/// it to print a specific line of test graph after a new iteration is finished.
/// </summary>
/// <returns> string representation of a line of test graph </returns>
protected virtual string GetTestGraphLine()
{
return string.Empty;
}
protected virtual string GetTestGraphLine() => string.Empty;

/// <summary>
/// A virtual method that is used to compute test results after each iteration is finished.
Expand Down
16 changes: 12 additions & 4 deletions src/Microsoft.ML.FastTree/FastTreeArguments.cs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ public enum Bundle : Byte
Adjacent = 2
}

internal static class Defaults
{
internal const int NumTrees = 100;
internal const int NumLeaves = 20;
internal const int MinDocumentsInLeafs = 10;
internal const double LearningRates = 0.2;
}

public abstract class TreeArgs : LearnerInputBaseWithGroupId
{
[Argument(ArgumentType.Multiple, HelpText = "Allows to choose Parallel FastTree Learning Algorithm", ShortName = "parag")]
Expand Down Expand Up @@ -229,20 +237,20 @@ public abstract class TreeArgs : LearnerInputBaseWithGroupId
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The max number of leaves in each regression tree", ShortName = "nl", SortOrder = 2)]
[TGUI(Description = "The maximum number of leaves per tree", SuggestedSweeps = "2-128;log;inc:4")]
[TlcModule.SweepableLongParamAttribute("NumLeaves", 2, 128, isLogScale: true, stepSize: 4)]
public int NumLeaves = 20;
public int NumLeaves = Defaults.NumLeaves;

// REVIEW: Arrays not supported in GUI
// REVIEW: Different shortname than FastRank module. Same as the TLC FRWrapper.
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data", ShortName = "mil", SortOrder = 3)]
[TGUI(Description = "Minimum number of training instances required to form a leaf", SuggestedSweeps = "1,10,50")]
[TlcModule.SweepableDiscreteParamAttribute("MinDocumentsInLeafs", new object[] { 1, 10, 50 })]
public int MinDocumentsInLeafs = 10;
public int MinDocumentsInLeafs = Defaults.MinDocumentsInLeafs;

// REVIEW: Different shortname than FastRank module. Same as the TLC FRWrapper.
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Total number of decision trees to create in the ensemble", ShortName = "iter", SortOrder = 1)]
[TGUI(Description = "Total number of trees constructed", SuggestedSweeps = "20,100,500")]
[TlcModule.SweepableDiscreteParamAttribute("NumTrees", new object[] { 20, 100, 500 })]
public int NumTrees = 100;
public int NumTrees = Defaults.NumTrees;

[Argument(ArgumentType.AtMostOnce, HelpText = "The fraction of features (chosen randomly) to use on each iteration", ShortName = "ff")]
public Double FeatureFraction = 1;
Expand Down Expand Up @@ -365,7 +373,7 @@ public enum OptimizationAlgorithmType { GradientDescent, AcceleratedGradientDesc
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The learning rate", ShortName = "lr", SortOrder = 4)]
[TGUI(Label = "Learning Rate", SuggestedSweeps = "0.025-0.4;log")]
[TlcModule.SweepableFloatParamAttribute("LearningRates", 0.025f, 0.4f, isLogScale: true)]
public Double LearningRates = 0.2;
public Double LearningRates = Defaults.LearningRates;

[Argument(ArgumentType.AtMostOnce, HelpText = "Shrinkage", ShortName = "shrk")]
[TGUI(Label = "Shrinkage", SuggestedSweeps = "0.25-4;log")]
Expand Down
27 changes: 23 additions & 4 deletions src/Microsoft.ML.FastTree/FastTreeClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,37 @@ public sealed partial class FastTreeBinaryClassificationTrainer :
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="labelColumn">The name of the label column.</param>
/// <param name="featureColumn">The name of the feature column.</param>
/// <param name="groupIdColumn">The name for the column containing the group ID. </param>
/// <param name="weightColumn">The name for the column containing the initial weight.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
public FastTreeBinaryClassificationTrainer(IHostEnvironment env, string labelColumn, string featureColumn,
string groupIdColumn = null, string weightColumn = null, Action<Arguments> advancedSettings = null)
: base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings)
/// <param name="learningRate">The learning rate.</param>
/// <param name="minDocumentsInLeafs">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param>
/// <param name="numLeaves">The max number of leaves in each regression tree.</param>
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
public FastTreeBinaryClassificationTrainer(IHostEnvironment env,
string labelColumn,
string featureColumn,
string weightColumn = null,
int numLeaves = Defaults.NumLeaves,
int numTrees = Defaults.NumTrees,
int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs,
double learningRate = Defaults.LearningRates,
Action<Arguments> advancedSettings = null)
: base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));

// Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss
_sigmoidParameter = 2.0 * Args.LearningRates;

if (advancedSettings != null)
CheckArgsAndAdvancedSettingMismatch(numLeaves, numTrees, minDocumentsInLeafs, learningRate, new Arguments(), Args);

//override with the directly provided values.
Args.NumLeaves = numLeaves;
Args.NumTrees = numTrees;
Args.MinDocumentsInLeafs = minDocumentsInLeafs;
Args.LearningRates = learningRate;
}

/// <summary>
Expand Down
21 changes: 17 additions & 4 deletions src/Microsoft.ML.FastTree/FastTreeRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,28 @@ public sealed partial class FastTreeRegressionTrainer
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="labelColumn">The name of the label column.</param>
/// <param name="featureColumn">The name of the feature column.</param>
/// <param name="groupIdColumn">The name for the column containing the group ID. </param>
/// <param name="weightColumn">The name for the column containing the initial weight.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
public FastTreeRegressionTrainer(IHostEnvironment env, string labelColumn, string featureColumn,
string weightColumn = null, string groupIdColumn = null, Action<Arguments> advancedSettings = null)
: base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings)
/// <param name="learningRate">The learning rate.</param>
/// <param name="minDocumentsInLeafs">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param>
/// <param name="numLeaves">The max number of leaves in each regression tree.</param>
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
public FastTreeRegressionTrainer(IHostEnvironment env,
string labelColumn,
string featureColumn,
string weightColumn = null,
int numLeaves = Defaults.NumLeaves,
int numTrees = Defaults.NumTrees,
int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs,
double learningRate = Defaults.LearningRates,
Action<Arguments> advancedSettings = null)
: base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));

if (advancedSettings != null)
CheckArgsAndAdvancedSettingMismatch(numLeaves, numTrees, minDocumentsInLeafs, learningRate, new Arguments(), Args);
}

/// <summary>
Expand Down
Loading