@@ -2809,9 +2809,10 @@ public abstract class TreeEnsembleModelParameters :
2809
2809
ISingleCanSavePfa ,
2810
2810
ISingleCanSaveOnnx
2811
2811
{
2812
+ public TreeEnsembleView TrainedTreeEnsembleView { get ; }
2812
2813
//The below two properties are necessary for tree Visualizer
2813
2814
[ BestFriend ]
2814
- internal TreeEnsemble TrainedEnsemble { get ; }
2815
+ internal TreeEnsemble TrainedEnsemble => TrainedTreeEnsembleView . TreeEnsemble ;
2815
2816
int ITreeEnsemble . NumTrees => TrainedEnsemble . NumTrees ;
2816
2817
2817
2818
// Inner args is used only for documentation purposes when saving comments to INI files.
@@ -2853,7 +2854,29 @@ public abstract class TreeEnsembleModelParameters :
2853
2854
/// </summary>
2854
2855
public FeatureContributionCalculator FeatureContributionCalculator => new FeatureContributionCalculator ( this ) ;
2855
2856
2856
- public TreeEnsembleModelParameters ( IHostEnvironment env , string name , TreeEnsemble trainedEnsemble , int numFeatures , string innerArgs )
2857
+ public TreeEnsembleModelParameters ( IHostEnvironment env , string name , TreeEnsembleView trainedEnsembleView ,
2858
+ int numFeatures )
2859
+ : base ( env , name )
2860
+ {
2861
+ Host . CheckValue ( trainedEnsembleView , nameof ( trainedEnsembleView ) ) ;
2862
+ Host . CheckParam ( numFeatures > 0 , nameof ( numFeatures ) , "must be positive" ) ;
2863
+
2864
+ // REVIEW: When we make the predictor wrapper, we may want to further "optimize"
2865
+ // the trained ensemble to, for instance, resize arrays so that they are of the length
2866
+ // the actual number of leaves/nodes, or remove unnecessary arrays, and so forth.
2867
+ TrainedTreeEnsembleView = trainedEnsembleView ;
2868
+ InnerArgs = "" ;
2869
+ NumFeatures = trainedEnsembleView . Trees . Select ( tree => tree . ActiveFeatures . Length ) . Max ( ) ;
2870
+
2871
+ MaxSplitFeatIdx = trainedEnsembleView . TreeEnsemble . GetMaxFeatureIndex ( ) ;
2872
+ Contracts . Assert ( NumFeatures > MaxSplitFeatIdx ) ;
2873
+
2874
+ InputType = new VectorType ( NumberType . Float , NumFeatures ) ;
2875
+ OutputType = NumberType . Float ;
2876
+ }
2877
+
2878
+ [ BestFriend ]
2879
+ internal TreeEnsembleModelParameters ( IHostEnvironment env , string name , TreeEnsemble trainedEnsemble , int numFeatures , string innerArgs )
2857
2880
: base ( env , name )
2858
2881
{
2859
2882
Host . CheckValue ( trainedEnsemble , nameof ( trainedEnsemble ) ) ;
@@ -2863,11 +2886,11 @@ public TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemb
2863
2886
// REVIEW: When we make the predictor wrapper, we may want to further "optimize"
2864
2887
// the trained ensemble to, for instance, resize arrays so that they are of the length
2865
2888
// the actual number of leaves/nodes, or remove unnecessary arrays, and so forth.
2866
- TrainedEnsemble = trainedEnsemble ;
2889
+ TrainedTreeEnsembleView = new TreeEnsembleView ( trainedEnsemble ) ;
2867
2890
InnerArgs = innerArgs ;
2868
2891
NumFeatures = numFeatures ;
2869
2892
2870
- MaxSplitFeatIdx = FindMaxFeatureIndex ( trainedEnsemble ) ;
2893
+ MaxSplitFeatIdx = trainedEnsemble . GetMaxFeatureIndex ( ) ;
2871
2894
Contracts . Assert ( NumFeatures > MaxSplitFeatIdx ) ;
2872
2895
2873
2896
InputType = new VectorType ( NumberType . Float , NumFeatures ) ;
@@ -2891,8 +2914,8 @@ protected TreeEnsembleModelParameters(IHostEnvironment env, string name, ModelLo
2891
2914
if ( ctx . Header . ModelVerWritten >= VerCategoricalSplitSerialized )
2892
2915
categoricalSplits = true ;
2893
2916
2894
- TrainedEnsemble = new TreeEnsemble ( ctx , usingDefaultValues , categoricalSplits ) ;
2895
- MaxSplitFeatIdx = FindMaxFeatureIndex ( TrainedEnsemble ) ;
2917
+ TrainedTreeEnsembleView = new TreeEnsembleView ( new TreeEnsemble ( ctx , usingDefaultValues , categoricalSplits ) ) ;
2918
+ MaxSplitFeatIdx = TrainedEnsemble . GetMaxFeatureIndex ( ) ;
2896
2919
2897
2920
InnerArgs = ctx . LoadStringOrNull ( ) ;
2898
2921
if ( ctx . Header . ModelVerWritten >= VerNumFeaturesSerialized )
@@ -3258,23 +3281,6 @@ public void GetFeatureWeights(ref VBuffer<Float> weights)
3258
3281
bldr . GetResult ( ref weights ) ;
3259
3282
}
3260
3283
3261
- private static int FindMaxFeatureIndex ( TreeEnsemble ensemble )
3262
- {
3263
- int ifeatMax = 0 ;
3264
- for ( int i = 0 ; i < ensemble . NumTrees ; i ++ )
3265
- {
3266
- var tree = ensemble . GetTreeAt ( i ) ;
3267
- for ( int n = 0 ; n < tree . NumNodes ; n ++ )
3268
- {
3269
- int ifeat = tree . SplitFeature ( n ) ;
3270
- if ( ifeat > ifeatMax )
3271
- ifeatMax = ifeat ;
3272
- }
3273
- }
3274
-
3275
- return ifeatMax ;
3276
- }
3277
-
3278
3284
ITree [ ] ITreeEnsemble . GetTrees ( )
3279
3285
{
3280
3286
return TrainedEnsemble . Trees . Select ( k => new Tree ( k ) ) . ToArray ( ) ;
0 commit comments