Skip to content

Commit 1f6b3be

Browse files
authored
Add cancellation signal checkpoints in FastTree. (#3028)
* Add cancellation signal checkpoints in FastTree. * undo temp changes. * build break. * Add checkpoint in FindBestThresholdFromHistogram(). * Add checkpoint for disk transpose.
1 parent b6c5b70 commit 1f6b3be

File tree

5 files changed

+22
-5
lines changed

5 files changed

+22
-5
lines changed

src/Microsoft.ML.FastTree/BoostingFastTree.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ private protected override TreeLearner ConstructTreeLearner(IChannel ch)
6868
FastTreeTrainerOptions.HistogramPoolSize, FastTreeTrainerOptions.Seed, FastTreeTrainerOptions.FeatureFractionPerSplit, FastTreeTrainerOptions.FilterZeroLambdas,
6969
FastTreeTrainerOptions.AllowEmptyTrees, FastTreeTrainerOptions.GainConfidenceLevel, FastTreeTrainerOptions.MaximumCategoricalGroupCountPerNode,
7070
FastTreeTrainerOptions.MaximumCategoricalSplitPointCount, BsrMaxTreeOutput(), ParallelTraining,
71-
FastTreeTrainerOptions.MinimumExampleFractionForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinimumExamplesForCategoricalSplit, FastTreeTrainerOptions.Bias);
71+
FastTreeTrainerOptions.MinimumExampleFractionForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinimumExamplesForCategoricalSplit,
72+
FastTreeTrainerOptions.Bias, Host);
7273
}
7374

7475
private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)

src/Microsoft.ML.FastTree/FastTree.cs

