5
5
using Microsoft . ML . Data ;
6
6
using Microsoft . ML . Internal . Internallearn ;
7
7
using Microsoft . ML . RunTests ;
8
+ using Microsoft . ML . SamplesUtils ;
8
9
using Microsoft . ML . Trainers ;
9
10
using Microsoft . ML . Trainers . FastTree ;
10
11
using Xunit ;
@@ -56,15 +57,15 @@ public void FastTreeClassificationIntrospectiveTraining()
56
57
57
58
BinaryPredictionTransformer < IPredictorWithFeatureWeights < float > > pred = null ;
58
59
59
- var pipeline = ml . Transforms . Text . FeaturizeText ( "SentimentText " , "Features " )
60
+ var pipeline = ml . Transforms . Text . FeaturizeText ( "Features " , "SentimentText " )
60
61
. AppendCacheCheckpoint ( ml )
61
62
. Append ( trainer . WithOnFitDelegate ( p => pred = p ) ) ;
62
63
63
64
// Train.
64
65
var model = pipeline . Fit ( data ) ;
65
66
66
67
// Extract the learned GBDT model.
67
- var treeCollection = ( ( FastTreeBinaryModelParameters ) ( ( Internal . Calibration . FeatureWeightsCalibratedPredictor ) pred . Model ) . SubPredictor ) . TrainedTreeEnsemble ;
68
+ var treeCollection = ( ( FastForestBinaryModelParameters ) ( ( Internal . Calibration . FeatureWeightsCalibratedPredictor ) pred . Model ) . SubPredictor ) . TrainedTreeEnsemble ;
68
69
69
70
// Inspect properties in the extracted model.
70
71
Assert . Equal ( 3 , treeCollection . Trees . Count ) ;
@@ -80,11 +81,63 @@ public void FastTreeClassificationIntrospectiveTraining()
80
81
Assert . Equal ( tree . LteChild , new int [ ] { 2 , - 2 , - 1 , - 3 } ) ;
81
82
Assert . Equal ( tree . GtChild , new int [ ] { 1 , 3 , - 4 , - 5 } ) ;
82
83
Assert . Equal ( tree . NumericalSplitFeatureIndexes , new int [ ] { 14 , 294 , 633 , 266 } ) ;
83
- Assert . Equal ( tree . NumericalSplitThresholds , new float [ ] { 0.0911167f , 0.06509889f , 0.019873254f , 0.0361835f } ) ;
84
+ var expectedThresholds = new float [ ] { 0.0911167f , 0.06509889f , 0.019873254f , 0.0361835f } ;
85
+ for ( int i = 0 ; i < tree . NumNodes ; ++ i )
86
+ Assert . Equal ( expectedThresholds [ i ] , tree . NumericalSplitThresholds [ i ] , 6 ) ;
84
87
Assert . All ( tree . CategoricalSplitFlags , flag => Assert . False ( flag ) ) ;
85
88
86
89
Assert . Equal ( 0 , tree . GetCategoricalSplitFeaturesAt ( 0 ) . Count ) ;
87
90
Assert . Equal ( 0 , tree . GetCategoricalCategoricalSplitFeatureRangeAt ( 0 ) . Count ) ;
88
91
}
92
+
93
+ [ Fact ]
94
+ public void FastForestRegressionIntrospectiveTraining ( )
95
+ {
96
+ var ml = new MLContext ( seed : 1 , conc : 1 ) ;
97
+ var data = DatasetUtils . GenerateFloatLabelFloatFeatureVectorSamples ( 1000 ) ;
98
+ var dataView = ml . Data . ReadFromEnumerable ( data ) ;
99
+
100
+ RegressionPredictionTransformer < FastForestRegressionModelParameters > pred = null ;
101
+ var trainer = ml . Regression . Trainers . FastForest ( numLeaves : 5 , numTrees : 3 ) . WithOnFitDelegate ( p => pred = p ) ;
102
+
103
+ // Train.
104
+ var model = trainer . Fit ( dataView ) ;
105
+
106
+ // Extract the learned RF model.
107
+ var treeCollection = pred . Model . TrainedTreeEnsemble ;
108
+
109
+ // Inspect properties in the extracted model.
110
+ Assert . Equal ( 3 , treeCollection . Trees . Count ) ;
111
+ Assert . Equal ( 3 , treeCollection . TreeWeights . Count ) ;
112
+ Assert . Equal ( 0 , treeCollection . Bias ) ;
113
+ Assert . All ( treeCollection . TreeWeights , weight => Assert . Equal ( 1.0 , weight ) ) ;
114
+
115
+ // Inspect the last tree.
116
+ var tree = treeCollection . Trees [ 2 ] ;
117
+
118
+ Assert . Equal ( 5 , tree . NumLeaves ) ;
119
+ Assert . Equal ( 4 , tree . NumNodes ) ;
120
+ Assert . Equal ( tree . LteChild , new int [ ] { - 1 , - 2 , - 3 , - 4 } ) ;
121
+ Assert . Equal ( tree . GtChild , new int [ ] { 1 , 2 , 3 , - 5 } ) ;
122
+ Assert . Equal ( tree . NumericalSplitFeatureIndexes , new int [ ] { 9 , 0 , 1 , 8 } ) ;
123
+ var expectedThresholds = new float [ ] { 0.208134219f , 0.198336035f , 0.202952743f , 0.205061346f } ;
124
+ for ( int i = 0 ; i < tree . NumNodes ; ++ i )
125
+ Assert . Equal ( expectedThresholds [ i ] , tree . NumericalSplitThresholds [ i ] , 6 ) ;
126
+ Assert . All ( tree . CategoricalSplitFlags , flag => Assert . False ( flag ) ) ;
127
+
128
+ Assert . Equal ( 0 , tree . GetCategoricalSplitFeaturesAt ( 0 ) . Count ) ;
129
+ Assert . Equal ( 0 , tree . GetCategoricalCategoricalSplitFeatureRangeAt ( 0 ) . Count ) ;
130
+
131
+ var samples = new double [ ] { 0.97468354430379744 , 1.0 , 0.97727272727272729 , 0.972972972972973 , 0.26124197002141325 } ;
132
+ for ( int i = 0 ; i < tree . NumLeaves ; ++ i )
133
+ {
134
+ var sample = tree . GetLeafSamplesAt ( i ) ;
135
+ Assert . Single ( sample ) ;
136
+ Assert . Equal ( samples [ i ] , sample [ 0 ] , 6 ) ;
137
+ var weight = tree . GetLeafSampleWeightsAt ( i ) ;
138
+ Assert . Single ( weight ) ;
139
+ Assert . Equal ( 1 , weight [ 0 ] ) ;
140
+ }
141
+ }
89
142
}
90
143
}
0 commit comments