Skip to content

Adding the extension methods for FastTree classification and regression #1009

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 25, 2018
Merged
19 changes: 19 additions & 0 deletions src/Microsoft.ML.Data/Training/TrainerUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using System;
using System.Collections.Generic;
Expand Down Expand Up @@ -382,6 +383,24 @@ public static SchemaShape.Column MakeR4ScalarWeightColumn(string weightColumn)
return null;
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}

/// <summary>
/// Check that the label, feature, weights is not supplied in the args the constructor.
/// Those parameters should be internal if they are not used from the maml help code path.
/// </summary>
public static void CheckArgsDefaultColNames(IHost env, string defaultColName, string argValue)
{
if (argValue != defaultColName)
throw env.Except($"Don't supply a value for the {defaultColName} column in the arguments, as it will be ignored. Specify them in the loader, or constructor instead instead.");
}

public static void CheckArgsAndAdvancedSettingMismatch<T>(IChannel channel, T methodParam, T defaultVal, T setting, string argName)
{
// if, after applying the advancedArgs delegate, the args are different that the default value
// and are also different than the value supplied directly to the xtension method, warn the user.
if (!setting.Equals(defaultVal) && !setting.Equals(methodParam))
channel.Warning($"The value supplied to advanced settings , is different than the value supplied directly. Using value {setting} for {argName}");
}
}

/// <summary>
Expand Down
7 changes: 7 additions & 0 deletions src/Microsoft.ML.FastTree/FastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ private protected FastTreeTrainerBase(IHostEnvironment env, SchemaShape.Column l

//apply the advanced args, if the user supplied any
advancedSettings?.Invoke(Args);

// check that the users didn't specify different label, group, feature, weights in the args, from what they supplied directly
TrainerUtils.CheckArgsDefaultColNames(Host, DefaultColumnNames.Label, Args.LabelColumn);
TrainerUtils.CheckArgsDefaultColNames(Host, DefaultColumnNames.Features, Args.FeatureColumn);
TrainerUtils.CheckArgsDefaultColNames(Host, DefaultColumnNames.GroupId, Args.GroupIdColumn);
TrainerUtils.CheckArgsDefaultColNames(Host, DefaultColumnNames.Weight, Args.WeightColumn);

Args.LabelColumn = label.Name;
Args.FeatureColumn = featureColumn;

Expand Down
16 changes: 12 additions & 4 deletions src/Microsoft.ML.FastTree/FastTreeArguments.cs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ public enum Bundle : Byte
Adjacent = 2
}

internal static class Defaults
{
internal const int NumTrees = 100;
internal const int NumLeaves = 20;
internal const int MinDocumentsInLeafs = 10;
internal const double LearningRates = 0.2;
}

