@@ -190,19 +190,56 @@ public void CopyFeatureHistogram(int subfeatureIndex, ref PerBinStats[] hist)
190
190
191
191
}
192
192
193
- public void FillSplitCandidates (
194
- Dataset trainData , double sumTargets ,
195
- LeastSquaresRegressionTreeLearner . LeafSplitCandidates leafSplitCandidates ,
196
- int globalFeatureIndex , double minDocsInLeaf ,
197
- double gainConfidenceInSquaredStandardDeviations , double entropyCoefficient )
193
+ public void FillSplitCandidates ( LeastSquaresRegressionTreeLearner learner , LeastSquaresRegressionTreeLearner . LeafSplitCandidates leafSplitCandidates ,
194
+ int flock , int [ ] featureUseCount , double featureFirstUsePenalty , double featureReusePenalty , double minDocsInLeaf ,
195
+ bool hasWeights , double gainConfidenceInSquaredStandardDeviations , double entropyCoefficient )
198
196
{
199
- int flockIndex ;
200
- int subfeatureIndex ;
201
- trainData . MapFeatureToFlockAndSubFeature ( globalFeatureIndex , out flockIndex , out subfeatureIndex ) ;
197
+ int featureMin = learner . TrainData . FlockToFirstFeature ( flock ) ;
198
+ int featureLim = featureMin + learner . TrainData . Flocks [ flock ] . Count ;
199
+ foreach ( var feature in learner . GetActiveFeatures ( featureMin , featureLim ) )
200
+ {
201
+ int subfeature = feature - featureMin ;
202
+ Contracts . Assert ( 0 <= subfeature && subfeature < Flock . Count ) ;
203
+ Contracts . Assert ( subfeature <= feature ) ;
204
+ Contracts . Assert ( learner . TrainData . FlockToFirstFeature ( flock ) == feature - subfeature ) ;
202
205
203
- double trust = trainData . Flocks [ flockIndex ] . Trust ( subfeatureIndex ) ;
204
- double minDocsForThis = minDocsInLeaf / trust ;
206
+ if ( ! IsSplittable [ subfeature ] )
207
+ continue ;
208
+
209
+ Contracts . Assert ( featureUseCount [ feature ] >= 0 ) ;
210
+
211
+ double trust = learner . TrainData . Flocks [ flock ] . Trust ( subfeature ) ;
212
+ double usePenalty = ( featureUseCount [ feature ] == 0 ) ?
213
+ featureFirstUsePenalty : featureReusePenalty * Math . Log ( featureUseCount [ feature ] + 1 ) ;
214
+ int totalCount = leafSplitCandidates . NumDocsInLeaf ;
215
+ double sumTargets = leafSplitCandidates . SumTargets ;
216
+ double sumWeights = leafSplitCandidates . SumWeights ;
205
217
218
+ FindBestSplitForFeature ( learner , leafSplitCandidates , totalCount , sumTargets , sumWeights ,
219
+ feature , flock , subfeature , minDocsInLeaf ,
220
+ hasWeights , gainConfidenceInSquaredStandardDeviations , entropyCoefficient ,
221
+ trust , usePenalty ) ;
222
+
223
+ if ( leafSplitCandidates . FlockToBestFeature != null )
224
+ {
225
+ if ( leafSplitCandidates . FlockToBestFeature [ flock ] == - 1 ||
226
+ leafSplitCandidates . FeatureSplitInfo [ leafSplitCandidates . FlockToBestFeature [ flock ] ] . Gain <
227
+ leafSplitCandidates . FeatureSplitInfo [ feature ] . Gain )
228
+ {
229
+ leafSplitCandidates . FlockToBestFeature [ flock ] = feature ;
230
+ }
231
+ }
232
+ }
233
+ }
234
+
235
+ internal void FindBestSplitForFeature ( ILeafSplitStatisticsCalculator leafCalculator ,
236
+ LeastSquaresRegressionTreeLearner . LeafSplitCandidates leafSplitCandidates ,
237
+ int totalCount , double sumTargets , double sumWeights ,
238
+ int featureIndex , int flockIndex , int subfeatureIndex , double minDocsInLeaf ,
239
+ bool hasWeights , double gainConfidenceInSquaredStandardDeviations , double entropyCoefficient ,
240
+ double trust , double usePenalty )
241
+ {
242
+ double minDocsForThis = minDocsInLeaf / trust ;
206
243
double bestSumGTTargets = double . NaN ;
207
244
double bestSumGTWeights = double . NaN ;
208
245
double bestShiftedGain = double . NegativeInfinity ;
@@ -211,8 +248,8 @@ public void FillSplitCandidates(
211
248
double sumGTTargets = 0.0 ;
212
249
double sumGTWeights = eps ;
213
250
int gtCount = 0 ;
214
- int totalCount = leafSplitCandidates . Targets . Length ;
215
- double gainShift = ( sumTargets * sumTargets ) / totalCount ;
251
+ sumWeights += 2 * eps ;
252
+ double gainShift = leafCalculator . GetLeafSplitGain ( totalCount , sumTargets , sumWeights ) ;
216
253
217
254
// We get to this more explicit handling of the zero case since, under the influence of
218
255
// numerical error, especially under single precision, the histogram computed values can
@@ -234,6 +271,8 @@ public void FillSplitCandidates(
234
271
var binStats = GetBinStats ( b ) ;
235
272
t -- ;
236
273
sumGTTargets += binStats . SumTargets ;
274
+ if ( hasWeights )
275
+ sumGTWeights += binStats . SumWeights ;
237
276
gtCount += binStats . Count ;
238
277
239
278
// Advance until GTCount is high enough.
@@ -246,8 +285,8 @@ public void FillSplitCandidates(
246
285
break ;
247
286
248
287
// Calculate the shifted gain, including the LTE child.
249
- double currentShiftedGain = ( sumGTTargets * sumGTTargets ) / gtCount
250
- + ( ( sumTargets - sumGTTargets ) * ( sumTargets - sumGTTargets ) ) / lteCount ;
288
+ double currentShiftedGain = leafCalculator . GetLeafSplitGain ( gtCount , sumGTTargets , sumGTWeights )
289
+ + leafCalculator . GetLeafSplitGain ( lteCount , sumTargets - sumGTTargets , sumWeights - sumGTWeights ) ;
251
290
252
291
// Test whether we are meeting the min shifted gain confidence criteria for this split.
253
292
if ( currentShiftedGain < minShiftedGain )
@@ -274,137 +313,17 @@ public void FillSplitCandidates(
274
313
}
275
314
}
276
315
// set the appropriate place in the output vectors
277
- leafSplitCandidates . FeatureSplitInfo [ globalFeatureIndex ] . Feature = flockIndex ;
278
- leafSplitCandidates . FeatureSplitInfo [ globalFeatureIndex ] . Threshold = bestThreshold ;
279
- leafSplitCandidates . FeatureSplitInfo [ globalFeatureIndex ] . LteOutput = ( sumTargets - bestSumGTTargets ) / ( totalCount - bestGTCount ) ;
280
- leafSplitCandidates . FeatureSplitInfo [ globalFeatureIndex ] . GTOutput = ( bestSumGTTargets - bestSumGTWeights ) / bestGTCount ;
281
- leafSplitCandidates . FeatureSplitInfo [ globalFeatureIndex ] . LteCount = totalCount - bestGTCount ;
282
- leafSplitCandidates . FeatureSplitInfo [ globalFeatureIndex ] . GTCount = bestGTCount ;
283
-
284
- leafSplitCandidates . FeatureSplitInfo [ globalFeatureIndex ] . Gain = ( bestShiftedGain - gainShift ) * trust ;
316
+ leafSplitCandidates . FeatureSplitInfo [ featureIndex ] . CategoricalSplit = false ;
317
+ leafSplitCandidates . FeatureSplitInfo [ featureIndex ] . Feature = featureIndex ;
318
+ leafSplitCandidates . FeatureSplitInfo [ featureIndex ] . Threshold = bestThreshold ;
319
+ leafSplitCandidates . FeatureSplitInfo [ featureIndex ] . LteOutput = leafCalculator . CalculateSplittedLeafOutput ( totalCount - bestGTCount , sumTargets - bestSumGTTargets , sumWeights - bestSumGTWeights ) ;
320
+ leafSplitCandidates . FeatureSplitInfo [ featureIndex ] . GTOutput = leafCalculator . CalculateSplittedLeafOutput ( bestGTCount , bestSumGTTargets , bestSumGTWeights ) ;
321
+ leafSplitCandidates . FeatureSplitInfo [ featureIndex ] . LteCount = totalCount - bestGTCount ;
322
+ leafSplitCandidates . FeatureSplitInfo [ featureIndex ] . GTCount = bestGTCount ;
323
+
324
+ leafSplitCandidates . FeatureSplitInfo [ featureIndex ] . Gain = ( bestShiftedGain - gainShift ) * trust - usePenalty ;
285
325
double erfcArg = Math . Sqrt ( ( bestShiftedGain - gainShift ) * ( totalCount - 1 ) / ( 2 * leafSplitCandidates . VarianceTargets * totalCount ) ) ;
286
- leafSplitCandidates . FeatureSplitInfo [ globalFeatureIndex ] . GainPValue = ProbabilityFunctions . Erfc ( erfcArg ) ;
287
- }
288
-
289
- public void FillSplitCandidates ( LeastSquaresRegressionTreeLearner learner , LeastSquaresRegressionTreeLearner . LeafSplitCandidates leafSplitCandidates ,
290
- int flock , int [ ] featureUseCount , double featureFirstUsePenalty , double featureReusePenalty , double minDocsInLeaf ,
291
- bool hasWeights , double gainConfidenceInSquaredStandardDeviations , double entropyCoefficient )
292
- {
293
- int featureMin = learner . TrainData . FlockToFirstFeature ( flock ) ;
294
- int featureLim = featureMin + learner . TrainData . Flocks [ flock ] . Count ;
295
- foreach ( var feature in learner . GetActiveFeatures ( featureMin , featureLim ) )
296
- {
297
- int subfeature = feature - featureMin ;
298
- Contracts . Assert ( 0 <= subfeature && subfeature < Flock . Count ) ;
299
- Contracts . Assert ( subfeature <= feature ) ;
300
- Contracts . Assert ( learner . TrainData . FlockToFirstFeature ( flock ) == feature - subfeature ) ;
301
-
302
- if ( ! IsSplittable [ subfeature ] )
303
- continue ;
304
-
305
- Contracts . Assert ( featureUseCount [ feature ] >= 0 ) ;
306
-
307
- double trust = learner . TrainData . Flocks [ flock ] . Trust ( subfeature ) ;
308
- double minDocsForThis = minDocsInLeaf / trust ;
309
- double usePenalty = ( featureUseCount [ feature ] == 0 ) ?
310
- featureFirstUsePenalty : featureReusePenalty * Math . Log ( featureUseCount [ feature ] + 1 ) ;
311
-
312
- double bestSumGTTargets = double . NaN ;
313
- double bestSumGTWeights = double . NaN ;
314
- double bestShiftedGain = double . NegativeInfinity ;
315
- const double eps = 1e-10 ;
316
- int bestGTCount = - 1 ;
317
- double sumGTTargets = 0.0 ;
318
- double sumGTWeights = eps ;
319
- int gtCount = 0 ;
320
- int totalCount = leafSplitCandidates . NumDocsInLeaf ;
321
- double sumTargets = leafSplitCandidates . SumTargets ;
322
- double sumWeights = leafSplitCandidates . SumWeights + 2 * eps ;
323
- double gainShift = learner . GetLeafSplitGain ( totalCount , sumTargets , sumWeights ) ;
324
-
325
- // We get to this more explicit handling of the zero case since, under the influence of
326
- // numerical error, especially under single precision, the histogram computed values can
327
- // be wildly inaccurate even to the point where 0 unshifted gain may become a strong
328
- // criteria.
329
- double minShiftedGain = gainConfidenceInSquaredStandardDeviations <= 0 ? 0.0 :
330
- ( gainConfidenceInSquaredStandardDeviations * leafSplitCandidates . VarianceTargets
331
- * totalCount / ( totalCount - 1 ) + gainShift ) ;
332
-
333
- // re-evaluate if the histogram is splittable
334
- IsSplittable [ subfeature ] = false ;
335
- int t = Flock . BinCount ( subfeature ) ;
336
- uint bestThreshold = ( uint ) t ;
337
- t -- ;
338
- int min = GetMinBorder ( subfeature ) ;
339
- int max = GetMaxBorder ( subfeature ) ;
340
- for ( int b = max ; b >= min ; -- b )
341
- {
342
- var binStats = GetBinStats ( b ) ;
343
- t -- ;
344
- sumGTTargets += binStats . SumTargets ;
345
- if ( hasWeights )
346
- sumGTWeights += binStats . SumWeights ;
347
- gtCount += binStats . Count ;
348
-
349
- // Advance until GTCount is high enough.
350
- if ( gtCount < minDocsForThis )
351
- continue ;
352
- int lteCount = totalCount - gtCount ;
353
-
354
- // If LTECount is too small, we are finished.
355
- if ( lteCount < minDocsForThis )
356
- break ;
357
-
358
- // Calculate the shifted gain, including the LTE child.
359
- double currentShiftedGain = learner . GetLeafSplitGain ( gtCount , sumGTTargets , sumGTWeights )
360
- + learner . GetLeafSplitGain ( lteCount , sumTargets - sumGTTargets , sumWeights - sumGTWeights ) ;
361
-
362
- // Test whether we are meeting the min shifted gain confidence criteria for this split.
363
- if ( currentShiftedGain < minShiftedGain )
364
- continue ;
365
-
366
- // If this point in the code is reached, the histogram is splittable.
367
- IsSplittable [ subfeature ] = true ;
368
-
369
- if ( entropyCoefficient > 0 )
370
- {
371
- // Consider the entropy of the split.
372
- double entropyGain = ( totalCount * Math . Log ( totalCount ) - lteCount * Math . Log ( lteCount ) - gtCount * Math . Log ( gtCount ) ) ;
373
- currentShiftedGain += entropyCoefficient * entropyGain ;
374
- }
375
-
376
- // Is t the best threshold so far?
377
- if ( currentShiftedGain > bestShiftedGain )
378
- {
379
- bestGTCount = gtCount ;
380
- bestSumGTTargets = sumGTTargets ;
381
- bestSumGTWeights = sumGTWeights ;
382
- bestThreshold = ( uint ) t ;
383
- bestShiftedGain = currentShiftedGain ;
384
- }
385
- }
386
- // set the appropriate place in the output vectors
387
- leafSplitCandidates . FeatureSplitInfo [ feature ] . CategoricalSplit = false ;
388
- leafSplitCandidates . FeatureSplitInfo [ feature ] . Feature = feature ;
389
- leafSplitCandidates . FeatureSplitInfo [ feature ] . Threshold = bestThreshold ;
390
- leafSplitCandidates . FeatureSplitInfo [ feature ] . LteOutput = learner . CalculateSplittedLeafOutput ( totalCount - bestGTCount , sumTargets - bestSumGTTargets , sumWeights - bestSumGTWeights ) ;
391
- leafSplitCandidates . FeatureSplitInfo [ feature ] . GTOutput = learner . CalculateSplittedLeafOutput ( bestGTCount , bestSumGTTargets , bestSumGTWeights ) ;
392
- leafSplitCandidates . FeatureSplitInfo [ feature ] . LteCount = totalCount - bestGTCount ;
393
- leafSplitCandidates . FeatureSplitInfo [ feature ] . GTCount = bestGTCount ;
394
-
395
- leafSplitCandidates . FeatureSplitInfo [ feature ] . Gain = ( bestShiftedGain - gainShift ) * trust - usePenalty ;
396
- double erfcArg = Math . Sqrt ( ( bestShiftedGain - gainShift ) * ( totalCount - 1 ) / ( 2 * leafSplitCandidates . VarianceTargets * totalCount ) ) ;
397
- leafSplitCandidates . FeatureSplitInfo [ feature ] . GainPValue = ProbabilityFunctions . Erfc ( erfcArg ) ;
398
- if ( leafSplitCandidates . FlockToBestFeature != null )
399
- {
400
- if ( leafSplitCandidates . FlockToBestFeature [ flock ] == - 1 ||
401
- leafSplitCandidates . FeatureSplitInfo [ leafSplitCandidates . FlockToBestFeature [ flock ] ] . Gain <
402
- leafSplitCandidates . FeatureSplitInfo [ feature ] . Gain )
403
- {
404
- leafSplitCandidates . FlockToBestFeature [ flock ] = feature ;
405
- }
406
- }
407
- }
326
+ leafSplitCandidates . FeatureSplitInfo [ featureIndex ] . GainPValue = ProbabilityFunctions . Erfc ( erfcArg ) ;
408
327
}
409
328
410
329
public void FillSplitCandidatesCategorical ( LeastSquaresRegressionTreeLearner learner ,
0 commit comments