Skip to content

Commit 5a02b59

Browse files
authored
[7.8][ML] Eagerly discard node statistics for leaves which we will never split (elastic#1148)
Backport elastic#1125.
1 parent 4abacd7 commit 5a02b59

File tree

4 files changed

+54
-22
lines changed

4 files changed

+54
-22
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@
4141
(See {ml-pull}1113[#1113].)
4242
* Improve robustness of anomaly detection to bad input data. (See {ml-pull}1114[#1114].)
4343
* Adds new `num_matches` and `preferred_to_categories` fields to category output.
44-
(See {ml-pull}1062[#1062])
44+
(See {ml-pull}1062[#1062].)
4545
* Adds mean squared logarithmic error (MSLE) for regression. (See {ml-pull}1101[#1101].)
4646
* Improve robustness of anomaly detection to bad input data. (See {ml-pull}1114[#1114].)
47+
* Reduce peak memory usage and memory estimates for classification and regression.
48+
(See {ml-pull}1125[#1125].)
4749
* Switched data frame analytics model memory estimates from kilobytes to megabytes.
4850
(See {ml-pull}1126[#1126], issue: {issue}54506[#54506].)
4951
* Added a {ml} native code build for Linux on AArch64. (See {ml-pull}1132[#1132] and

lib/api/unittest/CDataFrameAnalyzerTrainingTest.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,8 @@ BOOST_AUTO_TEST_CASE(testRunBoostedTreeRegressionTrainingMse) {
330330
<< "ms");
331331

332332
BOOST_TEST_REQUIRE(core::CProgramCounters::counter(
333-
counter_t::E_DFTPMEstimatedPeakMemoryUsage) < 6000000);
334-
BOOST_TEST_REQUIRE(core::CProgramCounters::counter(counter_t::E_DFTPMPeakMemoryUsage) < 1500000);
333+
counter_t::E_DFTPMEstimatedPeakMemoryUsage) < 4500000);
334+
BOOST_TEST_REQUIRE(core::CProgramCounters::counter(counter_t::E_DFTPMPeakMemoryUsage) < 1600000);
335335
BOOST_TEST_REQUIRE(
336336
core::CProgramCounters::counter(counter_t::E_DFTPMPeakMemoryUsage) <
337337
core::CProgramCounters::counter(counter_t::E_DFTPMEstimatedPeakMemoryUsage));
@@ -720,8 +720,8 @@ BOOST_AUTO_TEST_CASE(testRunBoostedTreeClassifierTraining) {
720720
<< "ms");
721721

722722
BOOST_TEST_REQUIRE(core::CProgramCounters::counter(
723-
counter_t::E_DFTPMEstimatedPeakMemoryUsage) < 6000000);
724-
BOOST_TEST_REQUIRE(core::CProgramCounters::counter(counter_t::E_DFTPMPeakMemoryUsage) < 1500000);
723+
counter_t::E_DFTPMEstimatedPeakMemoryUsage) < 4500000);
724+
BOOST_TEST_REQUIRE(core::CProgramCounters::counter(counter_t::E_DFTPMPeakMemoryUsage) < 1600000);
725725
BOOST_TEST_REQUIRE(
726726
core::CProgramCounters::counter(counter_t::E_DFTPMPeakMemoryUsage) <
727727
core::CProgramCounters::counter(counter_t::E_DFTPMEstimatedPeakMemoryUsage));

lib/api/unittest/CDataFrameMockAnalysisRunner.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313

1414
#include <test/CRandomNumbers.h>
1515

16-
#include <functional>
16+
#include <cinttypes>
17+
#include <string>
1718

1819
class CDataFrameMockAnalysisState final : public ml::api::CDataFrameAnalysisInstrumentation {
1920
public:
2021
CDataFrameMockAnalysisState(const std::string& jobId)
2122
: ml::api::CDataFrameAnalysisInstrumentation(jobId) {}
22-
void writeAnalysisStats(std::int64_t /* timestamp */) override{};
23+
void writeAnalysisStats(std::int64_t /* timestamp */) override {}
2324

2425
protected:
2526
ml::counter_t::ECounterTypes memoryCounterType() override;

lib/maths/CBoostedTreeImpl.cc

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
* you may not use this file except in compliance with the Elastic License.
55
*/
66

7-
#include "maths/CBoostedTreeUtils.h"
87
#include <maths/CBoostedTreeImpl.h>
98

109
#include <core/CContainerPrinter.h>
@@ -20,13 +19,18 @@
2019
#include <maths/CBoostedTree.h>
2120
#include <maths/CBoostedTreeLeafNodeStatistics.h>
2221
#include <maths/CBoostedTreeLoss.h>
22+
#include <maths/CBoostedTreeUtils.h>
2323
#include <maths/CDataFrameAnalysisInstrumentationInterface.h>
2424
#include <maths/CDataFrameCategoryEncoder.h>
2525
#include <maths/CQuantileSketch.h>
2626
#include <maths/CSampling.h>
2727
#include <maths/CSetTools.h>
2828
#include <maths/CTreeShapFeatureImportance.h>
2929

30+
#include <boost/circular_buffer.hpp>
31+
32+
#include <algorithm>
33+
3034
namespace ml {
3135
namespace maths {
3236
using namespace boosted_tree;
@@ -310,10 +314,14 @@ std::size_t CBoostedTreeImpl::estimateMemoryUsage(std::size_t numberRows,
310314
std::size_t foldRoundLossMemoryUsage{m_NumberFolds * m_NumberRounds *
311315
sizeof(TOptionalDouble)};
312316
std::size_t hyperparametersMemoryUsage{numberColumns * sizeof(double)};
317+
// We only maintain statistics for leaves we know we may possibly split this
318+
// halves the peak number of statistics we maintain.
313319
std::size_t leafNodeStatisticsMemoryUsage{
314-
maximumNumberLeaves * CBoostedTreeLeafNodeStatistics::estimateMemoryUsage(
315-
numberRows, maximumNumberFeatures, m_NumberSplitsPerFeature,
316-
m_Loss->numberParameters())};
320+
maximumNumberLeaves *
321+
CBoostedTreeLeafNodeStatistics::estimateMemoryUsage(
322+
numberRows, maximumNumberFeatures, m_NumberSplitsPerFeature,
323+
m_Loss->numberParameters()) /
324+
2};
317325
std::size_t dataTypeMemoryUsage{maximumNumberFeatures * sizeof(CDataFrameUtils::SDataType)};
318326
std::size_t featureSampleProbabilities{maximumNumberFeatures * sizeof(double)};
319327
std::size_t missingFeatureMaskMemoryUsage{
@@ -721,14 +729,13 @@ CBoostedTreeImpl::trainTree(core::CDataFrame& frame,
721729
LOG_TRACE(<< "Training one tree...");
722730

723731
using TLeafNodeStatisticsPtr = CBoostedTreeLeafNodeStatistics::TPtr;
724-
using TLeafNodeStatisticsPtrQueue =
725-
std::priority_queue<TLeafNodeStatisticsPtr, std::vector<TLeafNodeStatisticsPtr>, COrderings::SLess>;
732+
using TLeafNodeStatisticsPtrQueue = boost::circular_buffer<TLeafNodeStatisticsPtr>;
726733

727734
TNodeVec tree(1);
728735
tree.reserve(2 * maximumTreeSize + 1);
729736

730-
TLeafNodeStatisticsPtrQueue leaves;
731-
leaves.push(std::make_shared<CBoostedTreeLeafNodeStatistics>(
737+
TLeafNodeStatisticsPtrQueue leaves(maximumTreeSize / 2 + 3);
738+
leaves.push_back(std::make_shared<CBoostedTreeLeafNodeStatistics>(
732739
0 /*root*/, m_NumberInputColumns, m_Loss->numberParameters(),
733740
m_NumberThreads, frame, *m_Encoder, m_Regularization, candidateSplits,
734741
this->featureBag(), 0 /*depth*/, trainingRowMask));
@@ -752,10 +759,16 @@ CBoostedTreeImpl::trainTree(core::CDataFrame& frame,
752759

753760
double totalGain{0.0};
754761

762+
COrderings::SLess less;
763+
755764
for (std::size_t i = 0; i < maximumTreeSize; ++i) {
756765

757-
auto leaf = leaves.top();
758-
leaves.pop();
766+
if (leaves.empty()) {
767+
break;
768+
}
769+
770+
auto leaf = leaves.back();
771+
leaves.pop_back();
759772

760773
scopeMemoryUsage.remove(leaf);
761774

@@ -764,7 +777,8 @@ CBoostedTreeImpl::trainTree(core::CDataFrame& frame,
764777
}
765778

766779
totalGain += leaf->gain();
767-
LOG_TRACE(<< "splitting " << leaf->id() << " total gain = " << totalGain);
780+
LOG_TRACE(<< "splitting " << leaf->id() << " leaf gain = " << leaf->gain()
781+
<< " total gain = " << totalGain);
768782

769783
std::size_t splitFeature;
770784
double splitValue;
@@ -783,11 +797,26 @@ CBoostedTreeImpl::trainTree(core::CDataFrame& frame,
783797
leftChildId, rightChildId, m_NumberThreads, frame, *m_Encoder,
784798
m_Regularization, candidateSplits, this->featureBag(), tree[leaf->id()]);
785799

786-
scopeMemoryUsage.add(leftChild);
787-
scopeMemoryUsage.add(rightChild);
800+
if (less(rightChild, leftChild)) {
801+
std::swap(leftChild, rightChild);
802+
}
788803

789-
leaves.push(std::move(leftChild));
790-
leaves.push(std::move(rightChild));
804+
std::size_t n{leaves.size()};
805+
if (leftChild->gain() >= MINIMUM_RELATIVE_GAIN_PER_SPLIT * totalGain) {
806+
scopeMemoryUsage.add(leftChild);
807+
leaves.push_back(std::move(leftChild));
808+
}
809+
if (rightChild->gain() >= MINIMUM_RELATIVE_GAIN_PER_SPLIT * totalGain) {
810+
scopeMemoryUsage.add(rightChild);
811+
leaves.push_back(std::move(rightChild));
812+
}
813+
std::inplace_merge(leaves.begin(), leaves.begin() + n, leaves.end(), less);
814+
815+
// Drop any leaves which can't possibly be split.
816+
while (leaves.size() + i + 1 > maximumTreeSize) {
817+
scopeMemoryUsage.remove(leaves.front());
818+
leaves.pop_front();
819+
}
791820
}
792821

793822
tree.shrink_to_fit();

0 commit comments

Comments
 (0)