-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Update calibrator estimators to be more suitable. #2526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
using Microsoft.ML; | ||
using Microsoft.ML.Calibrator; | ||
using Microsoft.ML.CommandLine; | ||
using Microsoft.ML.Core.Data; | ||
using Microsoft.ML.Data; | ||
using Microsoft.ML.EntryPoints; | ||
using Microsoft.ML.Internal.Calibration; | ||
|
@@ -85,14 +86,24 @@ namespace Microsoft.ML.Internal.Calibration | |
/// <summary> | ||
/// Signature for the loaders of calibrators. | ||
/// </summary> | ||
public delegate void SignatureCalibrator(); | ||
[BestFriend] | ||
internal delegate void SignatureCalibrator(); | ||
|
||
[BestFriend] | ||
[TlcModule.ComponentKind("CalibratorTrainer")] | ||
public interface ICalibratorTrainerFactory : IComponentFactory<ICalibratorTrainer> | ||
internal interface ICalibratorTrainerFactory : IComponentFactory<ICalibratorTrainer> | ||
{ | ||
} | ||
|
||
public interface ICalibratorTrainer | ||
/// <summary> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Technically, this clarification is a <remark> and <summary> is a short description of what it does. |
||
/// This is a legacy interface still used for the command line and entry-points. All applications should transition away | ||
/// from this interface and still work instead via <see cref="IEstimator{TTransformer}"/> of <see cref="CalibratorTransformer{TICalibrator}"/>, | ||
/// for example, the subclasses of <see cref="CalibratorEstimatorBase{TICalibrator}"/>. However for now we retain this | ||
/// until such time as those components making use of it can transition to the new way. No public surface should use | ||
/// this, and even new internal code should avoid its use if possible. | ||
/// </summary> | ||
[BestFriend] | ||
internal interface ICalibratorTrainer | ||
{ | ||
/// <summary> | ||
/// True if the calibrator needs training, false otherwise. | ||
|
@@ -107,6 +118,17 @@ public interface ICalibratorTrainer | |
ICalibrator FinishTraining(IChannel ch); | ||
} | ||
|
||
/// <summary> | ||
/// This is a shim interface implemented only by <see cref="CalibratorEstimatorBase{TICalibrator}"/> to enable | ||
/// access to the underlying legacy <see cref="ICalibratorTrainer"/> interface for those components that use | ||
/// that old mechanism that we do not care to change right now. | ||
/// </summary> | ||
[BestFriend] | ||
internal interface IHaveCalibratorTrainer | ||
{ | ||
ICalibratorTrainer CalibratorTrainer { get; } | ||
} | ||
|
||
/// <summary> | ||
/// An interface for predictors that take care of their own calibration given an input data view. | ||
/// </summary> | ||
|
@@ -842,6 +864,64 @@ public static IPredictor GetCalibratedPredictor(IHostEnvironment env, IChannel c | |
return CreateCalibratedPredictor(env, (IPredictorProducing<float>)predictor, trainedCalibrator); | ||
} | ||
|
||
public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, IDataView scored, string labelColumn, string scoreColumn, string weightColumn = null, int maxRows = _maxCalibrationExamples) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Defaults? |
||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
env.CheckValue(ch, nameof(ch)); | ||
ch.CheckValue(scored, nameof(scored)); | ||
ch.CheckValue(caliTrainer, nameof(caliTrainer)); | ||
ch.CheckParam(!caliTrainer.NeedsTraining || !string.IsNullOrWhiteSpace(labelColumn), nameof(labelColumn), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Nit: This is a bit confusing to read. Maybe ! the And version? |
||
"If " + nameof(caliTrainer) + " requires training, then " + nameof(labelColumn) + " must have a value."); | ||
ch.CheckNonWhiteSpace(scoreColumn, nameof(scoreColumn)); | ||
|
||
if (!caliTrainer.NeedsTraining) | ||
return caliTrainer.FinishTraining(ch); | ||
|
||
var labelCol = scored.Schema[labelColumn]; | ||
var scoreCol = scored.Schema[scoreColumn]; | ||
|
||
var weightCol = weightColumn == null ? null : scored.Schema.GetColumnOrNull(weightColumn); | ||
if (weightColumn != null && !weightCol.HasValue) | ||
throw ch.ExceptSchemaMismatch(nameof(weightColumn), "weight", weightColumn); | ||
|
||
ch.Info("Training calibrator."); | ||
|
||
var cols = weightCol.HasValue ? | ||
new Schema.Column[] { labelCol, scoreCol, weightCol.Value } : | ||
new Schema.Column[] { labelCol, scoreCol }; | ||
|
||
using (var cursor = scored.GetRowCursor(cols)) | ||
{ | ||
var labelGetter = RowCursorUtils.GetLabelGetter(cursor, labelCol.Index); | ||
var scoreGetter = RowCursorUtils.GetGetterAs<Single>(NumberType.R4, cursor, scoreCol.Index); | ||
ValueGetter<Single> weightGetter = !weightCol.HasValue ? (ref float dst) => dst = 1 : | ||
RowCursorUtils.GetGetterAs<Single>(NumberType.R4, cursor, weightCol.Value.Index); | ||
|
||
int num = 0; | ||
while (cursor.MoveNext()) | ||
{ | ||
Single label = 0; | ||
labelGetter(ref label); | ||
if (!FloatUtils.IsFinite(label)) | ||
continue; | ||
Single score = 0; | ||
scoreGetter(ref score); | ||
if (!FloatUtils.IsFinite(score)) | ||
continue; | ||
Single weight = 0; | ||
weightGetter(ref weight); | ||
if (!FloatUtils.IsFinite(weight)) | ||
continue; | ||
|
||
caliTrainer.ProcessTrainingExample(score, label > 0, weight); | ||
|
||
if (maxRows > 0 && ++num >= maxRows) | ||
break; | ||
} | ||
} | ||
return caliTrainer.FinishTraining(ch); | ||
} | ||
|
||
/// <summary> | ||
/// Trains a calibrator. | ||
/// </summary> | ||
|
@@ -857,60 +937,14 @@ public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICa | |
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
env.CheckValue(ch, nameof(ch)); | ||
ch.CheckValue(caliTrainer, nameof(caliTrainer)); | ||
ch.CheckValue(predictor, nameof(predictor)); | ||
ch.CheckValue(data, nameof(data)); | ||
ch.CheckParam(data.Schema.Label.HasValue, nameof(data), "data must have a Label column"); | ||
|
||
var scored = ScoreUtils.GetScorer(predictor, data, env, null); | ||
|
||
if (caliTrainer.NeedsTraining) | ||
{ | ||
int labelCol; | ||
if (!scored.Schema.TryGetColumnIndex(data.Schema.Label.Value.Name, out labelCol)) | ||
throw ch.Except("No label column found"); | ||
int scoreCol; | ||
if (!scored.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreCol)) | ||
throw ch.Except("No score column found"); | ||
int weightCol = -1; | ||
if (data.Schema.Weight?.Name is string weightName && scored.Schema.GetColumnOrNull(weightName)?.Index is int weightIdx) | ||
weightCol = weightIdx; | ||
ch.Info("Training calibrator."); | ||
|
||
var cols = weightCol > -1 ? | ||
new Schema.Column[] { scored.Schema[labelCol], scored.Schema[scoreCol], scored.Schema[weightCol] } : | ||
new Schema.Column[] { scored.Schema[labelCol], scored.Schema[scoreCol] }; | ||
|
||
using (var cursor = scored.GetRowCursor(cols)) | ||
{ | ||
var labelGetter = RowCursorUtils.GetLabelGetter(cursor, labelCol); | ||
var scoreGetter = RowCursorUtils.GetGetterAs<Single>(NumberType.R4, cursor, scoreCol); | ||
ValueGetter<Single> weightGetter = weightCol == -1 ? (ref float dst) => dst = 1 : | ||
RowCursorUtils.GetGetterAs<Single>(NumberType.R4, cursor, weightCol); | ||
|
||
int num = 0; | ||
while (cursor.MoveNext()) | ||
{ | ||
Single label = 0; | ||
labelGetter(ref label); | ||
if (!FloatUtils.IsFinite(label)) | ||
continue; | ||
Single score = 0; | ||
scoreGetter(ref score); | ||
if (!FloatUtils.IsFinite(score)) | ||
continue; | ||
Single weight = 0; | ||
weightGetter(ref weight); | ||
if (!FloatUtils.IsFinite(weight)) | ||
continue; | ||
|
||
caliTrainer.ProcessTrainingExample(score, label > 0, weight); | ||
|
||
if (maxRows > 0 && ++num >= maxRows) | ||
break; | ||
} | ||
} | ||
} | ||
return caliTrainer.FinishTraining(ch); | ||
var scoreColumn = scored.Schema[DefaultColumnNames.Score]; | ||
return TrainCalibrator(env, ch, caliTrainer, scored, data.Schema.Label.Value.Name, DefaultColumnNames.Score, data.Schema.Weight?.Name, maxRows); | ||
} | ||
|
||
public static IPredictorProducing<float> CreateCalibratedPredictor<TSubPredictor, TCalibrator>(IHostEnvironment env, TSubPredictor predictor, TCalibrator cali) | ||
|
@@ -953,7 +987,8 @@ public ICalibratorTrainer CreateComponent(IHostEnvironment env) | |
/// The probability of belonging to a particular class, for example class 1, is the number of class 1 instances in the bin, divided by the total number | ||
/// of instances in that bin. | ||
/// </summary> | ||
public sealed class NaiveCalibratorTrainer : ICalibratorTrainer | ||
[BestFriend] | ||
internal sealed class NaiveCalibratorTrainer : ICalibratorTrainer | ||
{ | ||
private readonly IHost _host; | ||
|
||
|
@@ -1181,7 +1216,8 @@ public string GetSummary() | |
/// <summary> | ||
/// Base class for calibrator trainers. | ||
/// </summary> | ||
public abstract class CalibratorTrainerBase : ICalibratorTrainer | ||
[BestFriend] | ||
internal abstract class CalibratorTrainerBase : ICalibratorTrainer | ||
{ | ||
protected readonly IHost Host; | ||
protected CalibrationDataStore Data; | ||
|
@@ -1230,7 +1266,8 @@ public ICalibratorTrainer CreateComponent(IHostEnvironment env) | |
} | ||
} | ||
|
||
public sealed class PlattCalibratorTrainer : CalibratorTrainerBase | ||
[BestFriend] | ||
internal sealed class PlattCalibratorTrainer : CalibratorTrainerBase | ||
{ | ||
internal const string UserName = "Sigmoid Calibration"; | ||
internal const string LoadName = "PlattCalibration"; | ||
|
@@ -1389,7 +1426,8 @@ public override ICalibrator CreateCalibrator(IChannel ch) | |
} | ||
} | ||
|
||
public sealed class FixedPlattCalibratorTrainer : ICalibratorTrainer | ||
[BestFriend] | ||
internal sealed class FixedPlattCalibratorTrainer : ICalibratorTrainer | ||
{ | ||
[TlcModule.Component(Name = "FixedPlattCalibrator", FriendlyName = "Fixed Platt Calibrator", Aliases = new[] { "FixedPlatt", "FixedSigmoid" })] | ||
public sealed class Arguments : ICalibratorTrainerFactory | ||
|
@@ -1591,7 +1629,8 @@ public ICalibratorTrainer CreateComponent(IHostEnvironment env) | |
} | ||
} | ||
|
||
public class PavCalibratorTrainer : CalibratorTrainerBase | ||
[BestFriend] | ||
internal sealed class PavCalibratorTrainer : CalibratorTrainerBase | ||
{ | ||
// a piece of the piecwise function | ||
private readonly struct Piece | ||
|
@@ -1664,6 +1703,7 @@ public override ICalibrator CreateCalibrator(IChannel ch) | |
} | ||
|
||
/// <summary> | ||
/// The pair-adjacent violators calibrator. | ||
/// The function that is implemented by this calibrator is: | ||
/// f(x) = v_i, if minX_i <= x <= maxX_i | ||
/// = linear interpolate between v_i and v_i+1, if maxX_i < x < minX_i+1 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you give a bit more detail what's happening here? The SDCA trainer produces a transformer, which is calibrated, then you pull the original model out and recalibrate it? If so, this sample is a bit confusing because it's not necessarily something we would need to do, at least not as outlined here. I would at least spell out the steps precisely.