public abstract class TreeArgs : LearnerInputBaseWithGroupId
{
[Argument(ArgumentType.Multiple, HelpText = "Allows to choose Parallel FastTree Learning Algorithm", ShortName = "parag")]
Expand Down Expand Up @@ -229,20 +237,20 @@ public abstract class TreeArgs : LearnerInputBaseWithGroupId
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The max number of leaves in each regression tree", ShortName = "nl", SortOrder = 2)]
[TGUI(Description = "The maximum number of leaves per tree", SuggestedSweeps = "2-128;log;inc:4")]
[TlcModule.SweepableLongParamAttribute("NumLeaves", 2, 128, isLogScale: true, stepSize: 4)]
public int NumLeaves = 20;
public int NumLeaves = Defaults.NumLeaves;

// REVIEW: Arrays not supported in GUI
// REVIEW: Different shortname than FastRank module. Same as the TLC FRWrapper.
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data", ShortName = "mil", SortOrder = 3)]
[TGUI(Description = "Minimum number of training instances required to form a leaf", SuggestedSweeps = "1,10,50")]
[TlcModule.SweepableDiscreteParamAttribute("MinDocumentsInLeafs", new object[] { 1, 10, 50 })]
public int MinDocumentsInLeafs = 10;
public int MinDocumentsInLeafs = Defaults.MinDocumentsInLeafs;

// REVIEW: Different shortname than FastRank module. Same as the TLC FRWrapper.
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Total number of decision trees to create in the ensemble", ShortName = "iter", SortOrder = 1)]
[TGUI(Description = "Total number of trees constructed", SuggestedSweeps = "20,100,500")]
[TlcModule.SweepableDiscreteParamAttribute("NumTrees", new object[] { 20, 100, 500 })]
public int NumTrees = 100;
public int NumTrees = Defaults.NumTrees;

[Argument(ArgumentType.AtMostOnce, HelpText = "The fraction of features (chosen randomly) to use on each iteration", ShortName = "ff")]
public Double FeatureFraction = 1;
Expand Down Expand Up @@ -365,7 +373,7 @@ public enum OptimizationAlgorithmType { GradientDescent, AcceleratedGradientDesc
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The learning rate", ShortName = "lr", SortOrder = 4)]
[TGUI(Label = "Learning Rate", SuggestedSweeps = "0.025-0.4;log")]
[TlcModule.SweepableFloatParamAttribute("LearningRates", 0.025f, 0.4f, isLogScale: true)]
public Double LearningRates = 0.2;
public Double LearningRates = Defaults.LearningRates;

[Argument(ArgumentType.AtMostOnce, HelpText = "Shrinkage", ShortName = "shrk")]
[TGUI(Label = "Shrinkage", SuggestedSweeps = "0.25-4;log")]
Expand Down
41 changes: 37 additions & 4 deletions src/Microsoft.ML.FastTree/FastTreeClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,51 @@ public sealed partial class FastTreeBinaryClassificationTrainer :
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="labelColumn">The name of the label column.</param>
/// <param name="featureColumn">The name of the feature column.</param>
/// <param name="groupIdColumn">The name for the column containing the group ID. </param>
/// <param name="weightColumn">The name for the column containing the initial weight.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
public FastTreeBinaryClassificationTrainer(IHostEnvironment env, string labelColumn, string featureColumn,
string groupIdColumn = null, string weightColumn = null, Action<Arguments> advancedSettings = null)
: base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings)
/// <param name="learningRate">The learning rate.</param>
/// <param name="minDocumentsInLeafs">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param>
/// <param name="numLeaves">The max number of leaves in each regression tree.</param>
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
public FastTreeBinaryClassificationTrainer(IHostEnvironment env,
string labelColumn,
string featureColumn,
string weightColumn = null,
int numLeaves = Defaults.NumLeaves,
int numTrees = Defaults.NumTrees,
int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs,
double learningRate = Defaults.LearningRates,
Action<Arguments> advancedSettings = null)
: base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));

// Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss
_sigmoidParameter = 2.0 * Args.LearningRates;

if (advancedSettings != null)
{
using (var ch = Host.Start("Validating advanced settings."))
{
//take a quick snapshot at the defaults, for comparison with the current args values
var snapshot = new Arguments();

// Check that the user didn't supply different parameters in the args, from what it specified directly.
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numLeaves, snapshot.NumLeaves, Args.NumLeaves, nameof(numLeaves));
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numTrees, snapshot.NumTrees, Args.NumTrees, nameof(numTrees));
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, minDocumentsInLeafs, snapshot.MinDocumentsInLeafs, Args.MinDocumentsInLeafs, nameof(minDocumentsInLeafs));
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, learningRate, snapshot.LearningRates, Args.LearningRates, nameof(learningRate));

ch.Done();
}
}
Copy link
Member Author

@sfilipi sfilipi Sep 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move in FastTreeUtils and re-use. #Resolved


//override with the directly provided values.
Args.NumLeaves = numLeaves;
Args.NumTrees = numTrees;
Args.MinDocumentsInLeafs = minDocumentsInLeafs;
Args.LearningRates = learningRate;
}

