Skip to content

Commit 9ffc6c8

Browse files
committed
Strongly-typed public surface
Add a test Fix typo Some docs Seperate TreeEnsemble
1 parent 07053c2 commit 9ffc6c8

19 files changed

+301
-130
lines changed

src/Microsoft.ML.FastTree/FastTree.cs

+83-8
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
using Microsoft.ML.Data;
1818
using Microsoft.ML.Data.Conversion;
1919
using Microsoft.ML.EntryPoints;
20+
using Microsoft.ML.FastTree;
2021
using Microsoft.ML.Internal.Internallearn;
2122
using Microsoft.ML.Internal.Utilities;
2223
using Microsoft.ML.Model;
@@ -2810,14 +2811,10 @@ public abstract class TreeEnsembleModelParameters :
28102811
ISingleCanSavePfa,
28112812
ISingleCanSaveOnnx
28122813
{
2813-
/// <summary>
2814-
/// An ensemble of trees exposed to users. It is a wrapper on the <see langword="internal"/> <see cref="InternalTreeEnsemble"/> in <see cref="ML.FastTree.Representation.TreeEnsemble"/>.
2815-
/// </summary>
2816-
public ML.FastTree.Representation.TreeEnsemble TrainedTreeCollection { get; }
2817-
28182814
// The below two properties are necessary for tree Visualizer
28192815
[BestFriend]
2820-
internal InternalTreeEnsemble TrainedEnsemble => TrainedTreeCollection.UnderlyingTreeEnsemble;
2816+
internal InternalTreeEnsemble TrainedEnsemble { get; }
2817+
28212818
int ITreeEnsemble.NumTrees => TrainedEnsemble.NumTrees;
28222819

28232820
// Inner args is used only for documentation purposes when saving comments to INI files.
@@ -2871,7 +2868,7 @@ internal TreeEnsembleModelParameters(IHostEnvironment env, string name, Internal
28712868
// REVIEW: When we make the predictor wrapper, we may want to further "optimize"
28722869
// the trained ensemble to, for instance, resize arrays so that they are of the length
28732870
// the actual number of leaves/nodes, or remove unnecessary arrays, and so forth.
2874-
TrainedTreeCollection = new ML.FastTree.Representation.TreeEnsemble(trainedEnsemble);
2871+
TrainedEnsemble = trainedEnsemble;
28752872
InnerArgs = innerArgs;
28762873
NumFeatures = numFeatures;
28772874

@@ -2899,7 +2896,7 @@ protected TreeEnsembleModelParameters(IHostEnvironment env, string name, ModelLo
28992896
if (ctx.Header.ModelVerWritten >= VerCategoricalSplitSerialized)
29002897
categoricalSplits = true;
29012898

2902-
TrainedTreeCollection = new ML.FastTree.Representation.TreeEnsemble(new InternalTreeEnsemble(ctx, usingDefaultValues, categoricalSplits));
2899+
TrainedEnsemble = new InternalTreeEnsemble(ctx, usingDefaultValues, categoricalSplits);
29032900
MaxSplitFeatIdx = TrainedEnsemble.GetMaxFeatureIndex();
29042901

29052902
InnerArgs = ctx.LoadStringOrNull();
@@ -3380,4 +3377,82 @@ public TreeNode(Dictionary<string, object> keyValues)
33803377
public Dictionary<string, object> KeyValues { get; }
33813378
}
33823379
}
3380+
3381+
/// <summary>
3382+
/// <see cref="TreeEnsembleModelParametersBasedOnRegressionTree"/> is derived from
3383+
/// <see cref="TreeEnsembleModelParameters"/> plus a strongly-typed public attribute,
3384+
/// <see cref="TrainedTreeEnsemble"/>, for exposing trained model's details to users.
3385+
/// Its function, <see cref="CreateTreeEnsembleFromInternalDataStructure"/>, is
3386+
/// called to create <see cref="TrainedTreeEnsemble"/> inside <see cref="TreeEnsembleModelParameters"/>.
3387+
/// Note that the major difference between <see cref="TreeEnsembleModelParametersBasedOnQuantileRegressionTree"/>
3388+
/// and <see cref="TreeEnsembleModelParametersBasedOnRegressionTree"/> is the type of
3389+
/// <see cref="TrainedTreeEnsemble"/>.
3390+
/// </summary>
3391+
public abstract class TreeEnsembleModelParametersBasedOnRegressionTree : TreeEnsembleModelParameters
3392+
{
3393+
/// <summary>
3394+
/// An ensemble of trees exposed to users. It is a wrapper on the <see langword="internal"/>
3395+
/// <see cref="InternalTreeEnsemble"/> in <see cref="ML.FastTree.TreeEnsemble{T}"/>.
3396+
/// </summary>
3397+
public RegressionTreeEnsemble TrainedTreeEnsemble { get; }
3398+
3399+
[BestFriend]
3400+
internal TreeEnsembleModelParametersBasedOnRegressionTree(IHostEnvironment env, string name, InternalTreeEnsemble trainedEnsemble, int numFeatures, string innerArgs)
3401+
: base(env, name, trainedEnsemble, numFeatures, innerArgs)
3402+
{
3403+
TrainedTreeEnsemble = CreateTreeEnsembleFromInternalDataStructure();
3404+
}
3405+
3406+
protected TreeEnsembleModelParametersBasedOnRegressionTree(IHostEnvironment env, string name, ModelLoadContext ctx, VersionInfo ver)
3407+
: base(env, name, ctx, ver)
3408+
{
3409+
TrainedTreeEnsemble = CreateTreeEnsembleFromInternalDataStructure();
3410+
}
3411+
3412+
private RegressionTreeEnsemble CreateTreeEnsembleFromInternalDataStructure()
3413+
{
3414+
var trees = TrainedEnsemble.Trees.Select(tree => new RegressionTree(tree));
3415+
var treeWeights = TrainedEnsemble.Trees.Select(tree => tree.Weight);
3416+
return new RegressionTreeEnsemble(trees, treeWeights, TrainedEnsemble.Bias);
3417+
}
3418+
}
3419+
3420+
/// <summary>
3421+
/// <see cref="TreeEnsembleModelParametersBasedOnQuantileRegressionTree"/> is derived from
3422+
/// <see cref="TreeEnsembleModelParameters"/> plus a strongly-typed public attribute,
3423+
/// <see cref="TrainedTreeEnsemble"/>, for exposing trained model's details to users.
3424+
/// Its function, <see cref="CreateTreeEnsembleFromInternalDataStructure"/>, is
3425+
/// called to create <see cref="TrainedTreeEnsemble"/> inside <see cref="TreeEnsembleModelParameters"/>.
3426+
/// Note that the major difference between <see cref="TreeEnsembleModelParametersBasedOnQuantileRegressionTree"/>
3427+
/// and <see cref="TreeEnsembleModelParametersBasedOnRegressionTree"/> is the type of
3428+
/// <see cref="TrainedTreeEnsemble"/>.
3429+
/// </summary>
3430+
public abstract class TreeEnsembleModelParametersBasedOnQuantileRegressionTree : TreeEnsembleModelParameters
3431+
{
3432+
/// <summary>
3433+
/// An ensemble of trees exposed to users. It is a wrapper on the <see langword="internal"/>
3434+
/// <see cref="InternalTreeEnsemble"/> in <see cref="ML.FastTree.TreeEnsemble{T}"/>.
3435+
/// </summary>
3436+
public QuantileRegressionTreeEnsemble TrainedTreeEnsemble { get; }
3437+
3438+
[BestFriend]
3439+
internal TreeEnsembleModelParametersBasedOnQuantileRegressionTree(IHostEnvironment env, string name, InternalTreeEnsemble trainedEnsemble, int numFeatures, string innerArgs)
3440+
: base(env, name, trainedEnsemble, numFeatures, innerArgs)
3441+
{
3442+
TrainedTreeEnsemble = CreateTreeEnsembleFromInternalDataStructure();
3443+
}
3444+
3445+
protected TreeEnsembleModelParametersBasedOnQuantileRegressionTree(IHostEnvironment env, string name, ModelLoadContext ctx, VersionInfo ver)
3446+
: base(env, name, ctx, ver)
3447+
{
3448+
TrainedTreeEnsemble = CreateTreeEnsembleFromInternalDataStructure();
3449+
}
3450+
3451+
private QuantileRegressionTreeEnsemble CreateTreeEnsembleFromInternalDataStructure()
3452+
{
3453+
var trees = TrainedEnsemble.Trees.Select(tree => new QuantileRegressionTree((InternalQuantileRegressionTree)tree));
3454+
var treeWeights = TrainedEnsemble.Trees.Select(tree => tree.Weight);
3455+
return new QuantileRegressionTreeEnsemble(trees, treeWeights, TrainedEnsemble.Bias);
3456+
}
3457+
}
33833458
}

