Skip to content

[7.4][ML] Progress monitoring for regression model training #599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/maths/CBoostedTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,12 @@ class MATHS_EXPORT CBoostedTree final : public CDataFrameRegressionModel {
~CBoostedTree() override;

//! Train on the examples in the data frame supplied to the constructor.
void train(TProgressCallback recordProgress = noop) override;
void train() override;

//! Write the predictions to the data frame supplied to the constructor.
//!
//! \warning This can only be called after train.
void predict(TProgressCallback recordProgress = noop) const override;
void predict() const override;

//! Write the trained model to \p writer.
//!
Expand All @@ -160,7 +160,7 @@ class MATHS_EXPORT CBoostedTree final : public CDataFrameRegressionModel {
using TImplUPtr = std::unique_ptr<CBoostedTreeImpl>;

private:
CBoostedTree(core::CDataFrame& frame, TImplUPtr&& impl);
CBoostedTree(core::CDataFrame& frame, TProgressCallback recordProgress, TImplUPtr&& impl);

private:
TImplUPtr m_Impl;
Expand Down
13 changes: 8 additions & 5 deletions include/maths/CBoostedTreeFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class CBoostedTreeImpl;
class MATHS_EXPORT CBoostedTreeFactory final {
public:
using TBoostedTreeUPtr = std::unique_ptr<CBoostedTree>;
using TProgressCallback = CBoostedTree::TProgressCallback;

public:
//! Construct a boosted tree object from parameters.
Expand All @@ -38,7 +39,8 @@ class MATHS_EXPORT CBoostedTreeFactory final {

//! Construct a boosted tree object from its serialized version.
static TBoostedTreeUPtr constructFromString(std::stringstream& jsonStringStream,
core::CDataFrame& frame);
core::CDataFrame& frame,
TProgressCallback recordProgress = noop);

~CBoostedTreeFactory();
CBoostedTreeFactory(CBoostedTreeFactory&) = delete;
Expand Down Expand Up @@ -66,7 +68,7 @@ class MATHS_EXPORT CBoostedTreeFactory final {
//! Set the number of training examples we need per feature we'll include.
CBoostedTreeFactory& rowsPerFeature(std::size_t rowsPerFeature);
//! Set the callback function for progress monitoring.
CBoostedTreeFactory& progressCallback(CBoostedTree::TProgressCallback callback);
CBoostedTreeFactory& progressCallback(TProgressCallback callback);

//! Estimate the maximum booking memory that training the boosted tree on a data
//! frame with \p numberRows row and \p numberColumns columns will use.
Expand Down Expand Up @@ -103,19 +105,20 @@ class MATHS_EXPORT CBoostedTreeFactory final {
//! Read overrides for hyperparameters and if necessary estimate the initial
//! values for \f$\lambda\f$ and \f$\gamma\f$ which match the gain from an
//! overfit tree.
void initializeHyperparameters(core::CDataFrame& frame,
CBoostedTree::TProgressCallback recordProgress) const;
void initializeHyperparameters(core::CDataFrame& frame) const;

//! Initialize the state for hyperparameter optimisation.
void initializeHyperparameterOptimisation() const;

//! Get the number of hyperparameter tuning rounds to use.
std::size_t numberHyperparameterTuningRounds() const;

static void noop(double);

private:
double m_MinimumFrequencyToOneHotEncode;
TBoostedTreeImplUPtr m_TreeImpl;
CBoostedTree::TProgressCallback m_ProgressCallback;
TProgressCallback m_RecordProgress = noop;
};
}
}
Expand Down
11 changes: 6 additions & 5 deletions include/maths/CDataFrameRegressionModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ class MATHS_EXPORT CDataFrameRegressionModel {
CDataFrameRegressionModel& operator=(const CDataFrameRegressionModel&) = delete;

//! Train on the examples in the data frame supplied to the constructor.
virtual void train(TProgressCallback recordProgress = noop) = 0;
virtual void train() = 0;

//! Write the predictions to the data frame supplied to the constructor.
//!
//! \warning This can only be called after train.
virtual void predict(TProgressCallback recordProgress = noop) const = 0;
virtual void predict() const = 0;

//! Write this model to \p writer.
//!
Expand All @@ -52,12 +52,13 @@ class MATHS_EXPORT CDataFrameRegressionModel {
virtual std::size_t columnHoldingPrediction(std::size_t numberColumns) const = 0;

protected:
CDataFrameRegressionModel(core::CDataFrame& frame) : m_Frame{frame} {}
core::CDataFrame& frame() const { return m_Frame; }
static void noop(double);
CDataFrameRegressionModel(core::CDataFrame& frame, TProgressCallback recordProgress);
core::CDataFrame& frame() const;
const TProgressCallback& progressRecorder() const;

private:
core::CDataFrame& m_Frame;
TProgressCallback m_RecordProgress;
};
}
}
Expand Down
5 changes: 3 additions & 2 deletions lib/api/CDataFrameBoostedTreeRunner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ CDataFrameBoostedTreeRunner::CDataFrameBoostedTreeRunner(const CDataFrameAnalysi
maths::CBoostedTreeFactory::constructFromParameters(
this->spec().numberThreads(), std::make_unique<maths::boosted_tree::CMse>()));

