Skip to content

Update calibrator estimators to be more suitable. #2526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public static void Calibration()

// Let's train a calibrator estimator on this scored dataset. The trained calibrator estimator produces a transformer
// that can transform the scored data by adding a new column names "Probability".
var calibratorEstimator = new PlattCalibratorEstimator(mlContext, model, "Sentiment", "Features");
var calibratorEstimator = new PlattCalibratorEstimator(mlContext, "Sentiment", "Score");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

var calibratorEstimator = new PlattCalibratorEstimator(mlContext, "Sentiment", "Score"); [](start = 12, length = 88)

Can you give a bit more detail what's happening here? The SDCA trainer produces a transformer, which is calibrated, then you pull the original model out and recalibrate it? If so, this sample is a bit confusing because it's not necessarily something we would need to do, at least not as outlined here. I would at least spell out the steps precisely.

var calibratorTransformer = calibratorEstimator.Fit(scoredData);

// Transform the scored data with a calibrator transfomer by adding a new column names "Probability".
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/BestFriendAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace Microsoft.ML
#endif
{
/// <summary>
/// Intended to be applied to types and members marked as internal to indicate that friend access of this
/// Intended to be applied to types and members with internal scope to indicate that friend access of this
/// internal item is OK from another assembly. This restriction applies only to assemblies that declare the
/// <see cref="WantsToBeBestFriendsAttribute"/> assembly level attribute. Note that this attribute is not
/// transferrable: an internal member with this attribute does not somehow make a containing internal type
Expand Down
154 changes: 97 additions & 57 deletions src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.ML;
using Microsoft.ML.Calibrator;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Internal.Calibration;
Expand Down Expand Up @@ -85,14 +86,24 @@ namespace Microsoft.ML.Internal.Calibration
/// <summary>
/// Signature for the loaders of calibrators.
/// </summary>
public delegate void SignatureCalibrator();
[BestFriend]
internal delegate void SignatureCalibrator();

[BestFriend]
[TlcModule.ComponentKind("CalibratorTrainer")]
public interface ICalibratorTrainerFactory : IComponentFactory<ICalibratorTrainer>
internal interface ICalibratorTrainerFactory : IComponentFactory<ICalibratorTrainer>
{
}

public interface ICalibratorTrainer
/// <summary>
Copy link
Contributor

@rogancarr rogancarr Feb 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

///

[](start = 4, length = 13)

Technically, this clarification is a <remark> and <summary> is a short description of what it does.

/// This is a legacy interface still used for the command line and entry-points. All applications should transition away
/// from this interface and still work instead via <see cref="IEstimator{TTransformer}"/> of <see cref="CalibratorTransformer{TICalibrator}"/>,
/// for example, the subclasses of <see cref="CalibratorEstimatorBase{TICalibrator}"/>. However for now we retain this
/// until such time as those components making use of it can transition to the new way. No public surface should use
/// this, and even new internal code should avoid its use if possible.
/// </summary>
[BestFriend]
internal interface ICalibratorTrainer
{
/// <summary>
/// True if the calibrator needs training, false otherwise.
Expand All @@ -107,6 +118,17 @@ public interface ICalibratorTrainer
ICalibrator FinishTraining(IChannel ch);
}

/// <summary>
/// This is a shim interface implemented only by <see cref="CalibratorEstimatorBase{TICalibrator}"/> to enable
/// access to the underlying legacy <see cref="ICalibratorTrainer"/> interface for those components that use
/// that old mechanism that we do not care to change right now.
/// </summary>
[BestFriend]
internal interface IHaveCalibratorTrainer
{
ICalibratorTrainer CalibratorTrainer { get; }
}

/// <summary>
/// An interface for predictors that take care of their own calibration given an input data view.
/// </summary>
Expand Down Expand Up @@ -842,6 +864,64 @@ public static IPredictor GetCalibratedPredictor(IHostEnvironment env, IChannel c
return CreateCalibratedPredictor(env, (IPredictorProducing<float>)predictor, trainedCalibrator);
}

public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, IDataView scored, string labelColumn, string scoreColumn, string weightColumn = null, int maxRows = _maxCalibrationExamples)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

