diff --git a/src/Microsoft.ML.Core/Prediction/IPredictor.cs b/src/Microsoft.ML.Core/Prediction/IPredictor.cs
index 118375f134..728afe7802 100644
--- a/src/Microsoft.ML.Core/Prediction/IPredictor.cs
+++ b/src/Microsoft.ML.Core/Prediction/IPredictor.cs
@@ -12,8 +12,7 @@ namespace Microsoft.ML
/// and it is still useful, but for things based on
/// the idiom, it is inappropriate.
///
- [BestFriend]
- internal enum PredictionKind
+ public enum PredictionKind
{
Unknown = 0,
Custom = 1,
diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
index c4cc248dc7..d3735e6460 100644
--- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs
+++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
@@ -528,7 +528,7 @@ internal sealed class ParameterMixingCalibratedModelParameters _featureWeights;
+ private readonly TSubModel _featureWeights;
internal ParameterMixingCalibratedModelParameters(IHostEnvironment env, TSubModel predictor, TCalibrator calibrator)
: base(env, RegistrationName, predictor, calibrator)
@@ -536,7 +536,7 @@ internal ParameterMixingCalibratedModelParameters(IHostEnvironment env, TSubMode
Host.Check(predictor is IParameterMixer, "Predictor does not implement " + nameof(IParameterMixer));
Host.Check(calibrator is IParameterMixer, "Calibrator does not implement " + nameof(IParameterMixer));
Host.Assert(predictor is IPredictorWithFeatureWeights);
- _featureWeights = predictor as IPredictorWithFeatureWeights;
+ _featureWeights = predictor;
}
internal const string LoaderSignature = "PMixCaliPredExec";
@@ -558,7 +558,7 @@ private ParameterMixingCalibratedModelParameters(IHostEnvironment env, ModelLoad
{
Host.Check(SubModel is IParameterMixer, "Predictor does not implement " + nameof(IParameterMixer));
Host.Check(SubModel is IPredictorWithFeatureWeights, "Predictor does not implement " + nameof(IPredictorWithFeatureWeights));
- _featureWeights = SubModel as IPredictorWithFeatureWeights;
+ _featureWeights = SubModel;
}
private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx)
@@ -579,7 +579,7 @@ void ICanSaveModel.Save(ModelSaveContext ctx)
public void GetFeatureWeights(ref VBuffer weights)
{
- _featureWeights.GetFeatureWeights(ref weights);
+ ((IPredictorWithFeatureWeights)_featureWeights).GetFeatureWeights(ref weights);
}
IParameterMixer IParameterMixer.CombineParameters(IList> models)
@@ -879,6 +879,15 @@ public static IPredictor GetCalibratedPredictor(IHostEnvironment env, IChannel c
return CreateCalibratedPredictor(env, (IPredictorProducing)predictor, trainedCalibrator);
}
+ public static CalibratedModelParametersBase GetCalibratedPredictor(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer,
+ TSubPredictor predictor, RoleMappedData data, int maxRows = _maxCalibrationExamples)
+ where TSubPredictor : class
+ where TCalibrator : class, ICalibrator
+ {
+ var trainedCalibrator = TrainCalibrator(env, ch, caliTrainer, (IPredictor)predictor, data, maxRows) as TCalibrator;
+ return (CalibratedModelParametersBase)CreateCalibratedPredictor(env, predictor, trainedCalibrator);
+ }
+
public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, IDataView scored, string labelColumn, string scoreColumn, string weightColumn = null, int maxRows = _maxCalibrationExamples)
{
Contracts.CheckValue(env, nameof(env));
@@ -963,12 +972,12 @@ public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICa
}
public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironment env, TSubPredictor predictor, TCalibrator cali)
- where TSubPredictor : class, IPredictorProducing
+ where TSubPredictor : class
where TCalibrator : class, ICalibrator
{
Contracts.Assert(predictor != null);
if (cali == null)
- return predictor;
+ return (IPredictorProducing)predictor;
for (; ; )
{
@@ -980,7 +989,7 @@ public static IPredictorProducing CreateCalibratedPredictor;
if (predWithFeatureScores != null && predictor is IParameterMixer && cali is IParameterMixer)
- return new ParameterMixingCalibratedModelParameters, TCalibrator>(env, predWithFeatureScores, cali);
+ return new ParameterMixingCalibratedModelParameters(env, predictor, cali);
if (predictor is IValueMapper)
return new ValueMapperCalibratedModelParameters(env, predictor, cali);
diff --git a/src/Microsoft.ML.EntryPoints/ModelOperations.cs b/src/Microsoft.ML.EntryPoints/ModelOperations.cs
index 7b1b2afc8a..c44f86974f 100644
--- a/src/Microsoft.ML.EntryPoints/ModelOperations.cs
+++ b/src/Microsoft.ML.EntryPoints/ModelOperations.cs
@@ -155,7 +155,7 @@ public static PredictorModelOutput CombineOvaModels(IHostEnvironment env, Combin
return new PredictorModelOutput
{
PredictorModel = new PredictorModelImpl(env, data, input.TrainingData,
- OneVersusAllModelParameters.Create(host, input.UseProbabilities,
+ OneVersusAllModelParametersBuilder.Create(host, input.UseProbabilities,
input.ModelArray.Select(p => p.Predictor as IPredictorProducing).ToArray()))
};
}
diff --git a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs
index 353856e7bd..589be23714 100644
--- a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs
+++ b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs
@@ -198,9 +198,9 @@ private protected override OneVersusAllModelParameters CreatePredictor()
}
string obj = (string)GetGbmParameters()["objective"];
if (obj == "multiclass")
- return OneVersusAllModelParameters.Create(Host, OneVersusAllModelParameters.OutputFormula.Softmax, predictors);
+ return OneVersusAllModelParametersBuilder.Create(Host, OneVersusAllModelParameters.OutputFormula.Softmax, predictors);
else
- return OneVersusAllModelParameters.Create(Host, predictors);
+ return OneVersusAllModelParametersBuilder.Create(Host, predictors);
}
private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs
index 2ae05af908..741923c4a2 100644
--- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs
+++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs
@@ -37,6 +37,7 @@ namespace Microsoft.ML.Trainers
using TDistPredictor = IDistPredictorProducing;
using TScalarPredictor = IPredictorProducing;
using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>;
+
///
/// The for training a one-versus-all multi-class classifier that uses the specified binary classifier.
///
@@ -82,7 +83,7 @@ namespace Microsoft.ML.Trainers
///
///
///
- public sealed class OneVersusAllTrainer : MetaMulticlassTrainer, OneVersusAllModelParameters>
+ public abstract class OneVersusAllTrainerBase : MetaMulticlassTrainer, T> where T : class
{
internal const string LoadNameValue = "OVA";
internal const string UserNameValue = "One-vs-All";
@@ -90,10 +91,10 @@ public sealed class OneVersusAllTrainer : MetaMulticlassTrainer
- /// Options passed to
+ /// Options passed to
///
internal sealed class Options : OptionsBase
{
@@ -106,27 +107,27 @@ internal sealed class Options : OptionsBase
}
///
- /// Constructs a trainer supplying a .
+ /// Constructs a trainer supplying a .
///
/// The private for this estimator.
/// The legacy
- internal OneVersusAllTrainer(IHostEnvironment env, Options options)
+ internal OneVersusAllTrainerBase(IHostEnvironment env, Options options)
: base(env, options, LoadNameValue)
{
- _options = options;
+ TrainerOptions = options;
}
///
- /// Initializes a new instance of .
+ /// Initializes a new instance of .
///
/// The instance.
/// An instance of a binary used as the base trainer.
- /// The calibrator. If a calibrator is not provided, it will default to
+ /// /// The calibrator. If a calibrator is not provided, it will default to
/// The name of the label colum.
/// If true will treat missing labels as negative labels.
- /// Number of instances to train the calibrator.
+ /// /// Number of instances to train the calibrator.
/// Use probabilities (vs. raw outputs) to identify top-score category.
- internal OneVersusAllTrainer(IHostEnvironment env,
+ internal OneVersusAllTrainerBase(IHostEnvironment env,
TScalarTrainer binaryEstimator,
string labelColumnName = DefaultColumnNames.Label,
bool imputeMissingLabelsAsNegative = false,
@@ -142,23 +143,25 @@ internal OneVersusAllTrainer(IHostEnvironment env,
LoadNameValue, labelColumnName, binaryEstimator, calibrator)
{
Host.CheckValue(labelColumnName, nameof(labelColumnName), "Label column should not be null.");
- _options = (Options)Args;
- _options.UseProbabilities = useProbabilities;
+ TrainerOptions = (Options)Args;
+ TrainerOptions.UseProbabilities = useProbabilities;
}
- private protected override OneVersusAllModelParameters TrainCore(IChannel ch, RoleMappedData data, int count)
- {
- // Train one-vs-all models.
- var predictors = new TScalarPredictor[count];
- for (int i = 0; i < predictors.Length; i++)
- {
- ch.Info($"Training learner {i}");
- predictors[i] = TrainOne(ch, Trainer, data, i).Model;
- }
- return OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors);
- }
-
- private ISingleFeaturePredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls)
+ ///
+ /// Training helper method that is called by . This allows the
+ /// classes that inherit from this class to do any custom training changes needed, such as casting.
+ ///
+ /// The instance.
+ /// Whether probabilities should be used or not. Is pulled from the trainer .
+ /// The that has the data.
+ /// The label for the trainer.
+ /// The used by the trainer
+ ///
+ private protected abstract ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch,
+ bool useProbabilities, IDataView view, string trainerLabel,
+ ISingleFeaturePredictionTransformer transformer);
+
+ private protected ISingleFeaturePredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls)
{
var view = MapLabels(data, cls);
@@ -168,22 +171,7 @@ private ISingleFeaturePredictionTransformer TrainOne(IChannel
// this is currently unsupported.
var transformer = trainer.Fit(view);
- if (_options.UseProbabilities)
- {
- var calibratedModel = transformer.Model as TDistPredictor;
-
- // REVIEW: restoring the RoleMappedData, as much as we can.
- // not having the weight column on the data passed to the TrainCalibrator should be addressed.
- var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName);
-
- if (calibratedModel == null)
- calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor;
-
- Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface");
- return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName);
- }
-
- return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName);
+ return TrainOneHelper(ch, TrainerOptions.UseProbabilities, view, trainerLabel, transformer);
}
private IDataView MapLabels(RoleMappedData data, int cls)
@@ -202,10 +190,23 @@ private IDataView MapLabels(RoleMappedData data, int cls)
throw Host.ExceptNotSupp($"Label column type is not supported by OneVersusAllTrainer: {label.Type.RawType}");
}
+ ///
+ /// Fit helper method that is called by . This allows the
+ /// classes that inherit from this class to do any custom fit changes needed, such as casting.
+ ///
+ /// The .
+ /// Whether probabilities should be used or not. Is pulled from the trainer .
+ /// The array of used.
+ /// The of the transformer.
+ /// The feature column.
+ /// The name of the label column.
+ ///
+ private protected abstract MulticlassPredictionTransformer FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName);
+
/// Trains a model.
/// The input data.
/// A model./>
- public override MulticlassPredictionTransformer Fit(IDataView input)
+ public override MulticlassPredictionTransformer Fit(IDataView input)
{
var roles = new KeyValuePair[1];
roles[0] = new KeyValuePair(new CR(DefaultColumnNames.Label), LabelColumn.Name);
@@ -227,19 +228,306 @@ public override MulticlassPredictionTransformer Fit
var transformer = TrainOne(ch, Trainer, td, i);
featureColumn = transformer.FeatureColumnName;
}
-
predictors[i] = TrainOne(ch, Trainer, td, i).Model;
+
}
}
+ return FitHelper(Host, TrainerOptions.UseProbabilities, predictors, input.Schema, featureColumn, LabelColumn.Name);
+ }
+ }
+
+ ///
+ /// Implementation of the where T is a
+ /// to maintain api compatability.
+ ///
+ public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase
+ {
+ ///
+ /// Constructs a trainer supplying a .
+ ///
+ /// The private for this estimator.
+ /// The legacy
+ internal OneVersusAllTrainer(IHostEnvironment env, Options options)
+ : base(env, options)
+ {
+ }
+
+ ///
+ /// Initializes a new instance of .
+ ///
+ /// The instance.
+ /// An instance of a binary used as the base trainer.
+ /// /// The calibrator. If a calibrator is not provided, it will default to
+ /// The name of the label colum.
+ /// If true will treat missing labels as negative labels.
+ /// /// Number of instances to train the calibrator.
+ /// Use probabilities (vs. raw outputs) to identify top-score category.
+ internal OneVersusAllTrainer(IHostEnvironment env,
+ TScalarTrainer binaryEstimator,
+ string labelColumnName = DefaultColumnNames.Label,
+ bool imputeMissingLabelsAsNegative = false,
+ ICalibratorTrainer calibrator = null,
+ int maximumCalibrationExampleCount = 1000000000,
+ bool useProbabilities = true)
+ : base(env, binaryEstimator, labelColumnName, imputeMissingLabelsAsNegative, calibrator, maximumCalibrationExampleCount, useProbabilities)
+ {
+ }
+
+ private protected override OneVersusAllModelParameters TrainCore(IChannel ch, RoleMappedData data, int count)
+ {
+ // Train one-vs-all models.
+ var predictors = new TScalarPredictor[count];
+ for (int i = 0; i < predictors.Length; i++)
+ {
+ ch.Info($"Training learner {i}");
+ predictors[i] = TrainOne(ch, Trainer, data, i).Model;
+ }
+ return OneVersusAllModelParametersBuilder.Create(Host, TrainerOptions.UseProbabilities, predictors) as OneVersusAllModelParameters;
+ }
- return new MulticlassPredictionTransformer(Host, OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors), input.Schema, featureColumn, LabelColumn.Name);
+ private protected override MulticlassPredictionTransformer FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName)
+ {
+ return new MulticlassPredictionTransformer(Host, OneVersusAllModelParametersBuilder.Create(Host, useProbabilities, predictors), schema, featureColumn, LabelColumn.Name);
+ }
+
+ private protected override ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch,
+ bool useProbabilities, IDataView view, string trainerLabel,
+ ISingleFeaturePredictionTransformer transformer)
+ {
+ if (useProbabilities)
+ {
+ var calibratedModel = transformer.Model as TDistPredictor;
+
+ // REVIEW: restoring the RoleMappedData, as much as we can.
+ // not having the weight column on the data passed to the TrainCalibrator should be addressed.
+ var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName);
+
+ if (calibratedModel == null)
+ calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor;
+
+ Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface");
+ return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName);
+ }
+
+ return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName);
}
}
///
- /// Model parameters for .
+ /// Strongly typed implementation of the where T is a of type
+ /// This is used to turn a non calibrated binary classification estimator into its calibrated version.
///
- public sealed class OneVersusAllModelParameters :
+ ///
+ ///
+ public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase>>
+ where TSubPredictor : class
+ where TCalibrator : class, ICalibrator
+ {
+ ///
+ /// Constructs a trainer supplying a .
+ ///
+ /// The private for this estimator.
+ /// The legacy
+ internal OneVersusAllTrainer(IHostEnvironment env, Options options)
+ : base(env, options)
+ {
+ }
+
+ ///
+ /// Initializes a new instance of .
+ ///
+ /// The instance.
+ /// An instance of a binary used as the base trainer.
+ /// The calibrator. If a calibrator is not provided, it will default to
+ /// The name of the label colum.
+ /// If true will treat missing labels as negative labels.
+ /// Number of instances to train the calibrator.
+ /// Use probabilities (vs. raw outputs) to identify top-score category.
+ internal OneVersusAllTrainer(IHostEnvironment env,
+ TScalarTrainer binaryEstimator,
+ string labelColumnName = DefaultColumnNames.Label,
+ bool imputeMissingLabelsAsNegative = false,
+ ICalibratorTrainer calibrator = null,
+ int maximumCalibrationExampleCount = 1000000000,
+ bool useProbabilities = true)
+ : base(env, binaryEstimator, labelColumnName, imputeMissingLabelsAsNegative, calibrator, maximumCalibrationExampleCount, useProbabilities)
+ {
+ }
+
+ private protected override ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch,
+ bool useProbabilities, IDataView view, string trainerLabel,
+ ISingleFeaturePredictionTransformer transformer)
+ {
+ if (useProbabilities)
+ {
+ var calibratedModel = transformer.Model as TDistPredictor;
+
+ // REVIEW: restoring the RoleMappedData, as much as we can.
+ // not having the weight column on the data passed to the TrainCalibrator should be addressed.
+ var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName);
+
+ if (calibratedModel == null)
+ calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, (TSubPredictor)transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor;
+
+ Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface");
+ return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName);
+ }
+
+ return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName);
+ }
+ private protected override MulticlassPredictionTransformer>> FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName)
+ {
+ return new MulticlassPredictionTransformer>>(Host, OneVersusAllModelParametersBuilder.Create(Host, useProbabilities, predictors.Cast>().ToArray()), schema, featureColumn, LabelColumn.Name);
+ }
+
+ private protected override OneVersusAllModelParameters> TrainCore(IChannel ch, RoleMappedData data, int count)
+ {
+ // Train one-vs-all models.
+ var predictors = new CalibratedModelParametersBase[count];
+ for (int i = 0; i < predictors.Length; i++)
+ {
+ ch.Info($"Training learner {i}");
+ predictors[i] = (CalibratedModelParametersBase)TrainOne(ch, Trainer, data, i).Model;
+ }
+ return OneVersusAllModelParametersBuilder.Create(Host, TrainerOptions.UseProbabilities, predictors);
+ }
+ }
+
+ ///
+ /// Strongly typed implementation of the where T is a . T can either be
+ /// a calibrated binary estimator of type , or a non calibrated binary estimary.
+ /// This cannot be used to turn a non calibrated binary classification estimator into its calibrated version. If that is required, use instead.
+ ///
+ ///
+ public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase> where T : class
+ {
+ ///
+ /// Constructs a trainer supplying a .
+ ///
+ /// The private for this estimator.
+ /// The legacy
+ internal OneVersusAllTrainer(IHostEnvironment env, Options options)
+ : base(env, options)
+ {
+ }
+
+ ///
+ /// Initializes a new instance of .
+ ///
+ /// The instance.
+ /// An instance of a binary used as the base trainer.
+ /// The name of the label colum.
+ /// If true will treat missing labels as negative labels.
+ /// Use probabilities (vs. raw outputs) to identify top-score category.
+ internal OneVersusAllTrainer(IHostEnvironment env,
+ TScalarTrainer binaryEstimator,
+ string labelColumnName = DefaultColumnNames.Label,
+ bool imputeMissingLabelsAsNegative = false,
+ bool useProbabilities = true)
+ : base(env: env, binaryEstimator: binaryEstimator, labelColumnName: labelColumnName, imputeMissingLabelsAsNegative: imputeMissingLabelsAsNegative, useProbabilities: useProbabilities)
+ {
+ }
+
+ private protected override ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch,
+ bool useProbabilities, IDataView view, string trainerLabel,
+ ISingleFeaturePredictionTransformer transformer)
+ {
+ if (useProbabilities)
+ {
+ var calibratedModel = transformer.Model as TDistPredictor;
+
+ // If probabilities are requested and the Predictor is not calibrated or if it doesn't implement the right interface then throw.
+ Host.Check(calibratedModel != null, "Predictor is either not calibrated or does not implement the expected interface");
+
+ // REVIEW: restoring the RoleMappedData, as much as we can.
+ // not having the weight column on the data passed to the TrainCalibrator should be addressed.
+ var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName);
+
+ return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName);
+ }
+
+ return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName);
+ }
+
+ private protected override MulticlassPredictionTransformer> FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName)
+ {
+ return new MulticlassPredictionTransformer>(Host, OneVersusAllModelParametersBuilder.Create(Host, useProbabilities, predictors.Cast().ToArray()), schema, featureColumn, LabelColumn.Name);
+ }
+
+ private protected override OneVersusAllModelParameters TrainCore(IChannel ch, RoleMappedData data, int count)
+ {
+ // Train one-vs-all models.
+ var predictors = new T[count];
+ for (int i = 0; i < predictors.Length; i++)
+ {
+ ch.Info($"Training learner {i}");
+ predictors[i] = (T)TrainOne(ch, Trainer, data, i).Model;
+ }
+ return OneVersusAllModelParametersBuilder.Create(Host, TrainerOptions.UseProbabilities, predictors);
+ }
+ }
+
+ ///
+ /// Class that holds the static create methods for the classes.
+ ///
+ [BestFriend]
+ internal static class OneVersusAllModelParametersBuilder {
+ [BestFriend]
+ internal static OneVersusAllModelParameters Create(IHost host, OneVersusAllModelParameters.OutputFormula outputFormula, T[] predictors) where T : class
+ {
+ return new OneVersusAllModelParameters(host, outputFormula, predictors);
+ }
+
+ [BestFriend]
+ internal static OneVersusAllModelParameters Create(IHost host, bool useProbability, T[] predictors) where T : class
+ {
+ var outputFormula = useProbability ? OneVersusAllModelParameters.OutputFormula.ProbabilityNormalization : OneVersusAllModelParameters.OutputFormula.Raw;
+
+ return Create(host, outputFormula, predictors);
+ }
+
+ ///
+ /// Create a from an array of predictors.
+ ///
+ [BestFriend]
+ internal static OneVersusAllModelParameters Create(IHost host, T[] predictors) where T : class
+ {
+ Contracts.CheckValue(host, nameof(host));
+ host.CheckNonEmpty(predictors, nameof(predictors));
+ return Create(host, OneVersusAllModelParameters.OutputFormula.ProbabilityNormalization, predictors);
+ }
+
+ [BestFriend]
+ internal static OneVersusAllModelParameters Create(IHost host, OneVersusAllModelParameters.OutputFormula outputFormula, TScalarPredictor[] predictors)
+ {
+ return new OneVersusAllModelParameters(host, outputFormula, predictors);
+ }
+
+ [BestFriend]
+ internal static OneVersusAllModelParameters Create(IHost host, bool useProbability, TScalarPredictor[] predictors)
+ {
+ var outputFormula = useProbability ? OneVersusAllModelParameters.OutputFormula.ProbabilityNormalization : OneVersusAllModelParameters.OutputFormula.Raw;
+
+ return Create(host, outputFormula, predictors);
+ }
+
+ ///
+ /// Create a from an array of predictors. This is for backwards API compatability.
+ ///
+ [BestFriend]
+ internal static OneVersusAllModelParameters Create(IHost host, TScalarPredictor[] predictors)
+ {
+ Contracts.CheckValue(host, nameof(host));
+ host.CheckNonEmpty(predictors, nameof(predictors));
+ return Create(host, OneVersusAllModelParameters.OutputFormula.ProbabilityNormalization, predictors);
+ }
+
+ }
+
+ ///
+ /// Base model parameters for .
+ ///
+ public abstract class OneVersusAllModelParametersBase :
ModelParametersBase>,
IValueMapper,
ICanSaveInSourceCode,
@@ -262,12 +550,7 @@ private static VersionInfo GetVersionInfo()
private const string SubPredictorFmt = "SubPredictor_{0:000}";
- private readonly ImplBase _impl;
-
- ///
- /// Retrieves the model parameters.
- ///
- internal ImmutableArray