17
17
using Microsoft . ML . Data ;
18
18
using Microsoft . ML . Data . Conversion ;
19
19
using Microsoft . ML . EntryPoints ;
20
- using Microsoft . ML . Internal . Calibration ;
20
+ using Microsoft . ML . FastTree ;
21
21
using Microsoft . ML . Internal . Internallearn ;
22
22
using Microsoft . ML . Internal . Utilities ;
23
23
using Microsoft . ML . Model ;
@@ -56,16 +56,16 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
56
56
{
57
57
protected readonly TArgs Args ;
58
58
protected readonly bool AllowGC ;
59
- protected TreeEnsemble TrainedEnsemble ;
60
59
protected int FeatureCount ;
60
+ private protected InternalTreeEnsemble TrainedEnsemble ;
61
61
private protected RoleMappedData ValidData ;
62
62
/// <summary>
63
63
/// If not null, it's a test data set passed in from training context. It will be converted to one element in
64
64
/// <see cref="Tests"/> by calling <see cref="ExamplesToFastTreeBins.GetCompatibleDataset"/> in <see cref="InitializeTests"/>.
65
65
/// </summary>
66
66
private protected RoleMappedData TestData ;
67
- protected IParallelTraining ParallelTraining ;
68
- protected OptimizationAlgorithm OptimizationAlgorithm ;
67
+ private protected IParallelTraining ParallelTraining ;
68
+ private protected OptimizationAlgorithm OptimizationAlgorithm ;
69
69
protected Dataset TrainSet ;
70
70
protected Dataset ValidSet ;
71
71
/// <summary>
@@ -89,7 +89,7 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
89
89
protected double [ ] InitValidScores ;
90
90
protected double [ ] [ ] InitTestScores ;
91
91
//protected int Iteration;
92
- protected TreeEnsemble Ensemble ;
92
+ private protected InternalTreeEnsemble Ensemble ;
93
93
94
94
protected bool HasValidSet => ValidSet != null ;
95
95
@@ -175,8 +175,9 @@ private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaSh
175
175
176
176
protected abstract Test ConstructTestForTrainingData ( ) ;
177
177
178
- protected abstract OptimizationAlgorithm ConstructOptimizationAlgorithm ( IChannel ch ) ;
179
- protected abstract TreeLearner ConstructTreeLearner ( IChannel ch ) ;
178
+ private protected abstract OptimizationAlgorithm ConstructOptimizationAlgorithm ( IChannel ch ) ;
179
+
180
+ private protected abstract TreeLearner ConstructTreeLearner ( IChannel ch ) ;
180
181
181
182
protected abstract ObjectiveFunctionBase ConstructObjFunc ( IChannel ch ) ;
182
183
@@ -488,7 +489,7 @@ protected bool AreSamplesWeighted(IChannel ch)
488
489
489
490
private void InitializeEnsemble ( )
490
491
{
491
- Ensemble = new TreeEnsemble ( ) ;
492
+ Ensemble = new InternalTreeEnsemble ( ) ;
492
493
}
493
494
494
495
/// <summary>
@@ -793,7 +794,7 @@ private float GetMachineAvailableBytes()
793
794
794
795
// This method is called at the end of each training iteration, with the tree that was learnt on that iteration.
795
796
// Note that this tree can be null if no tree was learnt this iteration.
796
- protected virtual void CustomizedTrainingIteration ( RegressionTree tree )
797
+ private protected virtual void CustomizedTrainingIteration ( InternalRegressionTree tree )
797
798
{
798
799
}
799
800
@@ -924,7 +925,7 @@ internal abstract class DataConverter
924
925
/// of features we actually trained on. This can be null in the event that no filtering
925
926
/// occurred.
926
927
/// </summary>
927
- /// <seealso cref="TreeEnsemble .RemapFeatures"/>
928
+ /// <seealso cref="InternalTreeEnsemble .RemapFeatures"/>
928
929
public int [ ] FeatureMap ;
929
930
930
931
protected readonly IHost Host ;
@@ -2810,9 +2811,10 @@ public abstract class TreeEnsembleModelParameters :
2810
2811
ISingleCanSavePfa ,
2811
2812
ISingleCanSaveOnnx
2812
2813
{
2813
- //The below two properties are necessary for tree Visualizer
2814
+ // The below two properties are necessary for tree Visualizer
2814
2815
[ BestFriend ]
2815
- internal TreeEnsemble TrainedEnsemble { get ; }
2816
+ internal InternalTreeEnsemble TrainedEnsemble { get ; }
2817
+
2816
2818
int ITreeEnsemble . NumTrees => TrainedEnsemble . NumTrees ;
2817
2819
2818
2820
// Inner args is used only for documentation purposes when saving comments to INI files.
@@ -2854,7 +2856,9 @@ public abstract class TreeEnsembleModelParameters :
2854
2856
/// </summary>
2855
2857
public FeatureContributionCalculator FeatureContributionCalculator => new FeatureContributionCalculator ( this ) ;
2856
2858
2857
- public TreeEnsembleModelParameters ( IHostEnvironment env , string name , TreeEnsemble trainedEnsemble , int numFeatures , string innerArgs )
2859
+ /// The following function is used in both FastTree and LightGBM so <see cref="BestFriendAttribute"/> is required.
2860
+ [ BestFriend ]
2861
+ internal TreeEnsembleModelParameters ( IHostEnvironment env , string name , InternalTreeEnsemble trainedEnsemble , int numFeatures , string innerArgs )
2858
2862
: base ( env , name )
2859
2863
{
2860
2864
Host . CheckValue ( trainedEnsemble , nameof ( trainedEnsemble ) ) ;
@@ -2868,7 +2872,7 @@ public TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemb
2868
2872
InnerArgs = innerArgs ;
2869
2873
NumFeatures = numFeatures ;
2870
2874
2871
- MaxSplitFeatIdx = FindMaxFeatureIndex ( trainedEnsemble ) ;
2875
+ MaxSplitFeatIdx = trainedEnsemble . GetMaxFeatureIndex ( ) ;
2872
2876
Contracts . Assert ( NumFeatures > MaxSplitFeatIdx ) ;
2873
2877
2874
2878
InputType = new VectorType ( NumberType . Float , NumFeatures ) ;
@@ -2892,8 +2896,8 @@ protected TreeEnsembleModelParameters(IHostEnvironment env, string name, ModelLo
2892
2896
if ( ctx . Header . ModelVerWritten >= VerCategoricalSplitSerialized )
2893
2897
categoricalSplits = true ;
2894
2898
2895
- TrainedEnsemble = new TreeEnsemble ( ctx , usingDefaultValues , categoricalSplits ) ;
2896
- MaxSplitFeatIdx = FindMaxFeatureIndex ( TrainedEnsemble ) ;
2899
+ TrainedEnsemble = new InternalTreeEnsemble ( ctx , usingDefaultValues , categoricalSplits ) ;
2900
+ MaxSplitFeatIdx = TrainedEnsemble . GetMaxFeatureIndex ( ) ;
2897
2901
2898
2902
InnerArgs = ctx . LoadStringOrNull ( ) ;
2899
2903
if ( ctx . Header . ModelVerWritten >= VerNumFeaturesSerialized )
@@ -3195,7 +3199,7 @@ private void SaveEnsembleAsCode(TextWriter writer, RoleMappedSchema schema)
3195
3199
MetadataUtils . GetSlotNames ( schema , RoleMappedSchema . ColumnRole . Feature , NumFeatures , ref names ) ;
3196
3200
3197
3201
int i = 0 ;
3198
- foreach ( RegressionTree tree in TrainedEnsemble . Trees )
3202
+ foreach ( InternalRegressionTree tree in TrainedEnsemble . Trees )
3199
3203
{
3200
3204
writer . Write ( "double treeOutput{0}=" , i ) ;
3201
3205
SaveTreeAsCode ( tree , writer , in names ) ;
@@ -3211,13 +3215,13 @@ private void SaveEnsembleAsCode(TextWriter writer, RoleMappedSchema schema)
3211
3215
/// <summary>
3212
3216
/// Convert a single tree to code, called recursively
3213
3217
/// </summary>
3214
- private void SaveTreeAsCode ( RegressionTree tree , TextWriter writer , in VBuffer < ReadOnlyMemory < char > > names )
3218
+ private void SaveTreeAsCode ( InternalRegressionTree tree , TextWriter writer , in VBuffer < ReadOnlyMemory < char > > names )
3215
3219
{
3216
3220
ToCSharp ( tree , writer , 0 , in names ) ;
3217
3221
}
3218
3222
3219
3223
// converts a subtree into a C# expression
3220
- private void ToCSharp ( RegressionTree tree , TextWriter writer , int node , in VBuffer < ReadOnlyMemory < char > > names )
3224
+ private void ToCSharp ( InternalRegressionTree tree , TextWriter writer , int node , in VBuffer < ReadOnlyMemory < char > > names )
3221
3225
{
3222
3226
if ( node < 0 )
3223
3227
{
@@ -3259,23 +3263,6 @@ public void GetFeatureWeights(ref VBuffer<Float> weights)
3259
3263
bldr . GetResult ( ref weights ) ;
3260
3264
}
3261
3265
3262
- private static int FindMaxFeatureIndex ( TreeEnsemble ensemble )
3263
- {
3264
- int ifeatMax = 0 ;
3265
- for ( int i = 0 ; i < ensemble . NumTrees ; i ++ )
3266
- {
3267
- var tree = ensemble . GetTreeAt ( i ) ;
3268
- for ( int n = 0 ; n < tree . NumNodes ; n ++ )
3269
- {
3270
- int ifeat = tree . SplitFeature ( n ) ;
3271
- if ( ifeat > ifeatMax )
3272
- ifeatMax = ifeat ;
3273
- }
3274
- }
3275
-
3276
- return ifeatMax ;
3277
- }
3278
-
3279
3266
ITree [ ] ITreeEnsemble . GetTrees ( )
3280
3267
{
3281
3268
return TrainedEnsemble . Trees . Select ( k => new Tree ( k ) ) . ToArray ( ) ;
@@ -3318,9 +3305,9 @@ Row ICanGetSummaryAsIRow.GetStatsIRowOrNull(RoleMappedSchema schema)
3318
3305
3319
3306
private sealed class Tree : ITree < VBuffer < Float > >
3320
3307
{
3321
- private readonly RegressionTree _regTree ;
3308
+ private readonly InternalRegressionTree _regTree ;
3322
3309
3323
- public Tree ( RegressionTree regTree )
3310
+ public Tree ( InternalRegressionTree regTree )
3324
3311
{
3325
3312
_regTree = regTree ;
3326
3313
}
@@ -3390,4 +3377,82 @@ public TreeNode(Dictionary<string, object> keyValues)
3390
3377
public Dictionary < string , object > KeyValues { get ; }
3391
3378
}
3392
3379
}
3380
+
3381
+ /// <summary>
3382
+ /// <see cref="TreeEnsembleModelParametersBasedOnRegressionTree"/> is derived from
3383
+ /// <see cref="TreeEnsembleModelParameters"/> plus a strongly-typed public attribute,
3384
+ /// <see cref="TrainedTreeEnsemble"/>, for exposing trained model's details to users.
3385
+ /// Its function, <see cref="CreateTreeEnsembleFromInternalDataStructure"/>, is
3386
+ /// called to create <see cref="TrainedTreeEnsemble"/> inside <see cref="TreeEnsembleModelParameters"/>.
3387
+ /// Note that the major difference between <see cref="TreeEnsembleModelParametersBasedOnQuantileRegressionTree"/>
3388
+ /// and <see cref="TreeEnsembleModelParametersBasedOnRegressionTree"/> is the type of
3389
+ /// <see cref="TrainedTreeEnsemble"/>.
3390
+ /// </summary>
3391
+ public abstract class TreeEnsembleModelParametersBasedOnRegressionTree : TreeEnsembleModelParameters
3392
+ {
3393
+ /// <summary>
3394
+ /// An ensemble of trees exposed to users. It is a wrapper on the <see langword="internal"/>
3395
+ /// <see cref="InternalTreeEnsemble"/> in <see cref="ML.FastTree.TreeEnsemble{T}"/>.
3396
+ /// </summary>
3397
+ public RegressionTreeEnsemble TrainedTreeEnsemble { get ; }
3398
+
3399
+ [ BestFriend ]
3400
+ internal TreeEnsembleModelParametersBasedOnRegressionTree ( IHostEnvironment env , string name , InternalTreeEnsemble trainedEnsemble , int numFeatures , string innerArgs )
3401
+ : base ( env , name , trainedEnsemble , numFeatures , innerArgs )
3402
+ {
3403
+ TrainedTreeEnsemble = CreateTreeEnsembleFromInternalDataStructure ( ) ;
3404
+ }
3405
+
3406
+ protected TreeEnsembleModelParametersBasedOnRegressionTree ( IHostEnvironment env , string name , ModelLoadContext ctx , VersionInfo ver )
3407
+ : base ( env , name , ctx , ver )
3408
+ {
3409
+ TrainedTreeEnsemble = CreateTreeEnsembleFromInternalDataStructure ( ) ;
3410
+ }
3411
+
3412
+ private RegressionTreeEnsemble CreateTreeEnsembleFromInternalDataStructure ( )
3413
+ {
3414
+ var trees = TrainedEnsemble . Trees . Select ( tree => new RegressionTree ( tree ) ) ;
3415
+ var treeWeights = TrainedEnsemble . Trees . Select ( tree => tree . Weight ) ;
3416
+ return new RegressionTreeEnsemble ( trees , treeWeights , TrainedEnsemble . Bias ) ;
3417
+ }
3418
+ }
3419
+
3420
+ /// <summary>
3421
+ /// <see cref="TreeEnsembleModelParametersBasedOnQuantileRegressionTree"/> is derived from
3422
+ /// <see cref="TreeEnsembleModelParameters"/> plus a strongly-typed public attribute,
3423
+ /// <see cref="TrainedTreeEnsemble"/>, for exposing trained model's details to users.
3424
+ /// Its function, <see cref="CreateTreeEnsembleFromInternalDataStructure"/>, is
3425
+ /// called to create <see cref="TrainedTreeEnsemble"/> inside <see cref="TreeEnsembleModelParameters"/>.
3426
+ /// Note that the major difference between <see cref="TreeEnsembleModelParametersBasedOnQuantileRegressionTree"/>
3427
+ /// and <see cref="TreeEnsembleModelParametersBasedOnRegressionTree"/> is the type of
3428
+ /// <see cref="TrainedTreeEnsemble"/>.
3429
+ /// </summary>
3430
+ public abstract class TreeEnsembleModelParametersBasedOnQuantileRegressionTree : TreeEnsembleModelParameters
3431
+ {
3432
+ /// <summary>
3433
+ /// An ensemble of trees exposed to users. It is a wrapper on the <see langword="internal"/>
3434
+ /// <see cref="InternalTreeEnsemble"/> in <see cref="ML.FastTree.TreeEnsemble{T}"/>.
3435
+ /// </summary>
3436
+ public QuantileRegressionTreeEnsemble TrainedTreeEnsemble { get ; }
3437
+
3438
+ [ BestFriend ]
3439
+ internal TreeEnsembleModelParametersBasedOnQuantileRegressionTree ( IHostEnvironment env , string name , InternalTreeEnsemble trainedEnsemble , int numFeatures , string innerArgs )
3440
+ : base ( env , name , trainedEnsemble , numFeatures , innerArgs )
3441
+ {
3442
+ TrainedTreeEnsemble = CreateTreeEnsembleFromInternalDataStructure ( ) ;
3443
+ }
3444
+
3445
+ protected TreeEnsembleModelParametersBasedOnQuantileRegressionTree ( IHostEnvironment env , string name , ModelLoadContext ctx , VersionInfo ver )
3446
+ : base ( env , name , ctx , ver )
3447
+ {
3448
+ TrainedTreeEnsemble = CreateTreeEnsembleFromInternalDataStructure ( ) ;
3449
+ }
3450
+
3451
+ private QuantileRegressionTreeEnsemble CreateTreeEnsembleFromInternalDataStructure ( )
3452
+ {
3453
+ var trees = TrainedEnsemble . Trees . Select ( tree => new QuantileRegressionTree ( ( InternalQuantileRegressionTree ) tree ) ) ;
3454
+ var treeWeights = TrainedEnsemble . Trees . Select ( tree => tree . Weight ) ;
3455
+ return new QuantileRegressionTreeEnsemble ( trees , treeWeights , TrainedEnsemble . Bias ) ;
3456
+ }
3457
+ }
3393
3458
}
0 commit comments