m_BoostedTreeFactory->progressCallback(this->progressRecorder());
if (lambda >= 0.0) {
m_BoostedTreeFactory->lambda(lambda);
}
Expand Down Expand Up @@ -139,8 +140,8 @@ void CDataFrameBoostedTreeRunner::runImpl(const TStrVec& featureNames,

m_BoostedTree = m_BoostedTreeFactory->buildFor(
frame, dependentVariableColumn - featureNames.begin());
m_BoostedTree->train(this->progressRecorder());
m_BoostedTree->predict(this->progressRecorder());
m_BoostedTree->train();
m_BoostedTree->predict();
}

std::size_t CDataFrameBoostedTreeRunner::estimateBookkeepingMemoryUsage(
Expand Down
2 changes: 1 addition & 1 deletion lib/api/unittest/CDataFrameAnalyzerTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ void addRegressionTestData(TStrVec fieldNames,
std::unique_ptr<maths::CBoostedTree> tree =
treeFactory.buildFor(*frame, weights.size());

tree->train(ml::maths::CBoostedTree::TProgressCallback());
tree->train();

frame->readRows(1, [&](TRowItr beginRows, TRowItr endRows) {
for (auto row = beginRows; row != endRows; ++row) {
Expand Down
12 changes: 6 additions & 6 deletions lib/maths/CBoostedTree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,18 @@ double CArgMinMse::value() const {
}
}

CBoostedTree::CBoostedTree(core::CDataFrame& frame, TImplUPtr&& impl)
: CDataFrameRegressionModel{frame}, m_Impl{std::move(impl)} {
CBoostedTree::CBoostedTree(core::CDataFrame& frame, TProgressCallback recordProgress, TImplUPtr&& impl)
: CDataFrameRegressionModel{frame, std::move(recordProgress)}, m_Impl{std::move(impl)} {
}

CBoostedTree::~CBoostedTree() = default;

void CBoostedTree::train(TProgressCallback recordProgress) {
m_Impl->train(this->frame(), recordProgress);
void CBoostedTree::train() {
m_Impl->train(this->frame(), this->progressRecorder());
}

void CBoostedTree::predict(TProgressCallback recordProgress) const {
m_Impl->predict(this->frame(), recordProgress);
void CBoostedTree::predict() const {
m_Impl->predict(this->frame(), this->progressRecorder());
}

void CBoostedTree::write(core::CRapidJsonConcurrentLineWriter& writer) const {
Expand Down
23 changes: 13 additions & 10 deletions lib/maths/CBoostedTreeFactory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ CBoostedTreeFactory::buildFor(core::CDataFrame& frame, std::size_t dependentVari
this->selectFeaturesAndEncodeCategories(frame);

if (this->initializeFeatureSampleDistribution()) {
this->initializeHyperparameters(frame, m_ProgressCallback);
this->initializeHyperparameters(frame);
this->initializeHyperparameterOptimisation();
}

// TODO can only use factory to create one object since this is moved. This seems trappy.
return TBoostedTreeUPtr{new CBoostedTree(frame, std::move(m_TreeImpl))};
return TBoostedTreeUPtr{new CBoostedTree(frame, m_RecordProgress, std::move(m_TreeImpl))};
}

std::size_t CBoostedTreeFactory::numberHyperparameterTuningRounds() const {
Expand Down Expand Up @@ -180,8 +180,7 @@ bool CBoostedTreeFactory::initializeFeatureSampleDistribution() const {
return false;
}

void CBoostedTreeFactory::initializeHyperparameters(core::CDataFrame& frame,
CBoostedTree::TProgressCallback recordProgress) const {
void CBoostedTreeFactory::initializeHyperparameters(core::CDataFrame& frame) const {

m_TreeImpl->m_Lambda = m_TreeImpl->m_LambdaOverride.value_or(0.0);
m_TreeImpl->m_Gamma = m_TreeImpl->m_GammaOverride.value_or(0.0);
Expand Down Expand Up @@ -232,7 +231,7 @@ void CBoostedTreeFactory::initializeHyperparameters(core::CDataFrame& frame,
LOG_TRACE(<< "loss = " << L[0] << ", # leaves = " << T[0]
<< ", sum square weights = " << W[0]);

auto forest = m_TreeImpl->trainForest(frame, trainingRowMask, recordProgress);
auto forest = m_TreeImpl->trainForest(frame, trainingRowMask, m_RecordProgress);

std::tie(L[1], T[1], W[1]) =
m_TreeImpl->regularisedLoss(frame, trainingRowMask, forest);
Expand Down Expand Up @@ -272,10 +271,12 @@ CBoostedTreeFactory::constructFromParameters(std::size_t numberThreads,

CBoostedTreeFactory::TBoostedTreeUPtr
CBoostedTreeFactory::constructFromString(std::stringstream& jsonStringStream,
core::CDataFrame& frame) {
core::CDataFrame& frame,
TProgressCallback recordProgress) {
try {
TBoostedTreeUPtr treePtr{
new CBoostedTree{frame, TBoostedTreeImplUPtr{new CBoostedTreeImpl{}}}};
new CBoostedTree{frame, std::move(recordProgress),
TBoostedTreeImplUPtr{new CBoostedTreeImpl{}}}};
core::CJsonStateRestoreTraverser traverser(jsonStringStream);
if (treePtr->acceptRestoreTraverser(traverser) == false || traverser.haveBadState()) {
throw std::runtime_error{"failed to restore boosted tree"};
Expand Down Expand Up @@ -388,9 +389,8 @@ CBoostedTreeFactory& CBoostedTreeFactory::rowsPerFeature(std::size_t rowsPerFeat
return *this;
}

CBoostedTreeFactory&
CBoostedTreeFactory::progressCallback(CBoostedTree::TProgressCallback callback) {
m_ProgressCallback = callback;
CBoostedTreeFactory& CBoostedTreeFactory::progressCallback(TProgressCallback callback) {
m_RecordProgress = callback;
return *this;
}

Expand All @@ -403,6 +403,9 @@ std::size_t CBoostedTreeFactory::numberExtraColumnsForTrain() const {
return m_TreeImpl->numberExtraColumnsForTrain();
}

void CBoostedTreeFactory::noop(double) {
}

const double CBoostedTreeFactory::MINIMUM_ETA{1e-3};
const std::size_t CBoostedTreeFactory::MAXIMUM_NUMBER_TREES{
static_cast<std::size_t>(2.0 / MINIMUM_ETA + 0.5)};
Expand Down
14 changes: 14 additions & 0 deletions lib/maths/CBoostedTreeImpl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <maths/CBoostedTreeImpl.h>

#include <core/CLoopProgress.h>
#include <core/CPersistUtils.h>

#include <maths/CBayesianOptimisation.h>
Expand Down Expand Up @@ -88,6 +89,12 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame,

LOG_TRACE(<< "Main training loop...");

// We account for cost of setup as one round. The main optimisation loop runs
// for "m_NumberRounds + 1" rounds and training on the choosen hyperparameter
// values is counted as one round. This gives a total of m_NumberRounds + 3.
core::CLoopProgress progress{m_NumberRounds + 3, recordProgress};
progress.increment();

if (this->canTrain() == false) {
// Fallback to using the constant predictor which minimises the loss.

Expand Down Expand Up @@ -121,13 +128,20 @@ void CBoostedTreeImpl::train(core::CDataFrame& frame,
if (this->selectNextHyperparameters(lossMoments, *m_BayesianOptimization) == false) {
break;
}

progress.increment();

} while (m_CurrentRound++ < m_NumberRounds);

LOG_TRACE(<< "Test loss = " << m_BestForestTestLoss);

this->restoreBestHyperparameters();
m_BestForest = this->trainForest(frame, this->allTrainingRowsMask(), recordProgress);
}

// Force to at least one here because we can have early exit from loop or take
// a different path.
recordProgress(1.0);
}

void CBoostedTreeImpl::predict(core::CDataFrame& frame,
Expand Down
13 changes: 12 additions & 1 deletion lib/maths/CDataFrameRegressionModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,18 @@
namespace ml {
namespace maths {

void CDataFrameRegressionModel::noop(double) {
CDataFrameRegressionModel::CDataFrameRegressionModel(core::CDataFrame& frame,
TProgressCallback recordProgress)
: m_Frame{frame}, m_RecordProgress{recordProgress} {
}

core::CDataFrame& CDataFrameRegressionModel::frame() const {
return m_Frame;
}

const CDataFrameRegressionModel::TProgressCallback&
CDataFrameRegressionModel::progressRecorder() const {
return m_RecordProgress;
}
}
}
Loading