Skip to content

Commit 18ec5e0

Browse files
committed
Immutable RegressionTree and TreeEnsemble wrappers
1 parent 80b36f0 commit 18ec5e0

File tree

4 files changed

+195
-64
lines changed

4 files changed

+195
-64
lines changed

src/Microsoft.ML.FastTree/FastTree.cs

+29-23
Original file line numberDiff line numberDiff line change
@@ -2809,9 +2809,10 @@ public abstract class TreeEnsembleModelParameters :
28092809
ISingleCanSavePfa,
28102810
ISingleCanSaveOnnx
28112811
{
2812+
public TreeEnsembleView TrainedTreeEnsembleView { get; }
28122813
//The below two properties are necessary for tree Visualizer
28132814
[BestFriend]
2814-
internal TreeEnsemble TrainedEnsemble { get; }
2815+
internal TreeEnsemble TrainedEnsemble => TrainedTreeEnsembleView.TreeEnsemble;
28152816
int ITreeEnsemble.NumTrees => TrainedEnsemble.NumTrees;
28162817

28172818
// Inner args is used only for documentation purposes when saving comments to INI files.
@@ -2853,7 +2854,29 @@ public abstract class TreeEnsembleModelParameters :
28532854
/// </summary>
28542855
public FeatureContributionCalculator FeatureContributionCalculator => new FeatureContributionCalculator(this);
28552856

2856-
public TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemble trainedEnsemble, int numFeatures, string innerArgs)
2857+
public TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsembleView trainedEnsembleView,
2858+
int numFeatures)
2859+
: base(env, name)
2860+
{
2861+
Host.CheckValue(trainedEnsembleView, nameof(trainedEnsembleView));
2862+
Host.CheckParam(numFeatures > 0, nameof(numFeatures), "must be positive");
2863+
2864+
// REVIEW: When we make the predictor wrapper, we may want to further "optimize"
2865+
// the trained ensemble to, for instance, resize arrays so that they are of the length
2866+
// the actual number of leaves/nodes, or remove unnecessary arrays, and so forth.
2867+
TrainedTreeEnsembleView = trainedEnsembleView;
2868+
InnerArgs = "";
2869+
NumFeatures = trainedEnsembleView.Trees.Select(tree => tree.ActiveFeatures.Length).Max();
2870+
2871+
MaxSplitFeatIdx = trainedEnsembleView.TreeEnsemble.GetMaxFeatureIndex();
2872+
Contracts.Assert(NumFeatures > MaxSplitFeatIdx);
2873+
2874+
InputType = new VectorType(NumberType.Float, NumFeatures);
2875+
OutputType = NumberType.Float;
2876+
}
2877+
2878+
[BestFriend]
2879+
internal TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemble trainedEnsemble, int numFeatures, string innerArgs)
28572880
: base(env, name)
28582881
{
28592882
Host.CheckValue(trainedEnsemble, nameof(trainedEnsemble));
@@ -2863,11 +2886,11 @@ public TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemb
28632886
// REVIEW: When we make the predictor wrapper, we may want to further "optimize"
28642887
// the trained ensemble to, for instance, resize arrays so that they are of the length
28652888
// the actual number of leaves/nodes, or remove unnecessary arrays, and so forth.
2866-
TrainedEnsemble = trainedEnsemble;
2889+
TrainedTreeEnsembleView = new TreeEnsembleView(trainedEnsemble);
28672890
InnerArgs = innerArgs;
28682891
NumFeatures = numFeatures;
28692892

2870-
MaxSplitFeatIdx = FindMaxFeatureIndex(trainedEnsemble);
2893+
MaxSplitFeatIdx = trainedEnsemble.GetMaxFeatureIndex();
28712894
Contracts.Assert(NumFeatures > MaxSplitFeatIdx);
28722895

28732896
InputType = new VectorType(NumberType.Float, NumFeatures);
@@ -2891,8 +2914,8 @@ protected TreeEnsembleModelParameters(IHostEnvironment env, string name, ModelLo
28912914
if (ctx.Header.ModelVerWritten >= VerCategoricalSplitSerialized)
28922915
categoricalSplits = true;
28932916

2894-
TrainedEnsemble = new TreeEnsemble(ctx, usingDefaultValues, categoricalSplits);
2895-
MaxSplitFeatIdx = FindMaxFeatureIndex(TrainedEnsemble);
2917+
TrainedTreeEnsembleView = new TreeEnsembleView(new TreeEnsemble(ctx, usingDefaultValues, categoricalSplits));
2918+
MaxSplitFeatIdx = TrainedEnsemble.GetMaxFeatureIndex();
28962919

28972920
InnerArgs = ctx.LoadStringOrNull();
28982921
if (ctx.Header.ModelVerWritten >= VerNumFeaturesSerialized)
@@ -3258,23 +3281,6 @@ public void GetFeatureWeights(ref VBuffer<Float> weights)
32583281
bldr.GetResult(ref weights);
32593282
}
32603283

3261-
private static int FindMaxFeatureIndex(TreeEnsemble ensemble)
3262-
{
3263-
int ifeatMax = 0;
3264-
for (int i = 0; i < ensemble.NumTrees; i++)
3265-
{
3266-
var tree = ensemble.GetTreeAt(i);
3267-
for (int n = 0; n < tree.NumNodes; n++)
3268-
{
3269-
int ifeat = tree.SplitFeature(n);
3270-
if (ifeat > ifeatMax)
3271-
ifeatMax = ifeat;
3272-
}
3273-
}
3274-
3275-
return ifeatMax;
3276-
}
3277-
32783284
ITree[] ITreeEnsemble.GetTrees()
32793285
{
32803286
return TrainedEnsemble.Trees.Select(k => new Tree(k)).ToArray();

0 commit comments

Comments
 (0)