Skip to content

Commit 345a5c2

Browse files
rogancarrTomFinley
authored andcommitted
Fixes for General Additive Models (#743)
* Adding weights to GAM calculation; refactor redundant Leaf Split calculations to share one function. * Using MinDocsPerLeaf as set in the Arguments class during leaf split. Adding a validation on the input value. * Centering results and unifying outputs; fixing issues with lookups and sparse offsets * Adding in a validation set and validation set pruning to GAM * Fixed GAM Classifier to use a small learning rate: Updated the FastTreeBinaryClassification Loss to take the sigmoid parameter as input; FastTree uses the default, stays the same. Gam uses Unity. Refactored GamRegressor and GamClassifier into their own files. Added tests to verify Train loss and validation metrics. * Remove unused calibration parameters. * Recentering graph based on mean responses. * Unified the documents to thread calculation across implementations. * Updating methods and properties to be as close to private as possible. Adding XML Docs * Updating the FastTreeClassification Loss to incorporate the factor of 2 into the sigmoid parameter; this allows GAMs to output features on the scale of the logit. * Updating the entrypoints for BinaryClassGamPredictor to reflect the use of the calibrator. * Fixing an issue with no validation and no progress. * Fixing null pointer bug, where if a split didn't have positive gain, no graph would be defined. As the graph is still accessed, it must be defined to zero in such cases. * Switched sub-graph calculation to structs; removed TrainingInfo (and therefor the test to validate training loss); removed post-training scoring update; made the statistics calculator interface internal to FastTree. * Fixing arithmetic error in the weighted split finding
1 parent 8ac5ce8 commit 345a5c2

23 files changed

+3708
-3399
lines changed

src/Microsoft.ML.FastTree/Dataset/FeatureFlock.cs

+63-144
Original file line numberDiff line numberDiff line change
@@ -190,19 +190,56 @@ public void CopyFeatureHistogram(int subfeatureIndex, ref PerBinStats[] hist)
190190

191191
}
192192

