@@ -55,16 +55,16 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
55
55
{
56
56
protected readonly TArgs Args ;
57
57
protected readonly bool AllowGC ;
58
- protected TreeEnsemble TrainedEnsemble ;
58
+ internal TreeEnsemble TrainedEnsemble ;
59
59
protected int FeatureCount ;
60
60
private protected RoleMappedData ValidData ;
61
61
/// <summary>
62
62
/// If not null, it's a test data set passed in from training context. It will be converted to one element in
63
63
/// <see cref="Tests"/> by calling <see cref="ExamplesToFastTreeBins.GetCompatibleDataset"/> in <see cref="InitializeTests"/>.
64
64
/// </summary>
65
65
private protected RoleMappedData TestData ;
66
- protected IParallelTraining ParallelTraining ;
67
- protected OptimizationAlgorithm OptimizationAlgorithm ;
66
+ internal IParallelTraining ParallelTraining ;
67
+ internal OptimizationAlgorithm OptimizationAlgorithm ;
68
68
protected Dataset TrainSet ;
69
69
protected Dataset ValidSet ;
70
70
/// <summary>
@@ -88,7 +88,8 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
88
88
protected double [ ] InitValidScores ;
89
89
protected double [ ] [ ] InitTestScores ;
90
90
//protected int Iteration;
91
- protected TreeEnsemble Ensemble ;
91
+ [ BestFriend ]
92
+ internal TreeEnsemble Ensemble ;
92
93
93
94
protected bool HasValidSet => ValidSet != null ;
94
95
@@ -174,8 +175,8 @@ private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaSh
174
175
175
176
protected abstract Test ConstructTestForTrainingData ( ) ;
176
177
177
- protected abstract OptimizationAlgorithm ConstructOptimizationAlgorithm ( IChannel ch ) ;
178
- protected abstract TreeLearner ConstructTreeLearner ( IChannel ch ) ;
178
+ internal abstract OptimizationAlgorithm ConstructOptimizationAlgorithm ( IChannel ch ) ;
179
+ internal abstract TreeLearner ConstructTreeLearner ( IChannel ch ) ;
179
180
180
181
protected abstract ObjectiveFunctionBase ConstructObjFunc ( IChannel ch ) ;
181
182
@@ -792,7 +793,7 @@ private float GetMachineAvailableBytes()
792
793
793
794
// This method is called at the end of each training iteration, with the tree that was learnt on that iteration.
794
795
// Note that this tree can be null if no tree was learnt this iteration.
795
- protected virtual void CustomizedTrainingIteration ( RegressionTree tree )
796
+ internal virtual void CustomizedTrainingIteration ( RegressionTree tree )
796
797
{
797
798
}
798
799
@@ -2809,10 +2810,14 @@ public abstract class TreeEnsembleModelParameters :
2809
2810
ISingleCanSavePfa ,
2810
2811
ISingleCanSaveOnnx
2811
2812
{
2812
- public TreeEnsembleView TrainedTreeEnsembleView { get ; }
2813
- //The below two properties are necessary for tree Visualizer
2813
+ /// <summary>
2814
+ /// An ensemble of trees exposed to users. It is a wrapper on the <see langword="internal"/> <see cref="TreeEnsemble"/> in <see cref="TreeRegressorCollection"/>.
2815
+ /// </summary>
2816
+ public TreeRegressorCollection TrainedTreeCollection { get ; }
2817
+
2818
+ // The below two properties are necessary for tree Visualizer
2814
2819
[ BestFriend ]
2815
- internal TreeEnsemble TrainedEnsemble => TrainedTreeEnsembleView . TreeEnsemble ;
2820
+ internal TreeEnsemble TrainedEnsemble => TrainedTreeCollection . TreeEnsemble ;
2816
2821
int ITreeEnsemble . NumTrees => TrainedEnsemble . NumTrees ;
2817
2822
2818
2823
// Inner args is used only for documentation purposes when saving comments to INI files.
@@ -2854,7 +2859,9 @@ public abstract class TreeEnsembleModelParameters :
2854
2859
/// </summary>
2855
2860
public FeatureContributionCalculator FeatureContributionCalculator => new FeatureContributionCalculator ( this ) ;
2856
2861
2857
- public TreeEnsembleModelParameters ( IHostEnvironment env , string name , TreeEnsemble trainedEnsemble , int numFeatures , string innerArgs )
2862
+ /// The following function is used in both FastTree and LightGBM so <see cref="BestFriendAttribute"/> is required.
2863
+ [ BestFriend ]
2864
+ internal TreeEnsembleModelParameters ( IHostEnvironment env , string name , TreeEnsemble trainedEnsemble , int numFeatures , string innerArgs )
2858
2865
: base ( env , name )
2859
2866
{
2860
2867
Host . CheckValue ( trainedEnsemble , nameof ( trainedEnsemble ) ) ;
@@ -2864,7 +2871,7 @@ public TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemb
2864
2871
// REVIEW: When we make the predictor wrapper, we may want to further "optimize"
2865
2872
// the trained ensemble to, for instance, resize arrays so that they are of the length
2866
2873
// the actual number of leaves/nodes, or remove unnecessary arrays, and so forth.
2867
- TrainedTreeEnsembleView = new TreeEnsembleView ( trainedEnsemble ) ;
2874
+ TrainedTreeCollection = new TreeRegressorCollection ( trainedEnsemble ) ;
2868
2875
InnerArgs = innerArgs ;
2869
2876
NumFeatures = numFeatures ;
2870
2877
@@ -2892,7 +2899,7 @@ protected TreeEnsembleModelParameters(IHostEnvironment env, string name, ModelLo
2892
2899
if ( ctx . Header . ModelVerWritten >= VerCategoricalSplitSerialized )
2893
2900
categoricalSplits = true ;
2894
2901
2895
- TrainedTreeEnsembleView = new TreeEnsembleView ( new TreeEnsemble ( ctx , usingDefaultValues , categoricalSplits ) ) ;
2902
+ TrainedTreeCollection = new TreeRegressorCollection ( new TreeEnsemble ( ctx , usingDefaultValues , categoricalSplits ) ) ;
2896
2903
MaxSplitFeatIdx = TrainedEnsemble . GetMaxFeatureIndex ( ) ;
2897
2904
2898
2905
InnerArgs = ctx . LoadStringOrNull ( ) ;
0 commit comments