string labelColumn, string scoreColumn, string weightColumn = null [](start = 135, length = 66)

Defaults?

{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
ch.CheckValue(scored, nameof(scored));
ch.CheckValue(caliTrainer, nameof(caliTrainer));
ch.CheckParam(!caliTrainer.NeedsTraining || !string.IsNullOrWhiteSpace(labelColumn), nameof(labelColumn),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!caliTrainer.NeedsTraining || !string.IsNullOrWhiteSpace(labelColumn) [](start = 26, length = 69)

Nit: This is a bit confusing to read. Maybe ! the And version?

"If " + nameof(caliTrainer) + " requires training, then " + nameof(labelColumn) + " must have a value.");
ch.CheckNonWhiteSpace(scoreColumn, nameof(scoreColumn));

if (!caliTrainer.NeedsTraining)
return caliTrainer.FinishTraining(ch);

var labelCol = scored.Schema[labelColumn];
var scoreCol = scored.Schema[scoreColumn];

var weightCol = weightColumn == null ? null : scored.Schema.GetColumnOrNull(weightColumn);
if (weightColumn != null && !weightCol.HasValue)
throw ch.ExceptSchemaMismatch(nameof(weightColumn), "weight", weightColumn);

ch.Info("Training calibrator.");

var cols = weightCol.HasValue ?
new Schema.Column[] { labelCol, scoreCol, weightCol.Value } :
new Schema.Column[] { labelCol, scoreCol };

using (var cursor = scored.GetRowCursor(cols))
{
var labelGetter = RowCursorUtils.GetLabelGetter(cursor, labelCol.Index);
var scoreGetter = RowCursorUtils.GetGetterAs<Single>(NumberType.R4, cursor, scoreCol.Index);
ValueGetter<Single> weightGetter = !weightCol.HasValue ? (ref float dst) => dst = 1 :
RowCursorUtils.GetGetterAs<Single>(NumberType.R4, cursor, weightCol.Value.Index);

int num = 0;
while (cursor.MoveNext())
{
Single label = 0;
labelGetter(ref label);
if (!FloatUtils.IsFinite(label))
continue;
Single score = 0;
scoreGetter(ref score);
if (!FloatUtils.IsFinite(score))
continue;
Single weight = 0;
weightGetter(ref weight);
if (!FloatUtils.IsFinite(weight))
continue;

caliTrainer.ProcessTrainingExample(score, label > 0, weight);

if (maxRows > 0 && ++num >= maxRows)
break;
}
}
return caliTrainer.FinishTraining(ch);
}

/// <summary>
/// Trains a calibrator.
/// </summary>
Expand All @@ -857,60 +937,14 @@ public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICa
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
ch.CheckValue(caliTrainer, nameof(caliTrainer));
ch.CheckValue(predictor, nameof(predictor));
ch.CheckValue(data, nameof(data));
ch.CheckParam(data.Schema.Label.HasValue, nameof(data), "data must have a Label column");

var scored = ScoreUtils.GetScorer(predictor, data, env, null);

if (caliTrainer.NeedsTraining)
{
int labelCol;
if (!scored.Schema.TryGetColumnIndex(data.Schema.Label.Value.Name, out labelCol))
throw ch.Except("No label column found");
int scoreCol;
if (!scored.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreCol))
throw ch.Except("No score column found");
int weightCol = -1;
if (data.Schema.Weight?.Name is string weightName && scored.Schema.GetColumnOrNull(weightName)?.Index is int weightIdx)
weightCol = weightIdx;
ch.Info("Training calibrator.");

var cols = weightCol > -1 ?
new Schema.Column[] { scored.Schema[labelCol], scored.Schema[scoreCol], scored.Schema[weightCol] } :
new Schema.Column[] { scored.Schema[labelCol], scored.Schema[scoreCol] };

