Skip to content

Commit 66bd651

Browse files
committed
Seperate TreeEnsemble
1 parent beb322f commit 66bd651

File tree

2 files changed

+26
-27
lines changed

2 files changed

+26
-27
lines changed

src/Microsoft.ML.FastTree/FastTree.cs

+10-26
Original file line numberDiff line numberDiff line change
@@ -2869,7 +2869,6 @@ internal TreeEnsembleModelParameters(IHostEnvironment env, string name, Internal
28692869
// the trained ensemble to, for instance, resize arrays so that they are of the length
28702870
// the actual number of leaves/nodes, or remove unnecessary arrays, and so forth.
28712871
TrainedEnsemble = trainedEnsemble;
2872-
CreateTreeEnsembleFromInternalDataStructure();
28732872
InnerArgs = innerArgs;
28742873
NumFeatures = numFeatures;
28752874

@@ -2898,7 +2897,6 @@ protected TreeEnsembleModelParameters(IHostEnvironment env, string name, ModelLo
28982897
categoricalSplits = true;
28992898

29002899
TrainedEnsemble = new InternalTreeEnsemble(ctx, usingDefaultValues, categoricalSplits);
2901-
CreateTreeEnsembleFromInternalDataStructure();
29022900
MaxSplitFeatIdx = TrainedEnsemble.GetMaxFeatureIndex();
29032901

29042902
InnerArgs = ctx.LoadStringOrNull();
@@ -2919,14 +2917,6 @@ protected TreeEnsembleModelParameters(IHostEnvironment env, string name, ModelLo
29192917
OutputType = NumberType.Float;
29202918
}
29212919

2922-
/// <summary>
2923-
/// This function should be implemented in derived classs to create strongly-typed TreeEnsemble
2924-
/// from <see cref="TrainedEnsemble"/> and possibly other internal attributes in
2925-
/// <see cref="TreeEnsembleModelParameters"/>. This also implies we always call this function
2926-
/// after initializing <see cref="TrainedEnsemble"/>.
2927-
/// </summary>
2928-
protected abstract void CreateTreeEnsembleFromInternalDataStructure();
2929-
29302920
[BestFriend]
29312921
private protected override void SaveCore(ModelSaveContext ctx)
29322922
{
@@ -3400,33 +3390,30 @@ public TreeNode(Dictionary<string, object> keyValues)
34003390
/// </summary>
34013391
public abstract class TreeEnsembleModelParametersBasedOnRegressionTree : TreeEnsembleModelParameters
34023392
{
3403-
private TreeEnsemble<RegressionTree> _trainedTreeEnsemble;
3404-
34053393
/// <summary>
34063394
/// An ensemble of trees exposed to users. It is a wrapper on the <see langword="internal"/>
34073395
/// <see cref="InternalTreeEnsemble"/> in <see cref="ML.FastTree.TreeEnsemble{T}"/>.
34083396
/// </summary>
3409-
public TreeEnsemble<RegressionTree> TrainedTreeEnsemble => _trainedTreeEnsemble;
3397+
public RegressionTreeEnsemble TrainedTreeEnsemble { get; }
34103398

34113399
[BestFriend]
34123400
internal TreeEnsembleModelParametersBasedOnRegressionTree(IHostEnvironment env, string name, InternalTreeEnsemble trainedEnsemble, int numFeatures, string innerArgs)
34133401
: base(env, name, trainedEnsemble, numFeatures, innerArgs)
34143402
{
3403+
TrainedTreeEnsemble = CreateTreeEnsembleFromInternalDataStructure();
34153404
}
34163405

34173406
protected TreeEnsembleModelParametersBasedOnRegressionTree(IHostEnvironment env, string name, ModelLoadContext ctx, VersionInfo ver)
34183407
: base(env, name, ctx, ver)
34193408
{
3409+
TrainedTreeEnsemble = CreateTreeEnsembleFromInternalDataStructure();
34203410
}
34213411

3422-
/// <summary>
3423-
/// See <see cref="TreeEnsembleModelParameters.CreateTreeEnsembleFromInternalDataStructure"/>.
3424-
/// </summary>
3425-
protected override void CreateTreeEnsembleFromInternalDataStructure()
3412+
private RegressionTreeEnsemble CreateTreeEnsembleFromInternalDataStructure()
34263413
{
34273414
var trees = TrainedEnsemble.Trees.Select(tree => new RegressionTree(tree));
34283415
var treeWeights = TrainedEnsemble.Trees.Select(tree => tree.Weight);
3429-
_trainedTreeEnsemble = new TreeEnsemble<RegressionTree>(trees, treeWeights, TrainedEnsemble.Bias);
3416+
return new RegressionTreeEnsemble(trees, treeWeights, TrainedEnsemble.Bias);
34303417
}
34313418
}
34323419

@@ -3442,33 +3429,30 @@ protected override void CreateTreeEnsembleFromInternalDataStructure()
34423429
/// </summary>
34433430
public abstract class TreeEnsembleModelParametersBasedOnQuantileRegressionTree : TreeEnsembleModelParameters
34443431
{
3445-
private TreeEnsemble<QuantileRegressionTree> _trainedTreeEnsemble;
3446-
34473432
/// <summary>
34483433
/// An ensemble of trees exposed to users. It is a wrapper on the <see langword="internal"/>
34493434
/// <see cref="InternalTreeEnsemble"/> in <see cref="ML.FastTree.TreeEnsemble{T}"/>.
34503435
/// </summary>
3451-
public TreeEnsemble<QuantileRegressionTree> TrainedTreeEnsemble => _trainedTreeEnsemble;
3436+
public QuantileRegressionTreeEnsemble TrainedTreeEnsemble { get; }
34523437

34533438
[BestFriend]
34543439
internal TreeEnsembleModelParametersBasedOnQuantileRegressionTree(IHostEnvironment env, string name, InternalTreeEnsemble trainedEnsemble, int numFeatures, string innerArgs)
34553440
: base(env, name, trainedEnsemble, numFeatures, innerArgs)
34563441
{
3442+
TrainedTreeEnsemble = CreateTreeEnsembleFromInternalDataStructure();
34573443
}
34583444

34593445
protected TreeEnsembleModelParametersBasedOnQuantileRegressionTree(IHostEnvironment env, string name, ModelLoadContext ctx, VersionInfo ver)
34603446
: base(env, name, ctx, ver)
34613447
{
3448+
TrainedTreeEnsemble = CreateTreeEnsembleFromInternalDataStructure();
34623449
}
34633450

3464-
/// <summary>
3465-
/// See <see cref="TreeEnsembleModelParameters.CreateTreeEnsembleFromInternalDataStructure"/>.
3466-
/// </summary>
3467-
protected override void CreateTreeEnsembleFromInternalDataStructure()
3451+
private QuantileRegressionTreeEnsemble CreateTreeEnsembleFromInternalDataStructure()
34683452
{
34693453
var trees = TrainedEnsemble.Trees.Select(tree => new QuantileRegressionTree((InternalQuantileRegressionTree)tree));
34703454
var treeWeights = TrainedEnsemble.Trees.Select(tree => tree.Weight);
3471-
_trainedTreeEnsemble = new TreeEnsemble<QuantileRegressionTree>(trees, treeWeights, TrainedEnsemble.Bias);
3455+
return new QuantileRegressionTreeEnsemble(trees, treeWeights, TrainedEnsemble.Bias);
34723456
}
34733457
}
34743458
}

src/Microsoft.ML.FastTree/TreeEnsemble.cs

+16-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace Microsoft.ML.FastTree
1212
/// <see cref="TreeEnsemble{T}"/>, we need to compute the output values of all trees in <see cref="Trees"/>,
1313
/// scale those values via <see cref="TreeWeights"/>, and finally sum the scaled values and <see cref="Bias"/> up.
1414
/// </summary>
15-
public sealed class TreeEnsemble<T> where T : RegressionTreeBase
15+
public abstract class TreeEnsemble<T> where T : RegressionTreeBase
1616
{
1717
/// <summary>
1818
/// When doing prediction, this is a value added to the weighted sum of all <see cref="Trees"/>' outputs.
@@ -37,4 +37,19 @@ internal TreeEnsemble(IEnumerable<T> trees, IEnumerable<double> treeWeights, dou
3737
}
3838
}
3939

40+
public sealed class RegressionTreeEnsemble : TreeEnsemble<RegressionTree>
41+
{
42+
internal RegressionTreeEnsemble(IEnumerable<RegressionTree> trees, IEnumerable<double> treeWeights, double bias)
43+
: base(trees, treeWeights, bias)
44+
{
45+
}
46+
}
47+
48+
public sealed class QuantileRegressionTreeEnsemble : TreeEnsemble<QuantileRegressionTree>
49+
{
50+
internal QuantileRegressionTreeEnsemble(IEnumerable<QuantileRegressionTree> trees, IEnumerable<double> treeWeights, double bias)
51+
: base(trees, treeWeights, bias)
52+
{
53+
}
54+
}
4055
}

0 commit comments

Comments
 (0)