src/Microsoft.ML.FastTree/FastTreeClassification.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
namespace Microsoft.ML.Trainers.FastTree
4545
{
4646
public sealed class FastTreeBinaryModelParameters :
47-
TreeEnsembleModelParameters
47+
TreeEnsembleModelParametersBasedOnRegressionTree
4848
{
4949
internal const string LoaderSignature = "FastTreeBinaryExec";
5050
internal const string RegistrationName = "FastTreeBinaryPredictor";

src/Microsoft.ML.FastTree/FastTreeRanking.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1105,7 +1105,7 @@ private static extern unsafe void GetDerivatives(
11051105
}
11061106
}
11071107

1108-
public sealed class FastTreeRankingModelParameters : TreeEnsembleModelParameters
1108+
public sealed class FastTreeRankingModelParameters : TreeEnsembleModelParametersBasedOnRegressionTree
11091109
{
11101110
internal const string LoaderSignature = "FastTreeRankerExec";
11111111
internal const string RegistrationName = "FastTreeRankingPredictor";

src/Microsoft.ML.FastTree/FastTreeRegression.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ protected override void GetGradientInOneQuery(int query, int threadIndex)
441441
}
442442
}
443443

444-
public sealed class FastTreeRegressionModelParameters : TreeEnsembleModelParameters
444+
public sealed class FastTreeRegressionModelParameters : TreeEnsembleModelParametersBasedOnRegressionTree
445445
{
446446
internal const string LoaderSignature = "FastTreeRegressionExec";
447447
internal const string RegistrationName = "FastTreeRegressionPredictor";

src/Microsoft.ML.FastTree/FastTreeTweedie.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ protected override void GetGradientInOneQuery(int query, int threadIndex)
446446
}
447447
}
448448

