diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs index ed75040e77..997684465d 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs @@ -75,7 +75,7 @@ public static void Calibration() // Let's train a calibrator estimator on this scored dataset. The trained calibrator estimator produces a transformer // that can transform the scored data by adding a new column names "Probability". - var calibratorEstimator = new PlattCalibratorEstimator(mlContext, model, "Sentiment", "Features"); + var calibratorEstimator = new PlattCalibratorEstimator(mlContext, "Sentiment", "Score"); var calibratorTransformer = calibratorEstimator.Fit(scoredData); // Transform the scored data with a calibrator transfomer by adding a new column names "Probability". diff --git a/src/Microsoft.ML.Core/BestFriendAttribute.cs b/src/Microsoft.ML.Core/BestFriendAttribute.cs index 19c70922e7..2a786a6e79 100644 --- a/src/Microsoft.ML.Core/BestFriendAttribute.cs +++ b/src/Microsoft.ML.Core/BestFriendAttribute.cs @@ -13,7 +13,7 @@ namespace Microsoft.ML #endif { /// - /// Intended to be applied to types and members marked as internal to indicate that friend access of this + /// Intended to be applied to types and members with internal scope to indicate that friend access of this /// internal item is OK from another assembly. This restriction applies only to assemblies that declare the /// assembly level attribute. Note that this attribute is not /// transferrable: an internal member with this attribute does not somehow make a containing internal type diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 2ee036d760..2258622ff6 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -12,6 +12,7 @@ using Microsoft.ML; using Microsoft.ML.Calibrator; using Microsoft.ML.CommandLine; +using Microsoft.ML.Core.Data; using Microsoft.ML.Data; using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Calibration; @@ -85,14 +86,24 @@ namespace Microsoft.ML.Internal.Calibration /// /// Signature for the loaders of calibrators. /// - public delegate void SignatureCalibrator(); + [BestFriend] + internal delegate void SignatureCalibrator(); + [BestFriend] [TlcModule.ComponentKind("CalibratorTrainer")] - public interface ICalibratorTrainerFactory : IComponentFactory + internal interface ICalibratorTrainerFactory : IComponentFactory { } - public interface ICalibratorTrainer + /// + /// This is a legacy interface still used for the command line and entry-points. All applications should transition away + /// from this interface and still work instead via of , + /// for example, the subclasses of . However for now we retain this + /// until such time as those components making use of it can transition to the new way. No public surface should use + /// this, and even new internal code should avoid its use if possible. + /// + [BestFriend] + internal interface ICalibratorTrainer { /// /// True if the calibrator needs training, false otherwise. @@ -107,6 +118,17 @@ public interface ICalibratorTrainer ICalibrator FinishTraining(IChannel ch); } + /// + /// This is a shim interface implemented only by to enable + /// access to the underlying legacy interface for those components that use + /// that old mechanism that we do not care to change right now. + /// + [BestFriend] + internal interface IHaveCalibratorTrainer + { + ICalibratorTrainer CalibratorTrainer { get; } + } + /// /// An interface for predictors that take care of their own calibration given an input data view. /// @@ -842,6 +864,64 @@ public static IPredictor GetCalibratedPredictor(IHostEnvironment env, IChannel c return CreateCalibratedPredictor(env, (IPredictorProducing)predictor, trainedCalibrator); } + public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, IDataView scored, string labelColumn, string scoreColumn, string weightColumn = null, int maxRows = _maxCalibrationExamples) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ch, nameof(ch)); + ch.CheckValue(scored, nameof(scored)); + ch.CheckValue(caliTrainer, nameof(caliTrainer)); + ch.CheckParam(!caliTrainer.NeedsTraining || !string.IsNullOrWhiteSpace(labelColumn), nameof(labelColumn), + "If " + nameof(caliTrainer) + " requires training, then " + nameof(labelColumn) + " must have a value."); + ch.CheckNonWhiteSpace(scoreColumn, nameof(scoreColumn)); + + if (!caliTrainer.NeedsTraining) + return caliTrainer.FinishTraining(ch); + + var labelCol = scored.Schema[labelColumn]; + var scoreCol = scored.Schema[scoreColumn]; + + var weightCol = weightColumn == null ? null : scored.Schema.GetColumnOrNull(weightColumn); + if (weightColumn != null && !weightCol.HasValue) + throw ch.ExceptSchemaMismatch(nameof(weightColumn), "weight", weightColumn); + + ch.Info("Training calibrator."); + + var cols = weightCol.HasValue ? + new Schema.Column[] { labelCol, scoreCol, weightCol.Value } : + new Schema.Column[] { labelCol, scoreCol }; + + using (var cursor = scored.GetRowCursor(cols)) + { + var labelGetter = RowCursorUtils.GetLabelGetter(cursor, labelCol.Index); + var scoreGetter = RowCursorUtils.GetGetterAs(NumberType.R4, cursor, scoreCol.Index); + ValueGetter weightGetter = !weightCol.HasValue ? (ref float dst) => dst = 1 : + RowCursorUtils.GetGetterAs(NumberType.R4, cursor, weightCol.Value.Index); + + int num = 0; + while (cursor.MoveNext()) + { + Single label = 0; + labelGetter(ref label); + if (!FloatUtils.IsFinite(label)) + continue; + Single score = 0; + scoreGetter(ref score); + if (!FloatUtils.IsFinite(score)) + continue; + Single weight = 0; + weightGetter(ref weight); + if (!FloatUtils.IsFinite(weight)) + continue; + + caliTrainer.ProcessTrainingExample(score, label > 0, weight); + + if (maxRows > 0 && ++num >= maxRows) + break; + } + } + return caliTrainer.FinishTraining(ch); + } + /// /// Trains a calibrator. /// @@ -857,60 +937,14 @@ public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICa { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); + ch.CheckValue(caliTrainer, nameof(caliTrainer)); ch.CheckValue(predictor, nameof(predictor)); ch.CheckValue(data, nameof(data)); ch.CheckParam(data.Schema.Label.HasValue, nameof(data), "data must have a Label column"); var scored = ScoreUtils.GetScorer(predictor, data, env, null); - - if (caliTrainer.NeedsTraining) - { - int labelCol; - if (!scored.Schema.TryGetColumnIndex(data.Schema.Label.Value.Name, out labelCol)) - throw ch.Except("No label column found"); - int scoreCol; - if (!scored.Schema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreCol)) - throw ch.Except("No score column found"); - int weightCol = -1; - if (data.Schema.Weight?.Name is string weightName && scored.Schema.GetColumnOrNull(weightName)?.Index is int weightIdx) - weightCol = weightIdx; - ch.Info("Training calibrator."); - - var cols = weightCol > -1 ? - new Schema.Column[] { scored.Schema[labelCol], scored.Schema[scoreCol], scored.Schema[weightCol] } : - new Schema.Column[] { scored.Schema[labelCol], scored.Schema[scoreCol] }; - - using (var cursor = scored.GetRowCursor(cols)) - { - var labelGetter = RowCursorUtils.GetLabelGetter(cursor, labelCol); - var scoreGetter = RowCursorUtils.GetGetterAs(NumberType.R4, cursor, scoreCol); - ValueGetter weightGetter = weightCol == -1 ? (ref float dst) => dst = 1 : - RowCursorUtils.GetGetterAs(NumberType.R4, cursor, weightCol); - - int num = 0; - while (cursor.MoveNext()) - { - Single label = 0; - labelGetter(ref label); - if (!FloatUtils.IsFinite(label)) - continue; - Single score = 0; - scoreGetter(ref score); - if (!FloatUtils.IsFinite(score)) - continue; - Single weight = 0; - weightGetter(ref weight); - if (!FloatUtils.IsFinite(weight)) - continue; - - caliTrainer.ProcessTrainingExample(score, label > 0, weight); - - if (maxRows > 0 && ++num >= maxRows) - break; - } - } - } - return caliTrainer.FinishTraining(ch); + var scoreColumn = scored.Schema[DefaultColumnNames.Score]; + return TrainCalibrator(env, ch, caliTrainer, scored, data.Schema.Label.Value.Name, DefaultColumnNames.Score, data.Schema.Weight?.Name, maxRows); } public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironment env, TSubPredictor predictor, TCalibrator cali) @@ -953,7 +987,8 @@ public ICalibratorTrainer CreateComponent(IHostEnvironment env) /// The probability of belonging to a particular class, for example class 1, is the number of class 1 instances in the bin, divided by the total number /// of instances in that bin. /// - public sealed class NaiveCalibratorTrainer : ICalibratorTrainer + [BestFriend] + internal sealed class NaiveCalibratorTrainer : ICalibratorTrainer { private readonly IHost _host; @@ -1181,7 +1216,8 @@ public string GetSummary() /// /// Base class for calibrator trainers. /// - public abstract class CalibratorTrainerBase : ICalibratorTrainer + [BestFriend] + internal abstract class CalibratorTrainerBase : ICalibratorTrainer { protected readonly IHost Host; protected CalibrationDataStore Data; @@ -1230,7 +1266,8 @@ public ICalibratorTrainer CreateComponent(IHostEnvironment env) } } - public sealed class PlattCalibratorTrainer : CalibratorTrainerBase + [BestFriend] + internal sealed class PlattCalibratorTrainer : CalibratorTrainerBase { internal const string UserName = "Sigmoid Calibration"; internal const string LoadName = "PlattCalibration"; @@ -1389,7 +1426,8 @@ public override ICalibrator CreateCalibrator(IChannel ch) } } - public sealed class FixedPlattCalibratorTrainer : ICalibratorTrainer + [BestFriend] + internal sealed class FixedPlattCalibratorTrainer : ICalibratorTrainer { [TlcModule.Component(Name = "FixedPlattCalibrator", FriendlyName = "Fixed Platt Calibrator", Aliases = new[] { "FixedPlatt", "FixedSigmoid" })] public sealed class Arguments : ICalibratorTrainerFactory @@ -1591,7 +1629,8 @@ public ICalibratorTrainer CreateComponent(IHostEnvironment env) } } - public class PavCalibratorTrainer : CalibratorTrainerBase + [BestFriend] + internal sealed class PavCalibratorTrainer : CalibratorTrainerBase { // a piece of the piecwise function private readonly struct Piece @@ -1664,6 +1703,7 @@ public override ICalibrator CreateCalibrator(IChannel ch) } /// + /// The pair-adjacent violators calibrator. /// The function that is implemented by this calibrator is: /// f(x) = v_i, if minX_i <= x <= maxX_i /// = linear interpolate between v_i and v_i+1, if maxX_i < x < minX_i+1 diff --git a/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs b/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs index f53bb6993e..00c1f4ef4f 100644 --- a/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs +++ b/src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs @@ -3,7 +3,6 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections.Generic; using System.Linq; using Microsoft.Data.DataView; using Microsoft.ML; @@ -36,7 +35,7 @@ public interface ICalibrator } /// - /// Base class for CalibratorEstimators. + /// Base class for calibrator estimators. /// /// /// CalibratorEstimators take an (the output of a ) @@ -49,39 +48,35 @@ public interface ICalibrator /// [!code-csharp[Calibrators](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs)] /// ]]> /// - public abstract class CalibratorEstimatorBase : IEstimator> - where TCalibratorTrainer : ICalibratorTrainer + public abstract class CalibratorEstimatorBase : IEstimator>, IHaveCalibratorTrainer where TICalibrator : class, ICalibrator { - protected readonly IHostEnvironment Host; - protected readonly TCalibratorTrainer CalibratorTrainer; - - protected readonly IPredictor Predictor; - protected readonly SchemaShape.Column ScoreColumn; - protected readonly SchemaShape.Column FeatureColumn; - protected readonly SchemaShape.Column LabelColumn; - protected readonly SchemaShape.Column WeightColumn; - protected readonly SchemaShape.Column PredictedLabel; - - protected CalibratorEstimatorBase(IHostEnvironment env, - TCalibratorTrainer calibratorTrainer, - IPredictor predictor = null, - string labelColumn = DefaultColumnNames.Label, - string featureColumn = DefaultColumnNames.Features, - string weightColumn = null) + [BestFriend] + private protected readonly IHostEnvironment Host; + private readonly ICalibratorTrainer _calibratorTrainer; + ICalibratorTrainer IHaveCalibratorTrainer.CalibratorTrainer => _calibratorTrainer; + + [BestFriend] + private protected readonly SchemaShape.Column ScoreColumn; + [BestFriend] + private protected readonly SchemaShape.Column LabelColumn; + [BestFriend] + private protected readonly SchemaShape.Column WeightColumn; + [BestFriend] + private protected readonly SchemaShape.Column PredictedLabel; + + [BestFriend] + private protected CalibratorEstimatorBase(IHostEnvironment env, + ICalibratorTrainer calibratorTrainer, string labelColumn, string scoreColumn, string weightColumn) { Host = env; - Predictor = predictor; - CalibratorTrainer = calibratorTrainer; + _calibratorTrainer = calibratorTrainer; - ScoreColumn = TrainerUtils.MakeR4ScalarColumn(DefaultColumnNames.Score); // Do we fantom this being named anything else (renaming column)? Complete metadata? - LabelColumn = TrainerUtils.MakeBoolScalarLabel(labelColumn); - FeatureColumn = TrainerUtils.MakeR4VecFeature(featureColumn); - PredictedLabel = new SchemaShape.Column(DefaultColumnNames.PredictedLabel, - SchemaShape.Column.VectorKind.Scalar, - BoolType.Instance, - false, - new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())); + if (!string.IsNullOrWhiteSpace(labelColumn)) + LabelColumn = TrainerUtils.MakeBoolScalarLabel(labelColumn); + else + env.CheckParam(!calibratorTrainer.NeedsTraining, nameof(labelColumn), "For trained calibrators, " + nameof(labelColumn) + " must be specified."); + ScoreColumn = TrainerUtils.MakeR4ScalarColumn(scoreColumn); // Do we fanthom this being named anything else (renaming column)? Complete metadata? if (weightColumn != null) WeightColumn = TrainerUtils.MakeR4ScalarWeightColumn(weightColumn); @@ -105,14 +100,12 @@ SchemaShape IEstimator>.GetOutputSchema(Sche } }; - // check the input schema - checkColumnValid(ScoreColumn, DefaultColumnNames.Score); - checkColumnValid(WeightColumn, DefaultColumnNames.Weight); - checkColumnValid(LabelColumn, DefaultColumnNames.Label); - checkColumnValid(FeatureColumn, DefaultColumnNames.Features); - checkColumnValid(PredictedLabel, DefaultColumnNames.PredictedLabel); + // Check the input schema. + checkColumnValid(ScoreColumn, "score"); + checkColumnValid(WeightColumn, "weight"); + checkColumnValid(LabelColumn, "label"); - //create the new Probability column + // Create the new Probability column. var outColumns = inputSchema.ToDictionary(x => x.Name); outColumns[DefaultColumnNames.Probability] = new SchemaShape.Column(DefaultColumnNames.Probability, SchemaShape.Column.VectorKind.Scalar, @@ -132,35 +125,27 @@ SchemaShape IEstimator>.GetOutputSchema(Sche /// column. public CalibratorTransformer Fit(IDataView input) { - TICalibrator calibrator = null; - - var roles = new List>(); - roles.Add(RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, DefaultColumnNames.Score)); - roles.Add(RoleMappedSchema.ColumnRole.Label.Bind(LabelColumn.Name)); - roles.Add(RoleMappedSchema.ColumnRole.Feature.Bind(FeatureColumn.Name)); - if (WeightColumn.IsValid) - roles.Add(RoleMappedSchema.ColumnRole.Weight.Bind(WeightColumn.Name)); - - var roleMappedData = new RoleMappedData(input, opt: false, roles.ToArray()); - using (var ch = Host.Start("Creating calibrator.")) - calibrator = (TICalibrator)CalibratorUtils.TrainCalibrator(Host, ch, CalibratorTrainer, Predictor, roleMappedData); - - return Create(Host, calibrator); + { + var calibrator = (TICalibrator)CalibratorUtils.TrainCalibrator(Host, ch, + _calibratorTrainer, input, LabelColumn.Name, ScoreColumn.Name, WeightColumn.Name); + return Create(Host, calibrator); + } } /// /// Implemented by deriving classes that create a concrete calibrator. /// - protected abstract CalibratorTransformer Create(IHostEnvironment env, TICalibrator calibrator); + [BestFriend] + private protected abstract CalibratorTransformer Create(IHostEnvironment env, TICalibrator calibrator); } /// - /// CalibratorTransfomers, the artifact of calling Fit on a . + /// An instance of this class is the result of calling . /// If you pass a scored data, to the Transform method, it will add the Probability column /// to the dataset. The Probability column is the value of the Score normalized to be a valid probability. - /// The CalibratorTransformer is an instance of where score can be viewed as a feature - /// while probability is treated as the label. + /// The is an instance of + /// where score can be viewed as a feature while probability is treated as the label. /// /// The used to transform the data. public abstract class CalibratorTransformer : RowToRowTransformerBase, ISingleFeaturePredictionTransformer @@ -172,8 +157,6 @@ public abstract class CalibratorTransformer : RowToRowTransformerB internal CalibratorTransformer(IHostEnvironment env, TICalibrator calibrator, string loaderSignature) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CalibratorTransformer))) { - Host.CheckRef(calibrator, nameof(calibrator)); - _loaderSignature = loaderSignature; _calibrator = calibrator; } @@ -189,7 +172,7 @@ internal CalibratorTransformer(IHostEnvironment env, ModelLoadContext ctx, strin // *** Binary format *** // model: _calibrator - ctx.LoadModel(env, out _calibrator, @"Calibrator"); + ctx.LoadModel(env, out _calibrator, "Calibrator"); } string ISingleFeaturePredictionTransformer.FeatureColumn => DefaultColumnNames.Score; @@ -224,8 +207,8 @@ protected VersionInfo GetVersionInfo() loaderAssemblyName: typeof(CalibratorTransformer<>).Assembly.FullName); } - private sealed class Mapper : MapperBase - where TCalibrator : class, ICalibrator + private sealed class Mapper : MapperBase + where TCalibrator : class, ICalibrator { private TCalibrator _calibrator; private int _scoreColIndex; @@ -277,68 +260,72 @@ protected override Delegate MakeGetter(Row input, int iinfo, Func act } /// - /// The PlattCalibratorEstimator. + /// The Platt calibrator estimator. /// /// - /// For the usage pattern see the example in . + /// For the usage pattern see the example in . /// - public sealed class PlattCalibratorEstimator : CalibratorEstimatorBase + public sealed class PlattCalibratorEstimator : CalibratorEstimatorBase { /// /// Initializes a new instance of /// /// The environment to use. - /// The predictor used to train the data. - /// The label column name. - /// The feature column name. - /// The weight column name. + /// The label column name. This is consumed when this estimator is fit, + /// but not consumed by the resulting transformer. + /// The score column name. This is consumed both when this estimator + /// is fit and when the estimator is consumed. + /// The optional weight column name. Note that if specified this is + /// consumed when this estimator is fit, but not consumed by the resulting transformer. public PlattCalibratorEstimator(IHostEnvironment env, - IPredictor predictor, string labelColumn = DefaultColumnNames.Label, - string featureColumn = DefaultColumnNames.Features, - string weightColumn = null) : base(env, new PlattCalibratorTrainer(env), predictor, labelColumn, featureColumn, weightColumn) + string scoreColumn = DefaultColumnNames.Score, + string weightColumn = null) : base(env, new PlattCalibratorTrainer(env), labelColumn, scoreColumn, weightColumn) { - } - protected override CalibratorTransformer Create(IHostEnvironment env, PlattCalibrator calibrator) - => new PlattCalibratorTransformer(env, calibrator); + [BestFriend] + private protected override CalibratorTransformer Create(IHostEnvironment env, PlattCalibrator calibrator) + => new PlattCalibratorTransformer(env, calibrator); } /// - /// Obtains the probability values by fitting the sigmoid: f(x) = 1 / (1 + exp(-slope * x + offset). + /// Obtains the probability values by applying the sigmoid: f(x) = 1 / (1 + exp(-slope * x + offset). + /// Note that unlike, say, , the fit function here is trivial + /// and just "fits" a calibrator with the provided parameters. /// /// - /// For the usage pattern see the example in . + /// For the usage pattern see the example in . /// - public sealed class FixedPlattCalibratorEstimator : CalibratorEstimatorBase + public sealed class FixedPlattCalibratorEstimator : CalibratorEstimatorBase { /// - /// Initializes a new instance of + /// Initializes a new instance of . /// + /// + /// Note that unlike many other calibrator estimators this one has the parameters pre-specified. + /// This means that it does not have a label or weight column specified as an input during training. + /// /// The environment to use. - /// The predictor used to train the data. /// The slope in the function of the exponent of the sigmoid. /// The offset in the function of the exponent of the sigmoid. - /// The label column name. - /// The feature column name. - /// The weight column name. + /// The score column name. This is consumed both when this estimator + /// is fit and when the estimator is consumed. public FixedPlattCalibratorEstimator(IHostEnvironment env, - IPredictor predictor, - double slope = 1, +double slope = 1, double offset = 0, - string labelColumn = DefaultColumnNames.Label, - string featureColumn = DefaultColumnNames.Features, - string weightColumn = null) : base(env, new FixedPlattCalibratorTrainer(env, new FixedPlattCalibratorTrainer.Arguments() + string scoreColumn = DefaultColumnNames.Score) + : base(env, new FixedPlattCalibratorTrainer(env, new FixedPlattCalibratorTrainer.Arguments() { Slope = slope, Offset = offset - }), predictor, labelColumn, featureColumn, weightColumn) + }), null, scoreColumn, null) { } - protected override CalibratorTransformer Create(IHostEnvironment env, PlattCalibrator calibrator) + [BestFriend] + private protected override CalibratorTransformer Create(IHostEnvironment env, PlattCalibrator calibrator) => new PlattCalibratorTransformer(env, calibrator); } @@ -352,47 +339,46 @@ public sealed class PlattCalibratorTransformer : CalibratorTransformer - /// The naive binning-based calibratorEstimator. + /// The naive binning-based calbirator estimator. /// /// /// It divides the range of the outputs into equally sized bins. In each bin, /// the probability of belonging to class 1, is the number of class 1 instances in the bin, divided by the total number /// of instances in the bin. - /// For the usage pattern see the example in . + /// For the usage pattern see the example in . /// - public sealed class NaiveCalibratorEstimator : CalibratorEstimatorBase + public sealed class NaiveCalibratorEstimator : CalibratorEstimatorBase { /// /// Initializes a new instance of /// /// The environment to use. - /// The predictor used to train the data. - /// The label column name. - /// The feature column name. - /// The weight column name. + /// The label column name. This is consumed when this estimator is fit, + /// but not consumed by the resulting transformer. + /// The score column name. This is consumed both when this estimator + /// is fit and when the estimator is consumed. + /// The optional weight column name. Note that if specified this is + /// consumed when this estimator is fit, but not consumed by the resulting transformer. public NaiveCalibratorEstimator(IHostEnvironment env, - IPredictor predictor, string labelColumn = DefaultColumnNames.Label, - string featureColumn = DefaultColumnNames.Features, - string weightColumn = null) : base(env, new NaiveCalibratorTrainer(env), predictor, labelColumn, featureColumn, weightColumn) + string scoreColumn = DefaultColumnNames.Score, + string weightColumn = null) : base(env, new NaiveCalibratorTrainer(env), labelColumn, scoreColumn, weightColumn) { - } - protected override CalibratorTransformer Create(IHostEnvironment env, NaiveCalibrator calibrator) - => new NaiveCalibratorTransformer(env, calibrator); + [BestFriend] + private protected override CalibratorTransformer Create(IHostEnvironment env, NaiveCalibrator calibrator) + => new NaiveCalibratorTransformer(env, calibrator); } /// @@ -405,43 +391,42 @@ public sealed class NaiveCalibratorTransformer : CalibratorTransformer - /// The PavCalibratorEstimator. + /// The pair-adjacent violators calibrator estimator. /// /// - /// For the usage pattern see the example in . + /// For the usage pattern see the example in . /// - public sealed class PavCalibratorEstimator : CalibratorEstimatorBase + public sealed class PavCalibratorEstimator : CalibratorEstimatorBase { /// /// Initializes a new instance of /// /// The environment to use. - /// The predictor used to train the data. - /// The label column name. - /// The feature column name. - /// The weight column name. + /// The label column name. This is consumed when this estimator is fit, + /// but not consumed by the resulting transformer. + /// The score column name. This is consumed both when this estimator + /// is fit and when the estimator is consumed. + /// The optional weight column name. Note that if specified this is + /// consumed when this estimator is fit, but not consumed by the resulting transformer. public PavCalibratorEstimator(IHostEnvironment env, - IPredictor predictor, string labelColumn = DefaultColumnNames.Label, - string featureColumn = DefaultColumnNames.Features, - string weightColumn = null) : base(env, new PavCalibratorTrainer(env), predictor, labelColumn, featureColumn, weightColumn) + string scoreColumn = DefaultColumnNames.Score, + string weightColumn = null) : base(env, new PavCalibratorTrainer(env), labelColumn, scoreColumn, weightColumn) { - } - protected override CalibratorTransformer Create(IHostEnvironment env, PavCalibrator calibrator) + [BestFriend] + private protected override CalibratorTransformer Create(IHostEnvironment env, PavCalibrator calibrator) => new PavCalibratorTransformer(env, calibrator); } @@ -456,14 +441,12 @@ public sealed class PavCalibratorTransformer : CalibratorTransformer public abstract class TrainCatalogBase { - protected internal readonly IHost Host; + [BestFriend] + private protected readonly IHost Host; [BestFriend] internal IHostEnvironment Environment => Host; @@ -192,7 +192,8 @@ protected internal CrossValidationResult[] CrossValidateTrain(IDataView data, IE return result; } - protected internal TrainCatalogBase(IHostEnvironment env, string registrationName) + [BestFriend] + private protected TrainCatalogBase(IHostEnvironment env, string registrationName) { Contracts.CheckValue(env, nameof(env)); env.CheckNonEmpty(registrationName, nameof(registrationName)); diff --git a/src/Microsoft.ML.FastTree/RandomForestClassification.cs b/src/Microsoft.ML.FastTree/RandomForestClassification.cs index 13d2261ea9..5988b485b7 100644 --- a/src/Microsoft.ML.FastTree/RandomForestClassification.cs +++ b/src/Microsoft.ML.FastTree/RandomForestClassification.cs @@ -118,10 +118,10 @@ public sealed class Options : FastForestArgumentsBase public Double MaxTreeOutput = 100; [Argument(ArgumentType.AtMostOnce, HelpText = "The calibrator kind to apply to the predictor. Specify null for no calibration", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] - public ICalibratorTrainerFactory Calibrator = new PlattCalibratorTrainerFactory(); + internal ICalibratorTrainerFactory Calibrator = new PlattCalibratorTrainerFactory(); [Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] - public int MaxCalibrationExamples = 1000000; + internal int MaxCalibrationExamples = 1000000; } internal const string LoadNameValue = "FastForestClassification"; diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs index 448a078a8b..6124bbd09d 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MetaMulticlassTrainer.cs @@ -29,10 +29,10 @@ public abstract class ArgumentsBase internal IComponentFactory PredictorType; [Argument(ArgumentType.Multiple, HelpText = "Output calibrator", ShortName = "cali", SortOrder = 150, NullName = "", SignatureType = typeof(SignatureCalibrator))] - public IComponentFactory Calibrator = new PlattCalibratorTrainerFactory(); + internal IComponentFactory Calibrator = new PlattCalibratorTrainerFactory(); [Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of instances to train the calibrator", SortOrder = 150, ShortName = "numcali")] - public int MaxCalibrationExamples = 1000000000; + internal int MaxCalibrationExamples = 1000000000; [Argument(ArgumentType.Multiple, HelpText = "Whether to treat missing labels as having negative labels, instead of keeping them missing", SortOrder = 150, ShortName = "missNeg")] public bool ImputeMissingLabelsAsNegative; diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs index c3610c7dd1..6f7bb7909d 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/AveragedPerceptron.cs @@ -45,10 +45,10 @@ public sealed class Options : AveragedLinearArguments public ISupportClassificationLossFactory LossFunction = new HingeLoss.Arguments(); [Argument(ArgumentType.AtMostOnce, HelpText = "The calibrator kind to apply to the predictor. Specify null for no calibration", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] - public ICalibratorTrainerFactory Calibrator = new PlattCalibratorTrainerFactory(); + internal ICalibratorTrainerFactory Calibrator = new PlattCalibratorTrainerFactory(); [Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] - public int MaxCalibrationExamples = 1000000; + internal int MaxCalibrationExamples = 1000000; internal override IComponentFactory LossFunctionFactory => LossFunction; } diff --git a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs index e1382b3038..09a32d22ec 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/Online/LinearSvm.cs @@ -61,10 +61,10 @@ public sealed class Options : OnlineLinearArguments public bool NoBias = false; [Argument(ArgumentType.AtMostOnce, HelpText = "The calibrator kind to apply to the predictor. Specify null for no calibration", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] - public ICalibratorTrainerFactory Calibrator = new PlattCalibratorTrainerFactory(); + internal ICalibratorTrainerFactory Calibrator = new PlattCalibratorTrainerFactory(); [Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] - public int MaxCalibrationExamples = 1000000; + internal int MaxCalibrationExamples = 1000000; } private sealed class TrainState : TrainStateBase diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs index ca07f7fbe5..d24be7b68f 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs @@ -1408,10 +1408,10 @@ public sealed class Options : ArgumentsBase public float PositiveInstanceWeight = 1; [Argument(ArgumentType.AtMostOnce, HelpText = "The calibrator kind to apply to the predictor. Specify null for no calibration", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] - public ICalibratorTrainerFactory Calibrator = new PlattCalibratorTrainerFactory(); + internal ICalibratorTrainerFactory Calibrator = new PlattCalibratorTrainerFactory(); [Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] - public int MaxCalibrationExamples = 1000000; + internal int MaxCalibrationExamples = 1000000; internal override void Check(IHostEnvironment env) { @@ -1624,10 +1624,10 @@ public sealed class Options : LearnerInputBaseWithWeight public int? CheckFrequency; [Argument(ArgumentType.AtMostOnce, HelpText = "The calibrator kind to apply to the predictor. Specify null for no calibration", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] - public ICalibratorTrainerFactory Calibrator = new PlattCalibratorTrainerFactory(); + internal ICalibratorTrainerFactory Calibrator = new PlattCalibratorTrainerFactory(); [Argument(ArgumentType.AtMostOnce, HelpText = "The maximum number of examples to use when training the calibrator", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)] - public int MaxCalibrationExamples = 1000000; + internal int MaxCalibrationExamples = 1000000; internal void Check(IHostEnvironment env) { diff --git a/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs b/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs index ae03b46bc8..88af8fc3cc 100644 --- a/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs +++ b/src/Microsoft.ML.StandardLearners/StandardLearnersCatalog.cs @@ -3,6 +3,8 @@ // See the LICENSE file in the project root for more information. using System; +using Microsoft.ML.Calibrator; +using Microsoft.ML.Core.Data; using Microsoft.ML.Data; using Microsoft.ML.Internal.Calibration; using Microsoft.ML.Trainers; @@ -437,6 +439,26 @@ public static MultiClassNaiveBayesTrainer NaiveBayes(this MulticlassClassificati return new MultiClassNaiveBayesTrainer(CatalogUtils.GetEnvironment(catalog), labelColumn, featureColumn); } + /// + /// Works via the shim interface to extract from the calibrating training + /// estimator the internal object. Note that this should be a temporary measure, + /// since the trainers should really be changed to actually work over estimators. + /// + /// The exception context. + /// The estimator out of which we should try to extract the calibrator trainer. + /// The calibrator trainer. + private static ICalibratorTrainer GetCalibratorTrainerOrThrow(IExceptionContext ectx, IEstimator> calibratorEstimator) + { + Contracts.AssertValue(ectx); + ectx.AssertValueOrNull(calibratorEstimator); + if (calibratorEstimator == null) + return null; + if (calibratorEstimator is IHaveCalibratorTrainer haveCalibratorTrainer) + return haveCalibratorTrainer.CalibratorTrainer; + throw ectx.ExceptParam(nameof(calibratorEstimator), + "Calibrator estimator was not of a type usable in this context."); + } + /// /// Predicts a target using a linear multiclass classification model trained with the . /// @@ -459,7 +481,7 @@ public static Ova OneVersusAll(this MulticlassClassificationCatalog.Mult ITrainerEstimator, TModel> binaryEstimator, string labelColumn = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, - ICalibratorTrainer calibrator = null, + IEstimator> calibrator = null, int maxCalibrationExamples = 1000000000, bool useProbabilities = true) where TModel : class @@ -468,7 +490,7 @@ public static Ova OneVersusAll(this MulticlassClassificationCatalog.Mult var env = CatalogUtils.GetEnvironment(catalog); if (!(binaryEstimator is ITrainerEstimator>, IPredictorProducing> est)) throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); - return new Ova(env, est, labelColumn, imputeMissingLabelsAsNegative, calibrator, maxCalibrationExamples, useProbabilities); + return new Ova(env, est, labelColumn, imputeMissingLabelsAsNegative, GetCalibratorTrainerOrThrow(env, calibrator), maxCalibrationExamples, useProbabilities); } /// @@ -492,7 +514,7 @@ public static Pkpd PairwiseCoupling(this MulticlassClassificationCatalog ITrainerEstimator, TModel> binaryEstimator, string labelColumn = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, - ICalibratorTrainer calibrator = null, + IEstimator> calibrator = null, int maxCalibrationExamples = 1_000_000_000) where TModel : class { @@ -500,7 +522,7 @@ public static Pkpd PairwiseCoupling(this MulticlassClassificationCatalog var env = CatalogUtils.GetEnvironment(catalog); if (!(binaryEstimator is ITrainerEstimator>, IPredictorProducing> est)) throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); - return new Pkpd(env, est, labelColumn, imputeMissingLabelsAsNegative, calibrator, maxCalibrationExamples); + return new Pkpd(env, est, labelColumn, imputeMissingLabelsAsNegative, GetCalibratorTrainerOrThrow(env, calibrator), maxCalibrationExamples); } /// diff --git a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs index 1ef0cda99c..4754191d09 100644 --- a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs +++ b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs @@ -65,7 +65,7 @@ public void OvaAveragedPerceptron() // Pipeline var ap = mlContext.BinaryClassification.Trainers.AveragedPerceptron( - new AveragedPerceptronTrainer.Options { Shuffle = true, Calibrator = null }); + new AveragedPerceptronTrainer.Options { Shuffle = true }); var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll(ap, useProbabilities: false); var model = pipeline.Fit(data); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/CalibratorEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/CalibratorEstimators.cs index 3de1a3125a..ebb9a14690 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/CalibratorEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/CalibratorEstimators.cs @@ -21,15 +21,15 @@ public void PlattCalibratorEstimator() { var calibratorTestData = GetCalibratorTestData(); - // platCalibrator - var platCalibratorEstimator = new PlattCalibratorEstimator(Env, calibratorTestData.transformer.Model, "Label", "Features"); - var platCalibratorTransformer = platCalibratorEstimator.Fit(calibratorTestData.scoredData); + // plattCalibrator + var plattCalibratorEstimator = new PlattCalibratorEstimator(Env); + var plattCalibratorTransformer = plattCalibratorEstimator.Fit(calibratorTestData.ScoredData); //testData - checkValidCalibratedData(calibratorTestData.scoredData, platCalibratorTransformer); + CheckValidCalibratedData(calibratorTestData.ScoredData, plattCalibratorTransformer); //test estimator - TestEstimatorCore(platCalibratorEstimator, calibratorTestData.scoredData); + TestEstimatorCore(plattCalibratorEstimator, calibratorTestData.ScoredData); Done(); } @@ -38,18 +38,18 @@ public void PlattCalibratorEstimator() /// OVA and calibrators /// [Fact] - public void FixedPlatCalibratorEstimator() + public void FixedPlattCalibratorEstimator() { var calibratorTestData = GetCalibratorTestData(); - // fixedPlatCalibrator - var fixedPlatCalibratorEstimator = new FixedPlattCalibratorEstimator(Env, calibratorTestData.transformer.Model, labelColumn: "Label", featureColumn: "Features"); - var fixedPlatCalibratorTransformer = fixedPlatCalibratorEstimator.Fit(calibratorTestData.scoredData); + // fixedPlattCalibrator + var fixedPlattCalibratorEstimator = new FixedPlattCalibratorEstimator(Env); + var fixedPlattCalibratorTransformer = fixedPlattCalibratorEstimator.Fit(calibratorTestData.ScoredData); - checkValidCalibratedData(calibratorTestData.scoredData, fixedPlatCalibratorTransformer); + CheckValidCalibratedData(calibratorTestData.ScoredData, fixedPlattCalibratorTransformer); //test estimator - TestEstimatorCore(calibratorTestData.pipeline, calibratorTestData.data); + TestEstimatorCore(calibratorTestData.Pipeline, calibratorTestData.Data); Done(); } @@ -63,14 +63,14 @@ public void NaiveCalibratorEstimator() var calibratorTestData = GetCalibratorTestData(); // naive calibrator - var naiveCalibratorEstimator = new NaiveCalibratorEstimator(Env, calibratorTestData.transformer.Model, "Label", "Features"); - var naiveCalibratorTransformer = naiveCalibratorEstimator.Fit(calibratorTestData.scoredData); + var naiveCalibratorEstimator = new NaiveCalibratorEstimator(Env); + var naiveCalibratorTransformer = naiveCalibratorEstimator.Fit(calibratorTestData.ScoredData); // check data - checkValidCalibratedData(calibratorTestData.scoredData, naiveCalibratorTransformer); + CheckValidCalibratedData(calibratorTestData.ScoredData, naiveCalibratorTransformer); //test estimator - TestEstimatorCore(calibratorTestData.pipeline, calibratorTestData.data); + TestEstimatorCore(calibratorTestData.Pipeline, calibratorTestData.Data); Done(); } @@ -83,14 +83,14 @@ public void PavCalibratorEstimator() var calibratorTestData = GetCalibratorTestData(); // pav calibrator - var pavCalibratorEstimator = new PavCalibratorEstimator(Env, calibratorTestData.transformer.Model, "Label", "Features"); - var pavCalibratorTransformer = pavCalibratorEstimator.Fit(calibratorTestData.scoredData); + var pavCalibratorEstimator = new PavCalibratorEstimator(Env); + var pavCalibratorTransformer = pavCalibratorEstimator.Fit(calibratorTestData.ScoredData); //check data - checkValidCalibratedData(calibratorTestData.scoredData, pavCalibratorTransformer); + CheckValidCalibratedData(calibratorTestData.ScoredData, pavCalibratorTransformer); //test estimator - TestEstimatorCore(calibratorTestData.pipeline, calibratorTestData.data); + TestEstimatorCore(calibratorTestData.Pipeline, calibratorTestData.Data); Done(); } @@ -109,24 +109,24 @@ CalibratorTestData GetCalibratorTestData() return new CalibratorTestData { - data = data, - scoredData = scoredData, - pipeline = pipeline, - transformer = ((TransformerChain>)transformer).LastTransformer as BinaryPredictionTransformer, + Data = data, + ScoredData = scoredData, + Pipeline = pipeline, + Transformer = ((TransformerChain>)transformer).LastTransformer as BinaryPredictionTransformer, }; } - private class CalibratorTestData + private sealed class CalibratorTestData { - internal IDataView data { get; set; } - internal IDataView scoredData { get; set; } - internal IEstimator pipeline { get; set; } + public IDataView Data { get; set; } + public IDataView ScoredData { get; set; } + public IEstimator Pipeline { get; set; } - internal BinaryPredictionTransformer transformer { get; set; } + public BinaryPredictionTransformer Transformer { get; set; } } - void checkValidCalibratedData (IDataView scoredData, ITransformer transformer){ + private void CheckValidCalibratedData(IDataView scoredData, ITransformer transformer){ var calibratedData = transformer.Transform(scoredData).Preview(); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs index 7e0a34b1b6..d4f30aaa66 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs @@ -2,8 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Calibrator; using Microsoft.ML.Data; -using Microsoft.ML.Internal.Calibration; using Microsoft.ML.RunTests; using Microsoft.ML.Trainers; using Microsoft.ML.Trainers.Online; @@ -22,9 +22,9 @@ public partial class TrainerEstimators public void OVAWithAllConstructorArgs() { var (pipeline, data) = GetMultiClassPipeline(); - var calibrator = new PlattCalibratorTrainer(Env); + var calibrator = new PlattCalibratorEstimator(Env); var averagePerceptron = ML.BinaryClassification.Trainers.AveragedPerceptron( - new AveragedPerceptronTrainer.Options { Shuffle = true, Calibrator = null }); + new AveragedPerceptronTrainer.Options { Shuffle = true }); var ova = ML.MulticlassClassification.Trainers.OneVersusAll(averagePerceptron, imputeMissingLabelsAsNegative: true, calibrator: calibrator, maxCalibrationExamples: 10000, useProbabilities: true); @@ -44,7 +44,7 @@ public void OVAUncalibrated() { var (pipeline, data) = GetMultiClassPipeline(); var sdcaTrainer = ML.BinaryClassification.Trainers.StochasticDualCoordinateAscent( - new SdcaBinaryTrainer.Options { MaxIterations = 100, Shuffle = true, NumThreads = 1, Calibrator = null }); + new SdcaBinaryTrainer.Options { MaxIterations = 100, Shuffle = true, NumThreads = 1 }); pipeline = pipeline.Append(ML.MulticlassClassification.Trainers.OneVersusAll(sdcaTrainer, useProbabilities: false)) .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel"));