Skip to content

Fixes for General Additive Models #743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Sep 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
5339988
Adding a test to step through GAM
Jul 24, 2018
8c5907e
Adding weights to GAM calculation; refactor redundant Leaf Split calc…
Jul 24, 2018
ea92deb
Using MinDocsPerLeaf as set in the Arguments class during leaf split.…
Jul 24, 2018
07fc292
wip
Aug 1, 2018
66c6869
WIP: rebase
Aug 2, 2018
7675b64
Centering results and unifying outputs; fixing issues with lookups an…
Aug 3, 2018
2d46bde
wip
Aug 7, 2018
7057dfe
Adding in a validation set and validation set pruning to GAM
Aug 7, 2018
a8fd6e4
Fixed GAM Classifier to use a small learning rate: Updated the FastTr…
Aug 10, 2018
c20e760
Sped up tests, added comments, cleaned up test helpers.
Aug 11, 2018
da6a966
GAM: remove unused calibration parameters.
Aug 14, 2018
4a00b92
Rename GAM test file
Aug 14, 2018
0d6ea2d
Recentering graph based on mean responses.
Aug 15, 2018
ea3ab1c
Unified the documents to thread calculation across implementations.
Aug 15, 2018
6bc05db
Updating methods and properties to be as close to private as possible.
Aug 17, 2018
b70f203
Updating methods and properties to be as close to private as possible…
Aug 17, 2018
759d597
Updating the FastTreeClassification Loss to incorporate the factor of…
Aug 17, 2018
8719606
Updating the entrypoints for BinaryClassGamPredictor to reflect the u…
Aug 24, 2018
dda64f2
Fixing an issue with no validation and no progress.
Aug 24, 2018
d0d4e6b
Fixing null pointer bug, where if a split didn't have positive gain, …
Aug 26, 2018
b356e43
Responding to PR Comments
Aug 29, 2018
82218ee
Responding to PR Comments: Switched sub-graph calculation to structs;…
Aug 31, 2018
cd8a554
Fixing arithmetic error in the weighted split finding
Sep 12, 2018
638d1a7
Updating all baseline tests for GAM Classification
Sep 14, 2018
e3d896f
Addressed PR Comments: Comments to follow style guide; removed unnece…
Sep 14, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 63 additions & 144 deletions src/Microsoft.ML.FastTree/Dataset/FeatureFlock.cs
Original file line number Diff line number Diff line change
Expand Up @@ -190,19 +190,56 @@ public void CopyFeatureHistogram(int subfeatureIndex, ref PerBinStats[] hist)

}

public void FillSplitCandidates(
Dataset trainData, double sumTargets,
LeastSquaresRegressionTreeLearner.LeafSplitCandidates leafSplitCandidates,
int globalFeatureIndex, double minDocsInLeaf,
double gainConfidenceInSquaredStandardDeviations, double entropyCoefficient)
public void FillSplitCandidates(LeastSquaresRegressionTreeLearner learner, LeastSquaresRegressionTreeLearner.LeafSplitCandidates leafSplitCandidates,
int flock, int[] featureUseCount, double featureFirstUsePenalty, double featureReusePenalty, double minDocsInLeaf,
bool hasWeights, double gainConfidenceInSquaredStandardDeviations, double entropyCoefficient)
{
int flockIndex;
int subfeatureIndex;
trainData.MapFeatureToFlockAndSubFeature(globalFeatureIndex, out flockIndex, out subfeatureIndex);
int featureMin = learner.TrainData.FlockToFirstFeature(flock);
int featureLim = featureMin + learner.TrainData.Flocks[flock].Count;
foreach (var feature in learner.GetActiveFeatures(featureMin, featureLim))
{
int subfeature = feature - featureMin;
Contracts.Assert(0 <= subfeature && subfeature < Flock.Count);
Contracts.Assert(subfeature <= feature);
Contracts.Assert(learner.TrainData.FlockToFirstFeature(flock) == feature - subfeature);

double trust = trainData.Flocks[flockIndex].Trust(subfeatureIndex);
double minDocsForThis = minDocsInLeaf / trust;
if (!IsSplittable[subfeature])
continue;

Contracts.Assert(featureUseCount[feature] >= 0);

double trust = learner.TrainData.Flocks[flock].Trust(subfeature);
double usePenalty = (featureUseCount[feature] == 0) ?
featureFirstUsePenalty : featureReusePenalty * Math.Log(featureUseCount[feature] + 1);
int totalCount = leafSplitCandidates.NumDocsInLeaf;
double sumTargets = leafSplitCandidates.SumTargets;
double sumWeights = leafSplitCandidates.SumWeights;

FindBestSplitForFeature(learner, leafSplitCandidates, totalCount, sumTargets, sumWeights,
feature, flock, subfeature, minDocsInLeaf,
hasWeights, gainConfidenceInSquaredStandardDeviations, entropyCoefficient,
trust, usePenalty);

if (leafSplitCandidates.FlockToBestFeature != null)
{
if (leafSplitCandidates.FlockToBestFeature[flock] == -1 ||
leafSplitCandidates.FeatureSplitInfo[leafSplitCandidates.FlockToBestFeature[flock]].Gain <
leafSplitCandidates.FeatureSplitInfo[feature].Gain)
{
leafSplitCandidates.FlockToBestFeature[flock] = feature;
}
}
}
}

