Skip to content

Commit b27b171

Browse files
committed
Immutable RegressionTree and TreeEnsemble wrappers
Simplify public area
1 parent d8bc32e commit b27b171

File tree

3 files changed

+148
-63
lines changed

3 files changed

+148
-63
lines changed

src/Microsoft.ML.FastTree/FastTree.cs

+6-23
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
using Microsoft.ML.Data;
1818
using Microsoft.ML.Data.Conversion;
1919
using Microsoft.ML.EntryPoints;
20-
using Microsoft.ML.Internal.Calibration;
2120
using Microsoft.ML.Internal.Internallearn;
2221
using Microsoft.ML.Internal.Utilities;
2322
using Microsoft.ML.Model;
@@ -2810,9 +2809,10 @@ public abstract class TreeEnsembleModelParameters :
28102809
ISingleCanSavePfa,
28112810
ISingleCanSaveOnnx
28122811
{
2812+
public TreeEnsembleView TrainedTreeEnsembleView { get; }
28132813
//The below two properties are necessary for tree Visualizer
28142814
[BestFriend]
2815-
internal TreeEnsemble TrainedEnsemble { get; }
2815+
internal TreeEnsemble TrainedEnsemble => TrainedTreeEnsembleView.TreeEnsemble;
28162816
int ITreeEnsemble.NumTrees => TrainedEnsemble.NumTrees;
28172817

28182818
// Inner args is used only for documentation purposes when saving comments to INI files.
@@ -2864,11 +2864,11 @@ public TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemb
28642864
// REVIEW: When we make the predictor wrapper, we may want to further "optimize"
28652865
// the trained ensemble to, for instance, resize arrays so that they are of the length
28662866
// the actual number of leaves/nodes, or remove unnecessary arrays, and so forth.
2867-
TrainedEnsemble = trainedEnsemble;
2867+
TrainedTreeEnsembleView = new TreeEnsembleView(trainedEnsemble);
28682868
InnerArgs = innerArgs;
28692869
NumFeatures = numFeatures;
28702870

2871-
MaxSplitFeatIdx = FindMaxFeatureIndex(trainedEnsemble);
2871+
MaxSplitFeatIdx = trainedEnsemble.GetMaxFeatureIndex();
28722872
Contracts.Assert(NumFeatures > MaxSplitFeatIdx);
28732873

28742874
InputType = new VectorType(NumberType.Float, NumFeatures);
@@ -2892,8 +2892,8 @@ protected TreeEnsembleModelParameters(IHostEnvironment env, string name, ModelLo
28922892
if (ctx.Header.ModelVerWritten >= VerCategoricalSplitSerialized)
28932893
categoricalSplits = true;
28942894

2895-
TrainedEnsemble = new TreeEnsemble(ctx, usingDefaultValues, categoricalSplits);
2896-
MaxSplitFeatIdx = FindMaxFeatureIndex(TrainedEnsemble);
2895+
TrainedTreeEnsembleView = new TreeEnsembleView(new TreeEnsemble(ctx, usingDefaultValues, categoricalSplits));
2896+
MaxSplitFeatIdx = TrainedEnsemble.GetMaxFeatureIndex();
28972897

28982898
InnerArgs = ctx.LoadStringOrNull();
28992899
if (ctx.Header.ModelVerWritten >= VerNumFeaturesSerialized)
@@ -3259,23 +3259,6 @@ public void GetFeatureWeights(ref VBuffer<Float> weights)
32593259
bldr.GetResult(ref weights);
32603260
}
32613261

3262-
private static int FindMaxFeatureIndex(TreeEnsemble ensemble)
3263-
{
3264-
int ifeatMax = 0;
3265-
for (int i = 0; i < ensemble.NumTrees; i++)
3266-
{
3267-
var tree = ensemble.GetTreeAt(i);
3268-
for (int n = 0; n < tree.NumNodes; n++)
3269-
{
3270-
int ifeat = tree.SplitFeature(n);
3271-
if (ifeat > ifeatMax)
3272-
ifeatMax = ifeat;
3273-
}
3274-
}
3275-
3276-
return ifeatMax;
3277-
}
3278-
32793262
ITree[] ITreeEnsemble.GetTrees()
32803263
{
32813264
return TrainedEnsemble.Trees.Select(k => new Tree(k)).ToArray();

0 commit comments

Comments
 (0)