/// <summary>
Expand Down
34 changes: 30 additions & 4 deletions src/Microsoft.ML.FastTree/FastTreeRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,41 @@ public sealed partial class FastTreeRegressionTrainer
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="labelColumn">The name of the label column.</param>
/// <param name="featureColumn">The name of the feature column.</param>
/// <param name="groupIdColumn">The name for the column containing the group ID. </param>
/// <param name="weightColumn">The name for the column containing the initial weight.</param>
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
public FastTreeRegressionTrainer(IHostEnvironment env, string labelColumn, string featureColumn,
string weightColumn = null, string groupIdColumn = null, Action<Arguments> advancedSettings = null)
: base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings)
/// <param name="learningRates">The learning rate.</param>
/// <param name="minDocumentsInLeafs">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param>
/// <param name="numLeaves">The max number of leaves in each regression tree.</param>
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
public FastTreeRegressionTrainer(IHostEnvironment env,
string labelColumn,
string featureColumn,
string weightColumn = null,
int numLeaves = 20,
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

20 [](start = 28, length = 2)

Should it be part of static Defaults class? #Closed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.


In reply to: 220018807 [](ancestors = 220018807)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have Default class, but you don't use it here.


In reply to: 220063796 [](ancestors = 220063796,220018807)

int numTrees = 100,
int minDocumentsInLeafs = 10,
double learningRates = 0.2,
Action<Arguments> advancedSettings = null)
: base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings)
{
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));

if (advancedSettings != null)
{
using (var ch = Host.Start("Validating advanced settings."))
{
//take a quick snapshot at the defaults, for comparison with the current args values
var snapshot = new Arguments();

// Check that the user didn't supply different parameters in the args, from what it specified directly.
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numLeaves, snapshot.NumLeaves, Args.NumLeaves, nameof(numLeaves));
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numTrees, snapshot.NumTrees, Args.NumTrees, nameof(numTrees));
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, minDocumentsInLeafs, snapshot.MinDocumentsInLeafs, Args.MinDocumentsInLeafs, nameof(minDocumentsInLeafs));
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numLeaves, snapshot.NumLeaves, Args.NumLeaves, nameof(numLeaves));
ch.Done();
}
}
}

/// <summary>
Expand Down
86 changes: 86 additions & 0 deletions src/Microsoft.ML.FastTree/FastTreeStatic.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Data.StaticPipe.Runtime;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.FastTree;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Runtime.Training;
using System;

namespace Microsoft.ML.Trainers
{
public static class FastTreeStatic
{
public static Scalar<float> FastTree(this RegressionContext.RegressionTrainers ctx,
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's my time to comment what you need comments for public method! Oh, the joy of revenge!111 #Closed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did them, before you commented!!!! see the next push :)


In reply to: 220283064 [](ancestors = 220283064)

Scalar<float> label, Vector<float> features, Scalar<float> weights = null,
int numLeaves = Defaults.NumLeaves,
int numTrees = Defaults.NumTrees,
int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs,
double learningRate= Defaults.LearningRates,
Action<FastTreeRegressionTrainer.Arguments> advancedSettings = null,
Action<FastTreeRegressionPredictor> onFit = null)
{
CheckUserValues(label, features, weights, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings, onFit);

Copy link
Member Author

@sfilipi sfilipi Sep 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't duplicate #Resolved

var rec = new TrainerEstimatorReconciler.Regression(
(env, labelName, featuresName, weightsName) =>
{
var trainer = new FastTreeRegressionTrainer(env, labelName, featuresName, weightsName, numLeaves,
numTrees, minDocumentsInLeafs, learningRate, advancedSettings);
if (onFit != null)
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
return trainer;
}, label, features, weights);

return rec.Score;
}

public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel) FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
Scalar<bool> label, Vector<float> features, Scalar<float> weights = null,
int numLeaves = Defaults.NumLeaves,
int numTrees = Defaults.NumTrees,
int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs,
double learningRate = Defaults.LearningRates,
Action<FastTreeBinaryClassificationTrainer.Arguments> advancedSettings = null,
Action<IPredictorWithFeatureWeights<float>> onFit = null)
Copy link
Contributor

