diff --git a/src/Microsoft.ML.Core/Prediction/IPredictor.cs b/src/Microsoft.ML.Core/Prediction/IPredictor.cs index 118375f134..728afe7802 100644 --- a/src/Microsoft.ML.Core/Prediction/IPredictor.cs +++ b/src/Microsoft.ML.Core/Prediction/IPredictor.cs @@ -12,8 +12,7 @@ namespace Microsoft.ML /// and it is still useful, but for things based on /// the idiom, it is inappropriate. /// - [BestFriend] - internal enum PredictionKind + public enum PredictionKind { Unknown = 0, Custom = 1, diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index c4cc248dc7..d3735e6460 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -528,7 +528,7 @@ internal sealed class ParameterMixingCalibratedModelParameters _featureWeights; + private readonly TSubModel _featureWeights; internal ParameterMixingCalibratedModelParameters(IHostEnvironment env, TSubModel predictor, TCalibrator calibrator) : base(env, RegistrationName, predictor, calibrator) @@ -536,7 +536,7 @@ internal ParameterMixingCalibratedModelParameters(IHostEnvironment env, TSubMode Host.Check(predictor is IParameterMixer, "Predictor does not implement " + nameof(IParameterMixer)); Host.Check(calibrator is IParameterMixer, "Calibrator does not implement " + nameof(IParameterMixer)); Host.Assert(predictor is IPredictorWithFeatureWeights); - _featureWeights = predictor as IPredictorWithFeatureWeights; + _featureWeights = predictor; } internal const string LoaderSignature = "PMixCaliPredExec"; @@ -558,7 +558,7 @@ private ParameterMixingCalibratedModelParameters(IHostEnvironment env, ModelLoad { Host.Check(SubModel is IParameterMixer, "Predictor does not implement " + nameof(IParameterMixer)); Host.Check(SubModel is IPredictorWithFeatureWeights, "Predictor does not implement " + nameof(IPredictorWithFeatureWeights)); - _featureWeights = SubModel as IPredictorWithFeatureWeights; + _featureWeights = SubModel; } private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx) @@ -579,7 +579,7 @@ void ICanSaveModel.Save(ModelSaveContext ctx) public void GetFeatureWeights(ref VBuffer weights) { - _featureWeights.GetFeatureWeights(ref weights); + ((IPredictorWithFeatureWeights)_featureWeights).GetFeatureWeights(ref weights); } IParameterMixer IParameterMixer.CombineParameters(IList> models) @@ -879,6 +879,15 @@ public static IPredictor GetCalibratedPredictor(IHostEnvironment env, IChannel c return CreateCalibratedPredictor(env, (IPredictorProducing)predictor, trainedCalibrator); } + public static CalibratedModelParametersBase GetCalibratedPredictor(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, + TSubPredictor predictor, RoleMappedData data, int maxRows = _maxCalibrationExamples) + where TSubPredictor : class + where TCalibrator : class, ICalibrator + { + var trainedCalibrator = TrainCalibrator(env, ch, caliTrainer, (IPredictor)predictor, data, maxRows) as TCalibrator; + return (CalibratedModelParametersBase)CreateCalibratedPredictor(env, predictor, trainedCalibrator); + } + public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, IDataView scored, string labelColumn, string scoreColumn, string weightColumn = null, int maxRows = _maxCalibrationExamples) { Contracts.CheckValue(env, nameof(env)); @@ -963,12 +972,12 @@ public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICa } public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironment env, TSubPredictor predictor, TCalibrator cali) - where TSubPredictor : class, IPredictorProducing + where TSubPredictor : class where TCalibrator : class, ICalibrator { Contracts.Assert(predictor != null); if (cali == null) - return predictor; + return (IPredictorProducing)predictor; for (; ; ) { @@ -980,7 +989,7 @@ public static IPredictorProducing CreateCalibratedPredictor; if (predWithFeatureScores != null && predictor is IParameterMixer && cali is IParameterMixer) - return new ParameterMixingCalibratedModelParameters, TCalibrator>(env, predWithFeatureScores, cali); + return new ParameterMixingCalibratedModelParameters(env, predictor, cali); if (predictor is IValueMapper) return new ValueMapperCalibratedModelParameters(env, predictor, cali); diff --git a/src/Microsoft.ML.EntryPoints/ModelOperations.cs b/src/Microsoft.ML.EntryPoints/ModelOperations.cs index 7b1b2afc8a..c44f86974f 100644 --- a/src/Microsoft.ML.EntryPoints/ModelOperations.cs +++ b/src/Microsoft.ML.EntryPoints/ModelOperations.cs @@ -155,7 +155,7 @@ public static PredictorModelOutput CombineOvaModels(IHostEnvironment env, Combin return new PredictorModelOutput { PredictorModel = new PredictorModelImpl(env, data, input.TrainingData, - OneVersusAllModelParameters.Create(host, input.UseProbabilities, + OneVersusAllModelParametersBuilder.Create(host, input.UseProbabilities, input.ModelArray.Select(p => p.Predictor as IPredictorProducing).ToArray())) }; } diff --git a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs index 353856e7bd..589be23714 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs @@ -198,9 +198,9 @@ private protected override OneVersusAllModelParameters CreatePredictor() } string obj = (string)GetGbmParameters()["objective"]; if (obj == "multiclass") - return OneVersusAllModelParameters.Create(Host, OneVersusAllModelParameters.OutputFormula.Softmax, predictors); + return OneVersusAllModelParametersBuilder.Create(Host, OneVersusAllModelParameters.OutputFormula.Softmax, predictors); else - return OneVersusAllModelParameters.Create(Host, predictors); + return OneVersusAllModelParametersBuilder.Create(Host, predictors); } private protected override void CheckDataValid(IChannel ch, RoleMappedData data) diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index 2ae05af908..741923c4a2 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -37,6 +37,7 @@ namespace Microsoft.ML.Trainers using TDistPredictor = IDistPredictorProducing; using TScalarPredictor = IPredictorProducing; using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; + /// /// The for training a one-versus-all multi-class classifier that uses the specified binary classifier. /// @@ -82,7 +83,7 @@ namespace Microsoft.ML.Trainers /// /// /// - public sealed class OneVersusAllTrainer : MetaMulticlassTrainer, OneVersusAllModelParameters> + public abstract class OneVersusAllTrainerBase : MetaMulticlassTrainer, T> where T : class { internal const string LoadNameValue = "OVA"; internal const string UserNameValue = "One-vs-All"; @@ -90,10 +91,10 @@ public sealed class OneVersusAllTrainer : MetaMulticlassTrainer - /// Options passed to + /// Options passed to /// internal sealed class Options : OptionsBase { @@ -106,27 +107,27 @@ internal sealed class Options : OptionsBase } /// - /// Constructs a trainer supplying a . + /// Constructs a trainer supplying a . /// /// The private for this estimator. /// The legacy - internal OneVersusAllTrainer(IHostEnvironment env, Options options) + internal OneVersusAllTrainerBase(IHostEnvironment env, Options options) : base(env, options, LoadNameValue) { - _options = options; + TrainerOptions = options; } /// - /// Initializes a new instance of . + /// Initializes a new instance of . /// /// The instance. /// An instance of a binary used as the base trainer. - /// The calibrator. If a calibrator is not provided, it will default to + /// /// The calibrator. If a calibrator is not provided, it will default to /// The name of the label colum. /// If true will treat missing labels as negative labels. - /// Number of instances to train the calibrator. + /// /// Number of instances to train the calibrator. /// Use probabilities (vs. raw outputs) to identify top-score category. - internal OneVersusAllTrainer(IHostEnvironment env, + internal OneVersusAllTrainerBase(IHostEnvironment env, TScalarTrainer binaryEstimator, string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, @@ -142,23 +143,25 @@ internal OneVersusAllTrainer(IHostEnvironment env, LoadNameValue, labelColumnName, binaryEstimator, calibrator) { Host.CheckValue(labelColumnName, nameof(labelColumnName), "Label column should not be null."); - _options = (Options)Args; - _options.UseProbabilities = useProbabilities; + TrainerOptions = (Options)Args; + TrainerOptions.UseProbabilities = useProbabilities; } - private protected override OneVersusAllModelParameters TrainCore(IChannel ch, RoleMappedData data, int count) - { - // Train one-vs-all models. - var predictors = new TScalarPredictor[count]; - for (int i = 0; i < predictors.Length; i++) - { - ch.Info($"Training learner {i}"); - predictors[i] = TrainOne(ch, Trainer, data, i).Model; - } - return OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors); - } - - private ISingleFeaturePredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls) + /// + /// Training helper method that is called by . This allows the + /// classes that inherit from this class to do any custom training changes needed, such as casting. + /// + /// The instance. + /// Whether probabilities should be used or not. Is pulled from the trainer . + /// The that has the data. + /// The label for the trainer. + /// The used by the trainer + /// + private protected abstract ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch, + bool useProbabilities, IDataView view, string trainerLabel, + ISingleFeaturePredictionTransformer transformer); + + private protected ISingleFeaturePredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls) { var view = MapLabels(data, cls); @@ -168,22 +171,7 @@ private ISingleFeaturePredictionTransformer TrainOne(IChannel // this is currently unsupported. var transformer = trainer.Fit(view); - if (_options.UseProbabilities) - { - var calibratedModel = transformer.Model as TDistPredictor; - - // REVIEW: restoring the RoleMappedData, as much as we can. - // not having the weight column on the data passed to the TrainCalibrator should be addressed. - var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName); - - if (calibratedModel == null) - calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor; - - Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface"); - return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName); - } - - return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName); + return TrainOneHelper(ch, TrainerOptions.UseProbabilities, view, trainerLabel, transformer); } private IDataView MapLabels(RoleMappedData data, int cls) @@ -202,10 +190,23 @@ private IDataView MapLabels(RoleMappedData data, int cls) throw Host.ExceptNotSupp($"Label column type is not supported by OneVersusAllTrainer: {label.Type.RawType}"); } + /// + /// Fit helper method that is called by . This allows the + /// classes that inherit from this class to do any custom fit changes needed, such as casting. + /// + /// The . + /// Whether probabilities should be used or not. Is pulled from the trainer . + /// The array of used. + /// The of the transformer. + /// The feature column. + /// The name of the label column. + /// + private protected abstract MulticlassPredictionTransformer FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName); + /// Trains a model. /// The input data. /// A model./> - public override MulticlassPredictionTransformer Fit(IDataView input) + public override MulticlassPredictionTransformer Fit(IDataView input) { var roles = new KeyValuePair[1]; roles[0] = new KeyValuePair(new CR(DefaultColumnNames.Label), LabelColumn.Name); @@ -227,19 +228,306 @@ public override MulticlassPredictionTransformer Fit var transformer = TrainOne(ch, Trainer, td, i); featureColumn = transformer.FeatureColumnName; } - predictors[i] = TrainOne(ch, Trainer, td, i).Model; + } } + return FitHelper(Host, TrainerOptions.UseProbabilities, predictors, input.Schema, featureColumn, LabelColumn.Name); + } + } + + /// + /// Implementation of the where T is a + /// to maintain api compatability. + /// + public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase + { + /// + /// Constructs a trainer supplying a . + /// + /// The private for this estimator. + /// The legacy + internal OneVersusAllTrainer(IHostEnvironment env, Options options) + : base(env, options) + { + } + + /// + /// Initializes a new instance of . + /// + /// The instance. + /// An instance of a binary used as the base trainer. + /// /// The calibrator. If a calibrator is not provided, it will default to + /// The name of the label colum. + /// If true will treat missing labels as negative labels. + /// /// Number of instances to train the calibrator. + /// Use probabilities (vs. raw outputs) to identify top-score category. + internal OneVersusAllTrainer(IHostEnvironment env, + TScalarTrainer binaryEstimator, + string labelColumnName = DefaultColumnNames.Label, + bool imputeMissingLabelsAsNegative = false, + ICalibratorTrainer calibrator = null, + int maximumCalibrationExampleCount = 1000000000, + bool useProbabilities = true) + : base(env, binaryEstimator, labelColumnName, imputeMissingLabelsAsNegative, calibrator, maximumCalibrationExampleCount, useProbabilities) + { + } + + private protected override OneVersusAllModelParameters TrainCore(IChannel ch, RoleMappedData data, int count) + { + // Train one-vs-all models. + var predictors = new TScalarPredictor[count]; + for (int i = 0; i < predictors.Length; i++) + { + ch.Info($"Training learner {i}"); + predictors[i] = TrainOne(ch, Trainer, data, i).Model; + } + return OneVersusAllModelParametersBuilder.Create(Host, TrainerOptions.UseProbabilities, predictors) as OneVersusAllModelParameters; + } - return new MulticlassPredictionTransformer(Host, OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors), input.Schema, featureColumn, LabelColumn.Name); + private protected override MulticlassPredictionTransformer FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName) + { + return new MulticlassPredictionTransformer(Host, OneVersusAllModelParametersBuilder.Create(Host, useProbabilities, predictors), schema, featureColumn, LabelColumn.Name); + } + + private protected override ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch, + bool useProbabilities, IDataView view, string trainerLabel, + ISingleFeaturePredictionTransformer transformer) + { + if (useProbabilities) + { + var calibratedModel = transformer.Model as TDistPredictor; + + // REVIEW: restoring the RoleMappedData, as much as we can. + // not having the weight column on the data passed to the TrainCalibrator should be addressed. + var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName); + + if (calibratedModel == null) + calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor; + + Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface"); + return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName); + } + + return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName); } } /// - /// Model parameters for . + /// Strongly typed implementation of the where T is a of type + /// This is used to turn a non calibrated binary classification estimator into its calibrated version. /// - public sealed class OneVersusAllModelParameters : + /// + /// + public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase>> + where TSubPredictor : class + where TCalibrator : class, ICalibrator + { + /// + /// Constructs a trainer supplying a . + /// + /// The private for this estimator. + /// The legacy + internal OneVersusAllTrainer(IHostEnvironment env, Options options) + : base(env, options) + { + } + + /// + /// Initializes a new instance of . + /// + /// The instance. + /// An instance of a binary used as the base trainer. + /// The calibrator. If a calibrator is not provided, it will default to + /// The name of the label colum. + /// If true will treat missing labels as negative labels. + /// Number of instances to train the calibrator. + /// Use probabilities (vs. raw outputs) to identify top-score category. + internal OneVersusAllTrainer(IHostEnvironment env, + TScalarTrainer binaryEstimator, + string labelColumnName = DefaultColumnNames.Label, + bool imputeMissingLabelsAsNegative = false, + ICalibratorTrainer calibrator = null, + int maximumCalibrationExampleCount = 1000000000, + bool useProbabilities = true) + : base(env, binaryEstimator, labelColumnName, imputeMissingLabelsAsNegative, calibrator, maximumCalibrationExampleCount, useProbabilities) + { + } + + private protected override ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch, + bool useProbabilities, IDataView view, string trainerLabel, + ISingleFeaturePredictionTransformer transformer) + { + if (useProbabilities) + { + var calibratedModel = transformer.Model as TDistPredictor; + + // REVIEW: restoring the RoleMappedData, as much as we can. + // not having the weight column on the data passed to the TrainCalibrator should be addressed. + var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName); + + if (calibratedModel == null) + calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, (TSubPredictor)transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor; + + Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface"); + return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName); + } + + return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName); + } + private protected override MulticlassPredictionTransformer>> FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName) + { + return new MulticlassPredictionTransformer>>(Host, OneVersusAllModelParametersBuilder.Create(Host, useProbabilities, predictors.Cast>().ToArray()), schema, featureColumn, LabelColumn.Name); + } + + private protected override OneVersusAllModelParameters> TrainCore(IChannel ch, RoleMappedData data, int count) + { + // Train one-vs-all models. + var predictors = new CalibratedModelParametersBase[count]; + for (int i = 0; i < predictors.Length; i++) + { + ch.Info($"Training learner {i}"); + predictors[i] = (CalibratedModelParametersBase)TrainOne(ch, Trainer, data, i).Model; + } + return OneVersusAllModelParametersBuilder.Create(Host, TrainerOptions.UseProbabilities, predictors); + } + } + + /// + /// Strongly typed implementation of the where T is a . T can either be + /// a calibrated binary estimator of type , or a non calibrated binary estimary. + /// This cannot be used to turn a non calibrated binary classification estimator into its calibrated version. If that is required, use instead. + /// + /// + public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase> where T : class + { + /// + /// Constructs a trainer supplying a . + /// + /// The private for this estimator. + /// The legacy + internal OneVersusAllTrainer(IHostEnvironment env, Options options) + : base(env, options) + { + } + + /// + /// Initializes a new instance of . + /// + /// The instance. + /// An instance of a binary used as the base trainer. + /// The name of the label colum. + /// If true will treat missing labels as negative labels. + /// Use probabilities (vs. raw outputs) to identify top-score category. + internal OneVersusAllTrainer(IHostEnvironment env, + TScalarTrainer binaryEstimator, + string labelColumnName = DefaultColumnNames.Label, + bool imputeMissingLabelsAsNegative = false, + bool useProbabilities = true) + : base(env: env, binaryEstimator: binaryEstimator, labelColumnName: labelColumnName, imputeMissingLabelsAsNegative: imputeMissingLabelsAsNegative, useProbabilities: useProbabilities) + { + } + + private protected override ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch, + bool useProbabilities, IDataView view, string trainerLabel, + ISingleFeaturePredictionTransformer transformer) + { + if (useProbabilities) + { + var calibratedModel = transformer.Model as TDistPredictor; + + // If probabilities are requested and the Predictor is not calibrated or if it doesn't implement the right interface then throw. + Host.Check(calibratedModel != null, "Predictor is either not calibrated or does not implement the expected interface"); + + // REVIEW: restoring the RoleMappedData, as much as we can. + // not having the weight column on the data passed to the TrainCalibrator should be addressed. + var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName); + + return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName); + } + + return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName); + } + + private protected override MulticlassPredictionTransformer> FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName) + { + return new MulticlassPredictionTransformer>(Host, OneVersusAllModelParametersBuilder.Create(Host, useProbabilities, predictors.Cast().ToArray()), schema, featureColumn, LabelColumn.Name); + } + + private protected override OneVersusAllModelParameters TrainCore(IChannel ch, RoleMappedData data, int count) + { + // Train one-vs-all models. + var predictors = new T[count]; + for (int i = 0; i < predictors.Length; i++) + { + ch.Info($"Training learner {i}"); + predictors[i] = (T)TrainOne(ch, Trainer, data, i).Model; + } + return OneVersusAllModelParametersBuilder.Create(Host, TrainerOptions.UseProbabilities, predictors); + } + } + + /// + /// Class that holds the static create methods for the classes. + /// + [BestFriend] + internal static class OneVersusAllModelParametersBuilder { + [BestFriend] + internal static OneVersusAllModelParameters Create(IHost host, OneVersusAllModelParameters.OutputFormula outputFormula, T[] predictors) where T : class + { + return new OneVersusAllModelParameters(host, outputFormula, predictors); + } + + [BestFriend] + internal static OneVersusAllModelParameters Create(IHost host, bool useProbability, T[] predictors) where T : class + { + var outputFormula = useProbability ? OneVersusAllModelParameters.OutputFormula.ProbabilityNormalization : OneVersusAllModelParameters.OutputFormula.Raw; + + return Create(host, outputFormula, predictors); + } + + /// + /// Create a from an array of predictors. + /// + [BestFriend] + internal static OneVersusAllModelParameters Create(IHost host, T[] predictors) where T : class + { + Contracts.CheckValue(host, nameof(host)); + host.CheckNonEmpty(predictors, nameof(predictors)); + return Create(host, OneVersusAllModelParameters.OutputFormula.ProbabilityNormalization, predictors); + } + + [BestFriend] + internal static OneVersusAllModelParameters Create(IHost host, OneVersusAllModelParameters.OutputFormula outputFormula, TScalarPredictor[] predictors) + { + return new OneVersusAllModelParameters(host, outputFormula, predictors); + } + + [BestFriend] + internal static OneVersusAllModelParameters Create(IHost host, bool useProbability, TScalarPredictor[] predictors) + { + var outputFormula = useProbability ? OneVersusAllModelParameters.OutputFormula.ProbabilityNormalization : OneVersusAllModelParameters.OutputFormula.Raw; + + return Create(host, outputFormula, predictors); + } + + /// + /// Create a from an array of predictors. This is for backwards API compatability. + /// + [BestFriend] + internal static OneVersusAllModelParameters Create(IHost host, TScalarPredictor[] predictors) + { + Contracts.CheckValue(host, nameof(host)); + host.CheckNonEmpty(predictors, nameof(predictors)); + return Create(host, OneVersusAllModelParameters.OutputFormula.ProbabilityNormalization, predictors); + } + + } + + /// + /// Base model parameters for . + /// + public abstract class OneVersusAllModelParametersBase : ModelParametersBase>, IValueMapper, ICanSaveInSourceCode, @@ -262,12 +550,7 @@ private static VersionInfo GetVersionInfo() private const string SubPredictorFmt = "SubPredictor_{0:000}"; - private readonly ImplBase _impl; - - /// - /// Retrieves the model parameters. - /// - internal ImmutableArray SubModelParameters => _impl.Predictors.Cast().ToImmutableArray(); + private protected readonly ImplBase Impl; /// /// The type of the prediction task. @@ -289,78 +572,52 @@ internal enum OutputFormula { Raw = 0, ProbabilityNormalization = 1, Softmax = 2 private DataViewType DistType { get; } - bool ICanSavePfa.CanSavePfa => _impl.CanSavePfa; + bool ICanSavePfa.CanSavePfa => Impl.CanSavePfa; - [BestFriend] - internal static OneVersusAllModelParameters Create(IHost host, OutputFormula outputFormula, TScalarPredictor[] predictors) + internal OneVersusAllModelParametersBase(IHostEnvironment env, OutputFormula outputFormula, TScalarPredictor[] predictors) + : base(env, RegistrationName) { - ImplBase impl; - - using (var ch = host.Start("Creating OVA predictor")) + using (var ch = env.Start("Creating OVA predictor")) { if (outputFormula == OutputFormula.Softmax) { - impl = new ImplSoftmax(predictors); - return new OneVersusAllModelParameters(host, impl); + Impl = new ImplSoftmax(predictors); } // Caller of this function asks for probability output. We check if input predictor can produce probability. // If that predictor can't produce probability, ivmd will be null. - IValueMapperDist ivmd = null; - if (outputFormula == OutputFormula.ProbabilityNormalization && + else + { + IValueMapperDist ivmd = null; + if (outputFormula == OutputFormula.ProbabilityNormalization && ((ivmd = predictors[0] as IValueMapperDist) == null || ivmd.OutputType != NumberDataViewType.Single || ivmd.DistType != NumberDataViewType.Single)) - { - ch.Warning($"{nameof(OneVersusAllTrainer.Options.UseProbabilities)} specified with {nameof(OneVersusAllTrainer.Options.PredictorType)} that can't produce probabilities."); - ivmd = null; - } + { + ch.Warning($"{nameof(OneVersusAllTrainer.Options.UseProbabilities)} specified with {nameof(OneVersusAllTrainer.Options.PredictorType)} that can't produce probabilities."); + ivmd = null; + } - // If ivmd is null, either the user didn't ask for probability or the provided predictors can't produce probability. - if (ivmd != null) - { - var dists = new IValueMapperDist[predictors.Length]; - for (int i = 0; i < predictors.Length; ++i) - dists[i] = (IValueMapperDist)predictors[i]; - impl = new ImplDist(dists); + // If ivmd is null, either the user didn't ask for probability or the provided predictors can't produce probability. + if (ivmd != null) + { + var dists = new IValueMapperDist[predictors.Length]; + for (int i = 0; i < predictors.Length; ++i) + dists[i] = (IValueMapperDist)predictors[i]; + Impl = new ImplDist(dists); + } + else + Impl = new ImplRaw(predictors); } - else - impl = new ImplRaw(predictors); } - return new OneVersusAllModelParameters(host, impl); - } - - [BestFriend] - internal static OneVersusAllModelParameters Create(IHost host, bool useProbability, TScalarPredictor[] predictors) - { - var outputFormula = useProbability ? OutputFormula.ProbabilityNormalization : OutputFormula.Raw; - - return Create(host, outputFormula, predictors); - } + Host.AssertValue(Impl, nameof(Impl)); + Host.Assert(Utils.Size(Impl.Predictors) > 0); - /// - /// Create a from an array of predictors. - /// - [BestFriend] - internal static OneVersusAllModelParameters Create(IHost host, TScalarPredictor[] predictors) - { - Contracts.CheckValue(host, nameof(host)); - host.CheckNonEmpty(predictors, nameof(predictors)); - return Create(host, OutputFormula.ProbabilityNormalization, predictors); + DistType = new VectorDataViewType(NumberDataViewType.Single, Impl.Predictors.Length); } - private OneVersusAllModelParameters(IHostEnvironment env, ImplBase impl) - : base(env, RegistrationName) - { - Host.AssertValue(impl, nameof(impl)); - Host.Assert(Utils.Size(impl.Predictors) > 0); - - _impl = impl; - DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length); - } - - private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) + private protected OneVersusAllModelParametersBase(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx) { // *** Binary format *** @@ -374,24 +631,16 @@ private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) { var predictors = new IValueMapperDist[len]; LoadPredictors(Host, predictors, ctx); - _impl = new ImplDist(predictors); + Impl = new ImplDist(predictors); } else { var predictors = new TScalarPredictor[len]; LoadPredictors(Host, predictors, ctx); - _impl = new ImplRaw(predictors); + Impl = new ImplRaw(predictors); } - DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length); - } - - private static OneVersusAllModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(GetVersionInfo()); - return new OneVersusAllModelParameters(env, ctx); + DistType = new VectorDataViewType(NumberDataViewType.Single, Impl.Predictors.Length); } private static void LoadPredictors(IHostEnvironment env, TPredictor[] predictors, ModelLoadContext ctx) @@ -406,12 +655,12 @@ private protected override void SaveCore(ModelSaveContext ctx) base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); - var preds = _impl.Predictors; + var preds = Impl.Predictors; // *** Binary format *** // bool: useDist // int: predictor count - ctx.Writer.WriteBoolByte(_impl is ImplDist); + ctx.Writer.WriteBoolByte(Impl is ImplDist); ctx.Writer.Write(preds.Length); // Save other streams. @@ -423,12 +672,12 @@ JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) { Host.CheckValue(ctx, nameof(ctx)); Host.CheckValue(input, nameof(input)); - return _impl.SaveAsPfa(ctx, input); + return Impl.SaveAsPfa(ctx, input); } DataViewType IValueMapper.InputType { - get { return _impl.InputType; } + get { return Impl.InputType; } } DataViewType IValueMapper.OutputType @@ -440,7 +689,7 @@ ValueMapper IValueMapper.GetMapper() Host.Check(typeof(TIn) == typeof(VBuffer)); Host.Check(typeof(TOut) == typeof(VBuffer)); - return (ValueMapper)(Delegate)_impl.GetMapper(); + return (ValueMapper)(Delegate)Impl.GetMapper(); } void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema) @@ -448,7 +697,7 @@ void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema) Host.CheckValue(writer, nameof(writer)); Host.CheckValue(schema, nameof(schema)); - var preds = _impl.Predictors; + var preds = Impl.Predictors; writer.WriteLine("double[] outputs = new double[{0}];", preds.Length); for (int i = 0; i < preds.Length; i++) @@ -468,7 +717,7 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) Host.CheckValue(writer, nameof(writer)); Host.CheckValue(schema, nameof(schema)); - var preds = _impl.Predictors; + var preds = Impl.Predictors; for (int i = 0; i < preds.Length; i++) { @@ -483,7 +732,7 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) } } - private abstract class ImplBase : ISingleCanSavePfa + private protected abstract class ImplBase : ISingleCanSavePfa { public abstract DataViewType InputType { get; } public abstract IValueMapper[] Predictors { get; } @@ -759,4 +1008,48 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) } } } -} \ No newline at end of file + + /// + /// Model parameters for typed versions of . + /// + public sealed class OneVersusAllModelParameters : + OneVersusAllModelParametersBase where T : class + { + internal OneVersusAllModelParameters(IHostEnvironment env, OutputFormula outputFormula, T[] predictors) + : base(env, outputFormula, predictors.Cast>().ToArray()) + { + } + + private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) + : base(env, ctx) + { + } + + /// + /// Retrieves the model parameters. + /// + public ImmutableArray SubModelParameters => Impl.Predictors.Cast().ToImmutableArray(); + } + + /// + /// Model parameters for . + /// + public sealed class OneVersusAllModelParameters : + OneVersusAllModelParametersBase + { + internal OneVersusAllModelParameters(IHostEnvironment env, OutputFormula outputFormula, TScalarPredictor[] predictors) + : base(env, outputFormula, predictors) + { + } + + private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) + : base(env, ctx) + { + } + + /// + /// Retrieves the model parameters. + /// + public ImmutableArray SubModelParameters => Impl.Predictors.Cast().ToImmutableArray(); + } +} diff --git a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs index 3127f61116..a84b258bbd 100644 --- a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs @@ -725,7 +725,8 @@ private static ICalibratorTrainer GetCalibratorTrainerOrThrow(IExceptionContext /// /// Create a , which predicts a multiclass target using one-versus-all strategy with - /// the binary classification estimator specified by . + /// the binary classification estimator specified by . If you want to retrieve strongly typed model parameters, + /// use either the or methods instead. /// /// /// @@ -764,6 +765,93 @@ public static OneVersusAllTrainer OneVersusAll(this MulticlassClassifica return new OneVersusAllTrainer(env, est, labelColumnName, imputeMissingLabelsAsNegative, GetCalibratorTrainerOrThrow(env, calibrator), maximumCalibrationExampleCount, useProbabilities); } + /// + /// Create a , which predicts a multiclass target using one-versus-all strategy with + /// the binary classification estimator specified by . This method works with binary classifiers that + /// are either already calibrated, or non calibrated ones you don't want calibrated. If you need to have your classifier calibrated, use the + /// method instead. If you want to retrieve strongly typed model parameters, + /// you must use either this method or method. + /// + /// + /// + /// In one-versus-all strategy, a binary classification algorithm is used to train one classifier for each class, + /// which distinguishes that class from all other classes. Prediction is then performed by running these binary classifiers, + /// and choosing the prediction with the highest confidence score. + /// + /// + /// The multiclass classification catalog trainer object. + /// An instance of a binary used as the base trainer. + /// The name of the label column. + /// Whether to treat missing labels as having negative labels, instead of keeping them missing. + /// Use probabilities (vs. raw outputs) to identify top-score category. + /// The type of the model. This type parameter will usually be inferred automatically from . + /// + /// + /// + /// + public static OneVersusAllTrainer OneVersusAllTyped(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + ITrainerEstimator, TModel> binaryEstimator, + string labelColumnName = DefaultColumnNames.Label, + bool imputeMissingLabelsAsNegative = false, + bool useProbabilities = true) + where TModel : class + { + Contracts.CheckValue(catalog, nameof(catalog)); + var env = CatalogUtils.GetEnvironment(catalog); + if (!(binaryEstimator is ITrainerEstimator>, IPredictorProducing> est)) + throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); + return new OneVersusAllTrainer(env, est, labelColumnName, imputeMissingLabelsAsNegative, useProbabilities); + } + + /// + /// Create a , which predicts a multiclass target using one-versus-all strategy with + /// the binary classification estimator specified by .This method works with binary classifiers that + /// are not calibrated and need to be calibrated before use. Due to the type of estimator changing (from uncalibrated to calibrated), you must manually + /// specify both the type of the model and the type of the calibrator. If your classifier is already calibrated or it does not need to be, use the + /// method instead. If you want to retrieve strongly typed model parameters, you must either use this method or + /// method. + /// + /// + /// + /// In one-versus-all strategy, a binary classification algorithm is used to train one classifier for each class, + /// which distinguishes that class from all other classes. Prediction is then performed by running these binary classifiers, + /// and choosing the prediction with the highest confidence score. + /// + /// + /// The multiclass classification catalog trainer object. + /// An instance of a binary used as the base trainer. + /// The calibrator. If a calibrator is not explicitly provided, it will default to + /// The name of the label column. + /// Whether to treat missing labels as having negative labels, instead of keeping them missing. + /// Number of instances to train the calibrator. + /// Use probabilities (vs. raw outputs) to identify top-score category. + /// The type of the model. This type parameter cannot be inferred and must be specified manually. It is usually a . + /// The calibrator for the model. This type parameter cannot be inferred automatically and must be specified manually and must be of type . + /// + /// + /// + /// + public static OneVersusAllTrainer OneVersusAllUnCalibratedToCalibrated(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + ITrainerEstimator, TModelIn> binaryEstimator, + string labelColumnName = DefaultColumnNames.Label, + bool imputeMissingLabelsAsNegative = false, + IEstimator> calibrator = null, + int maximumCalibrationExampleCount = 1000000000, + bool useProbabilities = true) + where TModelIn : class + where TCalibrator : class, ICalibrator + { + Contracts.CheckValue(catalog, nameof(catalog)); + var env = CatalogUtils.GetEnvironment(catalog); + if (!(binaryEstimator is ITrainerEstimator>, IPredictorProducing> est)) + throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); + return new OneVersusAllTrainer(env, est, labelColumnName, imputeMissingLabelsAsNegative, GetCalibratorTrainerOrThrow(env, calibrator), maximumCalibrationExampleCount, useProbabilities); + } + /// /// Create a , which predicts a multiclass target using pairwise coupling strategy with /// the binary classification estimator specified by . diff --git a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs index dcfacdec98..1f9bc269f6 100644 --- a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs +++ b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Calibrators; using Microsoft.ML.Data; using Microsoft.ML.Trainers; using Microsoft.ML.Trainers.FastTree; @@ -35,13 +36,22 @@ public void OvaLogisticRegression() // Pipeline var logReg = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(); var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll(logReg, useProbabilities: false); + var pipelineTyped = mlContext.MulticlassClassification.Trainers.OneVersusAllTyped(logReg, useProbabilities: false); var model = pipeline.Fit(data); var predictions = model.Transform(data); + var modelTyped = pipelineTyped.Fit(data); + var predictionsTyped = modelTyped.Transform(data); + // Metrics var metrics = mlContext.MulticlassClassification.Evaluate(predictions); Assert.True(metrics.MicroAccuracy > 0.94); + + var metricsTyped = mlContext.MulticlassClassification.Evaluate(predictionsTyped); + Assert.True(metricsTyped.MicroAccuracy > 0.94); + + Assert.Equal(metrics.MicroAccuracy, metricsTyped.MicroAccuracy); } [Fact] @@ -69,15 +79,73 @@ public void OvaAveragedPerceptron() // Pipeline var ap = mlContext.BinaryClassification.Trainers.AveragedPerceptron( new AveragedPerceptronTrainer.Options { Shuffle = true }); + var apTyped = mlContext.BinaryClassification.Trainers.AveragedPerceptron( + new AveragedPerceptronTrainer.Options { Shuffle = true }); var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll(ap, useProbabilities: false); + var pipelineTyped = mlContext.MulticlassClassification.Trainers.OneVersusAllTyped(apTyped, useProbabilities: false); var model = pipeline.Fit(data); var predictions = model.Transform(data); + var modelTyped = pipelineTyped.Fit(data); + var predictionsTyped = modelTyped.Transform(data); + // Metrics var metrics = mlContext.MulticlassClassification.Evaluate(predictions); Assert.True(metrics.MicroAccuracy > 0.66); + + var metricsTyped = mlContext.MulticlassClassification.Evaluate(predictionsTyped); + Assert.True(metricsTyped.MicroAccuracy > 0.66); + + Assert.Equal(metrics.MicroAccuracy, metricsTyped.MicroAccuracy); + } + + [Fact] + public void OvaCalibratedAveragedPerceptron() + { + string dataPath = GetDataPath("iris.txt"); + + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + var mlContext = new MLContext(seed: 1); + var reader = new TextLoader(mlContext, new TextLoader.Options() + { + Columns = new[] + { + new TextLoader.Column("Label", DataKind.Single, 0), + new TextLoader.Column("Features", DataKind.Single, new [] { new TextLoader.Range(1, 4) }), + } + }); + + // Data + var textData = reader.Load(GetDataPath(dataPath)); + var data = mlContext.Data.Cache(mlContext.Transforms.Conversion.MapValueToKey("Label") + .Fit(textData).Transform(textData)); + + // Pipeline + var ap = mlContext.BinaryClassification.Trainers.AveragedPerceptron( + new AveragedPerceptronTrainer.Options { Shuffle = true }); + var apTyped = mlContext.BinaryClassification.Trainers.AveragedPerceptron( + new AveragedPerceptronTrainer.Options { Shuffle = true }); + + var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll(ap); + var pipelineTyped = mlContext.MulticlassClassification.Trainers.OneVersusAllUnCalibratedToCalibrated(apTyped); + + var model = pipeline.Fit(data); + var predictions = model.Transform(data); + + var modelTyped = pipelineTyped.Fit(data); + var predictionsTyped = modelTyped.Transform(data); + + // Metrics + var metrics = mlContext.MulticlassClassification.Evaluate(predictions); + Assert.True(metrics.MicroAccuracy > 0.95); + + var metricsTyped = mlContext.MulticlassClassification.Evaluate(predictionsTyped); + Assert.True(metricsTyped.MicroAccuracy > 0.95); + + Assert.Equal(metrics.MicroAccuracy, metricsTyped.MicroAccuracy); } [Fact] @@ -107,12 +175,24 @@ public void OvaFastTree() mlContext.BinaryClassification.Trainers.FastTree(new FastTreeBinaryTrainer.Options { NumberOfThreads = 1 }), useProbabilities: false); + var pipelineTyped = mlContext.MulticlassClassification.Trainers.OneVersusAllTyped( + mlContext.BinaryClassification.Trainers.FastTree(new FastTreeBinaryTrainer.Options { NumberOfThreads = 1 }), + useProbabilities: false); + var model = pipeline.Fit(data); var predictions = model.Transform(data); + var modelTyped = pipelineTyped.Fit(data); + var predictionsTyped = modelTyped.Transform(data); + // Metrics var metrics = mlContext.MulticlassClassification.Evaluate(predictions); Assert.True(metrics.MicroAccuracy > 0.99); + + var metricsTyped = mlContext.MulticlassClassification.Evaluate(predictionsTyped); + Assert.True(metricsTyped.MicroAccuracy > 0.99); + + Assert.Equal(metrics.MicroAccuracy, metricsTyped.MicroAccuracy); } [Fact] @@ -123,6 +203,8 @@ public void OvaLinearSvm() // Create a new context for ML.NET operations. It can be used for exception tracking and logging, // as a catalog of available operations and as the source of randomness. var mlContext = new MLContext(seed: 1); + var mlContextTyped = new MLContext(seed: 1); + var reader = new TextLoader(mlContext, new TextLoader.Options() { Columns = new[] @@ -131,22 +213,47 @@ public void OvaLinearSvm() new TextLoader.Column("Features", DataKind.Single, new [] { new TextLoader.Range(1, 4) }), } }); + + // REVIEW: readerTyped and dataTyped aren't used anywhere in this test, but if I take them out + // the test will fail. It seems to me that something is changing state somewhere, maybe in the cache? + var readerTyped = new TextLoader(mlContextTyped, new TextLoader.Options() + { + Columns = new[] + { + new TextLoader.Column("Label", DataKind.Single, 0), + new TextLoader.Column("Features", DataKind.Single, new [] { new TextLoader.Range(1, 4) }), + } + }); // Data var textData = reader.Load(GetDataPath(dataPath)); var data = mlContext.Data.Cache(mlContext.Transforms.Conversion.MapValueToKey("Label") .Fit(textData).Transform(textData)); + var dataTyped = mlContextTyped.Data.Cache(mlContextTyped.Transforms.Conversion.MapValueToKey("Label") + .Fit(textData).Transform(textData)); // Pipeline var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll( mlContext.BinaryClassification.Trainers.LinearSvm(new LinearSvmTrainer.Options { NumberOfIterations = 100 }), useProbabilities: false); + var pipelineTyped = mlContextTyped.MulticlassClassification.Trainers.OneVersusAllTyped( + mlContextTyped.BinaryClassification.Trainers.LinearSvm(new LinearSvmTrainer.Options { NumberOfIterations = 100 }), + useProbabilities: false); + var model = pipeline.Fit(data); var predictions = model.Transform(data); + var modelTyped = pipelineTyped.Fit(data); + var predictionsTyped = modelTyped.Transform(data); + // Metrics var metrics = mlContext.MulticlassClassification.Evaluate(predictions); - Assert.True(metrics.MicroAccuracy > 0.83); + Assert.True(metrics.MicroAccuracy > 0.95); + + var metricsTyped = mlContextTyped.MulticlassClassification.Evaluate(predictionsTyped); + Assert.True(metricsTyped.MicroAccuracy > 0.95); + + Assert.Equal(metrics.MicroAccuracy, metricsTyped.MicroAccuracy); } } } diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs index 9f94bc2560..7dc885a336 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs @@ -52,6 +52,62 @@ public void OVAUncalibrated() Done(); } + /// + /// Tests passing in a non calibrated trainer + /// + [Fact] + public void OVATypedUncalibrated() + { + var (pipeline, data) = GetMulticlassPipeline(); + var sdcaTrainer = ML.BinaryClassification.Trainers.SdcaNonCalibrated( + new SdcaNonCalibratedBinaryTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1 }); + + pipeline = pipeline.Append(ML.MulticlassClassification.Trainers.OneVersusAllTyped(sdcaTrainer, useProbabilities: false)) + .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); + + TestEstimatorCore(pipeline, data); + Done(); + } + + /// + /// Test passing in a trainer that is already calibrated + /// + [Fact] + public void OVATypedCalibrated() + { + var (pipeline, data) = GetMulticlassPipeline(); + var sdcaTrainer = ML.BinaryClassification.Trainers.SgdCalibrated( + new SgdCalibratedTrainer.Options { Shuffle = true, NumberOfThreads = 1 }); + + pipeline = pipeline.Append(ML.MulticlassClassification.Trainers.OneVersusAllTyped(sdcaTrainer, useProbabilities: true)) + .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); + + TestEstimatorCore(pipeline, data); + Done(); + } + + /// + /// Tests passing an uncalibrated trainer with a calibrator and having it be auto calibrated + /// using the strongly typed API. + /// + [Fact] + public void OVATypedUncalibratedToCalibrated() + { + var (pipeline, data) = GetMulticlassPipeline(); + var calibrator = new PlattCalibratorEstimator(Env); + var averagePerceptron = ML.BinaryClassification.Trainers.AveragedPerceptron( + new AveragedPerceptronTrainer.Options { Shuffle = true }); + + var ova = ML.MulticlassClassification.Trainers.OneVersusAllUnCalibratedToCalibrated(averagePerceptron, imputeMissingLabelsAsNegative: true, + calibrator: calibrator, maximumCalibrationExampleCount: 10000, useProbabilities: true); + + pipeline = pipeline.Append(ova) + .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); + + TestEstimatorCore(pipeline, data); + Done(); + } + /// /// Pairwise Coupling trainer ///