Skip to content

[ML] Improve boosted tree training initialisation #686

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Sep 26, 2019
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/api/CDataFrameBoostedTreeRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class API_EXPORT CDataFrameBoostedTreeRunner final : public CDataFrameAnalysisRu
private:
using TBoostedTreeUPtr = std::unique_ptr<maths::CBoostedTree>;
using TBoostedTreeFactoryUPtr = std::unique_ptr<maths::CBoostedTreeFactory>;
using TDataSearcherUPtr = CDataFrameAnalysisSpecification::TDataSearcherUPtr;
using TMemoryEstimator = std::function<void(std::int64_t)>;

private:
Expand All @@ -58,7 +59,8 @@ class API_EXPORT CDataFrameBoostedTreeRunner final : public CDataFrameAnalysisRu
TMemoryEstimator memoryEstimator();

bool restoreBoostedTree(core::CDataFrame& frame,
CDataFrameAnalysisSpecification::TDataSearcherUPtr& restoreSearcher);
std::size_t dependentVariableColumn,
TDataSearcherUPtr& restoreSearcher);

private:
// Note custom config is written directly to the factory object.
Expand Down
2 changes: 1 addition & 1 deletion include/core/CLoopProgress.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class CORE_EXPORT CLoopProgress {
double scale = 1.0);

//! Attach a new progress monitor callback.
void attach(const TProgressCallback& recordProgress);
void progressCallback(const TProgressCallback& recordProgress);

//! Increment the progress by \p i.
void increment(std::size_t i = 1);
Expand Down
32 changes: 19 additions & 13 deletions include/maths/CBoostedTreeFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,9 @@ class MATHS_EXPORT CBoostedTreeFactory final {
TLossFunctionUPtr loss);

//! Construct a boosted tree object from its serialized version.
static TBoostedTreeUPtr
constructFromString(std::istream& jsonStringStream,
core::CDataFrame& frame,
TProgressCallback recordProgress = noopRecordProgress,
TMemoryUsageCallback recordMemoryUsage = noopRecordMemoryUsage,
TTrainingStateCallback recordTrainingState = noopRecordTrainingState);
//!
//! \warning Throws runtime error on fail to restore.
static CBoostedTreeFactory constructFromString(std::istream& jsonStringStream);