449-
public sealed class FastTreeTweedieModelParameters : TreeEnsembleModelParameters
449+
public sealed class FastTreeTweedieModelParameters : TreeEnsembleModelParametersBasedOnRegressionTree
450450
{
451451
internal const string LoaderSignature = "FastTreeTweedieExec";
452452
internal const string RegistrationName = "FastTreeTweediePredictor";

src/Microsoft.ML.FastTree/RandomForestClassification.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public FastForestArgumentsBase()
4848
}
4949

5050
public sealed class FastForestClassificationModelParameters :
51-
TreeEnsembleModelParameters
51+
TreeEnsembleModelParametersBasedOnQuantileRegressionTree
5252
{
5353
internal const string LoaderSignature = "FastForestBinaryExec";
5454
internal const string RegistrationName = "FastForestClassificationPredictor";

src/Microsoft.ML.FastTree/RandomForestRegression.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
namespace Microsoft.ML.Trainers.FastTree
3030
{
3131
public sealed class FastForestRegressionModelParameters :
32-
TreeEnsembleModelParameters,
32+
TreeEnsembleModelParametersBasedOnQuantileRegressionTree,
3333
IQuantileValueMapper,
3434
IQuantileRegressionPredictor
3535
{

src/Microsoft.ML.FastTree/Representation/TreeRegressor.cs

+77-57
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,26 @@
1-
using System.Collections.Generic;
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Generic;
26
using System.Collections.Immutable;
3-
using System.Linq;
47
using Microsoft.ML.Trainers.FastTree.Internal;
58

69
namespace Microsoft.ML.FastTree
710
{
811
/// <summary>
9-
/// A container class for exposing <see cref="InternalRegressionTree"/>'s attributes to users.
12+
/// A container base class for exposing <see cref="InternalRegressionTree"/>'s and
13+
/// <see cref="InternalQuantileRegressionTree"/>'s attributes to users.
1014
/// This class should not be mutable, so it contains a lot of read-only members.
1115
/// </summary>
12-
public class RegressionTree
16+
public abstract class RegressionTreeBase
1317
{
1418
/// <summary>
15-
/// <see cref="RegressionTree"/> is an immutable wrapper over <see cref="_tree"/> for exposing some tree's
19+
/// <see cref="RegressionTreeBase"/> is an immutable wrapper over <see cref="_tree"/> for exposing some tree's
1620
/// attribute to users.
1721
/// </summary>
1822
private readonly InternalRegressionTree _tree;
1923

20-
/// <summary>
21-
/// Sample labels from training data. <see cref="_leafSamples"/>[i] stores the labels falling into the
22-
/// i-th leaf.
23-
/// </summary>
24-
private readonly double[][] _leafSamples;
25-
/// <summary>
26-
/// Sample labels' weights from training data. <see cref="_leafSampleWeights"/>[i] stores the weights for
27-
/// labels falling into the i-th leaf. <see cref="_leafSampleWeights"/>[i][j] is the weight of
28-
/// <see cref="_leafSamples"/>[i][j].
29-
/// </summary>
30-
private readonly double[][] _leafSampleWeights;
31-
3224
/// <summary>
3325
/// See <see cref="LteChild"/>.
3426
/// </summary>
@@ -133,6 +125,72 @@ public IReadOnlyList<int> GetCategoricalCategoricalSplitFeatureRangeAt(int nodeI
133125
return _tree.CategoricalSplitFeatureRanges[nodeIndex];
134126
}
135127

128+
/// <summary>
129+
/// Number of leaves in the tree. Note that <see cref="NumLeaves"/> does not take non-leaf nodes into account.
130+
/// </summary>
131+
public int NumLeaves => _tree.NumLeaves;
132+
133+
/// <summary>
134+
/// Number of nodes in the tree. This doesn't include any leaves. For example, a tree with node0->node1,
135+
/// node0->leaf3, node1->leaf1, node1->leaf2, <see cref="NumNodes"/> and <see cref="NumLeaves"/> should
136+
/// be 2 and 3, respectively.
137+
/// </summary>
138+
// A visualization of the example mentioned in this doc string.
139+
// node0
140+
// / \
141+
// node1 leaf3
142+
// / \
143+
// leaf1 leaf2
144+
// The index of leaf starts with 1 because interally we use "-1" as the 1st leaf's index, "-2" for the 2nd leaf's index, and so on.
145+
public int NumNodes => _tree.NumNodes;
146+
147+
internal RegressionTreeBase(InternalRegressionTree tree)
148+
{
149+
_tree = tree;
150+
151+
_lteChild = ImmutableArray.Create(_tree.LteChild, 0, _tree.NumNodes);
152+
_gtChild = ImmutableArray.Create(_tree.GtChild, 0, _tree.NumNodes);
153+
154+
_numericalSplitFeatureIndexes = ImmutableArray.Create(_tree.SplitFeatures, 0, _tree.NumNodes);
155+
_numericalSplitThresholds = ImmutableArray.Create(_tree.RawThresholds, 0, _tree.NumNodes);
156+
_categoricalSplitFlags = ImmutableArray.Create(_tree.CategoricalSplit, 0, _tree.NumNodes);
157+
_leafValues = ImmutableArray.Create(_tree.LeafValues, 0, _tree.NumLeaves);
158+
}
159+
}
160+
161+
/// <summary>
162+
/// A container class for exposing <see cref="InternalRegressionTree"/>'s attributes to users.
163+
/// This class should not be mutable, so it contains a lot of read-only members. Note that
164+
/// <see cref="RegressionTree"/> is identical to <see cref="RegressionTreeBase"/> but in
165+
/// another derived class <see cref="QuantileRegressionTree"/> some attributes are added.
166+
/// </summary>
167+
public sealed class RegressionTree : RegressionTreeBase
168+
{
169+
internal RegressionTree(InternalRegressionTree tree) : base(tree) { }
170+
}
171+
172+
/// <summary>
173+
/// A container class for exposing <see cref="InternalQuantileRegressionTree"/>'s attributes to users.
174+
/// This class should not be mutable, so it contains a lot of read-only members. In addition to
175+
/// things inherited from <see cref="RegressionTreeBase"/>, we add <see cref="GetLeafSamplesAt(int)"/>
176+
/// and <see cref="GetLeafSampleWeightsAt(int)"/> to expose (sub-sampled) training labels falling into
177+
/// the leafIndex-th leaf and their weights.
178+
/// </summary>
179+
public sealed class QuantileRegressionTree : RegressionTreeBase
180+
{
181+
/// <summary>
182+
/// Sample labels from training data. <see cref="_leafSamples"/>[i] stores the labels falling into the
183+
/// i-th leaf.
184+
/// </summary>
185+
private readonly double[][] _leafSamples;
186+
187+
/// <summary>
188+
/// Sample labels' weights from training data. <see cref="_leafSampleWeights"/>[i] stores the weights for
189+
/// labels falling into the i-th leaf. <see cref="_leafSampleWeights"/>[i][j] is the weight of
190+
/// <see cref="_leafSamples"/>[i][j].
191+
/// </summary>
192+
private readonly double[][] _leafSampleWeights;
193+
136194
/// <summary>
137195
/// Return the training labels falling into the specified leaf.
138196
/// </summary>
@@ -163,47 +221,9 @@ public IReadOnlyList<double> GetLeafSampleWeightsAt(int leafIndex)
163221
return _leafSampleWeights[leafIndex];
164222
}
165223

166-
/// <summary>
167-
/// Number of leaves in the tree. Note that <see cref="NumLeaves"/> does not take non-leaf nodes into account.
168-
/// </summary>
169-
public int NumLeaves => _tree.NumLeaves;
170-
171-
/// <summary>
172-
/// Number of nodes in the tree. This doesn't include any leaves. For example, a tree with node0->node1,
173-
/// node0->leaf3, node1->leaf1, node1->leaf2, <see cref="NumNodes"/> and <see cref="NumLeaves"/> should
174-
/// be 2 and 3, respectively.
175-
/// </summary>
176-
// A visualization of the example mentioned in this doc string.
177-
// node0
178-
// / \
179-
// node1 leaf3
180-
// / \
181-
// leaf1 leaf2
182-
// The index of leaf starts with 1 because interally we use "-1" as the 1st leaf's index, "-2" for the 2nd leaf's index, and so on.
183-
public int NumNodes => _tree.NumNodes;
184-
185-
internal RegressionTree(InternalRegressionTree tree)
224+
internal QuantileRegressionTree(InternalQuantileRegressionTree tree) : base(tree)
186225
{
187-
_tree = tree;
188-
_leafSamples = null;
189-
_leafSampleWeights = null;
190-
191-
_lteChild = ImmutableArray.Create(_tree.LteChild, 0, _tree.NumNodes);
192-
_gtChild = ImmutableArray.Create(_tree.GtChild, 0, _tree.NumNodes);
193-
194-
_numericalSplitFeatureIndexes = ImmutableArray.Create(_tree.SplitFeatures, 0, _tree.NumNodes);
195-
_numericalSplitThresholds = ImmutableArray.Create(_tree.RawThresholds, 0, _tree.NumNodes);
196-
_categoricalSplitFlags = ImmutableArray.Create(_tree.CategoricalSplit, 0, _tree.NumNodes);
197-
_leafValues = ImmutableArray.Create(_tree.LeafValues, 0, _tree.NumLeaves);
198-
199-
if (tree is QuantileRegressionTree)
200-
((QuantileRegressionTree)tree).ExtractLeafSamplesAndTheirWeights(out _leafSamples, out _leafSampleWeights);
201-
else
202-
{
203-
_leafSamples = tree.LeafValues.Select(value => new double[] { value }).ToArray();
204-
_leafSampleWeights = tree.LeafValues.Select(value => new double[] { 1.0 }).ToArray();
205-
}
226+
tree.ExtractLeafSamplesAndTheirWeights(out _leafSamples, out _leafSampleWeights);
206227
}
207228
}
208-
209229
}

0 commit comments

Comments
 (0)