4
4
* you may not use this file except in compliance with the Elastic License.
5
5
*/
6
6
7
- #include " maths/CBoostedTreeUtils.h"
8
7
#include < maths/CBoostedTreeImpl.h>
9
8
10
9
#include < core/CContainerPrinter.h>
20
19
#include < maths/CBoostedTree.h>
21
20
#include < maths/CBoostedTreeLeafNodeStatistics.h>
22
21
#include < maths/CBoostedTreeLoss.h>
22
+ #include < maths/CBoostedTreeUtils.h>
23
23
#include < maths/CDataFrameAnalysisInstrumentationInterface.h>
24
24
#include < maths/CDataFrameCategoryEncoder.h>
25
25
#include < maths/CQuantileSketch.h>
26
26
#include < maths/CSampling.h>
27
27
#include < maths/CSetTools.h>
28
28
#include < maths/CTreeShapFeatureImportance.h>
29
29
30
+ #include < boost/circular_buffer.hpp>
31
+
32
+ #include < algorithm>
33
+
30
34
namespace ml {
31
35
namespace maths {
32
36
using namespace boosted_tree ;
@@ -310,10 +314,14 @@ std::size_t CBoostedTreeImpl::estimateMemoryUsage(std::size_t numberRows,
310
314
std::size_t foldRoundLossMemoryUsage{m_NumberFolds * m_NumberRounds *
311
315
sizeof (TOptionalDouble)};
312
316
std::size_t hyperparametersMemoryUsage{numberColumns * sizeof (double )};
317
+ // We only maintain statistics for leaves we know we may possibly split this
318
+ // halves the peak number of statistics we maintain.
313
319
std::size_t leafNodeStatisticsMemoryUsage{
314
- maximumNumberLeaves * CBoostedTreeLeafNodeStatistics::estimateMemoryUsage (
315
- numberRows, maximumNumberFeatures, m_NumberSplitsPerFeature,
316
- m_Loss->numberParameters ())};
320
+ maximumNumberLeaves *
321
+ CBoostedTreeLeafNodeStatistics::estimateMemoryUsage (
322
+ numberRows, maximumNumberFeatures, m_NumberSplitsPerFeature,
323
+ m_Loss->numberParameters ()) /
324
+ 2 };
317
325
std::size_t dataTypeMemoryUsage{maximumNumberFeatures * sizeof (CDataFrameUtils::SDataType)};
318
326
std::size_t featureSampleProbabilities{maximumNumberFeatures * sizeof (double )};
319
327
std::size_t missingFeatureMaskMemoryUsage{
@@ -721,14 +729,13 @@ CBoostedTreeImpl::trainTree(core::CDataFrame& frame,
721
729
LOG_TRACE (<< " Training one tree..." );
722
730
723
731
using TLeafNodeStatisticsPtr = CBoostedTreeLeafNodeStatistics::TPtr;
724
- using TLeafNodeStatisticsPtrQueue =
725
- std::priority_queue<TLeafNodeStatisticsPtr, std::vector<TLeafNodeStatisticsPtr>, COrderings::SLess>;
732
+ using TLeafNodeStatisticsPtrQueue = boost::circular_buffer<TLeafNodeStatisticsPtr>;
726
733
727
734
TNodeVec tree (1 );
728
735
tree.reserve (2 * maximumTreeSize + 1 );
729
736
730
- TLeafNodeStatisticsPtrQueue leaves;
731
- leaves.push (std::make_shared<CBoostedTreeLeafNodeStatistics>(
737
+ TLeafNodeStatisticsPtrQueue leaves (maximumTreeSize / 2 + 3 ) ;
738
+ leaves.push_back (std::make_shared<CBoostedTreeLeafNodeStatistics>(
732
739
0 /* root*/ , m_NumberInputColumns, m_Loss->numberParameters (),
733
740
m_NumberThreads, frame, *m_Encoder, m_Regularization, candidateSplits,
734
741
this ->featureBag (), 0 /* depth*/ , trainingRowMask));
@@ -752,10 +759,16 @@ CBoostedTreeImpl::trainTree(core::CDataFrame& frame,
752
759
753
760
double totalGain{0.0 };
754
761
762
+ COrderings::SLess less;
763
+
755
764
for (std::size_t i = 0 ; i < maximumTreeSize; ++i) {
756
765
757
- auto leaf = leaves.top ();
758
- leaves.pop ();
766
+ if (leaves.empty ()) {
767
+ break ;
768
+ }
769
+
770
+ auto leaf = leaves.back ();
771
+ leaves.pop_back ();
759
772
760
773
scopeMemoryUsage.remove (leaf);
761
774
@@ -764,7 +777,8 @@ CBoostedTreeImpl::trainTree(core::CDataFrame& frame,
764
777
}
765
778
766
779
totalGain += leaf->gain ();
767
- LOG_TRACE (<< " splitting " << leaf->id () << " total gain = " << totalGain);
780
+ LOG_TRACE (<< " splitting " << leaf->id () << " leaf gain = " << leaf->gain ()
781
+ << " total gain = " << totalGain);
768
782
769
783
std::size_t splitFeature;
770
784
double splitValue;
@@ -783,11 +797,26 @@ CBoostedTreeImpl::trainTree(core::CDataFrame& frame,
783
797
leftChildId, rightChildId, m_NumberThreads, frame, *m_Encoder,
784
798
m_Regularization, candidateSplits, this ->featureBag (), tree[leaf->id ()]);
785
799
786
- scopeMemoryUsage.add (leftChild);
787
- scopeMemoryUsage.add (rightChild);
800
+ if (less (rightChild, leftChild)) {
801
+ std::swap (leftChild, rightChild);
802
+ }
788
803
789
- leaves.push (std::move (leftChild));
790
- leaves.push (std::move (rightChild));
804
+ std::size_t n{leaves.size ()};
805
+ if (leftChild->gain () >= MINIMUM_RELATIVE_GAIN_PER_SPLIT * totalGain) {
806
+ scopeMemoryUsage.add (leftChild);
807
+ leaves.push_back (std::move (leftChild));
808
+ }
809
+ if (rightChild->gain () >= MINIMUM_RELATIVE_GAIN_PER_SPLIT * totalGain) {
810
+ scopeMemoryUsage.add (rightChild);
811
+ leaves.push_back (std::move (rightChild));
812
+ }
813
+ std::inplace_merge (leaves.begin (), leaves.begin () + n, leaves.end (), less);
814
+
815
+ // Drop any leaves which can't possibly be split.
816
+ while (leaves.size () + i + 1 > maximumTreeSize) {
817
+ scopeMemoryUsage.remove (leaves.front ());
818
+ leaves.pop_front ();
819
+ }
791
820
}
792
821
793
822
tree.shrink_to_fit ();
0 commit comments