Skip to content

[ML] Improve boosted tree training initialisation #686

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Sep 26, 2019

Conversation

tveasey
Copy link
Contributor

@tveasey tveasey commented Sep 23, 2019

This makes some changes to initialisation:

  1. Measures the gain and sum curvature in the tree directly to estimate upper bounds for good values for gamma and lambda, respectively,
  2. Search a large range of values from these initial overestimates looking for a turning point in the test error as the model transitions from underfit to overfit.

The hyperaparameter search is centred on the values at this transition. I've reduced the number of hyperparameter optimisation rounds as a result of the improved initialisation.

This also does a better job of monitoring progress. We explicitly account for the cost of initialisation and update progress after each forest is trained rather than per round of the hyperparameter optimisation. Since it is useful to share progress monitoring between the boosted tree factory and the implementation I've migrated to storing the loop progress monitor on the implementation and persisting and restoring. This incidentally fixes a bug in progress monitoring on resume.

Finally, I've refactored the regularisation parameters to better encapsulate. This anticipates depth based regularisation.

Copy link
Contributor

@valeriy42 valeriy42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea on improving the search intervals for hyperparameters! Since this code is not trivial, I left a couple of comments to improve readability.

@@ -334,6 +482,8 @@ CBoostedTreeFactory::constructFromString(std::istream& jsonStringStream,
if (treePtr->acceptRestoreTraverser(traverser) == false || traverser.haveBadState()) {
throw std::runtime_error{"failed to restore boosted tree"};
}
treePtr->m_Impl->m_TrainingProgress.attach(recordProgress);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are using recordProgress after it was moved in Line 479. It cannot end well 😉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And indeed that was what was causing the test failures!

Copy link
Contributor Author

@tveasey tveasey Sep 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I'm mulling over having constructFromString return the factory as we do from construct from parameters. Aesthetically, I don't like the asymmetry. I'm just seeing how it'll work out.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated along the lines of the second comment, which incidentally fixes use of moved function. I feel like this is cleaner. Let me know what you think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See e42208c

} catch (const std::exception& e) {
HANDLE_FATAL(<< "Input error: '" << e.what() << "'. Check logs for more details.");
throw std::runtime_error{std::string{"Input error: '"} + e.what() + "'"};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note I changed this to throw. HANDLE_FATAL will abort the program, but this is recoverable since we can just start again. The exception is already handled in the runner classes in the api library.

@tveasey
Copy link
Contributor Author

tveasey commented Sep 24, 2019

Many thanks for the review @valeriy42 and good suggestions! I think I've now worked through them all, if you could take another look.

Copy link
Contributor

@valeriy42 valeriy42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work on improving the readability and introducing the symmetry in the Factory. I left a couple of comments also we discussed further improvements offline.

@@ -228,8 +230,11 @@ bool CDataFrameBoostedTreeRunner::restoreBoostedTree(
return false;
}

m_BoostedTree = maths::CBoostedTreeFactory::constructFromString(
*inputStream, frame, progressRecorder(), memoryEstimator(), statePersister());
m_BoostedTree = maths::CBoostedTreeFactory::constructFromString(*inputStream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I like the symmetry now. 👍

TVector fallbackInterval{{MIN_REGULARIZER_SCALE, 1.0, MAX_REGULARIZER_SCALE}};
m_TreeImpl->m_Regularization.gamma(m_GammaSearchInterval(MIN_REGULARIZER_INDEX));

double initialLambda{totalCurvaturePerNode};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 👍

@tveasey
Copy link
Contributor Author

tveasey commented Sep 25, 2019

I also refactored the line search function along the lines we discussed off line in 937df23. I agree this makes the idea clearer: good suggestion! Can you take another look @valeriy42.

Copy link
Contributor

@valeriy42 valeriy42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I like how the changes made the core of the algorithm much easier to understand while at the same time you were able to remove "magic numbers" and improve the prediction quality. I left a couple of comments regarding updating the code comments. You can merge without the need for me to review it again.

// These are scales > bestRegularizerScale hence 1 / multiplier.
interval(MAX_REGULARIZER_INDEX) = std::min(
std::pow(1.0 / multiplier, logScaleAtThreeSigma), MAX_REGULARIZER_SCALE);
double threeSigmaInterval{std::sqrt(3.0 * sigma / curvature)};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Maybe you can add a comment why you need to divide by the curvature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I have the comment In particular, we solve curvature * (x - best)^2 = 3 sigma... which I thought (maybe slightly tangentially) explained this. I feel like this is maybe enough.

@tveasey tveasey merged commit 4c03078 into elastic:master Sep 26, 2019
tveasey added a commit to tveasey/ml-cpp-1 that referenced this pull request Sep 26, 2019
@tveasey tveasey deleted the improved-initialisation branch September 26, 2019 14:07
tveasey added a commit that referenced this pull request Sep 26, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants