Skip to content

[ML] Improve boosted tree training initialisation #686

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Sep 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
For large data sets this change was observed to give a 10% to 20% decrease in
train time. (See {ml-pull}622[#622].)
* Upgrade Boost libraries to version 1.71. (See {ml-pull}638[#638].)
* Improve initialisation of boosted tree training. This generally enables us to
find lower loss models faster. (See {ml-pull}686[#686].)

== {es} version 7.4.0

Expand Down
4 changes: 3 additions & 1 deletion include/api/CDataFrameBoostedTreeRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class API_EXPORT CDataFrameBoostedTreeRunner final : public CDataFrameAnalysisRu
private:
using TBoostedTreeUPtr = std::unique_ptr<maths::CBoostedTree>;
using TBoostedTreeFactoryUPtr = std::unique_ptr<maths::CBoostedTreeFactory>;
using TDataSearcherUPtr = CDataFrameAnalysisSpecification::TDataSearcherUPtr;
using TMemoryEstimator = std::function<void(std::int64_t)>;

private:
Expand All @@ -58,7 +59,8 @@ class API_EXPORT CDataFrameBoostedTreeRunner final : public CDataFrameAnalysisRu
TMemoryEstimator memoryEstimator();

bool restoreBoostedTree(core::CDataFrame& frame,
CDataFrameAnalysisSpecification::TDataSearcherUPtr& restoreSearcher);
std::size_t dependentVariableColumn,
TDataSearcherUPtr& restoreSearcher);

private:
// Note custom config is written directly to the factory object.
Expand Down
27 changes: 25 additions & 2 deletions include/core/CLoopProgress.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

namespace ml {
namespace core {
class CStatePersistInserter;
class CStateRestoreTraverser;

//! \brief Manages recording the progress of a loop.
//!
Expand Down Expand Up @@ -46,14 +48,35 @@ class CORE_EXPORT CLoopProgress {
using TProgressCallback = std::function<void(double)>;

public:
CLoopProgress();
template<typename ITR>
CLoopProgress(ITR begin, ITR end, const TProgressCallback& recordProgress, double scale = 1.0)
CLoopProgress(ITR begin, ITR end, const TProgressCallback& recordProgress = noop, double scale = 1.0)
: CLoopProgress(std::distance(begin, end), recordProgress, scale) {}
CLoopProgress(std::size_t size, const TProgressCallback& recordProgress, double scale = 1.0);
CLoopProgress(std::size_t size,
const TProgressCallback& recordProgress = noop,
double scale = 1.0);

//! Attach a new progress monitor callback.
void progressCallback(const TProgressCallback& recordProgress);

//! Increment the progress by \p i.
void increment(std::size_t i = 1);

//! Resume progress monitoring which was restored.
void resumeRestored();

//! Get a checksum for this object.
std::uint64_t checksum() const;

//! Persist by passing information to \p inserter.
void acceptPersistInserter(CStatePersistInserter& inserter) const;

//! Populate the object from serialized data.
bool acceptRestoreTraverser(CStateRestoreTraverser& traverser);

private:
static void noop(double);

private:
std::size_t m_Size;
std::size_t m_Steps;
Expand Down
53 changes: 42 additions & 11 deletions include/maths/CBoostedTreeFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <core/CDataFrame.h>

#include <maths/CBoostedTree.h>
#include <maths/CLinearAlgebra.h>
#include <maths/ImportExport.h>

#include <boost/optional.hpp>
Expand Down Expand Up @@ -43,12 +44,9 @@ class MATHS_EXPORT CBoostedTreeFactory final {
TLossFunctionUPtr loss);

//! Construct a boosted tree object from its serialized version.
static TBoostedTreeUPtr
constructFromString(std::istream& jsonStringStream,
core::CDataFrame& frame,
TProgressCallback recordProgress = noopRecordProgress,
TMemoryUsageCallback recordMemoryUsage = noopRecordMemoryUsage,
TTrainingStateCallback recordTrainingState = noopRecordTrainingState);
//!
//! \warning Throws runtime error on fail to restore.
static CBoostedTreeFactory constructFromString(std::istream& jsonStringStream);

~CBoostedTreeFactory();
CBoostedTreeFactory(CBoostedTreeFactory&) = delete;
Expand Down Expand Up @@ -93,17 +91,22 @@ class MATHS_EXPORT CBoostedTreeFactory final {
TBoostedTreeUPtr buildFor(core::CDataFrame& frame, std::size_t dependentVariable);

private:
using TDoubleDoublePr = std::pair<double, double>;
using TOptionalDouble = boost::optional<double>;
using TOptionalSize = boost::optional<std::size_t>;
using TVector = CVectorNx1<double, 3>;
using TOptionalVector = boost::optional<TVector>;
using TPackedBitVectorVec = std::vector<core::CPackedBitVector>;
using TBoostedTreeImplUPtr = std::unique_ptr<CBoostedTreeImpl>;
using TApplyRegularizerStep =
std::function<void(CBoostedTreeImpl&, double, std::size_t)>;

private:
static const double MINIMUM_ETA;
static const std::size_t MAXIMUM_NUMBER_TREES;

private:
CBoostedTreeFactory(std::size_t numberThreads, TLossFunctionUPtr loss);
CBoostedTreeFactory(bool restored, std::size_t numberThreads, TLossFunctionUPtr loss);

//! Compute the row masks for the missing values for each feature.
void initializeMissingFeatureMasks(const core::CDataFrame& frame) const;
Expand All @@ -121,25 +124,53 @@ class MATHS_EXPORT CBoostedTreeFactory final {
//! Initialize the regressors sample distribution.
bool initializeFeatureSampleDistribution() const;

//! Read overrides for hyperparameters and if necessary estimate the initial
//! values for \f$\lambda\f$ and \f$\gamma\f$ which match the gain from an
//! overfit tree.
void initializeHyperparameters(core::CDataFrame& frame) const;
//! Set the initial values for the various hyperparameters.
void initializeHyperparameters(core::CDataFrame& frame);

//! Estimate a good central value for the regularisation hyperparameters
//! search bounding box.
void initializeUnsetRegularizationHyperparameters(core::CDataFrame& frame);

//! Estimate the reduction in gain from a split and the total curvature of
//! the loss function at a split.
TDoubleDoublePr estimateTreeGainAndCurvature(core::CDataFrame& frame,
const core::CPackedBitVector& trainingRowMask) const;

//! Perform a line search for the test loss w.r.t. a single regularization
//! hyperparameter and apply Newton's method to find the minimum. The plan
//! is to find a value near where the model starts to overfit.
//!
//! \return The interval to search during the main hyperparameter optimisation
//! loop or null if this couldn't be found.
TOptionalVector testLossNewtonLineSearch(core::CDataFrame& frame,
core::CPackedBitVector trainingRowMask,
const TApplyRegularizerStep& applyRegularizerStep,
double returnedIntervalLeftEndOffset,
double returnedIntervalRightEndOffset) const;

//! Initialize the state for hyperparameter optimisation.
void initializeHyperparameterOptimisation() const;

//! Get the number of hyperparameter tuning rounds to use.
std::size_t numberHyperparameterTuningRounds() const;

//! Setup monitoring for training progress.
void initializeTrainingProgressMonitoring();

//! Refresh progress monitoring after restoring from saved training state.
void resumeRestoredTrainingProgressMonitoring();

static void noopRecordProgress(double);
static void noopRecordMemoryUsage(std::int64_t);
static void noopRecordTrainingState(CDataFrameRegressionModel::TPersistFunc);

private:
TOptionalDouble m_MinimumFrequencyToOneHotEncode;
TOptionalSize m_BayesianOptimisationRestarts;
bool m_Restored = false;
TBoostedTreeImplUPtr m_TreeImpl;
TVector m_LogGammaSearchInterval;
TVector m_LogLambdaSearchInterval;
TProgressCallback m_RecordProgress = noopRecordProgress;
TMemoryUsageCallback m_RecordMemoryUsage = noopRecordMemoryUsage;
TTrainingStateCallback m_RecordTrainingState = noopRecordTrainingState;
Expand Down
Loading