Skip to content

Commit 3eccc93

Browse files
authored
Public Interface of RegressionTree and TreeEnsemble (#2243)
* Immutable RegressionTree and TreeEnsemble wrappers * Address comments and internalize RegressionTree and TreeEsnmble. A lot of things become internal and possibly best-friend because of RegressionTree and TreeEnsemble. * Create new files for public wrapper classes * private protected is good * Doc strings * Update entry point * Add a dynamic and static API tests * Address comments. 1. Rename classes. 2. Use IReadOnlyList instead of ReadOnlySpan. 3. Remove new namespace, Representation. * Strongly-typed public surface Seperate TreeEnsemble * Rename files
1 parent 28ae548 commit 3eccc93

File tree

53 files changed

+1009
-270
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1009
-270
lines changed

src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs

-2
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,6 @@ public abstract class SingleFeaturePredictionTransformerBase<TModel, TScorer> :
171171
public SingleFeaturePredictionTransformerBase(IHost host, TModel model, Schema trainSchema, string featureColumn)
172172
: base(host, model, trainSchema)
173173
{
174-
FeatureColumn = featureColumn;
175-
176174
FeatureColumn = featureColumn;
177175
if (featureColumn == null)
178176
FeatureColumnType = null;

src/Microsoft.ML.FastTree/BoostingFastTree.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ protected override void CheckArgs(IChannel ch)
6262
base.CheckArgs(ch);
6363
}
6464

65-
protected override TreeLearner ConstructTreeLearner(IChannel ch)
65+
private protected override TreeLearner ConstructTreeLearner(IChannel ch)
6666
{
6767
return new LeastSquaresRegressionTreeLearner(
6868
TrainSet, Args.NumLeaves, Args.MinDocumentsInLeafs, Args.EntropyCoefficient,
@@ -73,7 +73,7 @@ protected override TreeLearner ConstructTreeLearner(IChannel ch)
7373
Args.MinDocsPercentageForCategoricalSplit, Args.Bundling, Args.MinDocsForCategoricalSplit, Args.Bias);
7474
}
7575

76-
protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
76+
private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
7777
{
7878
Contracts.CheckValue(ch, nameof(ch));
7979
OptimizationAlgorithm optimizationAlgorithm;

src/Microsoft.ML.FastTree/Dataset/FeatureFlock.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public PerBinStats(Double sumTargets, Double sumWeights, int count)
5151
/// per flock. Note that feature indices, whenever present, refer to the feature within the
5252
/// particular flock the same as they do with <see cref="FeatureFlockBase"/>.
5353
/// </summary>
54-
public abstract class SufficientStatsBase
54+
internal abstract class SufficientStatsBase
5555
{
5656
// REVIEW: Holdover from histogram. I really don't like this. Figure out if
5757
// there's a better way.
@@ -929,7 +929,7 @@ public void FillSplitCandidatesCategoricalNeighborBundling(LeastSquaresRegressio
929929
/// </summary>
930930
/// <typeparam name="TSuffStats">The type of sufficient stats that we will be able to do
931931
/// "peer" operations against, like subtract. This will always be the derived class itself.</typeparam>
932-
public abstract class SufficientStatsBase<TSuffStats> : SufficientStatsBase
932+
internal abstract class SufficientStatsBase<TSuffStats> : SufficientStatsBase
933933
where TSuffStats : SufficientStatsBase<TSuffStats>
934934
{
935935
protected SufficientStatsBase(int features)
@@ -1005,7 +1005,7 @@ protected FeatureFlockBase(int count, bool categorical = false)
10051005
/// <param name="hasWeights">Whether structures related to tracking
10061006
/// example weights should be allocated</param>
10071007
/// <returns>A sufficient statistics object</returns>
1008-
public abstract SufficientStatsBase CreateSufficientStats(bool hasWeights);
1008+
internal abstract SufficientStatsBase CreateSufficientStats(bool hasWeights);
10091009

10101010
/// <summary>
10111011
/// Returns a forward indexer for a single feature. This has a default implementation that
@@ -1207,7 +1207,7 @@ public override long SizeInBytes()
12071207
+ sizeof(int) * HotFeatureStarts.Length;
12081208
}
12091209

1210-
public override SufficientStatsBase CreateSufficientStats(bool hasWeights)
1210+
internal override SufficientStatsBase CreateSufficientStats(bool hasWeights)
12111211
{
12121212
return new SufficientStats(this, hasWeights);
12131213
}

src/Microsoft.ML.FastTree/Dataset/SingletonFeatureFlock.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public override long SizeInBytes()
3434
return _bins.SizeInBytes() + sizeof(double) * _binUpperBounds.Length;
3535
}
3636

37-
public override SufficientStatsBase CreateSufficientStats(bool hasWeights)
37+
internal override SufficientStatsBase CreateSufficientStats(bool hasWeights)
3838
{
3939
return new SufficientStats(this, hasWeights);
4040
}

src/Microsoft.ML.FastTree/FastTree.cs

+103-38
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
using Microsoft.ML.Data;
1818
using Microsoft.ML.Data.Conversion;
1919
using Microsoft.ML.EntryPoints;
20-
using Microsoft.ML.Internal.Calibration;
20+
using Microsoft.ML.FastTree;
2121
using Microsoft.ML.Internal.Internallearn;
2222
using Microsoft.ML.Internal.Utilities;
2323
using Microsoft.ML.Model;
@@ -56,16 +56,16 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
5656
{
5757
protected readonly TArgs Args;
5858
protected readonly bool AllowGC;
59-
protected TreeEnsemble TrainedEnsemble;
6059
protected int FeatureCount;
60+
private protected InternalTreeEnsemble TrainedEnsemble;
6161
private protected RoleMappedData ValidData;
6262
/// <summary>
6363
/// If not null, it's a test data set passed in from training context. It will be converted to one element in
6464
/// <see cref="Tests"/> by calling <see cref="ExamplesToFastTreeBins.GetCompatibleDataset"/> in <see cref="InitializeTests"/>.
6565
/// </summary>
6666
private protected RoleMappedData TestData;
67-
protected IParallelTraining ParallelTraining;
68-
protected OptimizationAlgorithm OptimizationAlgorithm;
67+
private protected IParallelTraining ParallelTraining;
68+
private protected OptimizationAlgorithm OptimizationAlgorithm;
6969
protected Dataset TrainSet;
7070
protected Dataset ValidSet;
7171
/// <summary>
@@ -89,7 +89,7 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
8989
protected double[] InitValidScores;
9090
protected double[][] InitTestScores;
9191
//protected int Iteration;
92-
protected TreeEnsemble Ensemble;
92+
private protected InternalTreeEnsemble Ensemble;
9393

9494
protected bool HasValidSet => ValidSet != null;
9595

@@ -175,8 +175,9 @@ private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaSh
175175

176176
protected abstract Test ConstructTestForTrainingData();
177177

178-
protected abstract OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch);
179-
protected abstract TreeLearner ConstructTreeLearner(IChannel ch);
178+
private protected abstract OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch);
179+
180+
private protected abstract TreeLearner ConstructTreeLearner(IChannel ch);
180181

181182
protected abstract ObjectiveFunctionBase ConstructObjFunc(IChannel ch);
182183

@@ -488,7 +489,7 @@ protected bool AreSamplesWeighted(IChannel ch)
488489

489490
private void InitializeEnsemble()
490491
{
491-
Ensemble = new TreeEnsemble();
492+
Ensemble = new InternalTreeEnsemble();
492493
}
493494

494495
/// <summary>
@@ -793,7 +794,7 @@ private float GetMachineAvailableBytes()
793794

794795
// This method is called at the end of each training iteration, with the tree that was learnt on that iteration.
795796
// Note that this tree can be null if no tree was learnt this iteration.
796-
protected virtual void CustomizedTrainingIteration(RegressionTree tree)
797+
private protected virtual void CustomizedTrainingIteration(InternalRegressionTree tree)
797798
{
798799
}
799800

@@ -924,7 +925,7 @@ internal abstract class DataConverter
924925
/// of features we actually trained on. This can be null in the event that no filtering
925926
/// occurred.
926927
/// </summary>
927-
/// <seealso cref="TreeEnsemble.RemapFeatures"/>
928+
/// <seealso cref="InternalTreeEnsemble.RemapFeatures"/>
928929
public int[] FeatureMap;
929930

930931
protected readonly IHost Host;
@@ -2810,9 +2811,10 @@ public abstract class TreeEnsembleModelParameters :
28102811
ISingleCanSavePfa,
28112812
ISingleCanSaveOnnx
28122813
{
2813-
//The below two properties are necessary for tree Visualizer
2814+
// The below two properties are necessary for tree Visualizer
28142815
[BestFriend]
2815-
internal TreeEnsemble TrainedEnsemble { get; }
2816+
internal InternalTreeEnsemble TrainedEnsemble { get; }
2817+
28162818
int ITreeEnsemble.NumTrees => TrainedEnsemble.NumTrees;
28172819

28182820
// Inner args is used only for documentation purposes when saving comments to INI files.
@@ -2854,7 +2856,9 @@ public abstract class TreeEnsembleModelParameters :
28542856
/// </summary>
28552857
public FeatureContributionCalculator FeatureContributionCalculator => new FeatureContributionCalculator(this);
28562858

2857-
public TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemble trainedEnsemble, int numFeatures, string innerArgs)
2859+
/// The following function is used in both FastTree and LightGBM so <see cref="BestFriendAttribute"/> is required.
2860+
[BestFriend]
2861+
internal TreeEnsembleModelParameters(IHostEnvironment env, string name, InternalTreeEnsemble trainedEnsemble, int numFeatures, string innerArgs)
28582862
: base(env, name)
28592863
{
28602864
Host.CheckValue(trainedEnsemble, nameof(trainedEnsemble));
@@ -2868,7 +2872,7 @@ public TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemb
28682872
InnerArgs = innerArgs;
28692873
NumFeatures = numFeatures;
28702874

2871-
MaxSplitFeatIdx = FindMaxFeatureIndex(trainedEnsemble);
2875+
MaxSplitFeatIdx = trainedEnsemble.GetMaxFeatureIndex();
28722876
Contracts.Assert(NumFeatures > MaxSplitFeatIdx);
28732877

28742878
InputType = new VectorType(NumberType.Float, NumFeatures);
@@ -2892,8 +2896,8 @@ protected TreeEnsembleModelParameters(IHostEnvironment env, string name, ModelLo
28922896
if (ctx.Header.ModelVerWritten >= VerCategoricalSplitSerialized)
28932897
categoricalSplits = true;
28942898

2895-
TrainedEnsemble = new TreeEnsemble(ctx, usingDefaultValues, categoricalSplits);
2896-
MaxSplitFeatIdx = FindMaxFeatureIndex(TrainedEnsemble);
2899+
TrainedEnsemble = new InternalTreeEnsemble(ctx, usingDefaultValues, categoricalSplits);
2900+
MaxSplitFeatIdx = TrainedEnsemble.GetMaxFeatureIndex();
28972901

28982902
InnerArgs = ctx.LoadStringOrNull();
28992903
if (ctx.Header.ModelVerWritten >= VerNumFeaturesSerialized)
@@ -3195,7 +3199,7 @@ private void SaveEnsembleAsCode(TextWriter writer, RoleMappedSchema schema)
31953199
MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, NumFeatures, ref names);
31963200

31973201
int i = 0;
3198-
foreach (RegressionTree tree in TrainedEnsemble.Trees)
3202+
foreach (InternalRegressionTree tree in TrainedEnsemble.Trees)
31993203
{
32003204
writer.Write("double treeOutput{0}=", i);
32013205
SaveTreeAsCode(tree, writer, in names);
@@ -3211,13 +3215,13 @@ private void SaveEnsembleAsCode(TextWriter writer, RoleMappedSchema schema)
32113215
/// <summary>
32123216
/// Convert a single tree to code, called recursively
32133217
/// </summary>
3214-
private void SaveTreeAsCode(RegressionTree tree, TextWriter writer, in VBuffer<ReadOnlyMemory<char>> names)
3218+
private void SaveTreeAsCode(InternalRegressionTree tree, TextWriter writer, in VBuffer<ReadOnlyMemory<char>> names)
32153219
{
32163220
ToCSharp(tree, writer, 0, in names);
32173221
}
32183222

32193223
// converts a subtree into a C# expression
3220-
private void ToCSharp(RegressionTree tree, TextWriter writer, int node, in VBuffer<ReadOnlyMemory<char>> names)
3224+
private void ToCSharp(InternalRegressionTree tree, TextWriter writer, int node, in VBuffer<ReadOnlyMemory<char>> names)
32213225
{
32223226
if (node < 0)
32233227
{
@@ -3259,23 +3263,6 @@ public void GetFeatureWeights(ref VBuffer<Float> weights)
32593263
bldr.GetResult(ref weights);
32603264
}
32613265

3262-
private static int FindMaxFeatureIndex(TreeEnsemble ensemble)
3263-
{
3264-
int ifeatMax = 0;
3265-
for (int i = 0; i < ensemble.NumTrees; i++)
3266-
{
3267-
var tree = ensemble.GetTreeAt(i);
3268-
for (int n = 0; n < tree.NumNodes; n++)
3269-
{
3270-
int ifeat = tree.SplitFeature(n);
3271-
if (ifeat > ifeatMax)
3272-
ifeatMax = ifeat;
3273-
}
3274-
}
3275-
3276-
return ifeatMax;
3277-
}
3278-
32793266
ITree[] ITreeEnsemble.GetTrees()
32803267
{
32813268
return TrainedEnsemble.Trees.Select(k => new Tree(k)).ToArray();
@@ -3318,9 +3305,9 @@ Row ICanGetSummaryAsIRow.GetStatsIRowOrNull(RoleMappedSchema schema)
33183305

33193306
private sealed class Tree : ITree<VBuffer<Float>>
33203307
{
3321-
private readonly RegressionTree _regTree;
3308+
private readonly InternalRegressionTree _regTree;
33223309

3323-
public Tree(RegressionTree regTree)
3310+
public Tree(InternalRegressionTree regTree)
33243311
{
33253312
_regTree = regTree;
33263313
}
@@ -3390,4 +3377,82 @@ public TreeNode(Dictionary<string, object> keyValues)
33903377
public Dictionary<string, object> KeyValues { get; }
33913378
}
33923379
}
3380+
3381+
/// <summary>
3382+
/// <see cref="TreeEnsembleModelParametersBasedOnRegressionTree"/> is derived from
3383+
/// <see cref="TreeEnsembleModelParameters"/> plus a strongly-typed public attribute,
3384+
/// <see cref="TrainedTreeEnsemble"/>, for exposing trained model's details to users.
3385+
/// Its function, <see cref="CreateTreeEnsembleFromInternalDataStructure"/>, is
3386+
/// called to create <see cref="TrainedTreeEnsemble"/> inside <see cref="TreeEnsembleModelParameters"/>.
3387+
/// Note that the major difference between <see cref="TreeEnsembleModelParametersBasedOnQuantileRegressionTree"/>
3388+
/// and <see cref="TreeEnsembleModelParametersBasedOnRegressionTree"/> is the type of
3389+
/// <see cref="TrainedTreeEnsemble"/>.
3390+
/// </summary>
3391+
public abstract class TreeEnsembleModelParametersBasedOnRegressionTree : TreeEnsembleModelParameters
3392+
{
3393+
/// <summary>
3394+
/// An ensemble of trees exposed to users. It is a wrapper on the <see langword="internal"/>
3395+
/// <see cref="InternalTreeEnsemble"/> in <see cref="ML.FastTree.TreeEnsemble{T}"/>.
3396+
/// </summary>
3397+
public RegressionTreeEnsemble TrainedTreeEnsemble { get; }
3398+
3399+
[BestFriend]
3400+
internal TreeEnsembleModelParametersBasedOnRegressionTree(IHostEnvironment env, string name, InternalTreeEnsemble trainedEnsemble, int numFeatures, string innerArgs)
3401+
: base(env, name, trainedEnsemble, numFeatures, innerArgs)
3402+
{
3403+
TrainedTreeEnsemble = CreateTreeEnsembleFromInternalDataStructure();
3404+
}
3405+
3406+
protected TreeEnsembleModelParametersBasedOnRegressionTree(IHostEnvironment env, string name, ModelLoadContext ctx, VersionInfo ver)
3407+
: base(env, name, ctx, ver)
3408+
{
3409+
TrainedTreeEnsemble = CreateTreeEnsembleFromInternalDataStructure();
3410+
}
3411+
3412+
private RegressionTreeEnsemble CreateTreeEnsembleFromInternalDataStructure()
3413+
{
3414+
var trees = TrainedEnsemble.Trees.Select(tree => new RegressionTree(tree));
3415+
var treeWeights = TrainedEnsemble.Trees.Select(tree => tree.Weight);
3416+
return new RegressionTreeEnsemble(trees, treeWeights, TrainedEnsemble.Bias);
3417+
}
3418+
}
3419+
3420+
/// <summary>
3421+
/// <see cref="TreeEnsembleModelParametersBasedOnQuantileRegressionTree"/> is derived from
3422+
/// <see cref="TreeEnsembleModelParameters"/> plus a strongly-typed public attribute,
3423+
/// <see cref="TrainedTreeEnsemble"/>, for exposing trained model's details to users.
3424+
/// Its function, <see cref="CreateTreeEnsembleFromInternalDataStructure"/>, is
3425+
/// called to create <see cref="TrainedTreeEnsemble"/> inside <see cref="TreeEnsembleModelParameters"/>.
3426+
/// Note that the major difference between <see cref="TreeEnsembleModelParametersBasedOnQuantileRegressionTree"/>
3427+
/// and <see cref="TreeEnsembleModelParametersBasedOnRegressionTree"/> is the type of
3428+
/// <see cref="TrainedTreeEnsemble"/>.
3429+
/// </summary>
3430+
public abstract class TreeEnsembleModelParametersBasedOnQuantileRegressionTree : TreeEnsembleModelParameters
3431+
{
3432+
/// <summary>
3433+
/// An ensemble of trees exposed to users. It is a wrapper on the <see langword="internal"/>
3434+
/// <see cref="InternalTreeEnsemble"/> in <see cref="ML.FastTree.TreeEnsemble{T}"/>.
3435+
/// </summary>
3436+
public QuantileRegressionTreeEnsemble TrainedTreeEnsemble { get; }
3437+
3438+
[BestFriend]
3439+
internal TreeEnsembleModelParametersBasedOnQuantileRegressionTree(IHostEnvironment env, string name, InternalTreeEnsemble trainedEnsemble, int numFeatures, string innerArgs)
3440+
: base(env, name, trainedEnsemble, numFeatures, innerArgs)
3441+
{
3442+
TrainedTreeEnsemble = CreateTreeEnsembleFromInternalDataStructure();
3443+
}
3444+
3445+
protected TreeEnsembleModelParametersBasedOnQuantileRegressionTree(IHostEnvironment env, string name, ModelLoadContext ctx, VersionInfo ver)
3446+
: base(env, name, ctx, ver)
3447+
{
3448+
TrainedTreeEnsemble = CreateTreeEnsembleFromInternalDataStructure();
3449+
}
3450+
3451+
private QuantileRegressionTreeEnsemble CreateTreeEnsembleFromInternalDataStructure()
3452+
{
3453+
var trees = TrainedEnsemble.Trees.Select(tree => new QuantileRegressionTree((InternalQuantileRegressionTree)tree));
3454+
var treeWeights = TrainedEnsemble.Trees.Select(tree => tree.Weight);
3455+
return new QuantileRegressionTreeEnsemble(trees, treeWeights, TrainedEnsemble.Bias);
3456+
}
3457+
}
33933458
}

src/Microsoft.ML.FastTree/FastTreeArguments.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ public abstract class TreeArgs : LearnerInputBaseWithGroupId
156156
/// Allows to choose Parallel FastTree Learning Algorithm.
157157
/// </summary>
158158
[Argument(ArgumentType.Multiple, HelpText = "Allows to choose Parallel FastTree Learning Algorithm", ShortName = "parag")]
159-
public ISupportParallelTraining ParallelTrainer = new SingleTrainerFactory();
159+
internal ISupportParallelTraining ParallelTrainer = new SingleTrainerFactory();
160160

161161
/// <summary>
162162
/// The number of threads to use.

src/Microsoft.ML.FastTree/FastTreeClassification.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
namespace Microsoft.ML.Trainers.FastTree
4545
{
4646
public sealed class FastTreeBinaryModelParameters :
47-
TreeEnsembleModelParameters
47+
TreeEnsembleModelParametersBasedOnRegressionTree
4848
{
4949
internal const string LoaderSignature = "FastTreeBinaryExec";
5050
internal const string RegistrationName = "FastTreeBinaryPredictor";
@@ -70,7 +70,7 @@ private static VersionInfo GetVersionInfo()
7070

7171
protected override uint VerCategoricalSplitSerialized => 0x00010005;
7272

73-
public FastTreeBinaryModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
73+
internal FastTreeBinaryModelParameters(IHostEnvironment env, InternalTreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
7474
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
7575
{
7676
}
@@ -204,7 +204,7 @@ protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
204204
ParallelTraining);
205205
}
206206

207-
protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
207+
private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
208208
{
209209
OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);
210210
if (Args.UseLineSearch)
@@ -365,7 +365,7 @@ protected override void GetGradientInOneQuery(int query, int threadIndex)
365365
}
366366
}
367367

368-
public void AdjustTreeOutputs(IChannel ch, RegressionTree tree,
368+
public void AdjustTreeOutputs(IChannel ch, InternalRegressionTree tree,
369369
DocumentPartitioning partitioning, ScoreTracker trainingScores)
370370
{
371371
const double epsilon = 1.4e-45;

0 commit comments

Comments
 (0)