17
17
using Microsoft . ML . Data ;
18
18
using Microsoft . ML . Data . Conversion ;
19
19
using Microsoft . ML . EntryPoints ;
20
- using Microsoft . ML . Internal . Calibration ;
21
20
using Microsoft . ML . Internal . Internallearn ;
22
21
using Microsoft . ML . Internal . Utilities ;
23
22
using Microsoft . ML . Model ;
@@ -2810,9 +2809,10 @@ public abstract class TreeEnsembleModelParameters :
2810
2809
ISingleCanSavePfa ,
2811
2810
ISingleCanSaveOnnx
2812
2811
{
2812
+ public TreeEnsembleView TrainedTreeEnsembleView { get ; }
2813
2813
//The below two properties are necessary for tree Visualizer
2814
2814
[ BestFriend ]
2815
- internal TreeEnsemble TrainedEnsemble { get ; }
2815
+ internal TreeEnsemble TrainedEnsemble => TrainedTreeEnsembleView . TreeEnsemble ;
2816
2816
int ITreeEnsemble . NumTrees => TrainedEnsemble . NumTrees ;
2817
2817
2818
2818
// Inner args is used only for documentation purposes when saving comments to INI files.
@@ -2864,11 +2864,11 @@ public TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemb
2864
2864
// REVIEW: When we make the predictor wrapper, we may want to further "optimize"
2865
2865
// the trained ensemble to, for instance, resize arrays so that they are of the length
2866
2866
// the actual number of leaves/nodes, or remove unnecessary arrays, and so forth.
2867
- TrainedEnsemble = trainedEnsemble ;
2867
+ TrainedTreeEnsembleView = new TreeEnsembleView ( trainedEnsemble ) ;
2868
2868
InnerArgs = innerArgs ;
2869
2869
NumFeatures = numFeatures ;
2870
2870
2871
- MaxSplitFeatIdx = FindMaxFeatureIndex ( trainedEnsemble ) ;
2871
+ MaxSplitFeatIdx = trainedEnsemble . GetMaxFeatureIndex ( ) ;
2872
2872
Contracts . Assert ( NumFeatures > MaxSplitFeatIdx ) ;
2873
2873
2874
2874
InputType = new VectorType ( NumberType . Float , NumFeatures ) ;
@@ -2892,8 +2892,8 @@ protected TreeEnsembleModelParameters(IHostEnvironment env, string name, ModelLo
2892
2892
if ( ctx . Header . ModelVerWritten >= VerCategoricalSplitSerialized )
2893
2893
categoricalSplits = true ;
2894
2894
2895
- TrainedEnsemble = new TreeEnsemble ( ctx , usingDefaultValues , categoricalSplits ) ;
2896
- MaxSplitFeatIdx = FindMaxFeatureIndex ( TrainedEnsemble ) ;
2895
+ TrainedTreeEnsembleView = new TreeEnsembleView ( new TreeEnsemble ( ctx , usingDefaultValues , categoricalSplits ) ) ;
2896
+ MaxSplitFeatIdx = TrainedEnsemble . GetMaxFeatureIndex ( ) ;
2897
2897
2898
2898
InnerArgs = ctx . LoadStringOrNull ( ) ;
2899
2899
if ( ctx . Header . ModelVerWritten >= VerNumFeaturesSerialized )
@@ -3259,23 +3259,6 @@ public void GetFeatureWeights(ref VBuffer<Float> weights)
3259
3259
bldr . GetResult ( ref weights ) ;
3260
3260
}
3261
3261
3262
- private static int FindMaxFeatureIndex ( TreeEnsemble ensemble )
3263
- {
3264
- int ifeatMax = 0 ;
3265
- for ( int i = 0 ; i < ensemble . NumTrees ; i ++ )
3266
- {
3267
- var tree = ensemble . GetTreeAt ( i ) ;
3268
- for ( int n = 0 ; n < tree . NumNodes ; n ++ )
3269
- {
3270
- int ifeat = tree . SplitFeature ( n ) ;
3271
- if ( ifeat > ifeatMax )
3272
- ifeatMax = ifeat ;
3273
- }
3274
- }
3275
-
3276
- return ifeatMax ;
3277
- }
3278
-
3279
3262
ITree [ ] ITreeEnsemble . GetTrees ( )
3280
3263
{
3281
3264
return TrainedEnsemble . Trees . Select ( k => new Tree ( k ) ) . ToArray ( ) ;
0 commit comments