Skip to content

Commit d960da4

Browse files
committed
Immutable RegressionTree and TreeEnsemble wrappers
Simplify public area
1 parent 80b36f0 commit d960da4

File tree

3 files changed

+148
-63
lines changed

3 files changed

+148
-63
lines changed

src/Microsoft.ML.FastTree/FastTree.cs

+6-23
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
using Microsoft.ML.Data;
1717
using Microsoft.ML.Data.Conversion;
1818
using Microsoft.ML.EntryPoints;
19-
using Microsoft.ML.Internal.Calibration;
2019
using Microsoft.ML.Internal.Internallearn;
2120
using Microsoft.ML.Internal.Utilities;
2221
using Microsoft.ML.Model;
@@ -2809,9 +2808,10 @@ public abstract class TreeEnsembleModelParameters :
28092808
ISingleCanSavePfa,
28102809
ISingleCanSaveOnnx
28112810
{
2811+
public TreeEnsembleView TrainedTreeEnsembleView { get; }
28122812
//The below two properties are necessary for tree Visualizer
28132813
[BestFriend]
2814-
internal TreeEnsemble TrainedEnsemble { get; }
2814+
internal TreeEnsemble TrainedEnsemble => TrainedTreeEnsembleView.TreeEnsemble;
28152815
int ITreeEnsemble.NumTrees => TrainedEnsemble.NumTrees;
28162816

28172817
// Inner args is used only for documentation purposes when saving comments to INI files.
@@ -2863,11 +2863,11 @@ public TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemb
28632863
// REVIEW: When we make the predictor wrapper, we may want to further "optimize"
28642864
// the trained ensemble to, for instance, resize arrays so that they are of the length
28652865
// the actual number of leaves/nodes, or remove unnecessary arrays, and so forth.
2866-
TrainedEnsemble = trainedEnsemble;
2866+
TrainedTreeEnsembleView = new TreeEnsembleView(trainedEnsemble);
28672867
InnerArgs = innerArgs;
28682868
NumFeatures = numFeatures;
28692869

2870-
MaxSplitFeatIdx = FindMaxFeatureIndex(trainedEnsemble);
2870+
MaxSplitFeatIdx = trainedEnsemble.GetMaxFeatureIndex();
28712871
Contracts.Assert(NumFeatures > MaxSplitFeatIdx);
28722872

28732873
InputType = new VectorType(NumberType.Float, NumFeatures);
@@ -2891,8 +2891,8 @@ protected TreeEnsembleModelParameters(IHostEnvironment env, string name, ModelLo
28912891
if (ctx.Header.ModelVerWritten >= VerCategoricalSplitSerialized)
28922892
categoricalSplits = true;
28932893

2894-
TrainedEnsemble = new TreeEnsemble(ctx, usingDefaultValues, categoricalSplits);
2895-
MaxSplitFeatIdx = FindMaxFeatureIndex(TrainedEnsemble);
2894+
TrainedTreeEnsembleView = new TreeEnsembleView(new TreeEnsemble(ctx, usingDefaultValues, categoricalSplits));
2895+
MaxSplitFeatIdx = TrainedEnsemble.GetMaxFeatureIndex();
28962896

28972897
InnerArgs = ctx.LoadStringOrNull();
28982898
if (ctx.Header.ModelVerWritten >= VerNumFeaturesSerialized)
@@ -3258,23 +3258,6 @@ public void GetFeatureWeights(ref VBuffer<Float> weights)
32583258
bldr.GetResult(ref weights);
32593259
}
32603260

3261-
private static int FindMaxFeatureIndex(TreeEnsemble ensemble)
3262-
{
3263-
int ifeatMax = 0;
3264-
for (int i = 0; i < ensemble.NumTrees; i++)
3265-
{
3266-
var tree = ensemble.GetTreeAt(i);
3267-
for (int n = 0; n < tree.NumNodes; n++)
3268-
{
3269-
int ifeat = tree.SplitFeature(n);
3270-
if (ifeat > ifeatMax)
3271-
ifeatMax = ifeat;
3272-
}
3273-
}
3274-
3275-
return ifeatMax;
3276-
}
3277-
32783261
ITree[] ITreeEnsemble.GetTrees()
32793262
{
32803263
return TrainedEnsemble.Trees.Select(k => new Tree(k)).ToArray();

0 commit comments

Comments
 (0)