~CBoostedTreeFactory();
CBoostedTreeFactory(CBoostedTreeFactory&) = delete;
Expand Down Expand Up @@ -101,14 +98,14 @@ class MATHS_EXPORT CBoostedTreeFactory final {
using TOptionalVector = boost::optional<TVector>;
using TPackedBitVectorVec = std::vector<core::CPackedBitVector>;
using TBoostedTreeImplUPtr = std::unique_ptr<CBoostedTreeImpl>;
using TScaleRegularization = std::function<void(double)>;
using TScaleRegularization = std::function<void(CBoostedTreeImpl&, double)>;

private:
static const double MINIMUM_ETA;
static const std::size_t MAXIMUM_NUMBER_TREES;

private:
CBoostedTreeFactory(std::size_t numberThreads, TLossFunctionUPtr loss);
CBoostedTreeFactory(bool restored, std::size_t numberThreads, TLossFunctionUPtr loss);

//! Compute the row masks for the missing values for each feature.
void initializeMissingFeatureMasks(const core::CDataFrame& frame) const;
Expand Down Expand Up @@ -138,10 +135,15 @@ class MATHS_EXPORT CBoostedTreeFactory final {
TDoubleDoublePr estimateTreeGainAndCurvature(core::CDataFrame& frame,
const core::CPackedBitVector& trainingRowMask) const;

//! Get the regularizer value at the point the model starts to overfit.
TOptionalVector candidateRegularizerSearchInterval(core::CDataFrame& frame,
core::CPackedBitVector trainingRowMask,
TScaleRegularization scale) const;
//! Perform a line search with quadratic approximation for the regularizer
//! value at the model starts to overfit.
//!
//! \note applyScaleToRegularizer Applies a specified scale to the initial
//! choosen value for tree implemenation.
TOptionalVector
lineSearchWithQuadraticApproxToTestError(core::CDataFrame& frame,
core::CPackedBitVector trainingRowMask,
const TScaleRegularization& applyScaleToRegularizer) const;

//! Initialize the state for hyperparameter optimisation.
void initializeHyperparameterOptimisation() const;
Expand All @@ -150,7 +152,10 @@ class MATHS_EXPORT CBoostedTreeFactory final {
std::size_t numberHyperparameterTuningRounds() const;

//! Setup monitoring for training progress.
void setupTrainingProgressMonitoring();
void initializeTrainingProgressMonitoring();

//! Refresh progress monitoring after restoring from saved training state.
void resumeRestoredTrainingProgressMonitoring();

static void noopRecordProgress(double);
static void noopRecordMemoryUsage(std::int64_t);
Expand All @@ -159,6 +164,7 @@ class MATHS_EXPORT CBoostedTreeFactory final {
private:
TOptionalDouble m_MinimumFrequencyToOneHotEncode;
TOptionalSize m_BayesianOptimisationRestarts;
bool m_Restored = false;
TBoostedTreeImplUPtr m_TreeImpl;
TVector m_GammaSearchInterval;
TVector m_LambdaSearchInterval;
Expand Down
14 changes: 7 additions & 7 deletions include/maths/CBoostedTreeImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,18 @@ class MATHS_EXPORT CBoostedTreeImpl final {
using TNodeVec = std::vector<CNode>;
using TNodeVecVec = std::vector<TNodeVec>;

//! \brief Holds the parameters associated with the different types of regulariser
//! \brief Holds the parameters associated with the different types of regularizer
//! terms available.
template<typename T>
class CRegularization final {
public:
//! Set the multiplier of the tree size regularizer.
//! Set the multiplier of the tree size penalty.
CRegularization& gamma(double gamma) {
m_Gamma = gamma;
return *this;
}

//! Set the multiplier of the square leaf weight regularizer.
//! Set the multiplier of the square leaf weight penalty.
CRegularization& lambda(double lambda) {
m_Lambda = lambda;
return *this;
Expand All @@ -142,10 +142,10 @@ class MATHS_EXPORT CBoostedTreeImpl final {
return (m_Gamma == T{} ? 1 : 0) + (m_Lambda == T{} ? 1 : 0);
}

//! Multiplier of the tree size regularizer.
//! Multiplier of the tree size penalty.
T gamma() const { return m_Gamma; }

//! Multiplier of the square leaf weight regularizer.
//! Multiplier of the square leaf weight penalty.
T lambda() const { return m_Lambda; }

//! Get description of the regularization parameters.
Expand Down Expand Up @@ -674,8 +674,8 @@ class MATHS_EXPORT CBoostedTreeImpl final {
//! the dependent variable.
core::CPackedBitVector allTrainingRowsMask() const;

//! Compute the sum loss for the predictions from \p frame and the leaf
//! count and squared weight sum from \p forest.
//! Compute the \p percentile percentile gain per split and the sum of row
//! curvatures per internal node of \p forest.
TDoubleDoublePr gainAndCurvatureAtPercentile(double percentile,
const TNodeVecVec& forest) const;

Expand Down
3 changes: 2 additions & 1 deletion lib/api/CDataFrameAnalysisRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ void CDataFrameAnalysisRunner::computeAndSaveExecutionStrategy() {
if (memoryUsage <= memoryLimit) {
break;
}
// if we are not allowed to spill over to disk then only one partition is possible
// If we are not allowed to spill over to disk then only one partition
// is possible.
if (m_Spec.diskUsageAllowed() == false) {
LOG_TRACE(<< "stop partition number computation since disk usage is turned off");
break;
Expand Down
17 changes: 11 additions & 6 deletions lib/api/CDataFrameBoostedTreeRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ void CDataFrameBoostedTreeRunner::runImpl(const TStrVec& featureNames,
auto restoreSearcher{this->spec().restoreSearcher()};
bool treeRestored{false};
if (restoreSearcher != nullptr) {
treeRestored = this->restoreBoostedTree(frame, restoreSearcher);
treeRestored = this->restoreBoostedTree(
frame, dependentVariableColumn - featureNames.begin(), restoreSearcher);
}

if (treeRestored == false) {
Expand All @@ -204,9 +205,10 @@ void CDataFrameBoostedTreeRunner::runImpl(const TStrVec& featureNames,
core::CProgramCounters::counter(counter_t::E_DFTPMTimeToTrain) = watch.stop();
}

bool CDataFrameBoostedTreeRunner::restoreBoostedTree(
core::CDataFrame& frame,
CDataFrameAnalysisSpecification::TDataSearcherUPtr& restoreSearcher) { // Restore from Elasticsearch compressed data
bool CDataFrameBoostedTreeRunner::restoreBoostedTree(core::CDataFrame& frame,
std::size_t dependentVariableColumn,
TDataSearcherUPtr& restoreSearcher) {
// Restore from Elasticsearch compressed data
try {
core::CStateDecompressor decompressor(*restoreSearcher);
decompressor.setStateRestoreSearch(
Expand All @@ -228,8 +230,11 @@ bool CDataFrameBoostedTreeRunner::restoreBoostedTree(
return false;
}

m_BoostedTree = maths::CBoostedTreeFactory::constructFromString(
*inputStream, frame, progressRecorder(), memoryEstimator(), statePersister());
m_BoostedTree = maths::CBoostedTreeFactory::constructFromString(*inputStream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I like the symmetry now. 👍

.progressCallback(this->progressRecorder())
.trainingStateCallback(this->statePersister())
.memoryUsageCallback(this->memoryEstimator())
.buildFor(frame, dependentVariableColumn);
} catch (std::exception& e) {
LOG_ERROR(<< "Failed to restore state! " << e.what());
return false;
Expand Down
55 changes: 24 additions & 31 deletions lib/api/unittest/CDataFrameAnalyzerTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ class CTestDataAdder : public core::CDataAdder {
private:
TOStreamP m_Stream;
};
}

std::vector<std::string> streamToStringVector(std::stringstream&& tokenStream) {
std::vector<std::string> results;
Expand Down Expand Up @@ -362,6 +361,7 @@ void addRegressionTestData(const TStrVec& fieldNames,
}
});
}
}

void CDataFrameAnalyzerTest::testWithoutControlMessages() {

Expand Down Expand Up @@ -653,7 +653,7 @@ void CDataFrameAnalyzerTest::testRunBoostedTreeTraining() {
LOG_DEBUG(<< "time to train = " << core::CProgramCounters::counter(counter_t::E_DFTPMTimeToTrain)
<< "ms");
CPPUNIT_ASSERT(core::CProgramCounters::counter(
counter_t::E_DFTPMEstimatedPeakMemoryUsage) < 2300000);
counter_t::E_DFTPMEstimatedPeakMemoryUsage) < 2600000);
CPPUNIT_ASSERT(core::CProgramCounters::counter(counter_t::E_DFTPMPeakMemoryUsage) < 1050000);
CPPUNIT_ASSERT(core::CProgramCounters::counter(counter_t::E_DFTPMTimeToTrain) > 0);
CPPUNIT_ASSERT(core::CProgramCounters::counter(counter_t::E_DFTPMTimeToTrain) <= duration);
Expand Down Expand Up @@ -1130,8 +1130,7 @@ void CDataFrameAnalyzerTest::testRunBoostedTreeTrainingWithStateRecoverySubrouti
rng.generateUniformSamples(-10.0, 10.0, weights.size() * numberExamples, values);

auto persistenceStream{std::make_shared<std::ostringstream>()};
CDataFrameAnalyzerTest::TPersisterSupplier persisterSupplier =
[&persistenceStream]() -> TDataAdderUPtr {
TPersisterSupplier persisterSupplier = [&persistenceStream]() -> TDataAdderUPtr {
return std::make_unique<api::CSingleStreamDataAdder>(persistenceStream);
};

Expand All @@ -1142,20 +1141,21 @@ void CDataFrameAnalyzerTest::testRunBoostedTreeTrainingWithStateRecoverySubrouti
numberRoundsPerHyperparameter, 12, {}, lambda, gamma, eta,
maximumNumberTrees, featureBagFraction, &persisterSupplier),
outputWriterFactory};
std::size_t dependentVariable(
std::find(fieldNames.begin(), fieldNames.end(), "c5") - fieldNames.begin());

auto frame{passDataToAnalyzer(fieldNames, fieldValues, analyzer, weights, values)};
analyzer.handleRecord(fieldNames, {"", "", "", "", "", "", "$"});

TStrVec persistedStatesString{
streamToStringVector(std::stringstream(persistenceStream->str()))};
auto expectedTree{getFinalTree(persistedStatesString, frame)};
auto expectedTree{this->getFinalTree(persistedStatesString, frame, dependentVariable)};

// Compute actual tree
persistenceStream->str("");

std::istringstream intermediateStateStream{persistedStatesString[iterationToRestartFrom]};
CDataFrameAnalyzerTest::TRestoreSearcherSupplier restoreSearcherSupplier =
[&intermediateStateStream]() -> TDataSearcherUPtr {
TRestoreSearcherSupplier restoreSearcherSupplier = [&intermediateStateStream]() -> TDataSearcherUPtr {
return std::make_unique<CTestDataSearcher>(intermediateStateStream.str());
};

Expand All @@ -1170,49 +1170,42 @@ void CDataFrameAnalyzerTest::testRunBoostedTreeTrainingWithStateRecoverySubrouti

persistedStatesString =
streamToStringVector(std::stringstream(persistenceStream->str()));
auto actualTree{getFinalTree(persistedStatesString, frame)};
auto actualTree{this->getFinalTree(persistedStatesString, frame, dependentVariable)};

// compare hyperparameter

rapidjson::Document expectedResults{treeToJsonDocument(*expectedTree)};
const auto& expectedHyperparameters = expectedResults["best_hyperparameters"];
const auto& expectedRegularizationHyperparameters =
expectedHyperparameters["hyperparam_regularization"];

rapidjson::Document actualResults{treeToJsonDocument(*actualTree)};
const auto& actualHyperparameters = actualResults["best_hyperparameters"];
const auto& actualRegularizationHyperparameters =
actualHyperparameters["hyperparam_regularization"];

auto assertDoublesEqual = [&expectedHyperparameters,
&actualHyperparameters](std::string key) {
for (const auto& key : {"hyperparam_eta", "hyperparam_eta_growth_rate_per_tree",
"hyperparam_feature_bag_fraction"}) {
double expected{std::stod(expectedHyperparameters[key].GetString())};
double actual{std::stod(actualHyperparameters[key].GetString())};
CPPUNIT_ASSERT_DOUBLES_EQUAL(expected, actual, 1e-4 * expected);
};
auto assertDoublesArrayEqual = [&expectedHyperparameters,
&actualHyperparameters](std::string key) {
TDoubleVec expectedVector;
core::CPersistUtils::fromString(expectedHyperparameters[key].GetString(), expectedVector);
TDoubleVec actualVector;
core::CPersistUtils::fromString(actualHyperparameters[key].GetString(), actualVector);
CPPUNIT_ASSERT_EQUAL(expectedVector.size(), actualVector.size());
for (size_t i = 0; i < expectedVector.size(); i++) {
CPPUNIT_ASSERT_DOUBLES_EQUAL(expectedVector[i], actualVector[i],
1e-4 * expectedVector[i]);
}
};
assertDoublesEqual("hyperparam_lambda");
assertDoublesEqual("hyperparam_gamma");
assertDoublesEqual("hyperparam_eta");
assertDoublesEqual("hyperparam_eta_growth_rate_per_tree");
assertDoublesEqual("hyperparam_feature_bag_fraction");
assertDoublesArrayEqual("hyperparam_feature_sample_probabilities");
}
for (const auto& key : {"regularization_gamma", "regularization_lambda"}) {
double expected{std::stod(expectedRegularizationHyperparameters[key].GetString())};
double actual{std::stod(actualRegularizationHyperparameters[key].GetString())};
CPPUNIT_ASSERT_DOUBLES_EQUAL(expected, actual, 1e-4 * expected);
}
}

maths::CBoostedTreeFactory::TBoostedTreeUPtr
CDataFrameAnalyzerTest::getFinalTree(const TStrVec& persistedStates,
std::unique_ptr<core::CDataFrame>& frame) const {
std::unique_ptr<core::CDataFrame>& frame,
std::size_t dependentVariable) const {
CTestDataSearcher dataSearcher(persistedStates.back());
auto decompressor{std::make_unique<core::CStateDecompressor>(dataSearcher)};
decompressor->setStateRestoreSearch(api::ML_STATE_INDEX,
api::getRegressionStateId("testJob"));
auto stream{decompressor->search(1, 1)};
return maths::CBoostedTreeFactory::constructFromString(*stream, *frame);
return maths::CBoostedTreeFactory::constructFromString(*stream).buildFor(
*frame, dependentVariable);
}
4 changes: 3 additions & 1 deletion lib/api/unittest/CDataFrameAnalyzerTest.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ class CDataFrameAnalyzerTest : public CppUnit::TestFixture {
std::size_t iterationToRestartFrom) const;

ml::maths::CBoostedTreeFactory::TBoostedTreeUPtr
getFinalTree(const TStrVec& persistedStates, TDataFrameUPtr& frame) const;
getFinalTree(const TStrVec& persistedStates,
TDataFrameUPtr& frame,
std::size_t dependentVariable) const;
};

#endif // INCLUDED_CDataFrameAnalyzerTest_h
5 changes: 4 additions & 1 deletion lib/core/CLoopProgress.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ CLoopProgress::CLoopProgress(std::size_t size, const TProgressCallback& recordPr
m_StepProgress{scale / static_cast<double>(m_Steps)}, m_RecordProgress{recordProgress} {
}

void CLoopProgress::attach(const TProgressCallback& recordProgress) {
void CLoopProgress::progressCallback(const TProgressCallback& recordProgress) {
m_RecordProgress = recordProgress;
}

Expand All @@ -52,6 +52,7 @@ void CLoopProgress::increment(std::size_t i) {
}

void CLoopProgress::resumeRestored() {
// This outputs progress and updates m_LastProgress to the correct value.
this->increment(0);
}

Expand All @@ -70,6 +71,8 @@ void CLoopProgress::acceptPersistInserter(CStatePersistInserter& inserter) const
inserter.insertValue(CURRENT_STEP_PROGRESS_TAG, m_StepProgress,
core::CIEEE754::E_DoublePrecision);
inserter.insertValue(LOOP_POS_TAG, m_Pos);
// m_LastProgress is not persisted because when restoring we will have never
// recorded progress.
}

bool CLoopProgress::acceptRestoreTraverser(CStateRestoreTraverser& traverser) {
Expand Down
2 changes: 1 addition & 1 deletion lib/core/unittest/CLoopProgressTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ void CLoopProgressTest::testSerialization() {
auto restoredRecordProgress = [&restoredProgress](double p) {
restoredProgress += p;
};
restoredLoopProgress.attach(restoredRecordProgress);
restoredLoopProgress.progressCallback(restoredRecordProgress);
restoredLoopProgress.resumeRestored();

CPPUNIT_ASSERT_EQUAL(loopProgress.checksum(), restoredLoopProgress.checksum());
Expand Down
Loading