using (var cursor = scored.GetRowCursor(cols))
{
var labelGetter = RowCursorUtils.GetLabelGetter(cursor, labelCol);
var scoreGetter = RowCursorUtils.GetGetterAs<Single>(NumberType.R4, cursor, scoreCol);
ValueGetter<Single> weightGetter = weightCol == -1 ? (ref float dst) => dst = 1 :
RowCursorUtils.GetGetterAs<Single>(NumberType.R4, cursor, weightCol);

int num = 0;
while (cursor.MoveNext())
{
Single label = 0;
labelGetter(ref label);
if (!FloatUtils.IsFinite(label))
continue;
Single score = 0;
scoreGetter(ref score);
if (!FloatUtils.IsFinite(score))
continue;
Single weight = 0;
weightGetter(ref weight);
if (!FloatUtils.IsFinite(weight))
continue;

caliTrainer.ProcessTrainingExample(score, label > 0, weight);

if (maxRows > 0 && ++num >= maxRows)
break;
}
}
}
return caliTrainer.FinishTraining(ch);
var scoreColumn = scored.Schema[DefaultColumnNames.Score];
return TrainCalibrator(env, ch, caliTrainer, scored, data.Schema.Label.Value.Name, DefaultColumnNames.Score, data.Schema.Weight?.Name, maxRows);
}

public static IPredictorProducing<float> CreateCalibratedPredictor<TSubPredictor, TCalibrator>(IHostEnvironment env, TSubPredictor predictor, TCalibrator cali)
Expand Down Expand Up @@ -953,7 +987,8 @@ public ICalibratorTrainer CreateComponent(IHostEnvironment env)
/// The probability of belonging to a particular class, for example class 1, is the number of class 1 instances in the bin, divided by the total number
/// of instances in that bin.
/// </summary>
public sealed class NaiveCalibratorTrainer : ICalibratorTrainer
[BestFriend]
internal sealed class NaiveCalibratorTrainer : ICalibratorTrainer
{
private readonly IHost _host;

Expand Down Expand Up @@ -1181,7 +1216,8 @@ public string GetSummary()
/// <summary>
/// Base class for calibrator trainers.
/// </summary>
public abstract class CalibratorTrainerBase : ICalibratorTrainer
[BestFriend]
internal abstract class CalibratorTrainerBase : ICalibratorTrainer
{
protected readonly IHost Host;
protected CalibrationDataStore Data;
Expand Down Expand Up @@ -1230,7 +1266,8 @@ public ICalibratorTrainer CreateComponent(IHostEnvironment env)
}
}

public sealed class PlattCalibratorTrainer : CalibratorTrainerBase
[BestFriend]
internal sealed class PlattCalibratorTrainer : CalibratorTrainerBase
{
internal const string UserName = "Sigmoid Calibration";
internal const string LoadName = "PlattCalibration";
Expand Down Expand Up @@ -1389,7 +1426,8 @@ public override ICalibrator CreateCalibrator(IChannel ch)
}
}

public sealed class FixedPlattCalibratorTrainer : ICalibratorTrainer
[BestFriend]
internal sealed class FixedPlattCalibratorTrainer : ICalibratorTrainer
{
[TlcModule.Component(Name = "FixedPlattCalibrator", FriendlyName = "Fixed Platt Calibrator", Aliases = new[] { "FixedPlatt", "FixedSigmoid" })]
public sealed class Arguments : ICalibratorTrainerFactory
Expand Down Expand Up @@ -1591,7 +1629,8 @@ public ICalibratorTrainer CreateComponent(IHostEnvironment env)
}
}

public class PavCalibratorTrainer : CalibratorTrainerBase
[BestFriend]
internal sealed class PavCalibratorTrainer : CalibratorTrainerBase
{
// a piece of the piecwise function
private readonly struct Piece
Expand Down Expand Up @@ -1664,6 +1703,7 @@ public override ICalibrator CreateCalibrator(IChannel ch)
}

/// <summary>
/// The pair-adjacent violators calibrator.
/// The function that is implemented by this calibrator is:
/// f(x) = v_i, if minX_i &lt;= x &lt;= maxX_i
/// = linear interpolate between v_i and v_i+1, if maxX_i &lt; x &lt; minX_i+1
Expand Down
Loading