@Zruty0 Zruty0 Sep 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IPredictorWithFeatureWeights [](start = 19, length = 28)

that's the best we could do here? I would argue that we should expose some form of tree ensemble predictor or something #Pending

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave it for now and log an issue about it. Just casting didn't work.


In reply to: 220030085 [](ancestors = 220030085)

{
CheckUserValues(label, features, weights, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings, onFit);

var rec = new TrainerEstimatorReconciler.BinaryClassifier(
(env, labelName, featuresName, weightsName) =>
{
var trainer = new FastTreeBinaryClassificationTrainer(env, labelName, featuresName, weightsName, numLeaves,
numTrees, minDocumentsInLeafs, learningRate, advancedSettings);

if (onFit != null)
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
else
return trainer;
}, label, features, weights);

return rec.Output;
}

private static void CheckUserValues<TVal, TArgs, TPred>(Scalar<TVal> label, Vector<float> features, Scalar<float> weights = null,
int numLeaves = Defaults.NumLeaves,
int numTrees = Defaults.NumTrees,
int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs,
double learningRate = Defaults.LearningRates,
Copy link
Member Author

@sfilipi sfilipi Sep 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixxxx #Resolved

Action<TArgs> advancedSettings = null,
Action<TPred> onFit = null)
{
Contracts.CheckValue(label, nameof(label));
Contracts.CheckValue(features, nameof(features));
Contracts.CheckValueOrNull(weights);
Contracts.CheckParam(numLeaves >= 2, nameof(numLeaves), "Must be at least 2.");
Contracts.CheckParam(numTrees > 0, nameof(numTrees), "Must be positive");
Contracts.CheckParam(minDocumentsInLeafs > 0, nameof(minDocumentsInLeafs), "Must be positive");
Contracts.CheckParam(learningRate > 0, nameof(learningRate), "Must be positive");
Contracts.CheckValueOrNull(advancedSettings);
Contracts.CheckValueOrNull(onFit);
}
}
}
1 change: 1 addition & 0 deletions src/Microsoft.ML.FastTree/Microsoft.ML.FastTree.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
<Compile Include="FastTreeRanking.cs" />
<Compile Include="FastTreeRegression.cs" />
<Compile Include="FastTree.cs" />
<Compile Include="FastTreeStatic.cs" />
<Compile Include="FastTreeTweedie.cs" />
<Compile Include="GamClassification.cs" />
<Compile Include="GamRegression.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ public static (Scalar<float> score, Scalar<bool> predictedLabel) FieldAwareFacto
var trainer = new FieldAwareFactorizationMachineTrainer(env, labelCol, featureCols, advancedSettings:
args =>
{
advancedSettings?.Invoke(args);
Copy link
Contributor

@Zruty0 Zruty0 Sep 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

advancedSettings [](start = 24, length = 16)

my order was intentional :) #Pending

args.LearningRate = learningRate;
args.Iters = numIterations;
args.LatentDim = numLatentDimensions;
advancedSettings?.Invoke(args);
});
if (onFit != null)
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ namespace Microsoft.ML.Runtime.Learners
/// <include file='doc.xml' path='doc/members/member[@name="SDCA"]/*' />
public sealed class SdcaRegressionTrainer : SdcaTrainerBase<RegressionPredictionTransformer<LinearRegressionPredictor>, LinearRegressionPredictor>
{
public const string LoadNameValue = "SDCAR";
public const string UserNameValue = "Fast Linear Regression (SA-SDCA)";
public const string ShortName = "sasdcar";
internal const string LoadNameValue = "SDCAR";
internal const string UserNameValue = "Fast Linear Regression (SA-SDCA)";
internal const string ShortName = "sasdcar";
internal const string Summary = "The SDCA linear regression trainer.";

public sealed class Arguments : ArgumentsBase
Expand Down
Loading