193-
public void FillSplitCandidates(
194-
Dataset trainData, double sumTargets,
195-
LeastSquaresRegressionTreeLearner.LeafSplitCandidates leafSplitCandidates,
196-
int globalFeatureIndex, double minDocsInLeaf,
197-
double gainConfidenceInSquaredStandardDeviations, double entropyCoefficient)
193+
public void FillSplitCandidates(LeastSquaresRegressionTreeLearner learner, LeastSquaresRegressionTreeLearner.LeafSplitCandidates leafSplitCandidates,
194+
int flock, int[] featureUseCount, double featureFirstUsePenalty, double featureReusePenalty, double minDocsInLeaf,
195+
bool hasWeights, double gainConfidenceInSquaredStandardDeviations, double entropyCoefficient)
198196
{
199-
int flockIndex;
200-
int subfeatureIndex;
201-
trainData.MapFeatureToFlockAndSubFeature(globalFeatureIndex, out flockIndex, out subfeatureIndex);
197+
int featureMin = learner.TrainData.FlockToFirstFeature(flock);
198+
int featureLim = featureMin + learner.TrainData.Flocks[flock].Count;
199+
foreach (var feature in learner.GetActiveFeatures(featureMin, featureLim))
200+
{
201+
int subfeature = feature - featureMin;
202+
Contracts.Assert(0 <= subfeature && subfeature < Flock.Count);
203+
Contracts.Assert(subfeature <= feature);
204+
Contracts.Assert(learner.TrainData.FlockToFirstFeature(flock) == feature - subfeature);
202205

203-
double trust = trainData.Flocks[flockIndex].Trust(subfeatureIndex);
204-
double minDocsForThis = minDocsInLeaf / trust;
206+
if (!IsSplittable[subfeature])
207+
continue;
208+
209+
Contracts.Assert(featureUseCount[feature] >= 0);
210+
211+
double trust = learner.TrainData.Flocks[flock].Trust(subfeature);
212+
double usePenalty = (featureUseCount[feature] == 0) ?
213+
featureFirstUsePenalty : featureReusePenalty * Math.Log(featureUseCount[feature] + 1);
214+
int totalCount = leafSplitCandidates.NumDocsInLeaf;
215+
double sumTargets = leafSplitCandidates.SumTargets;
216+
double sumWeights = leafSplitCandidates.SumWeights;
205217

218+
FindBestSplitForFeature(learner, leafSplitCandidates, totalCount, sumTargets, sumWeights,
219+
feature, flock, subfeature, minDocsInLeaf,
220+
hasWeights, gainConfidenceInSquaredStandardDeviations, entropyCoefficient,
221+
trust, usePenalty);
222+
223+
if (leafSplitCandidates.FlockToBestFeature != null)
224+
{
225+
if (leafSplitCandidates.FlockToBestFeature[flock] == -1 ||
226+
leafSplitCandidates.FeatureSplitInfo[leafSplitCandidates.FlockToBestFeature[flock]].Gain <
227+
leafSplitCandidates.FeatureSplitInfo[feature].Gain)
228+
{
229+
leafSplitCandidates.FlockToBestFeature[flock] = feature;
230+
}
231+
}
232+
}
233+
}
234+
235+
internal void FindBestSplitForFeature(ILeafSplitStatisticsCalculator leafCalculator,
236+
LeastSquaresRegressionTreeLearner.LeafSplitCandidates leafSplitCandidates,
237+
int totalCount, double sumTargets, double sumWeights,
238+
int featureIndex, int flockIndex, int subfeatureIndex, double minDocsInLeaf,
239+
bool hasWeights, double gainConfidenceInSquaredStandardDeviations, double entropyCoefficient,
240+
double trust, double usePenalty)
241+
{
242+
double minDocsForThis = minDocsInLeaf / trust;
206243
double bestSumGTTargets = double.NaN;
207244
double bestSumGTWeights = double.NaN;
208245
double bestShiftedGain = double.NegativeInfinity;
@@ -211,8 +248,8 @@ public void FillSplitCandidates(
211248
double sumGTTargets = 0.0;
212249
double sumGTWeights = eps;
213250
int gtCount = 0;
214-
int totalCount = leafSplitCandidates.Targets.Length;
215-
double gainShift = (sumTargets * sumTargets) / totalCount;
251+
sumWeights += 2 * eps;
252+
double gainShift = leafCalculator.GetLeafSplitGain(totalCount, sumTargets, sumWeights);
216253

217254
// We get to this more explicit handling of the zero case since, under the influence of
218255
// numerical error, especially under single precision, the histogram computed values can
@@ -234,6 +271,8 @@ public void FillSplitCandidates(
234271
var binStats = GetBinStats(b);
235272
t--;
236273
sumGTTargets += binStats.SumTargets;
274+
if (hasWeights)
275+
sumGTWeights += binStats.SumWeights;
237276
gtCount += binStats.Count;
238277

239278
// Advance until GTCount is high enough.
@@ -246,8 +285,8 @@ public void FillSplitCandidates(
246285
break;
247286

248287
// Calculate the shifted gain, including the LTE child.
249-
double currentShiftedGain = (sumGTTargets * sumGTTargets) / gtCount
250-
+ ((sumTargets - sumGTTargets) * (sumTargets - sumGTTargets)) / lteCount;
288+
double currentShiftedGain = leafCalculator.GetLeafSplitGain(gtCount, sumGTTargets, sumGTWeights)
289+
+ leafCalculator.GetLeafSplitGain(lteCount, sumTargets - sumGTTargets, sumWeights - sumGTWeights);
251290

252291
// Test whether we are meeting the min shifted gain confidence criteria for this split.
253292
if (currentShiftedGain < minShiftedGain)
@@ -274,137 +313,17 @@ public void FillSplitCandidates(
274313
}
275314
}
276315
// set the appropriate place in the output vectors
277-
leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex].Feature = flockIndex;
278-
leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex].Threshold = bestThreshold;
279-
leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex].LteOutput = (sumTargets - bestSumGTTargets) / (totalCount - bestGTCount);
280-
leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex].GTOutput = (bestSumGTTargets - bestSumGTWeights) / bestGTCount;
281-
leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex].LteCount = totalCount - bestGTCount;
282-
leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex].GTCount = bestGTCount;
283-
284-
leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex].Gain = (bestShiftedGain - gainShift) * trust;
316+
leafSplitCandidates.FeatureSplitInfo[featureIndex].CategoricalSplit = false;
317+
leafSplitCandidates.FeatureSplitInfo[featureIndex].Feature = featureIndex;
318+
leafSplitCandidates.FeatureSplitInfo[featureIndex].Threshold = bestThreshold;
319+
leafSplitCandidates.FeatureSplitInfo[featureIndex].LteOutput = leafCalculator.CalculateSplittedLeafOutput(totalCount - bestGTCount, sumTargets - bestSumGTTargets, sumWeights - bestSumGTWeights);
320+
leafSplitCandidates.FeatureSplitInfo[featureIndex].GTOutput = leafCalculator.CalculateSplittedLeafOutput(bestGTCount, bestSumGTTargets, bestSumGTWeights);
321+
leafSplitCandidates.FeatureSplitInfo[featureIndex].LteCount = totalCount - bestGTCount;
322+
leafSplitCandidates.FeatureSplitInfo[featureIndex].GTCount = bestGTCount;
323+
324+
leafSplitCandidates.FeatureSplitInfo[featureIndex].Gain = (bestShiftedGain - gainShift) * trust - usePenalty;
285325
double erfcArg = Math.Sqrt((bestShiftedGain - gainShift) * (totalCount - 1) / (2 * leafSplitCandidates.VarianceTargets * totalCount));
286-
leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex].GainPValue = ProbabilityFunctions.Erfc(erfcArg);
287-
}
288-
289-
public void FillSplitCandidates(LeastSquaresRegressionTreeLearner learner, LeastSquaresRegressionTreeLearner.LeafSplitCandidates leafSplitCandidates,
290-
int flock, int[] featureUseCount, double featureFirstUsePenalty, double featureReusePenalty, double minDocsInLeaf,
291-
bool hasWeights, double gainConfidenceInSquaredStandardDeviations, double entropyCoefficient)
292-
{
293-
int featureMin = learner.TrainData.FlockToFirstFeature(flock);
294-
int featureLim = featureMin + learner.TrainData.Flocks[flock].Count;
295-
foreach (var feature in learner.GetActiveFeatures(featureMin, featureLim))
296-
{
297-
int subfeature = feature - featureMin;
298-
Contracts.Assert(0 <= subfeature && subfeature < Flock.Count);
299-
Contracts.Assert(subfeature <= feature);
300-
Contracts.Assert(learner.TrainData.FlockToFirstFeature(flock) == feature - subfeature);
301-
302-
if (!IsSplittable[subfeature])
303-
continue;
304-
305-
Contracts.Assert(featureUseCount[feature] >= 0);
306-
307-
double trust = learner.TrainData.Flocks[flock].Trust(subfeature);
308-
double minDocsForThis = minDocsInLeaf / trust;
309-
double usePenalty = (featureUseCount[feature] == 0) ?
310-
featureFirstUsePenalty : featureReusePenalty * Math.Log(featureUseCount[feature] + 1);
311-
312-
double bestSumGTTargets = double.NaN;
313-
double bestSumGTWeights = double.NaN;
314-
double bestShiftedGain = double.NegativeInfinity;
315-
const double eps = 1e-10;
316-
int bestGTCount = -1;
317-
double sumGTTargets = 0.0;
318-
double sumGTWeights = eps;
319-
int gtCount = 0;
320-
int totalCount = leafSplitCandidates.NumDocsInLeaf;
321-
double sumTargets = leafSplitCandidates.SumTargets;
322-
double sumWeights = leafSplitCandidates.SumWeights + 2 * eps;
323-
double gainShift = learner.GetLeafSplitGain(totalCount, sumTargets, sumWeights);
324-
325-
// We get to this more explicit handling of the zero case since, under the influence of
326-
// numerical error, especially under single precision, the histogram computed values can
327-
// be wildly inaccurate even to the point where 0 unshifted gain may become a strong
328-
// criteria.
329-
double minShiftedGain = gainConfidenceInSquaredStandardDeviations <= 0 ? 0.0 :
330-
(gainConfidenceInSquaredStandardDeviations * leafSplitCandidates.VarianceTargets
331-
* totalCount / (totalCount - 1) + gainShift);
332-
333-
// re-evaluate if the histogram is splittable
334-
IsSplittable[subfeature] = false;
335-
int t = Flock.BinCount(subfeature);
336-
uint bestThreshold = (uint)t;
337-
t--;
338-
int min = GetMinBorder(subfeature);
339-
int max = GetMaxBorder(subfeature);
340-
for (int b = max; b >= min; --b)
341-
{
342-
var binStats = GetBinStats(b);
343-
t--;
344-
sumGTTargets += binStats.SumTargets;
345-
if (hasWeights)
346-
sumGTWeights += binStats.SumWeights;
347-
gtCount += binStats.Count;
348-
349-
// Advance until GTCount is high enough.
350-
if (gtCount < minDocsForThis)
351-
continue;
352-
int lteCount = totalCount - gtCount;
353-
354-
// If LTECount is too small, we are finished.
355-
if (lteCount < minDocsForThis)
356-
break;
357-
358-
// Calculate the shifted gain, including the LTE child.
359-
double currentShiftedGain = learner.GetLeafSplitGain(gtCount, sumGTTargets, sumGTWeights)
360-
+ learner.GetLeafSplitGain(lteCount, sumTargets - sumGTTargets, sumWeights - sumGTWeights);
361-
362-
// Test whether we are meeting the min shifted gain confidence criteria for this split.
363-
if (currentShiftedGain < minShiftedGain)
364-
continue;
365-
366-
// If this point in the code is reached, the histogram is splittable.
367-
IsSplittable[subfeature] = true;
368-
369-
if (entropyCoefficient > 0)
370-
{
371-
// Consider the entropy of the split.
372-
double entropyGain = (totalCount * Math.Log(totalCount) - lteCount * Math.Log(lteCount) - gtCount * Math.Log(gtCount));
373-
currentShiftedGain += entropyCoefficient * entropyGain;
374-
}
375-
376-
// Is t the best threshold so far?
377-
if (currentShiftedGain > bestShiftedGain)
378-
{
379-
bestGTCount = gtCount;
380-
bestSumGTTargets = sumGTTargets;
381-
bestSumGTWeights = sumGTWeights;
382-
bestThreshold = (uint)t;
383-
bestShiftedGain = currentShiftedGain;
384-
}
385-
}
386-
// set the appropriate place in the output vectors
387-
leafSplitCandidates.FeatureSplitInfo[feature].CategoricalSplit = false;
388-
leafSplitCandidates.FeatureSplitInfo[feature].Feature = feature;
389-
leafSplitCandidates.FeatureSplitInfo[feature].Threshold = bestThreshold;
390-
leafSplitCandidates.FeatureSplitInfo[feature].LteOutput = learner.CalculateSplittedLeafOutput(totalCount - bestGTCount, sumTargets - bestSumGTTargets, sumWeights - bestSumGTWeights);
391-
leafSplitCandidates.FeatureSplitInfo[feature].GTOutput = learner.CalculateSplittedLeafOutput(bestGTCount, bestSumGTTargets, bestSumGTWeights);
392-
leafSplitCandidates.FeatureSplitInfo[feature].LteCount = totalCount - bestGTCount;
393-
leafSplitCandidates.FeatureSplitInfo[feature].GTCount = bestGTCount;
394-
395-
leafSplitCandidates.FeatureSplitInfo[feature].Gain = (bestShiftedGain - gainShift) * trust - usePenalty;
396-
double erfcArg = Math.Sqrt((bestShiftedGain - gainShift) * (totalCount - 1) / (2 * leafSplitCandidates.VarianceTargets * totalCount));
397-
leafSplitCandidates.FeatureSplitInfo[feature].GainPValue = ProbabilityFunctions.Erfc(erfcArg);
398-
if (leafSplitCandidates.FlockToBestFeature != null)
399-
{
400-
if (leafSplitCandidates.FlockToBestFeature[flock] == -1 ||
401-
leafSplitCandidates.FeatureSplitInfo[leafSplitCandidates.FlockToBestFeature[flock]].Gain <
402-
leafSplitCandidates.FeatureSplitInfo[feature].Gain)
403-
{
404-
leafSplitCandidates.FlockToBestFeature[flock] = feature;
405-
}
406-
}
407-
}
326+
leafSplitCandidates.FeatureSplitInfo[featureIndex].GainPValue = ProbabilityFunctions.Erfc(erfcArg);
408327
}
409328

410329
public void FillSplitCandidatesCategorical(LeastSquaresRegressionTreeLearner learner,

0 commit comments

Comments
 (0)