-
Notifications
You must be signed in to change notification settings - Fork 64
[ML] Improve boosted tree training initialisation #686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea on improving the search intervals for hyperparameters! Since this code is not trivial, I left a couple of comments to improve readability.
lib/maths/CBoostedTreeFactory.cc
Outdated
@@ -334,6 +482,8 @@ CBoostedTreeFactory::constructFromString(std::istream& jsonStringStream, | |||
if (treePtr->acceptRestoreTraverser(traverser) == false || traverser.haveBadState()) { | |||
throw std::runtime_error{"failed to restore boosted tree"}; | |||
} | |||
treePtr->m_Impl->m_TrainingProgress.attach(recordProgress); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are using recordProgress
after it was moved in Line 479. It cannot end well 😉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And indeed that was what was causing the test failures!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I'm mulling over having constructFromString
return the factory as we do from construct from parameters. Aesthetically, I don't like the asymmetry. I'm just seeing how it'll work out.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated along the lines of the second comment, which incidentally fixes use of moved function. I feel like this is cleaner. Let me know what you think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See e42208c
} catch (const std::exception& e) { | ||
HANDLE_FATAL(<< "Input error: '" << e.what() << "'. Check logs for more details."); | ||
throw std::runtime_error{std::string{"Input error: '"} + e.what() + "'"}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note I changed this to throw. HANDLE_FATAL
will abort the program, but this is recoverable since we can just start again. The exception is already handled in the runner classes in the api library.
Many thanks for the review @valeriy42 and good suggestions! I think I've now worked through them all, if you could take another look. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work on improving the readability and introducing the symmetry in the Factory. I left a couple of comments also we discussed further improvements offline.
@@ -228,8 +230,11 @@ bool CDataFrameBoostedTreeRunner::restoreBoostedTree( | |||
return false; | |||
} | |||
|
|||
m_BoostedTree = maths::CBoostedTreeFactory::constructFromString( | |||
*inputStream, frame, progressRecorder(), memoryEstimator(), statePersister()); | |||
m_BoostedTree = maths::CBoostedTreeFactory::constructFromString(*inputStream) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I like the symmetry now. 👍
lib/maths/CBoostedTreeFactory.cc
Outdated
TVector fallbackInterval{{MIN_REGULARIZER_SCALE, 1.0, MAX_REGULARIZER_SCALE}}; | ||
m_TreeImpl->m_Regularization.gamma(m_GammaSearchInterval(MIN_REGULARIZER_INDEX)); | ||
|
||
double initialLambda{totalCurvaturePerNode}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice 👍
I also refactored the line search function along the lines we discussed off line in 937df23. I agree this makes the idea clearer: good suggestion! Can you take another look @valeriy42. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I like how the changes made the core of the algorithm much easier to understand while at the same time you were able to remove "magic numbers" and improve the prediction quality. I left a couple of comments regarding updating the code comments. You can merge without the need for me to review it again.
// These are scales > bestRegularizerScale hence 1 / multiplier. | ||
interval(MAX_REGULARIZER_INDEX) = std::min( | ||
std::pow(1.0 / multiplier, logScaleAtThreeSigma), MAX_REGULARIZER_SCALE); | ||
double threeSigmaInterval{std::sqrt(3.0 * sigma / curvature)}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Maybe you can add a comment why you need to divide by the curvature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I have the comment In particular, we solve curvature * (x - best)^2 = 3 sigma...
which I thought (maybe slightly tangentially) explained this. I feel like this is maybe enough.
This makes some changes to initialisation:
gamma
andlambda
, respectively,The hyperaparameter search is centred on the values at this transition. I've reduced the number of hyperparameter optimisation rounds as a result of the improved initialisation.
This also does a better job of monitoring progress. We explicitly account for the cost of initialisation and update progress after each forest is trained rather than per round of the hyperparameter optimisation. Since it is useful to share progress monitoring between the boosted tree factory and the implementation I've migrated to storing the loop progress monitor on the implementation and persisting and restoring. This incidentally fixes a bug in progress monitoring on resume.
Finally, I've refactored the regularisation parameters to better encapsulate. This anticipates depth based regularisation.