+7
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,7 @@ private ValueMapper<VBuffer<T1>, VBuffer<T2>> GetCopier<T1, T2>(DataViewType ite
13341334

13351335
private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxBins, IParallelTraining parallelTraining)
13361336
{
1337+
Host.CheckAlive();
13371338
Host.AssertValue(examples);
13381339
Host.Assert(examples.Schema.Feature.HasValue);
13391340

@@ -1414,6 +1415,7 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB
14141415
pch.SetHeader(new ProgressHeader("features"), e => e.SetProgress(0, iFeature, features.Length));
14151416
while (cursor.MoveNext())
14161417
{
1418+
Host.CheckAlive();
14171419
iFeature = cursor.SlotIndex;
14181420
if (!localConstructBinFeatures[iFeature])
14191421
continue;
@@ -1489,6 +1491,8 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB
14891491
int catRangeIndex = 0;
14901492
for (iFeature = 0; iFeature < NumFeatures;)
14911493
{
1494+
Host.CheckAlive();
1495+
14921496
if (catRangeIndex < CategoricalFeatureIndices.Length &&
14931497
CategoricalFeatureIndices[catRangeIndex] == iFeature)
14941498
{
@@ -1565,6 +1569,7 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB
15651569
{
15661570
for (int i = 0; i < NumFeatures; i++)
15671571
{
1572+
Host.CheckAlive();
15681573
GetFeatureValues(cursor, i, getter, ref temp, ref doubleTemp, copier);
15691574
double[] upperBounds = BinUpperBounds[i];
15701575
Host.AssertValue(upperBounds);
@@ -1919,6 +1924,7 @@ private void InitializeBins(int maxBins, IParallelTraining parallelTraining)
19191924
List<int> trivialFeatures = new List<int>();
19201925
for (iFeature = 0; iFeature < NumFeatures; iFeature++)
19211926
{
1927+
Host.CheckAlive();
19221928
if (!localConstructBinFeatures[iFeature])
19231929
continue;
19241930
// The following strange call will actually sparsify.
@@ -2230,6 +2236,7 @@ private IEnumerable<FeatureFlockBase> CreateFlocksCore(IChannel ch, IProgressCha
22302236

22312237
for (; iFeature < featureLim; ++iFeature)
22322238
{
2239+
Host.CheckAlive();
22332240
double[] bup = BinUpperBounds[iFeature];
22342241
Contracts.Assert(Utils.Size(bup) > 0);
22352242
if (bup.Length == 1)

src/Microsoft.ML.FastTree/RandomForest.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ private protected override TreeLearner ConstructTreeLearner(IChannel ch)
6868
FastTreeTrainerOptions.HistogramPoolSize, FastTreeTrainerOptions.Seed, FastTreeTrainerOptions.FeatureFractionPerSplit,
6969
FastTreeTrainerOptions.AllowEmptyTrees, FastTreeTrainerOptions.GainConfidenceLevel, FastTreeTrainerOptions.MaximumCategoricalGroupCountPerNode,
7070
FastTreeTrainerOptions.MaximumCategoricalSplitPointCount, _quantileEnabled, FastTreeTrainerOptions.NumberOfQuantileSamples, ParallelTraining,
71-
FastTreeTrainerOptions.MinimumExampleFractionForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinimumExamplesForCategoricalSplit, FastTreeTrainerOptions.Bias);
71+
FastTreeTrainerOptions.MinimumExampleFractionForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinimumExamplesForCategoricalSplit,
72+
FastTreeTrainerOptions.Bias, Host);
7273
}
7374

7475
internal abstract class RandomForestObjectiveFunction : ObjectiveFunctionBase

src/Microsoft.ML.FastTree/Training/TreeLearners/FastForestLeastSquaresTreeLearner.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ internal class RandomForestLeastSquaresTreeLearner : LeastSquaresRegressionTreeL
1515
public RandomForestLeastSquaresTreeLearner(Dataset trainData, int numLeaves, int minDocsInLeaf, Double entropyCoefficient, Double featureFirstUsePenalty,
1616
Double featureReusePenalty, Double softmaxTemperature, int histogramPoolSize, int randomSeed, Double splitFraction, bool allowEmptyTrees,
1717
Double gainConfidenceLevel, int maxCategoricalGroupsPerNode, int maxCategoricalSplitPointsPerNode, bool quantileEnabled, int quantileSampleCount, IParallelTraining parallelTraining,
18-
double minDocsPercentageForCategoricalSplit, Bundle bundling, int minDocsForCategoricalSplit, double bias)
18+
double minDocsPercentageForCategoricalSplit, Bundle bundling, int minDocsForCategoricalSplit, double bias, IHost host)
1919
: base(trainData, numLeaves, minDocsInLeaf, entropyCoefficient, featureFirstUsePenalty, featureReusePenalty, softmaxTemperature, histogramPoolSize,
2020
randomSeed, splitFraction, false, allowEmptyTrees, gainConfidenceLevel, maxCategoricalGroupsPerNode, maxCategoricalSplitPointsPerNode, -1, parallelTraining,
21-
minDocsPercentageForCategoricalSplit, bundling, minDocsForCategoricalSplit, bias)
21+
minDocsPercentageForCategoricalSplit, bundling, minDocsForCategoricalSplit, bias, host)
2222
{
2323
_quantileSampleCount = quantileSampleCount;
2424
_quantileEnabled = quantileEnabled;

src/Microsoft.ML.FastTree/Training/TreeLearners/LeastSquaresRegressionTreeLearner.cs

+9-1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ internal class LeastSquaresRegressionTreeLearner : TreeLearner, ILeafSplitStatis
6969
protected readonly bool FilterZeros;
7070
protected readonly double BsrMaxTreeOutput;
7171

72+
protected readonly IHost Host;
73+
7274
// size of reserved memory
7375
private readonly long _sizeOfReservedMemory;
7476

@@ -114,12 +116,13 @@ internal class LeastSquaresRegressionTreeLearner : TreeLearner, ILeafSplitStatis
114116
/// <param name="bundling"></param>
115117
/// <param name="minDocsForCategoricalSplit"></param>
116118
/// <param name="bias"></param>
119+
/// <param name="host">Host</param>
117120
public LeastSquaresRegressionTreeLearner(Dataset trainData, int numLeaves, int minDocsInLeaf, double entropyCoefficient,
118121
double featureFirstUsePenalty, double featureReusePenalty, double softmaxTemperature, int histogramPoolSize,
119122
int randomSeed, double splitFraction, bool filterZeros, bool allowEmptyTrees, double gainConfidenceLevel,
120123
int maxCategoricalGroupsPerNode, int maxCategoricalSplitPointPerNode,
121124
double bsrMaxTreeOutput, IParallelTraining parallelTraining, double minDocsPercentageForCategoricalSplit,
122-
Bundle bundling, int minDocsForCategoricalSplit, double bias)
125+
Bundle bundling, int minDocsForCategoricalSplit, double bias, IHost host)
123126
: base(trainData, numLeaves)
124127
{
125128
MinDocsInLeaf = minDocsInLeaf;
@@ -135,6 +138,7 @@ public LeastSquaresRegressionTreeLearner(Dataset trainData, int numLeaves, int m
135138
MinDocsForCategoricalSplit = minDocsForCategoricalSplit;
136139
Bundling = bundling;
137140
Bias = bias;
141+
Host = host;
138142

139143
_calculateLeafSplitCandidates = ThreadTaskManager.MakeTask(
140144
FindBestThresholdForFlockThreadWorker, TrainData.NumFlocks);
@@ -148,6 +152,7 @@ public LeastSquaresRegressionTreeLearner(Dataset trainData, int numLeaves, int m
148152
histogramPool[i] = new SufficientStatsBase[TrainData.NumFlocks];
149153
for (int j = 0; j < TrainData.NumFlocks; j++)
150154
{
155+
Host.CheckAlive();
151156
var ss = histogramPool[i][j] = TrainData.Flocks[j].CreateSufficientStats(HasWeights);
152157
_sizeOfReservedMemory += ss.SizeInBytes();
153158
}
@@ -498,6 +503,7 @@ protected virtual void SetBestFeatureForLeaf(LeafSplitCandidates leafSplitCandid
498503
/// </summary>
499504
private void FindBestThresholdForFlockThreadWorker(int flock)
500505
{
506+
Host.CheckAlive();
501507
int featureMin = TrainData.FlockToFirstFeature(flock);
502508
int featureLim = featureMin + TrainData.Flocks[flock].Count;
503509
// Check if any feature is active.
@@ -649,6 +655,8 @@ public double CalculateSplittedLeafOutput(int count, double sumTargets, double s
649655
protected virtual void FindBestThresholdFromHistogram(SufficientStatsBase histogram,
650656
LeafSplitCandidates leafSplitCandidates, int flock)
651657
{
658+
Host.CheckAlive();
659+
652660
// Cache histograms for the parallel interface.
653661
int featureMin = TrainData.FlockToFirstFeature(flock);
654662
int featureLim = featureMin + TrainData.Flocks[flock].Count;

0 commit comments

Comments
 (0)