internal void FindBestSplitForFeature(ILeafSplitStatisticsCalculator leafCalculator,
LeastSquaresRegressionTreeLearner.LeafSplitCandidates leafSplitCandidates,
int totalCount, double sumTargets, double sumWeights,
int featureIndex, int flockIndex, int subfeatureIndex, double minDocsInLeaf,
bool hasWeights, double gainConfidenceInSquaredStandardDeviations, double entropyCoefficient,
double trust, double usePenalty)
{
double minDocsForThis = minDocsInLeaf / trust;
double bestSumGTTargets = double.NaN;
double bestSumGTWeights = double.NaN;
double bestShiftedGain = double.NegativeInfinity;
Expand All @@ -211,8 +248,8 @@ public void FillSplitCandidates(
double sumGTTargets = 0.0;
double sumGTWeights = eps;
int gtCount = 0;
int totalCount = leafSplitCandidates.Targets.Length;
double gainShift = (sumTargets * sumTargets) / totalCount;
sumWeights += 2 * eps;
double gainShift = leafCalculator.GetLeafSplitGain(totalCount, sumTargets, sumWeights);

// We get to this more explicit handling of the zero case since, under the influence of
// numerical error, especially under single precision, the histogram computed values can
Expand All @@ -234,6 +271,8 @@ public void FillSplitCandidates(
var binStats = GetBinStats(b);
t--;
sumGTTargets += binStats.SumTargets;
if (hasWeights)
sumGTWeights += binStats.SumWeights;
gtCount += binStats.Count;

// Advance until GTCount is high enough.
Expand All @@ -246,8 +285,8 @@ public void FillSplitCandidates(
break;

// Calculate the shifted gain, including the LTE child.
double currentShiftedGain = (sumGTTargets * sumGTTargets) / gtCount
+ ((sumTargets - sumGTTargets) * (sumTargets - sumGTTargets)) / lteCount;
double currentShiftedGain = leafCalculator.GetLeafSplitGain(gtCount, sumGTTargets, sumGTWeights)
+ leafCalculator.GetLeafSplitGain(lteCount, sumTargets - sumGTTargets, sumWeights - sumGTWeights);

// Test whether we are meeting the min shifted gain confidence criteria for this split.
if (currentShiftedGain < minShiftedGain)
Expand All @@ -274,137 +313,17 @@ public void FillSplitCandidates(
}
}
// set the appropriate place in the output vectors
leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex].Feature = flockIndex;
leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex].Threshold = bestThreshold;
leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex].LteOutput = (sumTargets - bestSumGTTargets) / (totalCount - bestGTCount);
leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex].GTOutput = (bestSumGTTargets - bestSumGTWeights) / bestGTCount;
leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex].LteCount = totalCount - bestGTCount;
leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex].GTCount = bestGTCount;

leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex].Gain = (bestShiftedGain - gainShift) * trust;
leafSplitCandidates.FeatureSplitInfo[featureIndex].CategoricalSplit = false;
leafSplitCandidates.FeatureSplitInfo[featureIndex].Feature = featureIndex;
leafSplitCandidates.FeatureSplitInfo[featureIndex].Threshold = bestThreshold;
leafSplitCandidates.FeatureSplitInfo[featureIndex].LteOutput = leafCalculator.CalculateSplittedLeafOutput(totalCount - bestGTCount, sumTargets - bestSumGTTargets, sumWeights - bestSumGTWeights);
leafSplitCandidates.FeatureSplitInfo[featureIndex].GTOutput = leafCalculator.CalculateSplittedLeafOutput(bestGTCount, bestSumGTTargets, bestSumGTWeights);
leafSplitCandidates.FeatureSplitInfo[featureIndex].LteCount = totalCount - bestGTCount;
leafSplitCandidates.FeatureSplitInfo[featureIndex].GTCount = bestGTCount;

leafSplitCandidates.FeatureSplitInfo[featureIndex].Gain = (bestShiftedGain - gainShift) * trust - usePenalty;
double erfcArg = Math.Sqrt((bestShiftedGain - gainShift) * (totalCount - 1) / (2 * leafSplitCandidates.VarianceTargets * totalCount));
leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex].GainPValue = ProbabilityFunctions.Erfc(erfcArg);
}

public void FillSplitCandidates(LeastSquaresRegressionTreeLearner learner, LeastSquaresRegressionTreeLearner.LeafSplitCandidates leafSplitCandidates,
int flock, int[] featureUseCount, double featureFirstUsePenalty, double featureReusePenalty, double minDocsInLeaf,
bool hasWeights, double gainConfidenceInSquaredStandardDeviations, double entropyCoefficient)
{
int featureMin = learner.TrainData.FlockToFirstFeature(flock);
int featureLim = featureMin + learner.TrainData.Flocks[flock].Count;
foreach (var feature in learner.GetActiveFeatures(featureMin, featureLim))
{
int subfeature = feature - featureMin;
Contracts.Assert(0 <= subfeature && subfeature < Flock.Count);
Contracts.Assert(subfeature <= feature);
Contracts.Assert(learner.TrainData.FlockToFirstFeature(flock) == feature - subfeature);

if (!IsSplittable[subfeature])
continue;

Contracts.Assert(featureUseCount[feature] >= 0);

double trust = learner.TrainData.Flocks[flock].Trust(subfeature);
double minDocsForThis = minDocsInLeaf / trust;
double usePenalty = (featureUseCount[feature] == 0) ?
featureFirstUsePenalty : featureReusePenalty * Math.Log(featureUseCount[feature] + 1);

double bestSumGTTargets = double.NaN;
double bestSumGTWeights = double.NaN;
double bestShiftedGain = double.NegativeInfinity;
const double eps = 1e-10;
int bestGTCount = -1;
double sumGTTargets = 0.0;
double sumGTWeights = eps;
int gtCount = 0;
int totalCount = leafSplitCandidates.NumDocsInLeaf;
double sumTargets = leafSplitCandidates.SumTargets;
double sumWeights = leafSplitCandidates.SumWeights + 2 * eps;
double gainShift = learner.GetLeafSplitGain(totalCount, sumTargets, sumWeights);

// We get to this more explicit handling of the zero case since, under the influence of
// numerical error, especially under single precision, the histogram computed values can
// be wildly inaccurate even to the point where 0 unshifted gain may become a strong
// criteria.
double minShiftedGain = gainConfidenceInSquaredStandardDeviations <= 0 ? 0.0 :
(gainConfidenceInSquaredStandardDeviations * leafSplitCandidates.VarianceTargets
* totalCount / (totalCount - 1) + gainShift);

// re-evaluate if the histogram is splittable
IsSplittable[subfeature] = false;
int t = Flock.BinCount(subfeature);
uint bestThreshold = (uint)t;
t--;
int min = GetMinBorder(subfeature);
int max = GetMaxBorder(subfeature);
for (int b = max; b >= min; --b)
{
var binStats = GetBinStats(b);
t--;
sumGTTargets += binStats.SumTargets;
if (hasWeights)
sumGTWeights += binStats.SumWeights;
gtCount += binStats.Count;

// Advance until GTCount is high enough.
if (gtCount < minDocsForThis)
continue;
int lteCount = totalCount - gtCount;

// If LTECount is too small, we are finished.
if (lteCount < minDocsForThis)
break;

// Calculate the shifted gain, including the LTE child.
double currentShiftedGain = learner.GetLeafSplitGain(gtCount, sumGTTargets, sumGTWeights)
+ learner.GetLeafSplitGain(lteCount, sumTargets - sumGTTargets, sumWeights - sumGTWeights);

// Test whether we are meeting the min shifted gain confidence criteria for this split.
if (currentShiftedGain < minShiftedGain)
continue;

// If this point in the code is reached, the histogram is splittable.
IsSplittable[subfeature] = true;

if (entropyCoefficient > 0)
{
// Consider the entropy of the split.
double entropyGain = (totalCount * Math.Log(totalCount) - lteCount * Math.Log(lteCount) - gtCount * Math.Log(gtCount));
currentShiftedGain += entropyCoefficient * entropyGain;
}

// Is t the best threshold so far?
if (currentShiftedGain > bestShiftedGain)
{
bestGTCount = gtCount;
bestSumGTTargets = sumGTTargets;
bestSumGTWeights = sumGTWeights;
bestThreshold = (uint)t;
bestShiftedGain = currentShiftedGain;
}
}
// set the appropriate place in the output vectors
leafSplitCandidates.FeatureSplitInfo[feature].CategoricalSplit = false;
leafSplitCandidates.FeatureSplitInfo[feature].Feature = feature;
leafSplitCandidates.FeatureSplitInfo[feature].Threshold = bestThreshold;
leafSplitCandidates.FeatureSplitInfo[feature].LteOutput = learner.CalculateSplittedLeafOutput(totalCount - bestGTCount, sumTargets - bestSumGTTargets, sumWeights - bestSumGTWeights);
leafSplitCandidates.FeatureSplitInfo[feature].GTOutput = learner.CalculateSplittedLeafOutput(bestGTCount, bestSumGTTargets, bestSumGTWeights);
leafSplitCandidates.FeatureSplitInfo[feature].LteCount = totalCount - bestGTCount;
leafSplitCandidates.FeatureSplitInfo[feature].GTCount = bestGTCount;

leafSplitCandidates.FeatureSplitInfo[feature].Gain = (bestShiftedGain - gainShift) * trust - usePenalty;
double erfcArg = Math.Sqrt((bestShiftedGain - gainShift) * (totalCount - 1) / (2 * leafSplitCandidates.VarianceTargets * totalCount));
leafSplitCandidates.FeatureSplitInfo[feature].GainPValue = ProbabilityFunctions.Erfc(erfcArg);
if (leafSplitCandidates.FlockToBestFeature != null)
{
if (leafSplitCandidates.FlockToBestFeature[flock] == -1 ||
leafSplitCandidates.FeatureSplitInfo[leafSplitCandidates.FlockToBestFeature[flock]].Gain <
leafSplitCandidates.FeatureSplitInfo[feature].Gain)
{
leafSplitCandidates.FlockToBestFeature[flock] = feature;
}
}
}
leafSplitCandidates.FeatureSplitInfo[featureIndex].GainPValue = ProbabilityFunctions.Erfc(erfcArg);
}

public void FillSplitCandidatesCategorical(LeastSquaresRegressionTreeLearner learner,
Expand Down
Loading