Skip to content

Commit d0d4e6b

Browse files
author
Rogan Carr
committed
Fixing null pointer bug, where if a split didn't have positive gain, no graph would be defined. As the graph is still accessed, it must be defined to zero in such cases.
1 parent dda64f2 commit d0d4e6b

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

src/Microsoft.ML.FastTree/GamTrainer.cs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,7 @@ private void TrainingIteration(int globalFeatureIndex, double[] gradient, double
331331
TrainSet.Flocks[flockIndex].Trust(subFeatureIndex), 0);
332332

333333
// Adjust the model
334-
if (_leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex].Gain > 0)
335-
ConvertTreeToGraph(globalFeatureIndex, iteration);
334+
ConvertTreeToGraph(globalFeatureIndex, iteration, _leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex].Gain > 0);
336335
}
337336

338337
/// <summary>
@@ -550,14 +549,19 @@ private void CenterGraph()
550549
}
551550
}
552551

553-
private void ConvertTreeToGraph(int globalFeatureIndex, int iteration)
552+
private void ConvertTreeToGraph(int globalFeatureIndex, int iteration, bool useSplitValues)
554553
{
555-
SplitInfo splitinfo = _leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex];
556-
557-
_splitPoint[globalFeatureIndex][iteration] = splitinfo.Threshold;
554+
// Always define the graph
558555
_splitValue[globalFeatureIndex][iteration] = new double[2]; // Easily extend to variable-length trees
559-
_splitValue[globalFeatureIndex][iteration][0] = Args.LearningRates * splitinfo.LteOutput;
560-
_splitValue[globalFeatureIndex][iteration][1] = Args.LearningRates * splitinfo.GTOutput;
556+
557+
// But only fill it in if some criteria were met
558+
if (useSplitValues)
559+
{
560+
SplitInfo splitinfo = _leafSplitCandidates.FeatureSplitInfo[globalFeatureIndex];
561+
_splitPoint[globalFeatureIndex][iteration] = splitinfo.Threshold;
562+
_splitValue[globalFeatureIndex][iteration][0] = Args.LearningRates * splitinfo.LteOutput;
563+
_splitValue[globalFeatureIndex][iteration][1] = Args.LearningRates * splitinfo.GTOutput;
564+
}
561565
}
562566

563567
private void InitializeGamHistograms()

0 commit comments

Comments
 (0)