@@ -69,6 +69,8 @@ internal class LeastSquaresRegressionTreeLearner : TreeLearner, ILeafSplitStatis
69
69
protected readonly bool FilterZeros ;
70
70
protected readonly double BsrMaxTreeOutput ;
71
71
72
+ protected readonly IHost Host ;
73
+
72
74
// size of reserved memory
73
75
private readonly long _sizeOfReservedMemory ;
74
76
@@ -114,12 +116,13 @@ internal class LeastSquaresRegressionTreeLearner : TreeLearner, ILeafSplitStatis
114
116
/// <param name="bundling"></param>
115
117
/// <param name="minDocsForCategoricalSplit"></param>
116
118
/// <param name="bias"></param>
119
+ /// <param name="host">Host</param>
117
120
public LeastSquaresRegressionTreeLearner ( Dataset trainData , int numLeaves , int minDocsInLeaf , double entropyCoefficient ,
118
121
double featureFirstUsePenalty , double featureReusePenalty , double softmaxTemperature , int histogramPoolSize ,
119
122
int randomSeed , double splitFraction , bool filterZeros , bool allowEmptyTrees , double gainConfidenceLevel ,
120
123
int maxCategoricalGroupsPerNode , int maxCategoricalSplitPointPerNode ,
121
124
double bsrMaxTreeOutput , IParallelTraining parallelTraining , double minDocsPercentageForCategoricalSplit ,
122
- Bundle bundling , int minDocsForCategoricalSplit , double bias )
125
+ Bundle bundling , int minDocsForCategoricalSplit , double bias , IHost host )
123
126
: base ( trainData , numLeaves )
124
127
{
125
128
MinDocsInLeaf = minDocsInLeaf ;
@@ -135,6 +138,7 @@ public LeastSquaresRegressionTreeLearner(Dataset trainData, int numLeaves, int m
135
138
MinDocsForCategoricalSplit = minDocsForCategoricalSplit ;
136
139
Bundling = bundling ;
137
140
Bias = bias ;
141
+ Host = host ;
138
142
139
143
_calculateLeafSplitCandidates = ThreadTaskManager . MakeTask (
140
144
FindBestThresholdForFlockThreadWorker , TrainData . NumFlocks ) ;
@@ -148,6 +152,7 @@ public LeastSquaresRegressionTreeLearner(Dataset trainData, int numLeaves, int m
148
152
histogramPool [ i ] = new SufficientStatsBase [ TrainData . NumFlocks ] ;
149
153
for ( int j = 0 ; j < TrainData . NumFlocks ; j ++ )
150
154
{
155
+ Host . CheckAlive ( ) ;
151
156
var ss = histogramPool [ i ] [ j ] = TrainData . Flocks [ j ] . CreateSufficientStats ( HasWeights ) ;
152
157
_sizeOfReservedMemory += ss . SizeInBytes ( ) ;
153
158
}
@@ -498,6 +503,7 @@ protected virtual void SetBestFeatureForLeaf(LeafSplitCandidates leafSplitCandid
498
503
/// </summary>
499
504
private void FindBestThresholdForFlockThreadWorker ( int flock )
500
505
{
506
+ Host . CheckAlive ( ) ;
501
507
int featureMin = TrainData . FlockToFirstFeature ( flock ) ;
502
508
int featureLim = featureMin + TrainData . Flocks [ flock ] . Count ;
503
509
// Check if any feature is active.
@@ -649,6 +655,8 @@ public double CalculateSplittedLeafOutput(int count, double sumTargets, double s
649
655
protected virtual void FindBestThresholdFromHistogram ( SufficientStatsBase histogram ,
650
656
LeafSplitCandidates leafSplitCandidates , int flock )
651
657
{
658
+ Host . CheckAlive ( ) ;
659
+
652
660
// Cache histograms for the parallel interface.
653
661
int featureMin = TrainData . FlockToFirstFeature ( flock ) ;
654
662
int featureLim = featureMin + TrainData . Flocks [ flock ] . Count ;
0 commit comments