16
16
using Microsoft . ML . Data ;
17
17
using Microsoft . ML . Data . Conversion ;
18
18
using Microsoft . ML . EntryPoints ;
19
- using Microsoft . ML . Internal . Calibration ;
20
19
using Microsoft . ML . Internal . Internallearn ;
21
20
using Microsoft . ML . Internal . Utilities ;
22
21
using Microsoft . ML . Model ;
@@ -2809,9 +2808,10 @@ public abstract class TreeEnsembleModelParameters :
2809
2808
ISingleCanSavePfa ,
2810
2809
ISingleCanSaveOnnx
2811
2810
{
2811
+ public TreeEnsembleView TrainedTreeEnsembleView { get ; }
2812
2812
//The below two properties are necessary for tree Visualizer
2813
2813
[ BestFriend ]
2814
- internal TreeEnsemble TrainedEnsemble { get ; }
2814
+ internal TreeEnsemble TrainedEnsemble => TrainedTreeEnsembleView . TreeEnsemble ;
2815
2815
int ITreeEnsemble . NumTrees => TrainedEnsemble . NumTrees ;
2816
2816
2817
2817
// Inner args is used only for documentation purposes when saving comments to INI files.
@@ -2863,11 +2863,11 @@ public TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemb
2863
2863
// REVIEW: When we make the predictor wrapper, we may want to further "optimize"
2864
2864
// the trained ensemble to, for instance, resize arrays so that they are of the length
2865
2865
// the actual number of leaves/nodes, or remove unnecessary arrays, and so forth.
2866
- TrainedEnsemble = trainedEnsemble ;
2866
+ TrainedTreeEnsembleView = new TreeEnsembleView ( trainedEnsemble ) ;
2867
2867
InnerArgs = innerArgs ;
2868
2868
NumFeatures = numFeatures ;
2869
2869
2870
- MaxSplitFeatIdx = FindMaxFeatureIndex ( trainedEnsemble ) ;
2870
+ MaxSplitFeatIdx = trainedEnsemble . GetMaxFeatureIndex ( ) ;
2871
2871
Contracts . Assert ( NumFeatures > MaxSplitFeatIdx ) ;
2872
2872
2873
2873
InputType = new VectorType ( NumberType . Float , NumFeatures ) ;
@@ -2891,8 +2891,8 @@ protected TreeEnsembleModelParameters(IHostEnvironment env, string name, ModelLo
2891
2891
if ( ctx . Header . ModelVerWritten >= VerCategoricalSplitSerialized )
2892
2892
categoricalSplits = true ;
2893
2893
2894
- TrainedEnsemble = new TreeEnsemble ( ctx , usingDefaultValues , categoricalSplits ) ;
2895
- MaxSplitFeatIdx = FindMaxFeatureIndex ( TrainedEnsemble ) ;
2894
+ TrainedTreeEnsembleView = new TreeEnsembleView ( new TreeEnsemble ( ctx , usingDefaultValues , categoricalSplits ) ) ;
2895
+ MaxSplitFeatIdx = TrainedEnsemble . GetMaxFeatureIndex ( ) ;
2896
2896
2897
2897
InnerArgs = ctx . LoadStringOrNull ( ) ;
2898
2898
if ( ctx . Header . ModelVerWritten >= VerNumFeaturesSerialized )
@@ -3258,23 +3258,6 @@ public void GetFeatureWeights(ref VBuffer<Float> weights)
3258
3258
bldr . GetResult ( ref weights ) ;
3259
3259
}
3260
3260
3261
- private static int FindMaxFeatureIndex ( TreeEnsemble ensemble )
3262
- {
3263
- int ifeatMax = 0 ;
3264
- for ( int i = 0 ; i < ensemble . NumTrees ; i ++ )
3265
- {
3266
- var tree = ensemble . GetTreeAt ( i ) ;
3267
- for ( int n = 0 ; n < tree . NumNodes ; n ++ )
3268
- {
3269
- int ifeat = tree . SplitFeature ( n ) ;
3270
- if ( ifeat > ifeatMax )
3271
- ifeatMax = ifeat ;
3272
- }
3273
- }
3274
-
3275
- return ifeatMax ;
3276
- }
3277
-
3278
3261
ITree [ ] ITreeEnsemble . GetTrees ( )
3279
3262
{
3280
3263
return TrainedEnsemble . Trees . Select ( k => new Tree ( k ) ) . ToArray ( ) ;
0 commit comments