Skip to content

Add cancellation checkpoint in logistic regression. #3032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Microsoft.ML.StandardTrainers/Optimizer/Optimizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,7 @@ public void Minimize(DifferentiableFunction function, ref VBuffer<float> initial
state.UpdateDir();
while (!finished)
{
Env.CheckAlive();
bool success = state.LineSearch(ch, false);
if (!success)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ private protected virtual void TrainCore(IChannel ch, RoleMappedData data)
e => e.SetProgress(0, exCount, totalCount));
while (cursor.MoveNext())
{
Host.CheckAlive();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Host.CheckAlive(); [](start = 18, length = 20)

is it too much to do it in every row fetch? would it be enough to do it every 10 cursor moves, or some other number > 1.
(idk if there are any best practiced on how to determine the frequency of checks , from maybe the CancellationToken implementations)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And how is that any more efficient than what we have now? You will end up executing an if condition on every row fetch ... based on my analysis of the current solution this doesn’t add any significant overhead.

Cancellation token works differently. You register a callback with it and when a signal is sent it invokes the callback and you do work to gracefully shutdown a process. Our plan is to implement cancellation token post 1.0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We spoke offline. I think this is the best we can do until we get cancellation tokens into the mix. CheckAlive only checks a bool property, so it's probably faster than checking to see if it's the 10th iteration or not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add some late flavor to this, the branch predictor should be slightly better at the (almost perfectly) constant bool property, than the return value from iteration % 10 (including the hidden division operation..could check every 8 as the compiler should optimize to a bitwise AND).

That said, there's the overhead of the CheckAlive() function call which maybe greater if not inlined.

WeightSum += cursor.Weight;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's not the only place we need a check point.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, added one more in Line Search Minimize function.

if (ShowTrainingStats)
ProcessPriorDistribution(cursor.Label, cursor.Weight);
Expand Down