-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Add cancellation checkpoint in logistic regression. #3032
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #3032 +/- ##
==========================================
+ Coverage 72.41% 72.5% +0.09%
==========================================
Files 803 804 +1
Lines 143851 144080 +229
Branches 16173 16179 +6
==========================================
+ Hits 104171 104467 +296
+ Misses 35258 35197 -61
+ Partials 4422 4416 -6
|
What's the performance implications here? |
@@ -475,6 +475,7 @@ private protected virtual void TrainCore(IChannel ch, RoleMappedData data) | |||
e => e.SetProgress(0, exCount, totalCount)); | |||
while (cursor.MoveNext()) | |||
{ | |||
Host.CheckAlive(); | |||
WeightSum += cursor.Weight; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel it's not the only place we need a check point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, added one more in Line Search Minimize function.
This looks very suspicious. Could you add some check points to this function? I also feel we need to perf algs before adding checking points. Refers to: src/Microsoft.ML.StandardTrainers/Standard/LogisticRegression/LbfgsPredictorBase.cs:567 in 5540101. [](commit_id = 5540101, deletion_comment = False) |
@@ -475,6 +475,7 @@ private protected virtual void TrainCore(IChannel ch, RoleMappedData data) | |||
e => e.SetProgress(0, exCount, totalCount)); | |||
while (cursor.MoveNext()) | |||
{ | |||
Host.CheckAlive(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Host.CheckAlive(); [](start = 18, length = 20)
is it too much to do it in every row fetch? would it be enough to do it every 10 cursor moves, or some other number > 1.
(idk if there are any best practiced on how to determine the frequency of checks , from maybe the CancellationToken implementations)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And how is that any more efficient than what we have now? You will end up executing an if condition on every row fetch ... based on my analysis of the current solution this doesn’t add any significant overhead.
Cancellation token works differently. You register a callback with it and when a signal is sent it invokes the callback and you do work to gracefully shutdown a process. Our plan is to implement cancellation token post 1.0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We spoke offline. I think this is the best we can do until we get cancellation tokens into the mix. CheckAlive
only checks a bool
property, so it's probably faster than checking to see if it's the 10th iteration or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To add some late flavor to this, the branch predictor should be slightly better at the (almost perfectly) constant bool
property, than the return value from iteration % 10
(including the hidden division operation..could check every 8 as the compiler should optimize to a bitwise AND
).
That said, there's the overhead of the CheckAlive()
function call which maybe greater if not inlined.
fixes #3031
Please read the issue before reviewing this PR.