@@ -54,7 +54,6 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
54
54
{
55
55
protected readonly TArgs Args ;
56
56
protected readonly bool AllowGC ;
57
- [ BestFriend ]
58
57
internal TreeEnsemble TrainedEnsemble ;
59
58
protected int FeatureCount ;
60
59
private protected RoleMappedData ValidData ;
@@ -63,8 +62,8 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
63
62
/// <see cref="Tests"/> by calling <see cref="ExamplesToFastTreeBins.GetCompatibleDataset"/> in <see cref="InitializeTests"/>.
64
63
/// </summary>
65
64
private protected RoleMappedData TestData ;
66
- protected IParallelTraining ParallelTraining ;
67
- protected OptimizationAlgorithm OptimizationAlgorithm ;
65
+ internal IParallelTraining ParallelTraining ;
66
+ internal OptimizationAlgorithm OptimizationAlgorithm ;
68
67
protected Dataset TrainSet ;
69
68
protected Dataset ValidSet ;
70
69
/// <summary>
@@ -175,8 +174,8 @@ private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaSh
175
174
176
175
protected abstract Test ConstructTestForTrainingData ( ) ;
177
176
178
- protected abstract OptimizationAlgorithm ConstructOptimizationAlgorithm ( IChannel ch ) ;
179
- protected abstract TreeLearner ConstructTreeLearner ( IChannel ch ) ;
177
+ internal abstract OptimizationAlgorithm ConstructOptimizationAlgorithm ( IChannel ch ) ;
178
+ internal abstract TreeLearner ConstructTreeLearner ( IChannel ch ) ;
180
179
181
180
protected abstract ObjectiveFunctionBase ConstructObjFunc ( IChannel ch ) ;
182
181
@@ -793,7 +792,7 @@ private float GetMachineAvailableBytes()
793
792
794
793
// This method is called at the end of each training iteration, with the tree that was learnt on that iteration.
795
794
// Note that this tree can be null if no tree was learnt this iteration.
796
- protected virtual void CustomizedTrainingIteration ( RegressionTree tree )
795
+ internal virtual void CustomizedTrainingIteration ( RegressionTree tree )
797
796
{
798
797
}
799
798
@@ -2810,10 +2809,13 @@ public abstract class TreeEnsembleModelParameters :
2810
2809
ISingleCanSavePfa ,
2811
2810
ISingleCanSaveOnnx
2812
2811
{
2813
- public TreeEnsembleView TrainedTreeEnsembleView { get ; }
2812
+ /// <summary>
2813
+ /// An ensemble of trees exposed to users. It is a wrapper on an <see langword="internal"/> <see cref="TreeEnsemble"/> in <see cref="TreeRegressorCollection"/>.
2814
+ /// </summary>
2815
+ public TreeRegressorCollection TrainedTreeCollection { get ; }
2814
2816
//The below two properties are necessary for tree Visualizer
2815
2817
[ BestFriend ]
2816
- internal TreeEnsemble TrainedEnsemble => TrainedTreeEnsembleView . TreeEnsemble ;
2818
+ internal TreeEnsemble TrainedEnsemble => TrainedTreeCollection . TreeEnsemble ;
2817
2819
int ITreeEnsemble . NumTrees => TrainedEnsemble . NumTrees ;
2818
2820
2819
2821
// Inner args is used only for documentation purposes when saving comments to INI files.
@@ -2866,7 +2868,7 @@ internal TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnse
2866
2868
// REVIEW: When we make the predictor wrapper, we may want to further "optimize"
2867
2869
// the trained ensemble to, for instance, resize arrays so that they are of the length
2868
2870
// the actual number of leaves/nodes, or remove unnecessary arrays, and so forth.
2869
- TrainedTreeEnsembleView = new TreeEnsembleView ( trainedEnsemble ) ;
2871
+ TrainedTreeCollection = new TreeRegressorCollection ( trainedEnsemble ) ;
2870
2872
InnerArgs = innerArgs ;
2871
2873
NumFeatures = numFeatures ;
2872
2874
@@ -2894,7 +2896,7 @@ protected TreeEnsembleModelParameters(IHostEnvironment env, string name, ModelLo
2894
2896
if ( ctx . Header . ModelVerWritten >= VerCategoricalSplitSerialized )
2895
2897
categoricalSplits = true ;
2896
2898
2897
- TrainedTreeEnsembleView = new TreeEnsembleView ( new TreeEnsemble ( ctx , usingDefaultValues , categoricalSplits ) ) ;
2899
+ TrainedTreeCollection = new TreeRegressorCollection ( new TreeEnsemble ( ctx , usingDefaultValues , categoricalSplits ) ) ;
2898
2900
MaxSplitFeatIdx = TrainedEnsemble . GetMaxFeatureIndex ( ) ;
2899
2901
2900
2902
InnerArgs = ctx . LoadStringOrNull ( ) ;
0 commit comments