From c43302eab4df90d44729e572d2dff0d16aa0b915 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Thu, 11 Jul 2019 09:30:24 -0700 Subject: [PATCH 01/19] base class in place --- .../OneVersusAllTrainer.cs | 702 ++++++++++++++++-- 1 file changed, 620 insertions(+), 82 deletions(-) diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index 2ae05af908..bea36a4701 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -37,52 +37,8 @@ namespace Microsoft.ML.Trainers using TDistPredictor = IDistPredictorProducing; using TScalarPredictor = IPredictorProducing; using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; - /// - /// The for training a one-versus-all multi-class classifier that uses the specified binary classifier. - /// - /// - /// , - /// can be different from , which develops a multi-class classifier directly. - /// Note that even if the classifier indicates that it does not need caching, OneVersusAll will always - /// request caching, as it will be performing multiple passes over the data set. - /// This trainer will request normalization from the data pipeline if the classifier indicates it would benefit from it. - /// - /// This can allow you to exploit trainers that do not naturally have a - /// multiclass option, for example, using the - /// to solve a multiclass problem. - /// Alternately, it can allow ML.NET to solve a "simpler" problem even in the cases - /// where the trainer has a multiclass option, but using it directly is not - /// practical due to, usually, memory constraints. For example, while a multiclass - /// logistic regression is a more principled way to solve a multiclass problem, it - /// requires that the trainer store a lot more intermediate state in the form of - /// L-BFGS history for all classes *simultaneously*, rather than just one-by-one - /// as would be needed for a one-versus-all classification model. - /// - /// Check the See Also section for links to usage examples. - /// ]]> - /// - /// - /// - public sealed class OneVersusAllTrainer : MetaMulticlassTrainer, OneVersusAllModelParameters> + + public abstract class OneVersusAllTrainerBase : MetaMulticlassTrainer>, OneVersusAllModelParameters> where T : class { internal const string LoadNameValue = "OVA"; internal const string UserNameValue = "One-vs-All"; @@ -93,7 +49,7 @@ public sealed class OneVersusAllTrainer : MetaMulticlassTrainer - /// Options passed to + /// Options passed to /// internal sealed class Options : OptionsBase { @@ -106,56 +62,51 @@ internal sealed class Options : OptionsBase } /// - /// Constructs a trainer supplying a . + /// Constructs a trainer supplying a . /// /// The private for this estimator. /// The legacy - internal OneVersusAllTrainer(IHostEnvironment env, Options options) + internal OneVersusAllTrainerBase(IHostEnvironment env, Options options) : base(env, options, LoadNameValue) { _options = options; } /// - /// Initializes a new instance of . + /// Initializes a new instance of . /// /// The instance. /// An instance of a binary used as the base trainer. - /// The calibrator. If a calibrator is not provided, it will default to /// The name of the label colum. /// If true will treat missing labels as negative labels. - /// Number of instances to train the calibrator. /// Use probabilities (vs. raw outputs) to identify top-score category. - internal OneVersusAllTrainer(IHostEnvironment env, + internal OneVersusAllTrainerBase(IHostEnvironment env, TScalarTrainer binaryEstimator, string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, - ICalibratorTrainer calibrator = null, - int maximumCalibrationExampleCount = 1000000000, bool useProbabilities = true) : base(env, new Options { - ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative, - MaxCalibrationExamples = maximumCalibrationExampleCount, + ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative }, - LoadNameValue, labelColumnName, binaryEstimator, calibrator) + LoadNameValue, labelColumnName, binaryEstimator) { Host.CheckValue(labelColumnName, nameof(labelColumnName), "Label column should not be null."); _options = (Options)Args; _options.UseProbabilities = useProbabilities; } - private protected override OneVersusAllModelParameters TrainCore(IChannel ch, RoleMappedData data, int count) + private protected override OneVersusAllModelParameters TrainCore(IChannel ch, RoleMappedData data, int count) { // Train one-vs-all models. - var predictors = new TScalarPredictor[count]; + var predictors = new T[count]; for (int i = 0; i < predictors.Length; i++) { ch.Info($"Training learner {i}"); - predictors[i] = TrainOne(ch, Trainer, data, i).Model; + predictors[i] = (T)TrainOne(ch, Trainer, data, i).Model; } - return OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors); + return OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors); } private ISingleFeaturePredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls) @@ -168,24 +119,13 @@ private ISingleFeaturePredictionTransformer TrainOne(IChannel // this is currently unsupported. var transformer = trainer.Fit(view); - if (_options.UseProbabilities) - { - var calibratedModel = transformer.Model as TDistPredictor; - - // REVIEW: restoring the RoleMappedData, as much as we can. - // not having the weight column on the data passed to the TrainCalibrator should be addressed. - var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName); - - if (calibratedModel == null) - calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor; - - Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface"); - return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName); - } - - return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName); + return TrainOneHelper(ch, _options.UseProbabilities, view, trainerLabel, transformer); } + private protected abstract ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch, + bool useProbabilities, IDataView view, string trainerLabel, + ISingleFeaturePredictionTransformer transformer); + private IDataView MapLabels(RoleMappedData data, int cls) { var label = data.Schema.Label.Value; @@ -205,7 +145,7 @@ private IDataView MapLabels(RoleMappedData data, int cls) /// Trains a model. /// The input data. /// A model./> - public override MulticlassPredictionTransformer Fit(IDataView input) + public override MulticlassPredictionTransformer> Fit(IDataView input) { var roles = new KeyValuePair[1]; roles[0] = new KeyValuePair(new CR(DefaultColumnNames.Label), LabelColumn.Name); @@ -213,7 +153,7 @@ public override MulticlassPredictionTransformer Fit td.CheckMulticlassLabel(out var numClasses); - var predictors = new TScalarPredictor[numClasses]; + var predictors = new T[numClasses]; string featureColumn = null; using (var ch = Host.Start("Fitting")) @@ -227,12 +167,610 @@ public override MulticlassPredictionTransformer Fit var transformer = TrainOne(ch, Trainer, td, i); featureColumn = transformer.FeatureColumnName; } + predictors[i] = (T)TrainOne(ch, Trainer, td, i).Model; + + } + } + + return new MulticlassPredictionTransformer>(Host, OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors), input.Schema, featureColumn, LabelColumn.Name); + } + } + + /// + /// Model parameters for . + /// + public sealed class OneVersusAllModelParameters : + ModelParametersBase>, + IValueMapper, + ICanSaveInSourceCode, + ICanSaveInTextFormat, + ISingleCanSavePfa + where T : class + { + internal const string LoaderSignature = "OVAExec"; + internal const string RegistrationName = "OVAPredictor"; + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "TLC OVA ", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(OneVersusAllModelParameters).Assembly.FullName); + } + + private const string SubPredictorFmt = "SubPredictor_{0:000}"; + + private readonly ImplBase _impl; + + /// + /// Retrieves the model parameters. + /// + internal ImmutableArray SubModelParameters => _impl.Predictors.Cast().ToImmutableArray(); + + /// + /// The type of the prediction task. + /// + private protected override PredictionKind PredictionKind => PredictionKind.MulticlassClassification; + + /// + /// Function applied to output of predictors. Assume that we have n predictors (one per class) and for the i-th predictor, + /// y_i is its raw output and p_i is its probability output. Note that not all predictors are able to produce probability output. + /// + /// : output the result of predictors without post-processing. Output is [y_1, ..., y_n]. + /// : fetch probability output of each class probability from provided predictors and make sure the sume of class probabilities is one. + /// Output is [p_1 / (p_1 + ... + p_n), ..., p_n / (p_1 + ... + p_n)]. + /// : Generate probability by feeding raw outputs to softmax function. Output is [z_1, ..., z_n], where z_i is exp(y_i) / (exp(y_1) + ... + exp(y_n)). + /// + /// + [BestFriend] + internal enum OutputFormula { Raw = 0, ProbabilityNormalization = 1, Softmax = 2 }; + + private DataViewType DistType { get; } + + bool ICanSavePfa.CanSavePfa => _impl.CanSavePfa; + + [BestFriend] + internal static OneVersusAllModelParameters Create(IHost host, OutputFormula outputFormula, T[] predictors) + { + ImplBase impl; + + using (var ch = host.Start("Creating OVA predictor")) + { + if (outputFormula == OutputFormula.Softmax) + { + impl = new ImplSoftmax(predictors); + return new OneVersusAllModelParameters(host, impl); + } + + // Caller of this function asks for probability output. We check if input predictor can produce probability. + // If that predictor can't produce probability, ivmd will be null. + IValueMapperDist ivmd = null; + if (outputFormula == OutputFormula.ProbabilityNormalization && + ((ivmd = predictors[0] as IValueMapperDist) == null || + ivmd.OutputType != NumberDataViewType.Single || + ivmd.DistType != NumberDataViewType.Single)) + { + ch.Warning($"{nameof(OneVersusAllTrainerBase.Options.UseProbabilities)} specified with {nameof(OneVersusAllTrainerBase.Options.PredictorType)} that can't produce probabilities."); + ivmd = null; + } - predictors[i] = TrainOne(ch, Trainer, td, i).Model; + // If ivmd is null, either the user didn't ask for probability or the provided predictors can't produce probability. + if (ivmd != null) + { + impl = new ImplDist(predictors); } + else + impl = new ImplRaw(predictors); + } + + return new OneVersusAllModelParameters(host, impl); + } + + [BestFriend] + internal static OneVersusAllModelParameters Create(IHost host, bool useProbability, T[] predictors) + { + var outputFormula = useProbability ? OutputFormula.ProbabilityNormalization : OutputFormula.Raw; + + return Create(host, outputFormula, predictors); + } + + /// + /// Create a from an array of predictors. + /// + [BestFriend] + internal static OneVersusAllModelParameters Create(IHost host, T[] predictors) + { + Contracts.CheckValue(host, nameof(host)); + host.CheckNonEmpty(predictors, nameof(predictors)); + return Create(host, OutputFormula.ProbabilityNormalization, predictors); + } + + private OneVersusAllModelParameters(IHostEnvironment env, ImplBase impl) + : base(env, RegistrationName) + { + Host.AssertValue(impl, nameof(impl)); + Host.Assert(Utils.Size(impl.Predictors) > 0); + + _impl = impl; + DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length); + } + + private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) + : base(env, RegistrationName, ctx) + { + // *** Binary format *** + // bool: useDist + // int: predictor count + bool useDist = ctx.Reader.ReadBoolByte(); + int len = ctx.Reader.ReadInt32(); + Host.CheckDecode(len > 0); + + if (useDist) + { + var predictors = new T[len]; + LoadPredictors(Host, predictors, ctx); + _impl = new ImplDist(predictors); + } + else + { + var predictors = new T[len]; + LoadPredictors(Host, predictors, ctx); + _impl = new ImplRaw(predictors); } - return new MulticlassPredictionTransformer(Host, OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors), input.Schema, featureColumn, LabelColumn.Name); + DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length); + } + + private static OneVersusAllModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + return new OneVersusAllModelParameters(env, ctx); + } + + private static void LoadPredictors(IHostEnvironment env, TPredictor[] predictors, ModelLoadContext ctx) + where TPredictor : class + { + for (int i = 0; i < predictors.Length; i++) + ctx.LoadModel(env, out predictors[i], string.Format(SubPredictorFmt, i)); + } + + private protected override void SaveCore(ModelSaveContext ctx) + { + base.SaveCore(ctx); + ctx.SetVersionInfo(GetVersionInfo()); + + var preds = _impl.Predictors; + + // *** Binary format *** + // bool: useDist + // int: predictor count + ctx.Writer.WriteBoolByte(_impl is ImplDist); + ctx.Writer.Write(preds.Length); + + // Save other streams. + for (int i = 0; i < preds.Length; i++) + ctx.SaveModel(preds[i], string.Format(SubPredictorFmt, i)); + } + + JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) + { + Host.CheckValue(ctx, nameof(ctx)); + Host.CheckValue(input, nameof(input)); + return _impl.SaveAsPfa(ctx, input); + } + + DataViewType IValueMapper.InputType + { + get { return _impl.InputType; } + } + + DataViewType IValueMapper.OutputType + { + get { return DistType; } + } + ValueMapper IValueMapper.GetMapper() + { + Host.Check(typeof(TIn) == typeof(VBuffer)); + Host.Check(typeof(TOut) == typeof(VBuffer)); + + return (ValueMapper)(Delegate)_impl.GetMapper(); + } + + void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema) + { + Host.CheckValue(writer, nameof(writer)); + Host.CheckValue(schema, nameof(schema)); + + var preds = _impl.Predictors; + writer.WriteLine("double[] outputs = new double[{0}];", preds.Length); + + for (int i = 0; i < preds.Length; i++) + { + var saveInSourceCode = preds[i] as ICanSaveInSourceCode; + Host.Check(saveInSourceCode != null, "Saving in code is not supported."); + + writer.WriteLine("{"); + saveInSourceCode.SaveAsCode(writer, schema); + writer.WriteLine("outputs[{0}] = output;", i); + writer.WriteLine("}"); + } + } + + void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) + { + Host.CheckValue(writer, nameof(writer)); + Host.CheckValue(schema, nameof(schema)); + + var preds = _impl.Predictors; + + for (int i = 0; i < preds.Length; i++) + { + var saveInText = preds[i] as ICanSaveInTextFormat; + Host.Check(saveInText != null, "Saving in text is not supported."); + + writer.WriteLine("#region: class-{0} classifier", i); + saveInText.SaveAsText(writer, schema); + + writer.WriteLine("#endregion: class-{0} classifier", i); + writer.WriteLine(); + } + } + + private abstract class ImplBase : ISingleCanSavePfa + { + public abstract DataViewType InputType { get; } + public abstract IValueMapper[] Predictors { get; } + public abstract bool CanSavePfa { get; } + public abstract ValueMapper, VBuffer> GetMapper(); + public abstract JToken SaveAsPfa(BoundPfaContext ctx, JToken input); + + protected bool IsValid(IValueMapper mapper, ref VectorDataViewType inputType) + { + Contracts.AssertValueOrNull(mapper); + Contracts.AssertValueOrNull(inputType); + + if (mapper == null) + return false; + if (mapper.OutputType != NumberDataViewType.Single) + return false; + if (!(mapper.InputType is VectorDataViewType mapperVectorType) || mapperVectorType.ItemType != NumberDataViewType.Single) + return false; + if (inputType == null) + inputType = mapperVectorType; + else if (inputType.Size != mapperVectorType.Size) + { + if (inputType.Size == 0) + inputType = mapperVectorType; + else if (mapperVectorType.Size != 0) + return false; + } + return true; + } + } + + private sealed class ImplRaw : ImplBase + { + public override DataViewType InputType { get; } + public override IValueMapper[] Predictors { get; } + public override bool CanSavePfa { get; } + + internal ImplRaw(T[] predictors) + { + Contracts.CheckNonEmpty(predictors, nameof(predictors)); + + Predictors = new IValueMapper[predictors.Length]; + VectorDataViewType inputType = null; + for (int i = 0; i < predictors.Length; i++) + { + var vm = predictors[i] as IValueMapper; + Contracts.Check(IsValid(vm, ref inputType), "Predictor doesn't implement the expected interface"); + Predictors[i] = vm; + } + CanSavePfa = Predictors.All(m => (m as ISingleCanSavePfa)?.CanSavePfa == true); + Contracts.AssertValue(inputType); + InputType = inputType; + } + + public override ValueMapper, VBuffer> GetMapper() + { + var maps = new ValueMapper, float>[Predictors.Length]; + for (int i = 0; i < Predictors.Length; i++) + maps[i] = Predictors[i].GetMapper, float>(); + + var buffer = new float[maps.Length]; + return + (in VBuffer src, ref VBuffer dst) => + { + int inputSize = InputType.GetVectorSize(); + if (inputSize > 0) + Contracts.Check(src.Length == inputSize); + + var tmp = src; + Parallel.For(0, maps.Length, i => maps[i](in tmp, ref buffer[i])); + + var editor = VBufferEditor.Create(ref dst, maps.Length); + buffer.CopyTo(editor.Values); + dst = editor.Commit(); + }; + } + + public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) + { + Contracts.CheckValue(ctx, nameof(ctx)); + Contracts.CheckValue(input, nameof(input)); + Contracts.Assert(CanSavePfa); + + JArray rootObjects = new JArray(); + for (int i = 0; i < Predictors.Length; ++i) + { + var pred = (ISingleCanSavePfa)Predictors[i]; + Contracts.Assert(pred.CanSavePfa); + rootObjects.Add(ctx.DeclareVar(null, pred.SaveAsPfa(ctx, input))); + } + JObject jobj = null; + return jobj.AddReturn("type", PfaUtils.Type.Array(PfaUtils.Type.Double)).AddReturn("new", rootObjects); + } + } + + private sealed class ImplDist : ImplBase + { + private readonly IValueMapperDist[] _mappers; + public override DataViewType InputType { get; } + public override IValueMapper[] Predictors => _mappers; + public override bool CanSavePfa { get; } + + internal ImplDist(T[] predictors) + { + Contracts.Check(Utils.Size(predictors) > 0); + + _mappers = new IValueMapperDist[predictors.Length]; + VectorDataViewType inputType = null; + for (int i = 0; i < predictors.Length; i++) + { + var vm = predictors[i] as IValueMapperDist; + Contracts.Check(IsValid(vm, ref inputType), "Predictor doesn't implement the expected interface"); + _mappers[i] = vm; + } + CanSavePfa = Predictors.All(m => (m as IDistCanSavePfa)?.CanSavePfa == true); + Contracts.AssertValue(inputType); + InputType = inputType; + } + + private bool IsValid(IValueMapperDist mapper, ref VectorDataViewType inputType) + { + return base.IsValid(mapper, ref inputType) && mapper.DistType == NumberDataViewType.Single; + } + + /// + /// Each predictor produces a probability of a class. All classes' probabilities are normalized so that + /// their sum is one. + /// + public override ValueMapper, VBuffer> GetMapper() + { + var maps = new ValueMapper, float, float>[Predictors.Length]; + for (int i = 0; i < Predictors.Length; i++) + maps[i] = _mappers[i].GetMapper, float, float>(); + + var buffer = new float[maps.Length]; + return + (in VBuffer src, ref VBuffer dst) => + { + int inputSize = InputType.GetVectorSize(); + if (inputSize > 0) + Contracts.Check(src.Length == inputSize); + + var tmp = src; + Parallel.For(0, maps.Length, + i => + { + float score = 0; + // buffer[i] is the probability of the i-th class. + // score is the raw prediction score. + maps[i](in tmp, ref score, ref buffer[i]); + }); + + // buffer[i] is the probability of the i-th class. + // score is the raw prediction score. + NormalizeSumToOne(buffer, maps.Length); + + var editor = VBufferEditor.Create(ref dst, maps.Length); + buffer.CopyTo(editor.Values); + dst = editor.Commit(); + }; + } + + private void NormalizeSumToOne(float[] output, int count) + { + // Clamp to zero and normalize. + Double sum = 0; + for (int i = 0; i < count; i++) + { + var value = output[i]; + if (float.IsNaN(value)) + continue; + + if (value >= 0) + sum += value; + else + output[i] = 0; + } + + if (sum > 0) + { + for (int i = 0; i < count; i++) + output[i] = (float)(output[i] / sum); + } + } + + public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) + { + Contracts.CheckValue(ctx, nameof(ctx)); + Contracts.CheckValue(input, nameof(input)); + Contracts.Assert(CanSavePfa); + + JArray rootObjects = new JArray(); + for (int i = 0; i < Predictors.Length; ++i) + { + var pred = (IDistCanSavePfa)Predictors[i]; + Contracts.Assert(pred.CanSavePfa); + pred.SaveAsPfa(ctx, input, null, out JToken scoreToken, null, out JToken probToken); + rootObjects.Add(probToken); + } + JObject jobj = null; + var rootResult = jobj.AddReturn("type", PfaUtils.Type.Array(PfaUtils.Type.Double)).AddReturn("new", rootObjects); + var resultVar = ctx.DeclareVar(null, rootResult); + var factorVar = ctx.DeclareVar(null, PfaUtils.Call("/", 1.0, PfaUtils.Call("a.sum", resultVar))); + return PfaUtils.Call("la.scale", resultVar, factorVar); + } + } + + private sealed class ImplSoftmax : ImplBase + { + public override DataViewType InputType { get; } + public override IValueMapper[] Predictors { get; } + public override bool CanSavePfa { get; } + + internal ImplSoftmax(T[] predictors) + { + Contracts.CheckNonEmpty(predictors, nameof(predictors)); + + Predictors = new IValueMapper[predictors.Length]; + VectorDataViewType inputType = null; + for (int i = 0; i < predictors.Length; i++) + { + var vm = predictors[i] as IValueMapper; + Contracts.Check(IsValid(vm, ref inputType), "Predictor doesn't implement the expected interface"); + Predictors[i] = vm; + } + CanSavePfa = false; + Contracts.AssertValue(inputType); + InputType = inputType; + } + + public override ValueMapper, VBuffer> GetMapper() + { + var maps = new ValueMapper, float>[Predictors.Length]; + for (int i = 0; i < Predictors.Length; i++) + maps[i] = Predictors[i].GetMapper, float>(); + + var buffer = new float[maps.Length]; + return + (in VBuffer src, ref VBuffer dst) => + { + int inputSize = InputType.GetVectorSize(); + if (inputSize > 0) + Contracts.Check(src.Length == inputSize); + + var tmp = src; + Parallel.For(0, maps.Length, i => maps[i](in tmp, ref buffer[i])); + NormalizeSoftmax(buffer, maps.Length); + + var editor = VBufferEditor.Create(ref dst, maps.Length); + buffer.CopyTo(editor.Values); + dst = editor.Commit(); + }; + } + + private void NormalizeSoftmax(float[] scores, int count) + { + double sum = 0; + var score = new double[count]; + + for (int i = 0; i < count; i++) + { + score[i] = Math.Exp(scores[i]); + sum += score[i]; + } + + for (int i = 0; i < count; i++) + scores[i] = (float)(score[i] / sum); + } + + public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) + { + throw new NotImplementedException("Softmax's PFA exporter is not implemented yet."); + } + } + } + + /// + /// The for training a one-versus-all multi-class classifier that uses the specified binary classifier. + /// + /// + /// , + /// can be different from , which develops a multi-class classifier directly. + /// Note that even if the classifier indicates that it does not need caching, OneVersusAll will always + /// request caching, as it will be performing multiple passes over the data set. + /// This trainer will request normalization from the data pipeline if the classifier indicates it would benefit from it. + /// + /// This can allow you to exploit trainers that do not naturally have a + /// multiclass option, for example, using the + /// to solve a multiclass problem. + /// Alternately, it can allow ML.NET to solve a "simpler" problem even in the cases + /// where the trainer has a multiclass option, but using it directly is not + /// practical due to, usually, memory constraints. For example, while a multiclass + /// logistic regression is a more principled way to solve a multiclass problem, it + /// requires that the trainer store a lot more intermediate state in the form of + /// L-BFGS history for all classes *simultaneously*, rather than just one-by-one + /// as would be needed for a one-versus-all classification model. + /// + /// Check the See Also section for links to usage examples. + /// ]]> + /// + /// + /// + public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase + { + //internal const string LoadNameValue = "OVA"; + //internal const string UserNameValue = "One-vs-All"; + //internal const string Summary = "In this strategy, a binary classification algorithm is used to train one classifier for each class, " + // + "which distinguishes that class from all other classes. Prediction is then performed by running these binary classifiers, " + // + "and choosing the prediction with the highest confidence score."; + + private protected override ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch, + bool useProbabilities, IDataView view, string trainerLabel, + ISingleFeaturePredictionTransformer transformer) + { + if (useProbabilities) + { + var calibratedModel = transformer.Model as TDistPredictor; + + // REVIEW: restoring the RoleMappedData, as much as we can. + // not having the weight column on the data passed to the TrainCalibrator should be addressed. + var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName); + + if (calibratedModel == null) + calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor; + + Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface"); + return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName); + } + + return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName); } } From d82fc6e8053152281aa1acedcb879972d5a93d09 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Thu, 11 Jul 2019 09:33:55 -0700 Subject: [PATCH 02/19] small changes --- .../OneVersusAllTrainer.cs | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index bea36a4701..3a317ba545 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -751,6 +751,34 @@ public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase + /// Constructs a trainer supplying a . + /// + /// The private for this estimator. + /// The legacy + internal OneVersusAllTrainer(IHostEnvironment env, Options options) + : base(env, options, LoadNameValue) + { + } + + /// + /// Initializes a new instance of . + /// + /// The instance. + /// An instance of a binary used as the base trainer. + /// The name of the label colum. + /// If true will treat missing labels as negative labels. + /// Use probabilities (vs. raw outputs) to identify top-score category. + internal OneVersusAllTrainerBase(IHostEnvironment env, + TScalarTrainer binaryEstimator, + string labelColumnName = DefaultColumnNames.Label, + bool imputeMissingLabelsAsNegative = false, + bool useProbabilities = true) + : base(env, + binaryEstimator, labelColumnName, imputeMissingLabelsAsNegative, useProbabilities) + { + } + private protected override ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch, bool useProbabilities, IDataView view, string trainerLabel, ISingleFeaturePredictionTransformer transformer) From 06637ca716d29364b9cf8301515c579bd12ddb48 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Thu, 11 Jul 2019 15:25:22 -0700 Subject: [PATCH 03/19] working but has copy pasted code --- .../Prediction/IPredictor.cs | 9 +- .../OneVersusAllTrainer.cs | 370 +++++++++++------- .../StandardTrainersCatalog.cs | 74 ++++ 3 files changed, 314 insertions(+), 139 deletions(-) diff --git a/src/Microsoft.ML.Core/Prediction/IPredictor.cs b/src/Microsoft.ML.Core/Prediction/IPredictor.cs index 118375f134..2f5fae8db3 100644 --- a/src/Microsoft.ML.Core/Prediction/IPredictor.cs +++ b/src/Microsoft.ML.Core/Prediction/IPredictor.cs @@ -12,8 +12,7 @@ namespace Microsoft.ML /// and it is still useful, but for things based on /// the idiom, it is inappropriate. /// - [BestFriend] - internal enum PredictionKind + public enum PredictionKind { Unknown = 0, Custom = 1, @@ -34,8 +33,7 @@ internal enum PredictionKind /// /// Weakly typed version of IPredictor. /// - [BestFriend] - internal interface IPredictor + public interface IPredictor { /// /// Return the type of prediction task. @@ -47,8 +45,7 @@ internal interface IPredictor /// A predictor the produces values of the indicated type. /// REVIEW: Determine whether this is just a temporary shim or long term solution. /// - [BestFriend] - internal interface IPredictorProducing : IPredictor + public interface IPredictorProducing : IPredictor { } diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index 3a317ba545..5a43578e01 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -37,8 +37,52 @@ namespace Microsoft.ML.Trainers using TDistPredictor = IDistPredictorProducing; using TScalarPredictor = IPredictorProducing; using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; - - public abstract class OneVersusAllTrainerBase : MetaMulticlassTrainer>, OneVersusAllModelParameters> where T : class + /// + /// The for training a one-versus-all multi-class classifier that uses the specified binary classifier. + /// + /// + /// , + /// can be different from , which develops a multi-class classifier directly. + /// Note that even if the classifier indicates that it does not need caching, OneVersusAll will always + /// request caching, as it will be performing multiple passes over the data set. + /// This trainer will request normalization from the data pipeline if the classifier indicates it would benefit from it. + /// + /// This can allow you to exploit trainers that do not naturally have a + /// multiclass option, for example, using the + /// to solve a multiclass problem. + /// Alternately, it can allow ML.NET to solve a "simpler" problem even in the cases + /// where the trainer has a multiclass option, but using it directly is not + /// practical due to, usually, memory constraints. For example, while a multiclass + /// logistic regression is a more principled way to solve a multiclass problem, it + /// requires that the trainer store a lot more intermediate state in the form of + /// L-BFGS history for all classes *simultaneously*, rather than just one-by-one + /// as would be needed for a one-versus-all classification model. + /// + /// Check the See Also section for links to usage examples. + /// ]]> + /// + /// + /// + public sealed class OneVersusAllTrainer : MetaMulticlassTrainer, OneVersusAllModelParameters> { internal const string LoadNameValue = "OVA"; internal const string UserNameValue = "One-vs-All"; @@ -49,7 +93,7 @@ public abstract class OneVersusAllTrainerBase : MetaMulticlassTrainer - /// Options passed to + /// Options passed to /// internal sealed class Options : OptionsBase { @@ -62,51 +106,56 @@ internal sealed class Options : OptionsBase } /// - /// Constructs a trainer supplying a . + /// Constructs a trainer supplying a . /// /// The private for this estimator. /// The legacy - internal OneVersusAllTrainerBase(IHostEnvironment env, Options options) + internal OneVersusAllTrainer(IHostEnvironment env, Options options) : base(env, options, LoadNameValue) { _options = options; } /// - /// Initializes a new instance of . + /// Initializes a new instance of . /// /// The instance. /// An instance of a binary used as the base trainer. + /// The calibrator. If a calibrator is not provided, it will default to /// The name of the label colum. /// If true will treat missing labels as negative labels. + /// Number of instances to train the calibrator. /// Use probabilities (vs. raw outputs) to identify top-score category. - internal OneVersusAllTrainerBase(IHostEnvironment env, + internal OneVersusAllTrainer(IHostEnvironment env, TScalarTrainer binaryEstimator, string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, + ICalibratorTrainer calibrator = null, + int maximumCalibrationExampleCount = 1000000000, bool useProbabilities = true) : base(env, new Options { - ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative + ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative, + MaxCalibrationExamples = maximumCalibrationExampleCount, }, - LoadNameValue, labelColumnName, binaryEstimator) + LoadNameValue, labelColumnName, binaryEstimator, calibrator) { Host.CheckValue(labelColumnName, nameof(labelColumnName), "Label column should not be null."); _options = (Options)Args; _options.UseProbabilities = useProbabilities; } - private protected override OneVersusAllModelParameters TrainCore(IChannel ch, RoleMappedData data, int count) + private protected override OneVersusAllModelParameters TrainCore(IChannel ch, RoleMappedData data, int count) { // Train one-vs-all models. - var predictors = new T[count]; + var predictors = new TScalarPredictor[count]; for (int i = 0; i < predictors.Length; i++) { ch.Info($"Training learner {i}"); - predictors[i] = (T)TrainOne(ch, Trainer, data, i).Model; + predictors[i] = TrainOne(ch, Trainer, data, i).Model; } - return OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors); + return OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors); } private ISingleFeaturePredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls) @@ -119,12 +168,23 @@ private ISingleFeaturePredictionTransformer TrainOne(IChannel // this is currently unsupported. var transformer = trainer.Fit(view); - return TrainOneHelper(ch, _options.UseProbabilities, view, trainerLabel, transformer); - } + if (_options.UseProbabilities) + { + var calibratedModel = transformer.Model as TDistPredictor; - private protected abstract ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch, - bool useProbabilities, IDataView view, string trainerLabel, - ISingleFeaturePredictionTransformer transformer); + // REVIEW: restoring the RoleMappedData, as much as we can. + // not having the weight column on the data passed to the TrainCalibrator should be addressed. + var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName); + + if (calibratedModel == null) + calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor; + + Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface"); + return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName); + } + + return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName); + } private IDataView MapLabels(RoleMappedData data, int cls) { @@ -145,7 +205,7 @@ private IDataView MapLabels(RoleMappedData data, int cls) /// Trains a model. /// The input data. /// A model./> - public override MulticlassPredictionTransformer> Fit(IDataView input) + public override MulticlassPredictionTransformer Fit(IDataView input) { var roles = new KeyValuePair[1]; roles[0] = new KeyValuePair(new CR(DefaultColumnNames.Label), LabelColumn.Name); @@ -153,7 +213,7 @@ public override MulticlassPredictionTransformer> td.CheckMulticlassLabel(out var numClasses); - var predictors = new T[numClasses]; + var predictors = new TScalarPredictor[numClasses]; string featureColumn = null; using (var ch = Host.Start("Fitting")) @@ -167,25 +227,24 @@ public override MulticlassPredictionTransformer> var transformer = TrainOne(ch, Trainer, td, i); featureColumn = transformer.FeatureColumnName; } - predictors[i] = (T)TrainOne(ch, Trainer, td, i).Model; + predictors[i] = TrainOne(ch, Trainer, td, i).Model; } } - return new MulticlassPredictionTransformer>(Host, OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors), input.Schema, featureColumn, LabelColumn.Name); + return new MulticlassPredictionTransformer(Host, OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors), input.Schema, featureColumn, LabelColumn.Name); } } /// - /// Model parameters for . + /// Model parameters for . /// - public sealed class OneVersusAllModelParameters : + public sealed class OneVersusAllModelParameters : ModelParametersBase>, IValueMapper, ICanSaveInSourceCode, ICanSaveInTextFormat, ISingleCanSavePfa - where T : class { internal const string LoaderSignature = "OVAExec"; internal const string RegistrationName = "OVAPredictor"; @@ -198,7 +257,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(OneVersusAllModelParameters).Assembly.FullName); + loaderAssemblyName: typeof(OneVersusAllModelParameters).Assembly.FullName); } private const string SubPredictorFmt = "SubPredictor_{0:000}"; @@ -208,7 +267,7 @@ private static VersionInfo GetVersionInfo() /// /// Retrieves the model parameters. /// - internal ImmutableArray SubModelParameters => _impl.Predictors.Cast().ToImmutableArray(); + internal ImmutableArray SubModelParameters => _impl.Predictors.Cast().ToImmutableArray(); /// /// The type of the prediction task. @@ -233,7 +292,7 @@ internal enum OutputFormula { Raw = 0, ProbabilityNormalization = 1, Softmax = 2 bool ICanSavePfa.CanSavePfa => _impl.CanSavePfa; [BestFriend] - internal static OneVersusAllModelParameters Create(IHost host, OutputFormula outputFormula, T[] predictors) + internal static OneVersusAllModelParameters Create(IHost host, OutputFormula outputFormula, TScalarPredictor[] predictors) { ImplBase impl; @@ -242,7 +301,7 @@ internal static OneVersusAllModelParameters Create(IHost host, OutputFormula if (outputFormula == OutputFormula.Softmax) { impl = new ImplSoftmax(predictors); - return new OneVersusAllModelParameters(host, impl); + return new OneVersusAllModelParameters(host, impl); } // Caller of this function asks for probability output. We check if input predictor can produce probability. @@ -253,24 +312,27 @@ internal static OneVersusAllModelParameters Create(IHost host, OutputFormula ivmd.OutputType != NumberDataViewType.Single || ivmd.DistType != NumberDataViewType.Single)) { - ch.Warning($"{nameof(OneVersusAllTrainerBase.Options.UseProbabilities)} specified with {nameof(OneVersusAllTrainerBase.Options.PredictorType)} that can't produce probabilities."); + ch.Warning($"{nameof(OneVersusAllTrainer.Options.UseProbabilities)} specified with {nameof(OneVersusAllTrainer.Options.PredictorType)} that can't produce probabilities."); ivmd = null; } // If ivmd is null, either the user didn't ask for probability or the provided predictors can't produce probability. if (ivmd != null) { - impl = new ImplDist(predictors); + var dists = new IValueMapperDist[predictors.Length]; + for (int i = 0; i < predictors.Length; ++i) + dists[i] = (IValueMapperDist)predictors[i]; + impl = new ImplDist(dists); } else impl = new ImplRaw(predictors); } - return new OneVersusAllModelParameters(host, impl); + return new OneVersusAllModelParameters(host, impl); } [BestFriend] - internal static OneVersusAllModelParameters Create(IHost host, bool useProbability, T[] predictors) + internal static OneVersusAllModelParameters Create(IHost host, bool useProbability, TScalarPredictor[] predictors) { var outputFormula = useProbability ? OutputFormula.ProbabilityNormalization : OutputFormula.Raw; @@ -278,10 +340,10 @@ internal static OneVersusAllModelParameters Create(IHost host, bool useProbab } /// - /// Create a from an array of predictors. + /// Create a from an array of predictors. /// [BestFriend] - internal static OneVersusAllModelParameters Create(IHost host, T[] predictors) + internal static OneVersusAllModelParameters Create(IHost host, TScalarPredictor[] predictors) { Contracts.CheckValue(host, nameof(host)); host.CheckNonEmpty(predictors, nameof(predictors)); @@ -310,13 +372,13 @@ private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) if (useDist) { - var predictors = new T[len]; + var predictors = new IValueMapperDist[len]; LoadPredictors(Host, predictors, ctx); _impl = new ImplDist(predictors); } else { - var predictors = new T[len]; + var predictors = new TScalarPredictor[len]; LoadPredictors(Host, predictors, ctx); _impl = new ImplRaw(predictors); } @@ -324,12 +386,12 @@ private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length); } - private static OneVersusAllModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) + private static OneVersusAllModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new OneVersusAllModelParameters(env, ctx); + return new OneVersusAllModelParameters(env, ctx); } private static void LoadPredictors(IHostEnvironment env, TPredictor[] predictors, ModelLoadContext ctx) @@ -459,7 +521,7 @@ private sealed class ImplRaw : ImplBase public override IValueMapper[] Predictors { get; } public override bool CanSavePfa { get; } - internal ImplRaw(T[] predictors) + internal ImplRaw(TScalarPredictor[] predictors) { Contracts.CheckNonEmpty(predictors, nameof(predictors)); @@ -524,7 +586,7 @@ private sealed class ImplDist : ImplBase public override IValueMapper[] Predictors => _mappers; public override bool CanSavePfa { get; } - internal ImplDist(T[] predictors) + internal ImplDist(IValueMapperDist[] predictors) { Contracts.Check(Utils.Size(predictors) > 0); @@ -532,7 +594,7 @@ internal ImplDist(T[] predictors) VectorDataViewType inputType = null; for (int i = 0; i < predictors.Length; i++) { - var vm = predictors[i] as IValueMapperDist; + var vm = predictors[i]; Contracts.Check(IsValid(vm, ref inputType), "Predictor doesn't implement the expected interface"); _mappers[i] = vm; } @@ -635,7 +697,7 @@ private sealed class ImplSoftmax : ImplBase public override IValueMapper[] Predictors { get; } public override bool CanSavePfa { get; } - internal ImplSoftmax(T[] predictors) + internal ImplSoftmax(TScalarPredictor[] predictors) { Contracts.CheckNonEmpty(predictors, nameof(predictors)); @@ -698,119 +760,164 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) } } - /// - /// The for training a one-versus-all multi-class classifier that uses the specified binary classifier. - /// - /// - /// , - /// can be different from , which develops a multi-class classifier directly. - /// Note that even if the classifier indicates that it does not need caching, OneVersusAll will always - /// request caching, as it will be performing multiple passes over the data set. - /// This trainer will request normalization from the data pipeline if the classifier indicates it would benefit from it. - /// - /// This can allow you to exploit trainers that do not naturally have a - /// multiclass option, for example, using the - /// to solve a multiclass problem. - /// Alternately, it can allow ML.NET to solve a "simpler" problem even in the cases - /// where the trainer has a multiclass option, but using it directly is not - /// practical due to, usually, memory constraints. For example, while a multiclass - /// logistic regression is a more principled way to solve a multiclass problem, it - /// requires that the trainer store a lot more intermediate state in the form of - /// L-BFGS history for all classes *simultaneously*, rather than just one-by-one - /// as would be needed for a one-versus-all classification model. - /// - /// Check the See Also section for links to usage examples. - /// ]]> - /// - /// - /// - public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase + public sealed class OneVersusAllTrainerTyped : MetaMulticlassTrainer>, OneVersusAllModelParametersTyped> where T : class { - //internal const string LoadNameValue = "OVA"; - //internal const string UserNameValue = "One-vs-All"; - //internal const string Summary = "In this strategy, a binary classification algorithm is used to train one classifier for each class, " - // + "which distinguishes that class from all other classes. Prediction is then performed by running these binary classifiers, " - // + "and choosing the prediction with the highest confidence score."; + internal const string LoadNameValue = "OVA"; + internal const string UserNameValue = "One-vs-All"; + internal const string Summary = "In this strategy, a binary classification algorithm is used to train one classifier for each class, " + + "which distinguishes that class from all other classes. Prediction is then performed by running these binary classifiers, " + + "and choosing the prediction with the highest confidence score."; + + private readonly Options _options; /// - /// Constructs a trainer supplying a . + /// Options passed to + /// + internal sealed class Options : OptionsBase + { + /// + /// Whether to use probabilities (vs. raw outputs) to identify top-score category. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Use probability or margins to determine max", ShortName = "useprob")] + [TGUI(Label = "Use Probability", Description = "Use probabilities (vs. raw outputs) to identify top-score category")] + public bool UseProbabilities = true; + } + + /// + /// Constructs a trainer supplying a . /// /// The private for this estimator. - /// The legacy - internal OneVersusAllTrainer(IHostEnvironment env, Options options) + /// The legacy + internal OneVersusAllTrainerTyped(IHostEnvironment env, Options options) : base(env, options, LoadNameValue) { + _options = options; } /// - /// Initializes a new instance of . + /// Initializes a new instance of . /// /// The instance. /// An instance of a binary used as the base trainer. /// The name of the label colum. /// If true will treat missing labels as negative labels. /// Use probabilities (vs. raw outputs) to identify top-score category. - internal OneVersusAllTrainerBase(IHostEnvironment env, + internal OneVersusAllTrainerTyped(IHostEnvironment env, TScalarTrainer binaryEstimator, string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, bool useProbabilities = true) : base(env, - binaryEstimator, labelColumnName, imputeMissingLabelsAsNegative, useProbabilities) + new Options + { + ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative + }, + LoadNameValue, labelColumnName, binaryEstimator) + { + Host.CheckValue(labelColumnName, nameof(labelColumnName), "Label column should not be null."); + _options = (Options)Args; + _options.UseProbabilities = useProbabilities; + } + + private protected override OneVersusAllModelParametersTyped TrainCore(IChannel ch, RoleMappedData data, int count) { + // Train one-vs-all models. + var predictors = new T[count]; + for (int i = 0; i < predictors.Length; i++) + { + ch.Info($"Training learner {i}"); + predictors[i] = (T)TrainOne(ch, Trainer, data, i).Model; + } + return OneVersusAllModelParametersTyped.Create(Host, _options.UseProbabilities, predictors); } - private protected override ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch, - bool useProbabilities, IDataView view, string trainerLabel, - ISingleFeaturePredictionTransformer transformer) + private ISingleFeaturePredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls) { - if (useProbabilities) + var view = MapLabels(data, cls); + + string trainerLabel = data.Schema.Label.Value.Name; + + // REVIEW: In principle we could support validation sets and the like via the train context, but + // this is currently unsupported. + var transformer = trainer.Fit(view); + + if (_options.UseProbabilities) { var calibratedModel = transformer.Model as TDistPredictor; + // If probabilities are requested and the Predictor is not calibrated or if it doesn't implement the right interface then throw. + Host.Check(calibratedModel != null, "Predictor is either not calibrated or does not implement the expected interface"); + // REVIEW: restoring the RoleMappedData, as much as we can. // not having the weight column on the data passed to the TrainCalibrator should be addressed. var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName); - if (calibratedModel == null) - calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor; - - Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface"); return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName); } return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName); } + + private IDataView MapLabels(RoleMappedData data, int cls) + { + var label = data.Schema.Label.Value; + Host.Assert(!label.IsHidden); + Host.Assert(label.Type.GetKeyCount() > 0 || label.Type == NumberDataViewType.Single || label.Type == NumberDataViewType.Double); + + if (label.Type.GetKeyCount() > 0) + { + // Key values are 1-based. + uint key = (uint)(cls + 1); + return MapLabelsCore(NumberDataViewType.UInt32, (in uint val) => key == val, data); + } + + throw Host.ExceptNotSupp($"Label column type is not supported by OneVersusAllTrainerTyped: {label.Type.RawType}"); + } + + /// Trains a model. + /// The input data. + /// A model./> + public override MulticlassPredictionTransformer> Fit(IDataView input) + { + var roles = new KeyValuePair[1]; + roles[0] = new KeyValuePair(new CR(DefaultColumnNames.Label), LabelColumn.Name); + var td = new RoleMappedData(input, roles); + + td.CheckMulticlassLabel(out var numClasses); + + var predictors = new T[numClasses]; + string featureColumn = null; + + using (var ch = Host.Start("Fitting")) + { + for (int i = 0; i < predictors.Length; i++) + { + ch.Info($"Training learner {i}"); + + if (i == 0) + { + var transformer = TrainOne(ch, Trainer, td, i); + featureColumn = transformer.FeatureColumnName; + } + predictors[i] = (T)TrainOne(ch, Trainer, td, i).Model; + + } + } + + return new MulticlassPredictionTransformer>(Host, OneVersusAllModelParametersTyped.Create(Host, _options.UseProbabilities, predictors), input.Schema, featureColumn, LabelColumn.Name); + } } /// - /// Model parameters for . + /// Model parameters for . /// - public sealed class OneVersusAllModelParameters : + public sealed class OneVersusAllModelParametersTyped : ModelParametersBase>, IValueMapper, ICanSaveInSourceCode, ICanSaveInTextFormat, ISingleCanSavePfa + where T : class { internal const string LoaderSignature = "OVAExec"; internal const string RegistrationName = "OVAPredictor"; @@ -823,7 +930,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(OneVersusAllModelParameters).Assembly.FullName); + loaderAssemblyName: typeof(OneVersusAllModelParametersTyped).Assembly.FullName); } private const string SubPredictorFmt = "SubPredictor_{0:000}"; @@ -833,7 +940,7 @@ private static VersionInfo GetVersionInfo() /// /// Retrieves the model parameters. /// - internal ImmutableArray SubModelParameters => _impl.Predictors.Cast().ToImmutableArray(); + internal ImmutableArray SubModelParameters => _impl.Predictors.Cast().ToImmutableArray(); /// /// The type of the prediction task. @@ -858,7 +965,7 @@ internal enum OutputFormula { Raw = 0, ProbabilityNormalization = 1, Softmax = 2 bool ICanSavePfa.CanSavePfa => _impl.CanSavePfa; [BestFriend] - internal static OneVersusAllModelParameters Create(IHost host, OutputFormula outputFormula, TScalarPredictor[] predictors) + internal static OneVersusAllModelParametersTyped Create(IHost host, OutputFormula outputFormula, T[] predictors) { ImplBase impl; @@ -867,7 +974,7 @@ internal static OneVersusAllModelParameters Create(IHost host, OutputFormula out if (outputFormula == OutputFormula.Softmax) { impl = new ImplSoftmax(predictors); - return new OneVersusAllModelParameters(host, impl); + return new OneVersusAllModelParametersTyped(host, impl); } // Caller of this function asks for probability output. We check if input predictor can produce probability. @@ -878,27 +985,24 @@ internal static OneVersusAllModelParameters Create(IHost host, OutputFormula out ivmd.OutputType != NumberDataViewType.Single || ivmd.DistType != NumberDataViewType.Single)) { - ch.Warning($"{nameof(OneVersusAllTrainer.Options.UseProbabilities)} specified with {nameof(OneVersusAllTrainer.Options.PredictorType)} that can't produce probabilities."); + ch.Warning($"{nameof(OneVersusAllTrainerTyped.Options.UseProbabilities)} specified with {nameof(OneVersusAllTrainerTyped.Options.PredictorType)} that can't produce probabilities."); ivmd = null; } // If ivmd is null, either the user didn't ask for probability or the provided predictors can't produce probability. if (ivmd != null) { - var dists = new IValueMapperDist[predictors.Length]; - for (int i = 0; i < predictors.Length; ++i) - dists[i] = (IValueMapperDist)predictors[i]; - impl = new ImplDist(dists); + impl = new ImplDist(predictors); } else impl = new ImplRaw(predictors); } - return new OneVersusAllModelParameters(host, impl); + return new OneVersusAllModelParametersTyped(host, impl); } [BestFriend] - internal static OneVersusAllModelParameters Create(IHost host, bool useProbability, TScalarPredictor[] predictors) + internal static OneVersusAllModelParametersTyped Create(IHost host, bool useProbability, T[] predictors) { var outputFormula = useProbability ? OutputFormula.ProbabilityNormalization : OutputFormula.Raw; @@ -906,17 +1010,17 @@ internal static OneVersusAllModelParameters Create(IHost host, bool useProbabili } /// - /// Create a from an array of predictors. + /// Create a from an array of predictors. /// [BestFriend] - internal static OneVersusAllModelParameters Create(IHost host, TScalarPredictor[] predictors) + internal static OneVersusAllModelParametersTyped Create(IHost host, T[] predictors) { Contracts.CheckValue(host, nameof(host)); host.CheckNonEmpty(predictors, nameof(predictors)); return Create(host, OutputFormula.ProbabilityNormalization, predictors); } - private OneVersusAllModelParameters(IHostEnvironment env, ImplBase impl) + private OneVersusAllModelParametersTyped(IHostEnvironment env, ImplBase impl) : base(env, RegistrationName) { Host.AssertValue(impl, nameof(impl)); @@ -926,7 +1030,7 @@ private OneVersusAllModelParameters(IHostEnvironment env, ImplBase impl) DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length); } - private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) + private OneVersusAllModelParametersTyped(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx) { // *** Binary format *** @@ -938,13 +1042,13 @@ private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) if (useDist) { - var predictors = new IValueMapperDist[len]; + var predictors = new T[len]; LoadPredictors(Host, predictors, ctx); _impl = new ImplDist(predictors); } else { - var predictors = new TScalarPredictor[len]; + var predictors = new T[len]; LoadPredictors(Host, predictors, ctx); _impl = new ImplRaw(predictors); } @@ -952,12 +1056,12 @@ private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length); } - private static OneVersusAllModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) + private static OneVersusAllModelParametersTyped Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new OneVersusAllModelParameters(env, ctx); + return new OneVersusAllModelParametersTyped(env, ctx); } private static void LoadPredictors(IHostEnvironment env, TPredictor[] predictors, ModelLoadContext ctx) @@ -1087,7 +1191,7 @@ private sealed class ImplRaw : ImplBase public override IValueMapper[] Predictors { get; } public override bool CanSavePfa { get; } - internal ImplRaw(TScalarPredictor[] predictors) + internal ImplRaw(T[] predictors) { Contracts.CheckNonEmpty(predictors, nameof(predictors)); @@ -1152,7 +1256,7 @@ private sealed class ImplDist : ImplBase public override IValueMapper[] Predictors => _mappers; public override bool CanSavePfa { get; } - internal ImplDist(IValueMapperDist[] predictors) + internal ImplDist(T[] predictors) { Contracts.Check(Utils.Size(predictors) > 0); @@ -1160,7 +1264,7 @@ internal ImplDist(IValueMapperDist[] predictors) VectorDataViewType inputType = null; for (int i = 0; i < predictors.Length; i++) { - var vm = predictors[i]; + var vm = predictors[i] as IValueMapperDist; Contracts.Check(IsValid(vm, ref inputType), "Predictor doesn't implement the expected interface"); _mappers[i] = vm; } @@ -1263,7 +1367,7 @@ private sealed class ImplSoftmax : ImplBase public override IValueMapper[] Predictors { get; } public override bool CanSavePfa { get; } - internal ImplSoftmax(TScalarPredictor[] predictors) + internal ImplSoftmax(T[] predictors) { Contracts.CheckNonEmpty(predictors, nameof(predictors)); diff --git a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs index 3127f61116..7c7aedbe2d 100644 --- a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs @@ -764,6 +764,80 @@ public static OneVersusAllTrainer OneVersusAll(this MulticlassClassifica return new OneVersusAllTrainer(env, est, labelColumnName, imputeMissingLabelsAsNegative, GetCalibratorTrainerOrThrow(env, calibrator), maximumCalibrationExampleCount, useProbabilities); } + /// + /// Create a , which predicts a multiclass target using one-versus-all strategy with + /// the binary classification estimator specified by . + /// + /// + /// + /// In one-versus-all strategy, a binary classification algorithm is used to train one classifier for each class, + /// which distinguishes that class from all other classes. Prediction is then performed by running these binary classifiers, + /// and choosing the prediction with the highest confidence score. + /// + /// + /// The multiclass classification catalog trainer object. + /// An instance of a binary used as the base trainer. + /// The name of the label column. + /// Whether to treat missing labels as having negative labels, instead of keeping them missing. + /// Use probabilities (vs. raw outputs) to identify top-score category. + /// The type of the model. This type parameter will usually be inferred automatically from . + /// + /// + /// + /// + public static OneVersusAllTrainerTyped OneVersusAllStronglyTyped(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + ITrainerEstimator, TModel> binaryEstimator, + string labelColumnName = DefaultColumnNames.Label, + bool imputeMissingLabelsAsNegative = false, + bool useProbabilities = true) + where TModel : class + { + Contracts.CheckValue(catalog, nameof(catalog)); + var env = CatalogUtils.GetEnvironment(catalog); + if (!(binaryEstimator is ITrainerEstimator>, IPredictorProducing> est)) + throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); + return new OneVersusAllTrainerTyped(env, est, labelColumnName, imputeMissingLabelsAsNegative, useProbabilities); + } + + /// + /// Create a , which predicts a multiclass target using one-versus-all strategy with + /// the binary classification estimator specified by . + /// + /// + /// + /// In one-versus-all strategy, a binary classification algorithm is used to train one classifier for each class, + /// which distinguishes that class from all other classes. Prediction is then performed by running these binary classifiers, + /// and choosing the prediction with the highest confidence score. + /// + /// + /// The multiclass classification catalog trainer object. + /// An instance of a binary used as the base trainer. + /// The name of the label column. + /// Whether to treat missing labels as having negative labels, instead of keeping them missing. + /// Use probabilities (vs. raw outputs) to identify top-score category. + /// The type of the model. This type parameter will usually be inferred automatically from . + /// + /// + /// + /// + public static OneVersusAllTrainerTyped OneVersusAllStronglyTypedNoChange(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + IPredictorProducing binaryEstimator, + string labelColumnName = DefaultColumnNames.Label, + bool imputeMissingLabelsAsNegative = false, + bool useProbabilities = true) + where TModel : class + { + Contracts.CheckValue(catalog, nameof(catalog)); + var env = CatalogUtils.GetEnvironment(catalog); + if (!(binaryEstimator is ITrainerEstimator>, IPredictorProducing> est)) + throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); + return new OneVersusAllTrainerTyped(env, est, labelColumnName, imputeMissingLabelsAsNegative, useProbabilities); + } + /// /// Create a , which predicts a multiclass target using pairwise coupling strategy with /// the binary classification estimator specified by . From 15401d2295ca45de5a153dcb048fd9890619bc32 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Fri, 12 Jul 2019 15:32:27 -0700 Subject: [PATCH 04/19] changes per feedback --- .../Prediction/Calibrator.cs | 30 +++++++++++----- .../OneVersusAllTrainer.cs | 36 ++++++++++++++++--- .../StandardTrainersCatalog.cs | 27 ++++++++------ .../TrainerEstimators/MetalinearEstimators.cs | 2 +- 4 files changed, 70 insertions(+), 25 deletions(-) diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index c4cc248dc7..58f6fbfca2 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -528,7 +528,7 @@ internal sealed class ParameterMixingCalibratedModelParameters _featureWeights; + private readonly TSubModel _featureWeights; internal ParameterMixingCalibratedModelParameters(IHostEnvironment env, TSubModel predictor, TCalibrator calibrator) : base(env, RegistrationName, predictor, calibrator) @@ -536,7 +536,7 @@ internal ParameterMixingCalibratedModelParameters(IHostEnvironment env, TSubMode Host.Check(predictor is IParameterMixer, "Predictor does not implement " + nameof(IParameterMixer)); Host.Check(calibrator is IParameterMixer, "Calibrator does not implement " + nameof(IParameterMixer)); Host.Assert(predictor is IPredictorWithFeatureWeights); - _featureWeights = predictor as IPredictorWithFeatureWeights; + _featureWeights = predictor; } internal const string LoaderSignature = "PMixCaliPredExec"; @@ -558,7 +558,7 @@ private ParameterMixingCalibratedModelParameters(IHostEnvironment env, ModelLoad { Host.Check(SubModel is IParameterMixer, "Predictor does not implement " + nameof(IParameterMixer)); Host.Check(SubModel is IPredictorWithFeatureWeights, "Predictor does not implement " + nameof(IPredictorWithFeatureWeights)); - _featureWeights = SubModel as IPredictorWithFeatureWeights; + _featureWeights = SubModel; } private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx) @@ -579,7 +579,7 @@ void ICanSaveModel.Save(ModelSaveContext ctx) public void GetFeatureWeights(ref VBuffer weights) { - _featureWeights.GetFeatureWeights(ref weights); + ((IPredictorWithFeatureWeights)_featureWeights).GetFeatureWeights(ref weights); } IParameterMixer IParameterMixer.CombineParameters(IList> models) @@ -879,6 +879,14 @@ public static IPredictor GetCalibratedPredictor(IHostEnvironment env, IChannel c return CreateCalibratedPredictor(env, (IPredictorProducing)predictor, trainedCalibrator); } + public static CalibratedModelParametersBase GetCalibratedPredictor(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, + T predictor, RoleMappedData data, int maxRows = _maxCalibrationExamples) where T: class + { + var trainedCalibrator = TrainCalibrator(env, ch, caliTrainer, (IPredictorProducing)predictor, data, maxRows); + var cp = CreateCalibratedPredictor(env, predictor, trainedCalibrator); + return cp; + } + public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, IDataView scored, string labelColumn, string scoreColumn, string weightColumn = null, int maxRows = _maxCalibrationExamples) { Contracts.CheckValue(env, nameof(env)); @@ -962,13 +970,13 @@ public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICa return TrainCalibrator(env, ch, caliTrainer, scored, data.Schema.Label.Value.Name, DefaultColumnNames.Score, data.Schema.Weight?.Name, maxRows); } - public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironment env, TSubPredictor predictor, TCalibrator cali) - where TSubPredictor : class, IPredictorProducing + public static CalibratedModelParametersBase CreateCalibratedPredictor(IHostEnvironment env, TSubPredictor predictor, TCalibrator cali) + where TSubPredictor : class where TCalibrator : class, ICalibrator { Contracts.Assert(predictor != null); - if (cali == null) - return predictor; + //if (cali == null) + // return predictor; for (; ; ) { @@ -980,7 +988,11 @@ public static IPredictorProducing CreateCalibratedPredictor; if (predWithFeatureScores != null && predictor is IParameterMixer && cali is IParameterMixer) - return new ParameterMixingCalibratedModelParameters, TCalibrator>(env, predWithFeatureScores, cali); + { + var s = typeof(TSubPredictor); + var pm = new ParameterMixingCalibratedModelParameters(env, predictor, cali); + return pm; + } if (predictor is IValueMapper) return new ValueMapperCalibratedModelParameters(env, predictor, cali); diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index 5a43578e01..b1b712fdeb 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -769,7 +769,6 @@ public sealed class OneVersusAllTrainerTyped : MetaMulticlassTrainer /// Options passed to /// @@ -831,9 +830,9 @@ private protected override OneVersusAllModelParametersTyped TrainCore(IChanne return OneVersusAllModelParametersTyped.Create(Host, _options.UseProbabilities, predictors); } - private ISingleFeaturePredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls) + private ISingleFeaturePredictionTransformer TrainOne(IChannel ch, dynamic trainer, RoleMappedData data, int cls) { - var view = MapLabels(data, cls); + /*var view = MapLabels(data, cls); string trainerLabel = data.Schema.Label.Value.Name; @@ -855,6 +854,33 @@ private ISingleFeaturePredictionTransformer TrainOne(IChannel return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName); } + return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName);*/ + + var view = MapLabels(data, cls); + + string trainerLabel = data.Schema.Label.Value.Name; + + // REVIEW: In principle we could support validation sets and the like via the train context, but + // this is currently unsupported. + var transformer = trainer.Fit(view); + + if (_options.UseProbabilities) + { + var calibratedModel = transformer.Model as TDistPredictor; + + var s = transformer.Model.GetType(); + + // REVIEW: restoring the RoleMappedData, as much as we can. + // not having the weight column on the data passed to the TrainCalibrator should be addressed. + var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName); + + if (calibratedModel == null) + calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor; + + Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface"); + return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName); + } + return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName); } @@ -899,7 +925,9 @@ public override MulticlassPredictionTransformer(this MulticlassClassifica /// [!code-csharp[OneVersusAll](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/OneVersusAll.cs)] /// ]]> /// - public static OneVersusAllTrainerTyped OneVersusAllStronglyTyped(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + public static OneVersusAllTrainerTyped OneVersusAll(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, ITrainerEstimator, TModel> binaryEstimator, string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, + IEstimator> calibrator = null, + int maximumCalibrationExampleCount = 1000000000, bool useProbabilities = true) where TModel : class { - Contracts.CheckValue(catalog, nameof(catalog)); - var env = CatalogUtils.GetEnvironment(catalog); - if (!(binaryEstimator is ITrainerEstimator>, IPredictorProducing> est)) - throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); - return new OneVersusAllTrainerTyped(env, est, labelColumnName, imputeMissingLabelsAsNegative, useProbabilities); + return OneVersusAllStronglyTyped( + catalog, + binaryEstimator, + DefaultColumnNames.Label, + false, + true); } /// @@ -817,25 +820,27 @@ public static OneVersusAllTrainerTyped OneVersusAllStronglyTyped /// The name of the label column. /// Whether to treat missing labels as having negative labels, instead of keeping them missing. /// Use probabilities (vs. raw outputs) to identify top-score category. - /// The type of the model. This type parameter will usually be inferred automatically from . + /// The type of the model. This type parameter will usually be inferred automatically from . + /// The type of the model. This type parameter will usually be inferred automatically from . /// /// /// /// - public static OneVersusAllTrainerTyped OneVersusAllStronglyTypedNoChange(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, - IPredictorProducing binaryEstimator, + public static OneVersusAllTrainerTyped OneVersusAllStronglyTyped(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + ITrainerEstimator, TModelIn> binaryEstimator, string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, bool useProbabilities = true) - where TModel : class + where TModelIn : class + where TModelOut : class { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); if (!(binaryEstimator is ITrainerEstimator>, IPredictorProducing> est)) throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); - return new OneVersusAllTrainerTyped(env, est, labelColumnName, imputeMissingLabelsAsNegative, useProbabilities); + return new OneVersusAllTrainerTyped(env, est, labelColumnName, imputeMissingLabelsAsNegative, useProbabilities); } /// diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs index 9f94bc2560..ced55cced0 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs @@ -95,7 +95,7 @@ public void MetacomponentsFeaturesRenamed() var pipeline = new ColumnConcatenatingEstimator(Env, "Vars", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(new ValueToKeyMappingEstimator(Env, "Label"), TransformerScope.TrainTest) - .Append(ML.MulticlassClassification.Trainers.OneVersusAll(sdcaTrainer)) + .Append(ML.MulticlassClassification.Trainers.OneVersusAllStronglyTyped>(sdcaTrainer)) .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); var model = pipeline.Fit(data); From 5af4c01ef5cb8700ac1757f7e91a66e4ffe10c74 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Mon, 15 Jul 2019 09:39:19 -0700 Subject: [PATCH 05/19] fixes in methods --- .../OneVersusAllTrainer.cs | 9 ++++-- .../StandardTrainersCatalog.cs | 28 ++++++++++--------- .../TrainerEstimators/MetalinearEstimators.cs | 3 +- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index b1b712fdeb..095d404c52 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -798,20 +798,25 @@ internal OneVersusAllTrainerTyped(IHostEnvironment env, Options options) /// /// The instance. /// An instance of a binary used as the base trainer. + /// The calibrator. If a calibrator is not provided, it will default to /// The name of the label colum. /// If true will treat missing labels as negative labels. + /// Number of instances to train the calibrator. /// Use probabilities (vs. raw outputs) to identify top-score category. internal OneVersusAllTrainerTyped(IHostEnvironment env, TScalarTrainer binaryEstimator, string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, + ICalibratorTrainer calibrator = null, + int maximumCalibrationExampleCount = 1000000000, bool useProbabilities = true) : base(env, new Options { - ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative + ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative, + MaxCalibrationExamples = maximumCalibrationExampleCount, }, - LoadNameValue, labelColumnName, binaryEstimator) + LoadNameValue, labelColumnName, binaryEstimator, calibrator) { Host.CheckValue(labelColumnName, nameof(labelColumnName), "Label column should not be null."); _options = (Options)Args; diff --git a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs index 96fd802323..cc89a9ff92 100644 --- a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs @@ -765,7 +765,7 @@ public static OneVersusAllTrainer OneVersusAll(this MulticlassClassifica } /// - /// Create a , which predicts a multiclass target using one-versus-all strategy with + /// Create a , which predicts a multiclass target using one-versus-all strategy with /// the binary classification estimator specified by . /// /// @@ -787,25 +787,23 @@ public static OneVersusAllTrainer OneVersusAll(this MulticlassClassifica /// [!code-csharp[OneVersusAll](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/OneVersusAll.cs)] /// ]]> /// - public static OneVersusAllTrainerTyped OneVersusAll(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + public static OneVersusAllTrainerTyped OneVersusAllTyped(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, ITrainerEstimator, TModel> binaryEstimator, string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, - IEstimator> calibrator = null, - int maximumCalibrationExampleCount = 1000000000, bool useProbabilities = true) where TModel : class { - return OneVersusAllStronglyTyped( - catalog, - binaryEstimator, - DefaultColumnNames.Label, - false, - true); + return OneVersusAllTyped( + catalog: catalog, + binaryEstimator: binaryEstimator, + labelColumnName: labelColumnName, + imputeMissingLabelsAsNegative: imputeMissingLabelsAsNegative, + useProbabilities: useProbabilities); } /// - /// Create a , which predicts a multiclass target using one-versus-all strategy with + /// Create a , which predicts a multiclass target using one-versus-all strategy with /// the binary classification estimator specified by . /// /// @@ -817,8 +815,10 @@ public static OneVersusAllTrainerTyped OneVersusAll(this Multicl /// /// The multiclass classification catalog trainer object. /// An instance of a binary used as the base trainer. + /// The calibrator. If a calibrator is not explicitly provided, it will default to /// The name of the label column. /// Whether to treat missing labels as having negative labels, instead of keeping them missing. + /// Number of instances to train the calibrator. /// Use probabilities (vs. raw outputs) to identify top-score category. /// The type of the model. This type parameter will usually be inferred automatically from . /// The type of the model. This type parameter will usually be inferred automatically from . @@ -828,10 +828,12 @@ public static OneVersusAllTrainerTyped OneVersusAll(this Multicl /// [!code-csharp[OneVersusAll](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/OneVersusAll.cs)] /// ]]> /// - public static OneVersusAllTrainerTyped OneVersusAllStronglyTyped(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + public static OneVersusAllTrainerTyped OneVersusAllTyped(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, ITrainerEstimator, TModelIn> binaryEstimator, string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, + IEstimator> calibrator = null, + int maximumCalibrationExampleCount = 1000000000, bool useProbabilities = true) where TModelIn : class where TModelOut : class @@ -840,7 +842,7 @@ public static OneVersusAllTrainerTyped OneVersusAllStronglyTyped>, IPredictorProducing> est)) throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); - return new OneVersusAllTrainerTyped(env, est, labelColumnName, imputeMissingLabelsAsNegative, useProbabilities); + return new OneVersusAllTrainerTyped(env, est, labelColumnName, imputeMissingLabelsAsNegative, GetCalibratorTrainerOrThrow(env, calibrator), maximumCalibrationExampleCount, useProbabilities); } /// diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs index ced55cced0..54b58ae1f4 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs @@ -95,7 +95,8 @@ public void MetacomponentsFeaturesRenamed() var pipeline = new ColumnConcatenatingEstimator(Env, "Vars", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(new ValueToKeyMappingEstimator(Env, "Label"), TransformerScope.TrainTest) - .Append(ML.MulticlassClassification.Trainers.OneVersusAllStronglyTyped>(sdcaTrainer)) + //.Append(ML.MulticlassClassification.Trainers.OneVersusAllTyped>(sdcaTrainer)) + .Append(ML.MulticlassClassification.Trainers.OneVersusAllTyped(sdcaTrainer, useProbabilities: false)) .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); var model = pipeline.Fit(data); From 31b9aa8065c351c07e37f5f5ca8eb4e2bf8c660a Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Mon, 15 Jul 2019 11:43:55 -0700 Subject: [PATCH 06/19] temp checkpoint --- .../Prediction/Calibrator.cs | 9 +- .../OneVersusAllTrainer.cs | 185 +++++++++++++++++- .../StandardTrainersCatalog.cs | 1 + .../TrainerEstimators/MetalinearEstimators.cs | 31 ++- 4 files changed, 219 insertions(+), 7 deletions(-) diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 58f6fbfca2..6cb6478c4d 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -879,10 +879,12 @@ public static IPredictor GetCalibratedPredictor(IHostEnvironment env, IChannel c return CreateCalibratedPredictor(env, (IPredictorProducing)predictor, trainedCalibrator); } - public static CalibratedModelParametersBase GetCalibratedPredictor(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, - T predictor, RoleMappedData data, int maxRows = _maxCalibrationExamples) where T: class + public static CalibratedModelParametersBase GetCalibratedPredictor(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, + TSubPredictor predictor, RoleMappedData data, int maxRows = _maxCalibrationExamples) + where TSubPredictor : class + where TCalibrator : class, ICalibrator { - var trainedCalibrator = TrainCalibrator(env, ch, caliTrainer, (IPredictorProducing)predictor, data, maxRows); + var trainedCalibrator = TrainCalibrator(env, ch, caliTrainer, (IPredictorProducing)predictor, data, maxRows) as TCalibrator; var cp = CreateCalibratedPredictor(env, predictor, trainedCalibrator); return cp; } @@ -990,6 +992,7 @@ public static CalibratedModelParametersBase CreateCa if (predWithFeatureScores != null && predictor is IParameterMixer && cali is IParameterMixer) { var s = typeof(TSubPredictor); + var d = typeof(TCalibrator); var pm = new ParameterMixingCalibratedModelParameters(env, predictor, cali); return pm; } diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index 095d404c52..1cbd0b08bc 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -760,6 +760,187 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) } } + public sealed class OneVersusAllTrainerTypedT : MetaMulticlassTrainer>>, OneVersusAllModelParametersTyped>> where TSubPredictor : class where TCalibrator: class, ICalibrator + { + internal const string LoadNameValue = "OVA"; + internal const string UserNameValue = "One-vs-All"; + internal const string Summary = "In this strategy, a binary classification algorithm is used to train one classifier for each class, " + + "which distinguishes that class from all other classes. Prediction is then performed by running these binary classifiers, " + + "and choosing the prediction with the highest confidence score."; + + private readonly Options _options; + /// + /// Options passed to + /// + internal sealed class Options : OptionsBase + { + /// + /// Whether to use probabilities (vs. raw outputs) to identify top-score category. + /// + [Argument(ArgumentType.AtMostOnce, HelpText = "Use probability or margins to determine max", ShortName = "useprob")] + [TGUI(Label = "Use Probability", Description = "Use probabilities (vs. raw outputs) to identify top-score category")] + public bool UseProbabilities = true; + } + + /// + /// Constructs a trainer supplying a . + /// + /// The private for this estimator. + /// The legacy + internal OneVersusAllTrainerTypedT(IHostEnvironment env, Options options) + : base(env, options, LoadNameValue) + { + _options = options; + } + + /// + /// Initializes a new instance of . + /// + /// The instance. + /// An instance of a binary used as the base trainer. + /// The calibrator. If a calibrator is not provided, it will default to + /// The name of the label colum. + /// If true will treat missing labels as negative labels. + /// Number of instances to train the calibrator. + /// Use probabilities (vs. raw outputs) to identify top-score category. + internal OneVersusAllTrainerTypedT(IHostEnvironment env, + TScalarTrainer binaryEstimator, + string labelColumnName = DefaultColumnNames.Label, + bool imputeMissingLabelsAsNegative = false, + ICalibratorTrainer calibrator = null, + int maximumCalibrationExampleCount = 1000000000, + bool useProbabilities = true) + : base(env, + new Options + { + ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative, + MaxCalibrationExamples = maximumCalibrationExampleCount, + }, + LoadNameValue, labelColumnName, binaryEstimator, calibrator) + { + Host.CheckValue(labelColumnName, nameof(labelColumnName), "Label column should not be null."); + _options = (Options)Args; + _options.UseProbabilities = useProbabilities; + } + + private protected override OneVersusAllModelParametersTyped> TrainCore(IChannel ch, RoleMappedData data, int count) + { + // Train one-vs-all models. + var predictors = new CalibratedModelParametersBase[count]; + for (int i = 0; i < predictors.Length; i++) + { + ch.Info($"Training learner {i}"); + predictors[i] = (CalibratedModelParametersBase)TrainOne(ch, Trainer, data, i).Model; + } + return OneVersusAllModelParametersTyped>.Create(Host, _options.UseProbabilities, predictors); + } + + private dynamic TrainOne(IChannel ch, dynamic trainer, RoleMappedData data, int cls) + { + /*var view = MapLabels(data, cls); + + string trainerLabel = data.Schema.Label.Value.Name; + + // REVIEW: In principle we could support validation sets and the like via the train context, but + // this is currently unsupported. + var transformer = trainer.Fit(view); + + if (_options.UseProbabilities) + { + var calibratedModel = transformer.Model as TDistPredictor; + + // If probabilities are requested and the Predictor is not calibrated or if it doesn't implement the right interface then throw. + Host.Check(calibratedModel != null, "Predictor is either not calibrated or does not implement the expected interface"); + + // REVIEW: restoring the RoleMappedData, as much as we can. + // not having the weight column on the data passed to the TrainCalibrator should be addressed. + var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName); + + return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName); + } + + return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName);*/ + + var view = MapLabels(data, cls); + + string trainerLabel = data.Schema.Label.Value.Name; + + // REVIEW: In principle we could support validation sets and the like via the train context, but + // this is currently unsupported. + var transformer = trainer.Fit(view); + + if (_options.UseProbabilities) + { + var calibratedModel = transformer.Model as TDistPredictor; + + var s = transformer.Model.GetType(); + + // REVIEW: restoring the RoleMappedData, as much as we can. + // not having the weight column on the data passed to the TrainCalibrator should be addressed. + var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName); + + if (calibratedModel == null) + calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor; + + Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface"); + return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName); + } + + return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName); + } + + private IDataView MapLabels(RoleMappedData data, int cls) + { + var label = data.Schema.Label.Value; + Host.Assert(!label.IsHidden); + Host.Assert(label.Type.GetKeyCount() > 0 || label.Type == NumberDataViewType.Single || label.Type == NumberDataViewType.Double); + + if (label.Type.GetKeyCount() > 0) + { + // Key values are 1-based. + uint key = (uint)(cls + 1); + return MapLabelsCore(NumberDataViewType.UInt32, (in uint val) => key == val, data); + } + + throw Host.ExceptNotSupp($"Label column type is not supported by OneVersusAllTrainerTyped: {label.Type.RawType}"); + } + + /// Trains a model. + /// The input data. + /// A model./> + public override MulticlassPredictionTransformer>> Fit(IDataView input) + { + var roles = new KeyValuePair[1]; + roles[0] = new KeyValuePair(new CR(DefaultColumnNames.Label), LabelColumn.Name); + var td = new RoleMappedData(input, roles); + + td.CheckMulticlassLabel(out var numClasses); + + var predictors = new CalibratedModelParametersBase[numClasses]; + string featureColumn = null; + + using (var ch = Host.Start("Fitting")) + { + for (int i = 0; i < predictors.Length; i++) + { + ch.Info($"Training learner {i}"); + + if (i == 0) + { + var transformer = TrainOne(ch, Trainer, td, i); + featureColumn = transformer.FeatureColumnName; + } + var model = TrainOne(ch, Trainer, td, i).Model; + var m = model as CalibratedModelParametersBase; + predictors[i] = model as CalibratedModelParametersBase; + + } + } + + return new MulticlassPredictionTransformer>>(Host, OneVersusAllModelParametersTyped>.Create(Host, _options.UseProbabilities, predictors), input.Schema, featureColumn, LabelColumn.Name); + } + } + public sealed class OneVersusAllTrainerTyped : MetaMulticlassTrainer>, OneVersusAllModelParametersTyped> where T : class { internal const string LoadNameValue = "OVA"; @@ -835,7 +1016,7 @@ private protected override OneVersusAllModelParametersTyped TrainCore(IChanne return OneVersusAllModelParametersTyped.Create(Host, _options.UseProbabilities, predictors); } - private ISingleFeaturePredictionTransformer TrainOne(IChannel ch, dynamic trainer, RoleMappedData data, int cls) + private dynamic TrainOne(IChannel ch, dynamic trainer, RoleMappedData data, int cls) { /*var view = MapLabels(data, cls); @@ -880,7 +1061,7 @@ private ISingleFeaturePredictionTransformer TrainOne(IChannel var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName); if (calibratedModel == null) - calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor; + calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor; Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface"); return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName); diff --git a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs index cc89a9ff92..8259cacd81 100644 --- a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs @@ -758,6 +758,7 @@ public static OneVersusAllTrainer OneVersusAll(this MulticlassClassifica where TModel : class { Contracts.CheckValue(catalog, nameof(catalog)); + var s = typeof(TModel); var env = CatalogUtils.GetEnvironment(catalog); if (!(binaryEstimator is ITrainerEstimator>, IPredictorProducing> est)) throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs index 54b58ae1f4..30e9f8d77d 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs @@ -52,6 +52,23 @@ public void OVAUncalibrated() Done(); } + /// + /// OVA strongly typed un-calibrated + /// + [Fact] + public void OVATypedUncalibrated() + { + var (pipeline, data) = GetMulticlassPipeline(); + var sdcaTrainer = ML.BinaryClassification.Trainers.SdcaNonCalibrated( + new SdcaNonCalibratedBinaryTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1 }); + + pipeline = pipeline.Append(ML.MulticlassClassification.Trainers.OneVersusAllTyped(sdcaTrainer, useProbabilities: false)) + .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); + + TestEstimatorCore(pipeline, data); + Done(); + } + /// /// Pairwise Coupling trainer /// @@ -93,10 +110,20 @@ public void MetacomponentsFeaturesRenamed() NumberOfThreads = 1, }); + var sdca = ML.BinaryClassification.Trainers.SgdCalibrated( + new SgdCalibratedTrainer.Options + { + LabelColumnName = "Label", + FeatureColumnName = "Vars", + Shuffle = true, + NumberOfThreads = 1, + }); + var pipeline = new ColumnConcatenatingEstimator(Env, "Vars", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(new ValueToKeyMappingEstimator(Env, "Label"), TransformerScope.TrainTest) - //.Append(ML.MulticlassClassification.Trainers.OneVersusAllTyped>(sdcaTrainer)) - .Append(ML.MulticlassClassification.Trainers.OneVersusAllTyped(sdcaTrainer, useProbabilities: false)) + .Append(ML.MulticlassClassification.Trainers.OneVersusAllTyped>(sdcaTrainer)) + + //.Append(ML.MulticlassClassification.Trainers.OneVersusAllTyped(sdca)) .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); var model = pipeline.Fit(data); From 82ec5e26634b6b1f9a260630802dd962325f1a69 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Mon, 15 Jul 2019 15:07:17 -0700 Subject: [PATCH 07/19] all working. Needs more test cases --- .../Prediction/Calibrator.cs | 20 +- .../OneVersusAllTrainer.cs | 1128 ++++------------- .../StandardTrainersCatalog.cs | 21 +- .../TrainerEstimators/MetalinearEstimators.cs | 2 +- 4 files changed, 256 insertions(+), 915 deletions(-) diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 6cb6478c4d..1179a0f306 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -881,12 +881,11 @@ public static IPredictor GetCalibratedPredictor(IHostEnvironment env, IChannel c public static CalibratedModelParametersBase GetCalibratedPredictor(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, TSubPredictor predictor, RoleMappedData data, int maxRows = _maxCalibrationExamples) - where TSubPredictor : class + where TSubPredictor : class, IPredictorProducing where TCalibrator : class, ICalibrator { - var trainedCalibrator = TrainCalibrator(env, ch, caliTrainer, (IPredictorProducing)predictor, data, maxRows) as TCalibrator; - var cp = CreateCalibratedPredictor(env, predictor, trainedCalibrator); - return cp; + var trainedCalibrator = TrainCalibrator(env, ch, caliTrainer, predictor, data, maxRows) as TCalibrator; + return (CalibratedModelParametersBase)CreateCalibratedPredictor(env, predictor, trainedCalibrator); } public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, IDataView scored, string labelColumn, string scoreColumn, string weightColumn = null, int maxRows = _maxCalibrationExamples) @@ -972,13 +971,13 @@ public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICa return TrainCalibrator(env, ch, caliTrainer, scored, data.Schema.Label.Value.Name, DefaultColumnNames.Score, data.Schema.Weight?.Name, maxRows); } - public static CalibratedModelParametersBase CreateCalibratedPredictor(IHostEnvironment env, TSubPredictor predictor, TCalibrator cali) - where TSubPredictor : class + public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironment env, TSubPredictor predictor, TCalibrator cali) + where TSubPredictor : class, IPredictorProducing where TCalibrator : class, ICalibrator { Contracts.Assert(predictor != null); - //if (cali == null) - // return predictor; + if (cali == null) + return predictor; for (; ; ) { @@ -991,10 +990,7 @@ public static CalibratedModelParametersBase CreateCa var predWithFeatureScores = predictor as IPredictorWithFeatureWeights; if (predWithFeatureScores != null && predictor is IParameterMixer && cali is IParameterMixer) { - var s = typeof(TSubPredictor); - var d = typeof(TCalibrator); - var pm = new ParameterMixingCalibratedModelParameters(env, predictor, cali); - return pm; + return new ParameterMixingCalibratedModelParameters(env, predictor, cali); } if (predictor is IValueMapper) diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index 1cbd0b08bc..7a48d63e4a 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -37,52 +37,8 @@ namespace Microsoft.ML.Trainers using TDistPredictor = IDistPredictorProducing; using TScalarPredictor = IPredictorProducing; using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; - /// - /// The for training a one-versus-all multi-class classifier that uses the specified binary classifier. - /// - /// - /// , - /// can be different from , which develops a multi-class classifier directly. - /// Note that even if the classifier indicates that it does not need caching, OneVersusAll will always - /// request caching, as it will be performing multiple passes over the data set. - /// This trainer will request normalization from the data pipeline if the classifier indicates it would benefit from it. - /// - /// This can allow you to exploit trainers that do not naturally have a - /// multiclass option, for example, using the - /// to solve a multiclass problem. - /// Alternately, it can allow ML.NET to solve a "simpler" problem even in the cases - /// where the trainer has a multiclass option, but using it directly is not - /// practical due to, usually, memory constraints. For example, while a multiclass - /// logistic regression is a more principled way to solve a multiclass problem, it - /// requires that the trainer store a lot more intermediate state in the form of - /// L-BFGS history for all classes *simultaneously*, rather than just one-by-one - /// as would be needed for a one-versus-all classification model. - /// - /// Check the See Also section for links to usage examples. - /// ]]> - /// - /// - /// - public sealed class OneVersusAllTrainer : MetaMulticlassTrainer, OneVersusAllModelParameters> + + public abstract class OneVersusAllTrainerBase : MetaMulticlassTrainer>, OneVersusAllModelParameters> where T : class { internal const string LoadNameValue = "OVA"; internal const string UserNameValue = "One-vs-All"; @@ -93,7 +49,7 @@ public sealed class OneVersusAllTrainer : MetaMulticlassTrainer - /// Options passed to + /// Options passed to /// internal sealed class Options : OptionsBase { @@ -106,27 +62,27 @@ internal sealed class Options : OptionsBase } /// - /// Constructs a trainer supplying a . + /// Constructs a trainer supplying a . /// /// The private for this estimator. /// The legacy - internal OneVersusAllTrainer(IHostEnvironment env, Options options) + internal OneVersusAllTrainerBase(IHostEnvironment env, Options options) : base(env, options, LoadNameValue) { _options = options; } /// - /// Initializes a new instance of . + /// Initializes a new instance of . /// /// The instance. /// An instance of a binary used as the base trainer. - /// The calibrator. If a calibrator is not provided, it will default to + /// /// The calibrator. If a calibrator is not provided, it will default to /// The name of the label colum. /// If true will treat missing labels as negative labels. - /// Number of instances to train the calibrator. + /// /// Number of instances to train the calibrator. /// Use probabilities (vs. raw outputs) to identify top-score category. - internal OneVersusAllTrainer(IHostEnvironment env, + internal OneVersusAllTrainerBase(IHostEnvironment env, TScalarTrainer binaryEstimator, string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, @@ -146,16 +102,16 @@ internal OneVersusAllTrainer(IHostEnvironment env, _options.UseProbabilities = useProbabilities; } - private protected override OneVersusAllModelParameters TrainCore(IChannel ch, RoleMappedData data, int count) + private protected override OneVersusAllModelParameters TrainCore(IChannel ch, RoleMappedData data, int count) { // Train one-vs-all models. - var predictors = new TScalarPredictor[count]; + var predictors = new T[count]; for (int i = 0; i < predictors.Length; i++) { ch.Info($"Training learner {i}"); - predictors[i] = TrainOne(ch, Trainer, data, i).Model; + predictors[i] = (T)TrainOne(ch, Trainer, data, i).Model; } - return OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors); + return OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors); } private ISingleFeaturePredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls) @@ -168,24 +124,13 @@ private ISingleFeaturePredictionTransformer TrainOne(IChannel // this is currently unsupported. var transformer = trainer.Fit(view); - if (_options.UseProbabilities) - { - var calibratedModel = transformer.Model as TDistPredictor; - - // REVIEW: restoring the RoleMappedData, as much as we can. - // not having the weight column on the data passed to the TrainCalibrator should be addressed. - var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName); - - if (calibratedModel == null) - calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor; - - Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface"); - return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName); - } - - return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName); + return TrainOneHelper(ch, _options.UseProbabilities, view, trainerLabel, transformer); } + private protected abstract ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch, + bool useProbabilities, IDataView view, string trainerLabel, + ISingleFeaturePredictionTransformer transformer); + private IDataView MapLabels(RoleMappedData data, int cls) { var label = data.Schema.Label.Value; @@ -205,7 +150,7 @@ private IDataView MapLabels(RoleMappedData data, int cls) /// Trains a model. /// The input data. /// A model./> - public override MulticlassPredictionTransformer Fit(IDataView input) + public override MulticlassPredictionTransformer> Fit(IDataView input) { var roles = new KeyValuePair[1]; roles[0] = new KeyValuePair(new CR(DefaultColumnNames.Label), LabelColumn.Name); @@ -213,7 +158,7 @@ public override MulticlassPredictionTransformer Fit td.CheckMulticlassLabel(out var numClasses); - var predictors = new TScalarPredictor[numClasses]; + var predictors = new T[numClasses]; string featureColumn = null; using (var ch = Host.Start("Fitting")) @@ -227,660 +172,107 @@ public override MulticlassPredictionTransformer Fit var transformer = TrainOne(ch, Trainer, td, i); featureColumn = transformer.FeatureColumnName; } + predictors[i] = (T)TrainOne(ch, Trainer, td, i).Model; - predictors[i] = TrainOne(ch, Trainer, td, i).Model; } } - return new MulticlassPredictionTransformer(Host, OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors), input.Schema, featureColumn, LabelColumn.Name); + return new MulticlassPredictionTransformer>(Host, OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors), input.Schema, featureColumn, LabelColumn.Name); } } /// - /// Model parameters for . + /// The for training a one-versus-all multi-class classifier that uses the specified binary classifier. /// - public sealed class OneVersusAllModelParameters : - ModelParametersBase>, - IValueMapper, - ICanSaveInSourceCode, - ICanSaveInTextFormat, - ISingleCanSavePfa + /// + /// , + /// can be different from , which develops a multi-class classifier directly. + /// Note that even if the classifier indicates that it does not need caching, OneVersusAll will always + /// request caching, as it will be performing multiple passes over the data set. + /// This trainer will request normalization from the data pipeline if the classifier indicates it would benefit from it. + /// + /// This can allow you to exploit trainers that do not naturally have a + /// multiclass option, for example, using the + /// to solve a multiclass problem. + /// Alternately, it can allow ML.NET to solve a "simpler" problem even in the cases + /// where the trainer has a multiclass option, but using it directly is not + /// practical due to, usually, memory constraints. For example, while a multiclass + /// logistic regression is a more principled way to solve a multiclass problem, it + /// requires that the trainer store a lot more intermediate state in the form of + /// L-BFGS history for all classes *simultaneously*, rather than just one-by-one + /// as would be needed for a one-versus-all classification model. + /// + /// Check the See Also section for links to usage examples. + /// ]]> + /// + /// + /// + public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase { - internal const string LoaderSignature = "OVAExec"; - internal const string RegistrationName = "OVAPredictor"; - - private static VersionInfo GetVersionInfo() - { - return new VersionInfo( - modelSignature: "TLC OVA ", - verWrittenCur: 0x00010001, // Initial - verReadableCur: 0x00010001, - verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(OneVersusAllModelParameters).Assembly.FullName); - } - - private const string SubPredictorFmt = "SubPredictor_{0:000}"; - - private readonly ImplBase _impl; - - /// - /// Retrieves the model parameters. - /// - internal ImmutableArray SubModelParameters => _impl.Predictors.Cast().ToImmutableArray(); - - /// - /// The type of the prediction task. - /// - private protected override PredictionKind PredictionKind => PredictionKind.MulticlassClassification; - /// - /// Function applied to output of predictors. Assume that we have n predictors (one per class) and for the i-th predictor, - /// y_i is its raw output and p_i is its probability output. Note that not all predictors are able to produce probability output. - /// - /// : output the result of predictors without post-processing. Output is [y_1, ..., y_n]. - /// : fetch probability output of each class probability from provided predictors and make sure the sume of class probabilities is one. - /// Output is [p_1 / (p_1 + ... + p_n), ..., p_n / (p_1 + ... + p_n)]. - /// : Generate probability by feeding raw outputs to softmax function. Output is [z_1, ..., z_n], where z_i is exp(y_i) / (exp(y_1) + ... + exp(y_n)). - /// + /// Constructs a trainer supplying a . /// - [BestFriend] - internal enum OutputFormula { Raw = 0, ProbabilityNormalization = 1, Softmax = 2 }; - - private DataViewType DistType { get; } - - bool ICanSavePfa.CanSavePfa => _impl.CanSavePfa; - - [BestFriend] - internal static OneVersusAllModelParameters Create(IHost host, OutputFormula outputFormula, TScalarPredictor[] predictors) - { - ImplBase impl; - - using (var ch = host.Start("Creating OVA predictor")) - { - if (outputFormula == OutputFormula.Softmax) - { - impl = new ImplSoftmax(predictors); - return new OneVersusAllModelParameters(host, impl); - } - - // Caller of this function asks for probability output. We check if input predictor can produce probability. - // If that predictor can't produce probability, ivmd will be null. - IValueMapperDist ivmd = null; - if (outputFormula == OutputFormula.ProbabilityNormalization && - ((ivmd = predictors[0] as IValueMapperDist) == null || - ivmd.OutputType != NumberDataViewType.Single || - ivmd.DistType != NumberDataViewType.Single)) - { - ch.Warning($"{nameof(OneVersusAllTrainer.Options.UseProbabilities)} specified with {nameof(OneVersusAllTrainer.Options.PredictorType)} that can't produce probabilities."); - ivmd = null; - } - - // If ivmd is null, either the user didn't ask for probability or the provided predictors can't produce probability. - if (ivmd != null) - { - var dists = new IValueMapperDist[predictors.Length]; - for (int i = 0; i < predictors.Length; ++i) - dists[i] = (IValueMapperDist)predictors[i]; - impl = new ImplDist(dists); - } - else - impl = new ImplRaw(predictors); - } - - return new OneVersusAllModelParameters(host, impl); - } - - [BestFriend] - internal static OneVersusAllModelParameters Create(IHost host, bool useProbability, TScalarPredictor[] predictors) + /// The private for this estimator. + /// The legacy + internal OneVersusAllTrainer(IHostEnvironment env, Options options) + : base(env, options) { - var outputFormula = useProbability ? OutputFormula.ProbabilityNormalization : OutputFormula.Raw; - - return Create(host, outputFormula, predictors); } /// - /// Create a from an array of predictors. - /// - [BestFriend] - internal static OneVersusAllModelParameters Create(IHost host, TScalarPredictor[] predictors) - { - Contracts.CheckValue(host, nameof(host)); - host.CheckNonEmpty(predictors, nameof(predictors)); - return Create(host, OutputFormula.ProbabilityNormalization, predictors); - } - - private OneVersusAllModelParameters(IHostEnvironment env, ImplBase impl) - : base(env, RegistrationName) - { - Host.AssertValue(impl, nameof(impl)); - Host.Assert(Utils.Size(impl.Predictors) > 0); - - _impl = impl; - DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length); - } - - private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) - : base(env, RegistrationName, ctx) - { - // *** Binary format *** - // bool: useDist - // int: predictor count - bool useDist = ctx.Reader.ReadBoolByte(); - int len = ctx.Reader.ReadInt32(); - Host.CheckDecode(len > 0); - - if (useDist) - { - var predictors = new IValueMapperDist[len]; - LoadPredictors(Host, predictors, ctx); - _impl = new ImplDist(predictors); - } - else - { - var predictors = new TScalarPredictor[len]; - LoadPredictors(Host, predictors, ctx); - _impl = new ImplRaw(predictors); - } - - DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length); - } - - private static OneVersusAllModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(GetVersionInfo()); - return new OneVersusAllModelParameters(env, ctx); - } - - private static void LoadPredictors(IHostEnvironment env, TPredictor[] predictors, ModelLoadContext ctx) - where TPredictor : class - { - for (int i = 0; i < predictors.Length; i++) - ctx.LoadModel(env, out predictors[i], string.Format(SubPredictorFmt, i)); - } - - private protected override void SaveCore(ModelSaveContext ctx) - { - base.SaveCore(ctx); - ctx.SetVersionInfo(GetVersionInfo()); - - var preds = _impl.Predictors; - - // *** Binary format *** - // bool: useDist - // int: predictor count - ctx.Writer.WriteBoolByte(_impl is ImplDist); - ctx.Writer.Write(preds.Length); - - // Save other streams. - for (int i = 0; i < preds.Length; i++) - ctx.SaveModel(preds[i], string.Format(SubPredictorFmt, i)); - } - - JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) - { - Host.CheckValue(ctx, nameof(ctx)); - Host.CheckValue(input, nameof(input)); - return _impl.SaveAsPfa(ctx, input); - } - - DataViewType IValueMapper.InputType - { - get { return _impl.InputType; } - } - - DataViewType IValueMapper.OutputType - { - get { return DistType; } - } - ValueMapper IValueMapper.GetMapper() - { - Host.Check(typeof(TIn) == typeof(VBuffer)); - Host.Check(typeof(TOut) == typeof(VBuffer)); - - return (ValueMapper)(Delegate)_impl.GetMapper(); - } - - void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema) - { - Host.CheckValue(writer, nameof(writer)); - Host.CheckValue(schema, nameof(schema)); - - var preds = _impl.Predictors; - writer.WriteLine("double[] outputs = new double[{0}];", preds.Length); - - for (int i = 0; i < preds.Length; i++) - { - var saveInSourceCode = preds[i] as ICanSaveInSourceCode; - Host.Check(saveInSourceCode != null, "Saving in code is not supported."); - - writer.WriteLine("{"); - saveInSourceCode.SaveAsCode(writer, schema); - writer.WriteLine("outputs[{0}] = output;", i); - writer.WriteLine("}"); - } - } - - void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) - { - Host.CheckValue(writer, nameof(writer)); - Host.CheckValue(schema, nameof(schema)); - - var preds = _impl.Predictors; - - for (int i = 0; i < preds.Length; i++) - { - var saveInText = preds[i] as ICanSaveInTextFormat; - Host.Check(saveInText != null, "Saving in text is not supported."); - - writer.WriteLine("#region: class-{0} classifier", i); - saveInText.SaveAsText(writer, schema); - - writer.WriteLine("#endregion: class-{0} classifier", i); - writer.WriteLine(); - } - } - - private abstract class ImplBase : ISingleCanSavePfa - { - public abstract DataViewType InputType { get; } - public abstract IValueMapper[] Predictors { get; } - public abstract bool CanSavePfa { get; } - public abstract ValueMapper, VBuffer> GetMapper(); - public abstract JToken SaveAsPfa(BoundPfaContext ctx, JToken input); - - protected bool IsValid(IValueMapper mapper, ref VectorDataViewType inputType) - { - Contracts.AssertValueOrNull(mapper); - Contracts.AssertValueOrNull(inputType); - - if (mapper == null) - return false; - if (mapper.OutputType != NumberDataViewType.Single) - return false; - if (!(mapper.InputType is VectorDataViewType mapperVectorType) || mapperVectorType.ItemType != NumberDataViewType.Single) - return false; - if (inputType == null) - inputType = mapperVectorType; - else if (inputType.Size != mapperVectorType.Size) - { - if (inputType.Size == 0) - inputType = mapperVectorType; - else if (mapperVectorType.Size != 0) - return false; - } - return true; - } - } - - private sealed class ImplRaw : ImplBase - { - public override DataViewType InputType { get; } - public override IValueMapper[] Predictors { get; } - public override bool CanSavePfa { get; } - - internal ImplRaw(TScalarPredictor[] predictors) - { - Contracts.CheckNonEmpty(predictors, nameof(predictors)); - - Predictors = new IValueMapper[predictors.Length]; - VectorDataViewType inputType = null; - for (int i = 0; i < predictors.Length; i++) - { - var vm = predictors[i] as IValueMapper; - Contracts.Check(IsValid(vm, ref inputType), "Predictor doesn't implement the expected interface"); - Predictors[i] = vm; - } - CanSavePfa = Predictors.All(m => (m as ISingleCanSavePfa)?.CanSavePfa == true); - Contracts.AssertValue(inputType); - InputType = inputType; - } - - public override ValueMapper, VBuffer> GetMapper() - { - var maps = new ValueMapper, float>[Predictors.Length]; - for (int i = 0; i < Predictors.Length; i++) - maps[i] = Predictors[i].GetMapper, float>(); - - var buffer = new float[maps.Length]; - return - (in VBuffer src, ref VBuffer dst) => - { - int inputSize = InputType.GetVectorSize(); - if (inputSize > 0) - Contracts.Check(src.Length == inputSize); - - var tmp = src; - Parallel.For(0, maps.Length, i => maps[i](in tmp, ref buffer[i])); - - var editor = VBufferEditor.Create(ref dst, maps.Length); - buffer.CopyTo(editor.Values); - dst = editor.Commit(); - }; - } - - public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) - { - Contracts.CheckValue(ctx, nameof(ctx)); - Contracts.CheckValue(input, nameof(input)); - Contracts.Assert(CanSavePfa); - - JArray rootObjects = new JArray(); - for (int i = 0; i < Predictors.Length; ++i) - { - var pred = (ISingleCanSavePfa)Predictors[i]; - Contracts.Assert(pred.CanSavePfa); - rootObjects.Add(ctx.DeclareVar(null, pred.SaveAsPfa(ctx, input))); - } - JObject jobj = null; - return jobj.AddReturn("type", PfaUtils.Type.Array(PfaUtils.Type.Double)).AddReturn("new", rootObjects); - } - } - - private sealed class ImplDist : ImplBase - { - private readonly IValueMapperDist[] _mappers; - public override DataViewType InputType { get; } - public override IValueMapper[] Predictors => _mappers; - public override bool CanSavePfa { get; } - - internal ImplDist(IValueMapperDist[] predictors) - { - Contracts.Check(Utils.Size(predictors) > 0); - - _mappers = new IValueMapperDist[predictors.Length]; - VectorDataViewType inputType = null; - for (int i = 0; i < predictors.Length; i++) - { - var vm = predictors[i]; - Contracts.Check(IsValid(vm, ref inputType), "Predictor doesn't implement the expected interface"); - _mappers[i] = vm; - } - CanSavePfa = Predictors.All(m => (m as IDistCanSavePfa)?.CanSavePfa == true); - Contracts.AssertValue(inputType); - InputType = inputType; - } - - private bool IsValid(IValueMapperDist mapper, ref VectorDataViewType inputType) - { - return base.IsValid(mapper, ref inputType) && mapper.DistType == NumberDataViewType.Single; - } - - /// - /// Each predictor produces a probability of a class. All classes' probabilities are normalized so that - /// their sum is one. - /// - public override ValueMapper, VBuffer> GetMapper() - { - var maps = new ValueMapper, float, float>[Predictors.Length]; - for (int i = 0; i < Predictors.Length; i++) - maps[i] = _mappers[i].GetMapper, float, float>(); - - var buffer = new float[maps.Length]; - return - (in VBuffer src, ref VBuffer dst) => - { - int inputSize = InputType.GetVectorSize(); - if (inputSize > 0) - Contracts.Check(src.Length == inputSize); - - var tmp = src; - Parallel.For(0, maps.Length, - i => - { - float score = 0; - // buffer[i] is the probability of the i-th class. - // score is the raw prediction score. - maps[i](in tmp, ref score, ref buffer[i]); - }); - - // buffer[i] is the probability of the i-th class. - // score is the raw prediction score. - NormalizeSumToOne(buffer, maps.Length); - - var editor = VBufferEditor.Create(ref dst, maps.Length); - buffer.CopyTo(editor.Values); - dst = editor.Commit(); - }; - } - - private void NormalizeSumToOne(float[] output, int count) - { - // Clamp to zero and normalize. - Double sum = 0; - for (int i = 0; i < count; i++) - { - var value = output[i]; - if (float.IsNaN(value)) - continue; - - if (value >= 0) - sum += value; - else - output[i] = 0; - } - - if (sum > 0) - { - for (int i = 0; i < count; i++) - output[i] = (float)(output[i] / sum); - } - } - - public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) - { - Contracts.CheckValue(ctx, nameof(ctx)); - Contracts.CheckValue(input, nameof(input)); - Contracts.Assert(CanSavePfa); - - JArray rootObjects = new JArray(); - for (int i = 0; i < Predictors.Length; ++i) - { - var pred = (IDistCanSavePfa)Predictors[i]; - Contracts.Assert(pred.CanSavePfa); - pred.SaveAsPfa(ctx, input, null, out JToken scoreToken, null, out JToken probToken); - rootObjects.Add(probToken); - } - JObject jobj = null; - var rootResult = jobj.AddReturn("type", PfaUtils.Type.Array(PfaUtils.Type.Double)).AddReturn("new", rootObjects); - var resultVar = ctx.DeclareVar(null, rootResult); - var factorVar = ctx.DeclareVar(null, PfaUtils.Call("/", 1.0, PfaUtils.Call("a.sum", resultVar))); - return PfaUtils.Call("la.scale", resultVar, factorVar); - } - } - - private sealed class ImplSoftmax : ImplBase - { - public override DataViewType InputType { get; } - public override IValueMapper[] Predictors { get; } - public override bool CanSavePfa { get; } - - internal ImplSoftmax(TScalarPredictor[] predictors) - { - Contracts.CheckNonEmpty(predictors, nameof(predictors)); - - Predictors = new IValueMapper[predictors.Length]; - VectorDataViewType inputType = null; - for (int i = 0; i < predictors.Length; i++) - { - var vm = predictors[i] as IValueMapper; - Contracts.Check(IsValid(vm, ref inputType), "Predictor doesn't implement the expected interface"); - Predictors[i] = vm; - } - CanSavePfa = false; - Contracts.AssertValue(inputType); - InputType = inputType; - } - - public override ValueMapper, VBuffer> GetMapper() - { - var maps = new ValueMapper, float>[Predictors.Length]; - for (int i = 0; i < Predictors.Length; i++) - maps[i] = Predictors[i].GetMapper, float>(); - - var buffer = new float[maps.Length]; - return - (in VBuffer src, ref VBuffer dst) => - { - int inputSize = InputType.GetVectorSize(); - if (inputSize > 0) - Contracts.Check(src.Length == inputSize); - - var tmp = src; - Parallel.For(0, maps.Length, i => maps[i](in tmp, ref buffer[i])); - NormalizeSoftmax(buffer, maps.Length); - - var editor = VBufferEditor.Create(ref dst, maps.Length); - buffer.CopyTo(editor.Values); - dst = editor.Commit(); - }; - } - - private void NormalizeSoftmax(float[] scores, int count) - { - double sum = 0; - var score = new double[count]; - - for (int i = 0; i < count; i++) - { - score[i] = Math.Exp(scores[i]); - sum += score[i]; - } - - for (int i = 0; i < count; i++) - scores[i] = (float)(score[i] / sum); - } - - public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) - { - throw new NotImplementedException("Softmax's PFA exporter is not implemented yet."); - } - } - } - - public sealed class OneVersusAllTrainerTypedT : MetaMulticlassTrainer>>, OneVersusAllModelParametersTyped>> where TSubPredictor : class where TCalibrator: class, ICalibrator - { - internal const string LoadNameValue = "OVA"; - internal const string UserNameValue = "One-vs-All"; - internal const string Summary = "In this strategy, a binary classification algorithm is used to train one classifier for each class, " - + "which distinguishes that class from all other classes. Prediction is then performed by running these binary classifiers, " - + "and choosing the prediction with the highest confidence score."; - - private readonly Options _options; - /// - /// Options passed to - /// - internal sealed class Options : OptionsBase - { - /// - /// Whether to use probabilities (vs. raw outputs) to identify top-score category. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Use probability or margins to determine max", ShortName = "useprob")] - [TGUI(Label = "Use Probability", Description = "Use probabilities (vs. raw outputs) to identify top-score category")] - public bool UseProbabilities = true; - } - - /// - /// Constructs a trainer supplying a . - /// - /// The private for this estimator. - /// The legacy - internal OneVersusAllTrainerTypedT(IHostEnvironment env, Options options) - : base(env, options, LoadNameValue) - { - _options = options; - } - - /// - /// Initializes a new instance of . + /// Initializes a new instance of . /// /// The instance. /// An instance of a binary used as the base trainer. - /// The calibrator. If a calibrator is not provided, it will default to + /// /// The calibrator. If a calibrator is not provided, it will default to /// The name of the label colum. /// If true will treat missing labels as negative labels. - /// Number of instances to train the calibrator. + /// /// Number of instances to train the calibrator. /// Use probabilities (vs. raw outputs) to identify top-score category. - internal OneVersusAllTrainerTypedT(IHostEnvironment env, + internal OneVersusAllTrainer(IHostEnvironment env, TScalarTrainer binaryEstimator, - string labelColumnName = DefaultColumnNames.Label, - bool imputeMissingLabelsAsNegative = false, - ICalibratorTrainer calibrator = null, - int maximumCalibrationExampleCount = 1000000000, - bool useProbabilities = true) - : base(env, - new Options - { - ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative, - MaxCalibrationExamples = maximumCalibrationExampleCount, - }, - LoadNameValue, labelColumnName, binaryEstimator, calibrator) - { - Host.CheckValue(labelColumnName, nameof(labelColumnName), "Label column should not be null."); - _options = (Options)Args; - _options.UseProbabilities = useProbabilities; - } - - private protected override OneVersusAllModelParametersTyped> TrainCore(IChannel ch, RoleMappedData data, int count) - { - // Train one-vs-all models. - var predictors = new CalibratedModelParametersBase[count]; - for (int i = 0; i < predictors.Length; i++) - { - ch.Info($"Training learner {i}"); - predictors[i] = (CalibratedModelParametersBase)TrainOne(ch, Trainer, data, i).Model; - } - return OneVersusAllModelParametersTyped>.Create(Host, _options.UseProbabilities, predictors); - } - - private dynamic TrainOne(IChannel ch, dynamic trainer, RoleMappedData data, int cls) - { - /*var view = MapLabels(data, cls); - - string trainerLabel = data.Schema.Label.Value.Name; - - // REVIEW: In principle we could support validation sets and the like via the train context, but - // this is currently unsupported. - var transformer = trainer.Fit(view); - - if (_options.UseProbabilities) - { - var calibratedModel = transformer.Model as TDistPredictor; - - // If probabilities are requested and the Predictor is not calibrated or if it doesn't implement the right interface then throw. - Host.Check(calibratedModel != null, "Predictor is either not calibrated or does not implement the expected interface"); - - // REVIEW: restoring the RoleMappedData, as much as we can. - // not having the weight column on the data passed to the TrainCalibrator should be addressed. - var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName); - - return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName); - } - - return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName);*/ - - var view = MapLabels(data, cls); - - string trainerLabel = data.Schema.Label.Value.Name; - - // REVIEW: In principle we could support validation sets and the like via the train context, but - // this is currently unsupported. - var transformer = trainer.Fit(view); + string labelColumnName = DefaultColumnNames.Label, + bool imputeMissingLabelsAsNegative = false, + ICalibratorTrainer calibrator = null, + int maximumCalibrationExampleCount = 1000000000, + bool useProbabilities = true) + : base(env, binaryEstimator, labelColumnName, imputeMissingLabelsAsNegative, calibrator, maximumCalibrationExampleCount, useProbabilities) + { + } - if (_options.UseProbabilities) + private protected override ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch, + bool useProbabilities, IDataView view, string trainerLabel, + ISingleFeaturePredictionTransformer transformer) + { + if (useProbabilities) { var calibratedModel = transformer.Model as TDistPredictor; - var s = transformer.Model.GetType(); - // REVIEW: restoring the RoleMappedData, as much as we can. // not having the weight column on the data passed to the TrainCalibrator should be addressed. var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName); if (calibratedModel == null) - calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor; + calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor; Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface"); return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName); @@ -888,90 +280,20 @@ private dynamic TrainOne(IChannel ch, dynamic trainer, RoleMappedData data, int return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName); } - - private IDataView MapLabels(RoleMappedData data, int cls) - { - var label = data.Schema.Label.Value; - Host.Assert(!label.IsHidden); - Host.Assert(label.Type.GetKeyCount() > 0 || label.Type == NumberDataViewType.Single || label.Type == NumberDataViewType.Double); - - if (label.Type.GetKeyCount() > 0) - { - // Key values are 1-based. - uint key = (uint)(cls + 1); - return MapLabelsCore(NumberDataViewType.UInt32, (in uint val) => key == val, data); - } - - throw Host.ExceptNotSupp($"Label column type is not supported by OneVersusAllTrainerTyped: {label.Type.RawType}"); - } - - /// Trains a model. - /// The input data. - /// A model./> - public override MulticlassPredictionTransformer>> Fit(IDataView input) - { - var roles = new KeyValuePair[1]; - roles[0] = new KeyValuePair(new CR(DefaultColumnNames.Label), LabelColumn.Name); - var td = new RoleMappedData(input, roles); - - td.CheckMulticlassLabel(out var numClasses); - - var predictors = new CalibratedModelParametersBase[numClasses]; - string featureColumn = null; - - using (var ch = Host.Start("Fitting")) - { - for (int i = 0; i < predictors.Length; i++) - { - ch.Info($"Training learner {i}"); - - if (i == 0) - { - var transformer = TrainOne(ch, Trainer, td, i); - featureColumn = transformer.FeatureColumnName; - } - var model = TrainOne(ch, Trainer, td, i).Model; - var m = model as CalibratedModelParametersBase; - predictors[i] = model as CalibratedModelParametersBase; - - } - } - - return new MulticlassPredictionTransformer>>(Host, OneVersusAllModelParametersTyped>.Create(Host, _options.UseProbabilities, predictors), input.Schema, featureColumn, LabelColumn.Name); - } } - public sealed class OneVersusAllTrainerTyped : MetaMulticlassTrainer>, OneVersusAllModelParametersTyped> where T : class + public sealed class OneVersusAllTrainerTyped : OneVersusAllTrainerBase> + where TSubPredictor: class, IPredictorProducing + where TCalibrator: class, ICalibrator { - internal const string LoadNameValue = "OVA"; - internal const string UserNameValue = "One-vs-All"; - internal const string Summary = "In this strategy, a binary classification algorithm is used to train one classifier for each class, " - + "which distinguishes that class from all other classes. Prediction is then performed by running these binary classifiers, " - + "and choosing the prediction with the highest confidence score."; - - private readonly Options _options; - /// - /// Options passed to - /// - internal sealed class Options : OptionsBase - { - /// - /// Whether to use probabilities (vs. raw outputs) to identify top-score category. - /// - [Argument(ArgumentType.AtMostOnce, HelpText = "Use probability or margins to determine max", ShortName = "useprob")] - [TGUI(Label = "Use Probability", Description = "Use probabilities (vs. raw outputs) to identify top-score category")] - public bool UseProbabilities = true; - } - /// - /// Constructs a trainer supplying a . + /// Constructs a trainer supplying a . /// /// The private for this estimator. - /// The legacy + /// The legacy internal OneVersusAllTrainerTyped(IHostEnvironment env, Options options) - : base(env, options, LoadNameValue) + : base(env, options) { - _options = options; } /// @@ -991,77 +313,24 @@ internal OneVersusAllTrainerTyped(IHostEnvironment env, ICalibratorTrainer calibrator = null, int maximumCalibrationExampleCount = 1000000000, bool useProbabilities = true) - : base(env, - new Options - { - ImputeMissingLabelsAsNegative = imputeMissingLabelsAsNegative, - MaxCalibrationExamples = maximumCalibrationExampleCount, - }, - LoadNameValue, labelColumnName, binaryEstimator, calibrator) - { - Host.CheckValue(labelColumnName, nameof(labelColumnName), "Label column should not be null."); - _options = (Options)Args; - _options.UseProbabilities = useProbabilities; - } - - private protected override OneVersusAllModelParametersTyped TrainCore(IChannel ch, RoleMappedData data, int count) + : base(env, binaryEstimator, labelColumnName, imputeMissingLabelsAsNegative, calibrator, maximumCalibrationExampleCount, useProbabilities) { - // Train one-vs-all models. - var predictors = new T[count]; - for (int i = 0; i < predictors.Length; i++) - { - ch.Info($"Training learner {i}"); - predictors[i] = (T)TrainOne(ch, Trainer, data, i).Model; - } - return OneVersusAllModelParametersTyped.Create(Host, _options.UseProbabilities, predictors); } - private dynamic TrainOne(IChannel ch, dynamic trainer, RoleMappedData data, int cls) + private protected override ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch, + bool useProbabilities, IDataView view, string trainerLabel, + ISingleFeaturePredictionTransformer transformer) { - /*var view = MapLabels(data, cls); - - string trainerLabel = data.Schema.Label.Value.Name; - - // REVIEW: In principle we could support validation sets and the like via the train context, but - // this is currently unsupported. - var transformer = trainer.Fit(view); - - if (_options.UseProbabilities) - { - var calibratedModel = transformer.Model as TDistPredictor; - - // If probabilities are requested and the Predictor is not calibrated or if it doesn't implement the right interface then throw. - Host.Check(calibratedModel != null, "Predictor is either not calibrated or does not implement the expected interface"); - - // REVIEW: restoring the RoleMappedData, as much as we can. - // not having the weight column on the data passed to the TrainCalibrator should be addressed. - var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName); - - return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName); - } - - return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName);*/ - - var view = MapLabels(data, cls); - - string trainerLabel = data.Schema.Label.Value.Name; - - // REVIEW: In principle we could support validation sets and the like via the train context, but - // this is currently unsupported. - var transformer = trainer.Fit(view); - - if (_options.UseProbabilities) + if (useProbabilities) { var calibratedModel = transformer.Model as TDistPredictor; - var s = transformer.Model.GetType(); - // REVIEW: restoring the RoleMappedData, as much as we can. // not having the weight column on the data passed to the TrainCalibrator should be addressed. var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName); if (calibratedModel == null) - calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor; + calibratedModel = CalibratorUtils.GetCalibratedPredictor(Host, ch, Calibrator, (TSubPredictor)transformer.Model, trainedData, Args.MaxCalibrationExamples) as TDistPredictor; Host.Check(calibratedModel != null, "Calibrated predictor does not implement the expected interface"); return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName); @@ -1069,63 +338,67 @@ private dynamic TrainOne(IChannel ch, dynamic trainer, RoleMappedData data, int return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName); } + } - private IDataView MapLabels(RoleMappedData data, int cls) + public sealed class OneVersusAllTrainerTyped : OneVersusAllTrainerBase where T: class + { + /// + /// Constructs a trainer supplying a . + /// + /// The private for this estimator. + /// The legacy + internal OneVersusAllTrainerTyped(IHostEnvironment env, Options options) + : base(env, options) { - var label = data.Schema.Label.Value; - Host.Assert(!label.IsHidden); - Host.Assert(label.Type.GetKeyCount() > 0 || label.Type == NumberDataViewType.Single || label.Type == NumberDataViewType.Double); - - if (label.Type.GetKeyCount() > 0) - { - // Key values are 1-based. - uint key = (uint)(cls + 1); - return MapLabelsCore(NumberDataViewType.UInt32, (in uint val) => key == val, data); - } - - throw Host.ExceptNotSupp($"Label column type is not supported by OneVersusAllTrainerTyped: {label.Type.RawType}"); } - /// Trains a model. - /// The input data. - /// A model./> - public override MulticlassPredictionTransformer> Fit(IDataView input) + /// + /// Initializes a new instance of . + /// + /// The instance. + /// An instance of a binary used as the base trainer. + /// The calibrator. If a calibrator is not provided, it will default to + /// The name of the label colum. + /// If true will treat missing labels as negative labels. + /// Number of instances to train the calibrator. + /// Use probabilities (vs. raw outputs) to identify top-score category. + internal OneVersusAllTrainerTyped(IHostEnvironment env, + TScalarTrainer binaryEstimator, + string labelColumnName = DefaultColumnNames.Label, + bool imputeMissingLabelsAsNegative = false, + ICalibratorTrainer calibrator = null, + int maximumCalibrationExampleCount = 1000000000, + bool useProbabilities = true) + : base(env, binaryEstimator, labelColumnName, imputeMissingLabelsAsNegative, calibrator, maximumCalibrationExampleCount, useProbabilities) { - var roles = new KeyValuePair[1]; - roles[0] = new KeyValuePair(new CR(DefaultColumnNames.Label), LabelColumn.Name); - var td = new RoleMappedData(input, roles); - - td.CheckMulticlassLabel(out var numClasses); - - var predictors = new T[numClasses]; - string featureColumn = null; + } - using (var ch = Host.Start("Fitting")) + private protected override ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch, + bool useProbabilities, IDataView view, string trainerLabel, + ISingleFeaturePredictionTransformer transformer) + { + if (useProbabilities) { - for (int i = 0; i < predictors.Length; i++) - { - ch.Info($"Training learner {i}"); + var calibratedModel = transformer.Model as TDistPredictor; - if (i == 0) - { - var transformer = TrainOne(ch, Trainer, td, i); - featureColumn = transformer.FeatureColumnName; - } - var model = TrainOne(ch, Trainer, td, i).Model; - var m = model as T; - predictors[i] = model as T; + // If probabilities are requested and the Predictor is not calibrated or if it doesn't implement the right interface then throw. + Host.Check(calibratedModel != null, "Predictor is either not calibrated or does not implement the expected interface"); - } + // REVIEW: restoring the RoleMappedData, as much as we can. + // not having the weight column on the data passed to the TrainCalibrator should be addressed. + var trainedData = new RoleMappedData(view, label: trainerLabel, feature: transformer.FeatureColumnName); + + return new BinaryPredictionTransformer(Host, calibratedModel, trainedData.Data.Schema, transformer.FeatureColumnName); } - return new MulticlassPredictionTransformer>(Host, OneVersusAllModelParametersTyped.Create(Host, _options.UseProbabilities, predictors), input.Schema, featureColumn, LabelColumn.Name); + return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName); } } /// - /// Model parameters for . + /// Model parameters for . /// - public sealed class OneVersusAllModelParametersTyped : + public class OneVersusAllModelParameters : ModelParametersBase>, IValueMapper, ICanSaveInSourceCode, @@ -1136,7 +409,7 @@ public sealed class OneVersusAllModelParametersTyped : internal const string LoaderSignature = "OVAExec"; internal const string RegistrationName = "OVAPredictor"; - private static VersionInfo GetVersionInfo() + private protected static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "TLC OVA ", @@ -1144,7 +417,7 @@ private static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(OneVersusAllModelParametersTyped).Assembly.FullName); + loaderAssemblyName: typeof(OneVersusAllModelParameters).Assembly.FullName); } private const string SubPredictorFmt = "SubPredictor_{0:000}"; @@ -1179,7 +452,7 @@ internal enum OutputFormula { Raw = 0, ProbabilityNormalization = 1, Softmax = 2 bool ICanSavePfa.CanSavePfa => _impl.CanSavePfa; [BestFriend] - internal static OneVersusAllModelParametersTyped Create(IHost host, OutputFormula outputFormula, T[] predictors) + internal static OneVersusAllModelParameters Create(IHost host, OutputFormula outputFormula, T[] predictors) { ImplBase impl; @@ -1188,7 +461,7 @@ internal static OneVersusAllModelParametersTyped Create(IHost host, OutputFor if (outputFormula == OutputFormula.Softmax) { impl = new ImplSoftmax(predictors); - return new OneVersusAllModelParametersTyped(host, impl); + return new OneVersusAllModelParameters(host, impl); } // Caller of this function asks for probability output. We check if input predictor can produce probability. @@ -1199,7 +472,7 @@ internal static OneVersusAllModelParametersTyped Create(IHost host, OutputFor ivmd.OutputType != NumberDataViewType.Single || ivmd.DistType != NumberDataViewType.Single)) { - ch.Warning($"{nameof(OneVersusAllTrainerTyped.Options.UseProbabilities)} specified with {nameof(OneVersusAllTrainerTyped.Options.PredictorType)} that can't produce probabilities."); + ch.Warning($"{nameof(OneVersusAllTrainerBase.Options.UseProbabilities)} specified with {nameof(OneVersusAllTrainerBase.Options.PredictorType)} that can't produce probabilities."); ivmd = null; } @@ -1212,11 +485,11 @@ internal static OneVersusAllModelParametersTyped Create(IHost host, OutputFor impl = new ImplRaw(predictors); } - return new OneVersusAllModelParametersTyped(host, impl); + return new OneVersusAllModelParameters(host, impl); } [BestFriend] - internal static OneVersusAllModelParametersTyped Create(IHost host, bool useProbability, T[] predictors) + internal static OneVersusAllModelParameters Create(IHost host, bool useProbability, T[] predictors) { var outputFormula = useProbability ? OutputFormula.ProbabilityNormalization : OutputFormula.Raw; @@ -1224,17 +497,17 @@ internal static OneVersusAllModelParametersTyped Create(IHost host, bool useP } /// - /// Create a from an array of predictors. + /// Create a from an array of predictors. /// [BestFriend] - internal static OneVersusAllModelParametersTyped Create(IHost host, T[] predictors) + internal static OneVersusAllModelParameters Create(IHost host, T[] predictors) { Contracts.CheckValue(host, nameof(host)); host.CheckNonEmpty(predictors, nameof(predictors)); return Create(host, OutputFormula.ProbabilityNormalization, predictors); } - private OneVersusAllModelParametersTyped(IHostEnvironment env, ImplBase impl) + private protected OneVersusAllModelParameters(IHostEnvironment env, ImplBase impl) : base(env, RegistrationName) { Host.AssertValue(impl, nameof(impl)); @@ -1244,7 +517,7 @@ private OneVersusAllModelParametersTyped(IHostEnvironment env, ImplBase impl) DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length); } - private OneVersusAllModelParametersTyped(IHostEnvironment env, ModelLoadContext ctx) + private protected OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx) { // *** Binary format *** @@ -1270,12 +543,12 @@ private OneVersusAllModelParametersTyped(IHostEnvironment env, ModelLoadContext DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length); } - private static OneVersusAllModelParametersTyped Create(IHostEnvironment env, ModelLoadContext ctx) + private static OneVersusAllModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new OneVersusAllModelParametersTyped(env, ctx); + return new OneVersusAllModelParameters(env, ctx); } private static void LoadPredictors(IHostEnvironment env, TPredictor[] predictors, ModelLoadContext ctx) @@ -1367,7 +640,7 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) } } - private abstract class ImplBase : ISingleCanSavePfa + private protected abstract class ImplBase : ISingleCanSavePfa { public abstract DataViewType InputType { get; } public abstract IValueMapper[] Predictors { get; } @@ -1399,7 +672,7 @@ protected bool IsValid(IValueMapper mapper, ref VectorDataViewType inputType) } } - private sealed class ImplRaw : ImplBase + private protected sealed class ImplRaw : ImplBase { public override DataViewType InputType { get; } public override IValueMapper[] Predictors { get; } @@ -1463,7 +736,7 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) } } - private sealed class ImplDist : ImplBase + private protected sealed class ImplDist : ImplBase { private readonly IValueMapperDist[] _mappers; public override DataViewType InputType { get; } @@ -1575,7 +848,7 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) } } - private sealed class ImplSoftmax : ImplBase + private protected sealed class ImplSoftmax : ImplBase { public override DataViewType InputType { get; } public override IValueMapper[] Predictors { get; } @@ -1643,4 +916,77 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) } } } -} \ No newline at end of file + + /// + /// Model parameters for . + /// + public sealed class OneVersusAllModelParameters : + OneVersusAllModelParameters + { + private OneVersusAllModelParameters(IHostEnvironment env, ImplBase impl) + : base(env, impl) + { + } + + private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) + : base(env, ctx) + { + } + + /// + /// Create a from an array of predictors. + /// + [BestFriend] + internal static new OneVersusAllModelParameters Create(IHost host, TScalarPredictor[] predictors) + { + Contracts.CheckValue(host, nameof(host)); + host.CheckNonEmpty(predictors, nameof(predictors)); + return Create(host, OutputFormula.ProbabilityNormalization, predictors); + } + + [BestFriend] + internal static new OneVersusAllModelParameters Create(IHost host, OutputFormula outputFormula, TScalarPredictor[] predictors) + { + ImplBase impl; + + using (var ch = host.Start("Creating OVA predictor")) + { + if (outputFormula == OutputFormula.Softmax) + { + impl = new ImplSoftmax(predictors); + return new OneVersusAllModelParameters(host, impl); + } + + // Caller of this function asks for probability output. We check if input predictor can produce probability. + // If that predictor can't produce probability, ivmd will be null. + IValueMapperDist ivmd = null; + if (outputFormula == OutputFormula.ProbabilityNormalization && + ((ivmd = predictors[0] as IValueMapperDist) == null || + ivmd.OutputType != NumberDataViewType.Single || + ivmd.DistType != NumberDataViewType.Single)) + { + ch.Warning($"{nameof(OneVersusAllTrainer.Options.UseProbabilities)} specified with {nameof(OneVersusAllTrainer.Options.PredictorType)} that can't produce probabilities."); + ivmd = null; + } + + // If ivmd is null, either the user didn't ask for probability or the provided predictors can't produce probability. + if (ivmd != null) + { + impl = new ImplDist(predictors); + } + else + impl = new ImplRaw(predictors); + } + + return new OneVersusAllModelParameters(host, impl); + } + + private static OneVersusAllModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + return new OneVersusAllModelParameters(env, ctx); + } + } +} diff --git a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs index 8259cacd81..67c55079a8 100644 --- a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs @@ -795,12 +795,11 @@ public static OneVersusAllTrainerTyped OneVersusAllTyped(this Mu bool useProbabilities = true) where TModel : class { - return OneVersusAllTyped( - catalog: catalog, - binaryEstimator: binaryEstimator, - labelColumnName: labelColumnName, - imputeMissingLabelsAsNegative: imputeMissingLabelsAsNegative, - useProbabilities: useProbabilities); + Contracts.CheckValue(catalog, nameof(catalog)); + var env = CatalogUtils.GetEnvironment(catalog); + if (!(binaryEstimator is ITrainerEstimator>, IPredictorProducing> est)) + throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); + return new OneVersusAllTrainerTyped(env, est, labelColumnName, imputeMissingLabelsAsNegative, null, 10000, useProbabilities); } /// @@ -822,28 +821,28 @@ public static OneVersusAllTrainerTyped OneVersusAllTyped(this Mu /// Number of instances to train the calibrator. /// Use probabilities (vs. raw outputs) to identify top-score category. /// The type of the model. This type parameter will usually be inferred automatically from . - /// The type of the model. This type parameter will usually be inferred automatically from . + /// The type of the model. This type parameter will usually be inferred automatically from . /// /// /// /// - public static OneVersusAllTrainerTyped OneVersusAllTyped(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + public static OneVersusAllTrainerTyped OneVersusAllUnCalibratedToCalibratedTyped(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, ITrainerEstimator, TModelIn> binaryEstimator, string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, IEstimator> calibrator = null, int maximumCalibrationExampleCount = 1000000000, bool useProbabilities = true) - where TModelIn : class - where TModelOut : class + where TModelIn : class, IPredictorProducing + where TCalibrator : class, ICalibrator { Contracts.CheckValue(catalog, nameof(catalog)); var env = CatalogUtils.GetEnvironment(catalog); if (!(binaryEstimator is ITrainerEstimator>, IPredictorProducing> est)) throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); - return new OneVersusAllTrainerTyped(env, est, labelColumnName, imputeMissingLabelsAsNegative, GetCalibratorTrainerOrThrow(env, calibrator), maximumCalibrationExampleCount, useProbabilities); + return new OneVersusAllTrainerTyped(env, est, labelColumnName, imputeMissingLabelsAsNegative, GetCalibratorTrainerOrThrow(env, calibrator), maximumCalibrationExampleCount, useProbabilities); } /// diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs index 30e9f8d77d..501bfcf94f 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs @@ -121,7 +121,7 @@ public void MetacomponentsFeaturesRenamed() var pipeline = new ColumnConcatenatingEstimator(Env, "Vars", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(new ValueToKeyMappingEstimator(Env, "Label"), TransformerScope.TrainTest) - .Append(ML.MulticlassClassification.Trainers.OneVersusAllTyped>(sdcaTrainer)) + .Append(ML.MulticlassClassification.Trainers.OneVersusAllUnCalibratedToCalibratedTyped(sdcaTrainer)) //.Append(ML.MulticlassClassification.Trainers.OneVersusAllTyped(sdca)) .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); From ece3bd0e1b75b7191f2cfb19a906c0624f2627fc Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Mon, 15 Jul 2019 15:28:36 -0700 Subject: [PATCH 08/19] constructor changes --- .../MulticlassClassification/OneVersusAllTrainer.cs | 6 +----- .../StandardTrainersCatalog.cs | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index 7a48d63e4a..5a488bd42f 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -357,19 +357,15 @@ internal OneVersusAllTrainerTyped(IHostEnvironment env, Options options) /// /// The instance. /// An instance of a binary used as the base trainer. - /// The calibrator. If a calibrator is not provided, it will default to /// The name of the label colum. /// If true will treat missing labels as negative labels. - /// Number of instances to train the calibrator. /// Use probabilities (vs. raw outputs) to identify top-score category. internal OneVersusAllTrainerTyped(IHostEnvironment env, TScalarTrainer binaryEstimator, string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, - ICalibratorTrainer calibrator = null, - int maximumCalibrationExampleCount = 1000000000, bool useProbabilities = true) - : base(env, binaryEstimator, labelColumnName, imputeMissingLabelsAsNegative, calibrator, maximumCalibrationExampleCount, useProbabilities) + : base(env: env, binaryEstimator: binaryEstimator, labelColumnName: labelColumnName, imputeMissingLabelsAsNegative: imputeMissingLabelsAsNegative, useProbabilities: useProbabilities) { } diff --git a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs index 67c55079a8..3742861dc2 100644 --- a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs @@ -799,7 +799,7 @@ public static OneVersusAllTrainerTyped OneVersusAllTyped(this Mu var env = CatalogUtils.GetEnvironment(catalog); if (!(binaryEstimator is ITrainerEstimator>, IPredictorProducing> est)) throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); - return new OneVersusAllTrainerTyped(env, est, labelColumnName, imputeMissingLabelsAsNegative, null, 10000, useProbabilities); + return new OneVersusAllTrainerTyped(env, est, labelColumnName, imputeMissingLabelsAsNegative, useProbabilities); } /// From b56fe0e1553372fad7c7f6ea5b46239359f845e4 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Tue, 16 Jul 2019 11:01:25 -0700 Subject: [PATCH 09/19] more tests added --- .../StandardTrainersCatalog.cs | 1 - test/Microsoft.ML.Tests/Scenarios/OvaTest.cs | 107 ++++++++++++++++++ .../TrainerEstimators/MetalinearEstimators.cs | 12 -- 3 files changed, 107 insertions(+), 13 deletions(-) diff --git a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs index 3742861dc2..4d2781f9d4 100644 --- a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs @@ -758,7 +758,6 @@ public static OneVersusAllTrainer OneVersusAll(this MulticlassClassifica where TModel : class { Contracts.CheckValue(catalog, nameof(catalog)); - var s = typeof(TModel); var env = CatalogUtils.GetEnvironment(catalog); if (!(binaryEstimator is ITrainerEstimator>, IPredictorProducing> est)) throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); diff --git a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs index dcfacdec98..9322e7b54d 100644 --- a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs +++ b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Reflection.Metadata; +using Microsoft.ML.Calibrators; using Microsoft.ML.Data; using Microsoft.ML.Trainers; using Microsoft.ML.Trainers.FastTree; @@ -35,14 +37,24 @@ public void OvaLogisticRegression() // Pipeline var logReg = mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(); var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll(logReg, useProbabilities: false); + var pipelineTyped = mlContext.MulticlassClassification.Trainers.OneVersusAllTyped(logReg, useProbabilities: false); var model = pipeline.Fit(data); var predictions = model.Transform(data); + var modelTyped = pipelineTyped.Fit(data); + var predictionsTyped = modelTyped.Transform(data); + // Metrics var metrics = mlContext.MulticlassClassification.Evaluate(predictions); Assert.True(metrics.MicroAccuracy > 0.94); + + var metricsTyped = mlContext.MulticlassClassification.Evaluate(predictionsTyped); + Assert.True(metricsTyped.MicroAccuracy > 0.94); + + Assert.Equal(metrics.MicroAccuracy, metricsTyped.MicroAccuracy); } + //.Append(ML.MulticlassClassification.Trainers.OneVersusAllUnCalibratedToCalibratedTyped(sdcaTrainer)) [Fact] public void OvaAveragedPerceptron() @@ -69,15 +81,73 @@ public void OvaAveragedPerceptron() // Pipeline var ap = mlContext.BinaryClassification.Trainers.AveragedPerceptron( new AveragedPerceptronTrainer.Options { Shuffle = true }); + var apTyped = mlContext.BinaryClassification.Trainers.AveragedPerceptron( + new AveragedPerceptronTrainer.Options { Shuffle = true }); var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll(ap, useProbabilities: false); + var pipelineTyped = mlContext.MulticlassClassification.Trainers.OneVersusAllTyped(apTyped, useProbabilities: false); var model = pipeline.Fit(data); var predictions = model.Transform(data); + var modelTyped = pipelineTyped.Fit(data); + var predictionsTyped = modelTyped.Transform(data); + // Metrics var metrics = mlContext.MulticlassClassification.Evaluate(predictions); Assert.True(metrics.MicroAccuracy > 0.66); + + var metricsTyped = mlContext.MulticlassClassification.Evaluate(predictionsTyped); + Assert.True(metricsTyped.MicroAccuracy > 0.66); + + Assert.Equal(metrics.MicroAccuracy, metricsTyped.MicroAccuracy); + } + + [Fact] + public void OvaCalibratedAveragedPerceptron() + { + string dataPath = GetDataPath("iris.txt"); + + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // as a catalog of available operations and as the source of randomness. + var mlContext = new MLContext(seed: 1); + var reader = new TextLoader(mlContext, new TextLoader.Options() + { + Columns = new[] + { + new TextLoader.Column("Label", DataKind.Single, 0), + new TextLoader.Column("Features", DataKind.Single, new [] { new TextLoader.Range(1, 4) }), + } + }); + + // Data + var textData = reader.Load(GetDataPath(dataPath)); + var data = mlContext.Data.Cache(mlContext.Transforms.Conversion.MapValueToKey("Label") + .Fit(textData).Transform(textData)); + + // Pipeline + var ap = mlContext.BinaryClassification.Trainers.AveragedPerceptron( + new AveragedPerceptronTrainer.Options { Shuffle = true }); + var apTyped = mlContext.BinaryClassification.Trainers.AveragedPerceptron( + new AveragedPerceptronTrainer.Options { Shuffle = true }); + + var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll(ap); + var pipelineTyped = mlContext.MulticlassClassification.Trainers.OneVersusAllUnCalibratedToCalibratedTyped(apTyped); + + var model = pipeline.Fit(data); + var predictions = model.Transform(data); + + var modelTyped = pipelineTyped.Fit(data); + var predictionsTyped = modelTyped.Transform(data); + + // Metrics + var metrics = mlContext.MulticlassClassification.Evaluate(predictions); + Assert.True(metrics.MicroAccuracy > 0.95); + + var metricsTyped = mlContext.MulticlassClassification.Evaluate(predictionsTyped); + Assert.True(metricsTyped.MicroAccuracy > 0.95); + + Assert.Equal(metrics.MicroAccuracy, metricsTyped.MicroAccuracy); } [Fact] @@ -107,12 +177,24 @@ public void OvaFastTree() mlContext.BinaryClassification.Trainers.FastTree(new FastTreeBinaryTrainer.Options { NumberOfThreads = 1 }), useProbabilities: false); + var pipelineTyped = mlContext.MulticlassClassification.Trainers.OneVersusAllTyped( + mlContext.BinaryClassification.Trainers.FastTree(new FastTreeBinaryTrainer.Options { NumberOfThreads = 1 }), + useProbabilities: false); + var model = pipeline.Fit(data); var predictions = model.Transform(data); + var modelTyped = pipelineTyped.Fit(data); + var predictionsTyped = modelTyped.Transform(data); + // Metrics var metrics = mlContext.MulticlassClassification.Evaluate(predictions); Assert.True(metrics.MicroAccuracy > 0.99); + + var metricsTyped = mlContext.MulticlassClassification.Evaluate(predictionsTyped); + Assert.True(metricsTyped.MicroAccuracy > 0.99); + + Assert.Equal(metrics.MicroAccuracy, metricsTyped.MicroAccuracy); } [Fact] @@ -123,6 +205,8 @@ public void OvaLinearSvm() // Create a new context for ML.NET operations. It can be used for exception tracking and logging, // as a catalog of available operations and as the source of randomness. var mlContext = new MLContext(seed: 1); + var mlContextTyped = new MLContext(seed: 1); + var reader = new TextLoader(mlContext, new TextLoader.Options() { Columns = new[] @@ -131,22 +215,45 @@ public void OvaLinearSvm() new TextLoader.Column("Features", DataKind.Single, new [] { new TextLoader.Range(1, 4) }), } }); + var readerTyped = new TextLoader(mlContextTyped, new TextLoader.Options() + { + Columns = new[] + { + new TextLoader.Column("Label", DataKind.Single, 0), + new TextLoader.Column("Features", DataKind.Single, new [] { new TextLoader.Range(1, 4) }), + } + }); // Data var textData = reader.Load(GetDataPath(dataPath)); + var textDataTyped = reader.Load(GetDataPath(dataPath)); var data = mlContext.Data.Cache(mlContext.Transforms.Conversion.MapValueToKey("Label") .Fit(textData).Transform(textData)); + var dataTyped = mlContextTyped.Data.Cache(mlContextTyped.Transforms.Conversion.MapValueToKey("Label") + .Fit(textDataTyped).Transform(textDataTyped)); // Pipeline var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll( mlContext.BinaryClassification.Trainers.LinearSvm(new LinearSvmTrainer.Options { NumberOfIterations = 100 }), useProbabilities: false); + var pipelineTyped = mlContextTyped.MulticlassClassification.Trainers.OneVersusAllTyped( + mlContextTyped.BinaryClassification.Trainers.LinearSvm(new LinearSvmTrainer.Options { NumberOfIterations = 100 }), + useProbabilities: false); + var model = pipeline.Fit(data); var predictions = model.Transform(data); + var modelTyped = pipelineTyped.Fit(dataTyped); + var predictionsTyped = modelTyped.Transform(dataTyped); + // Metrics var metrics = mlContext.MulticlassClassification.Evaluate(predictions); Assert.True(metrics.MicroAccuracy > 0.83); + + var metricsTyped = mlContextTyped.MulticlassClassification.Evaluate(predictionsTyped); + Assert.True(metricsTyped.MicroAccuracy > 0.83); + + Assert.Equal(metrics.MicroAccuracy, metricsTyped.MicroAccuracy); } } } diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs index 501bfcf94f..b3ba9877af 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs @@ -110,20 +110,8 @@ public void MetacomponentsFeaturesRenamed() NumberOfThreads = 1, }); - var sdca = ML.BinaryClassification.Trainers.SgdCalibrated( - new SgdCalibratedTrainer.Options - { - LabelColumnName = "Label", - FeatureColumnName = "Vars", - Shuffle = true, - NumberOfThreads = 1, - }); - var pipeline = new ColumnConcatenatingEstimator(Env, "Vars", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(new ValueToKeyMappingEstimator(Env, "Label"), TransformerScope.TrainTest) - .Append(ML.MulticlassClassification.Trainers.OneVersusAllUnCalibratedToCalibratedTyped(sdcaTrainer)) - - //.Append(ML.MulticlassClassification.Trainers.OneVersusAllTyped(sdca)) .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); var model = pipeline.Fit(data); From cd04fd90f1f7c474d62a75a69dfd11ceb5f4e88d Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Tue, 16 Jul 2019 15:29:48 -0700 Subject: [PATCH 10/19] fixes for API compat based on Artidoro's comments --- .../OneVersusAllTrainer.cs | 149 ++++++++++++------ 1 file changed, 103 insertions(+), 46 deletions(-) diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index 5a488bd42f..3a3ae54f7a 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -38,7 +38,7 @@ namespace Microsoft.ML.Trainers using TScalarPredictor = IPredictorProducing; using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; - public abstract class OneVersusAllTrainerBase : MetaMulticlassTrainer>, OneVersusAllModelParameters> where T : class + public abstract class OneVersusAllTrainerBase : MetaMulticlassTrainer, T> where T : class { internal const string LoadNameValue = "OVA"; internal const string UserNameValue = "One-vs-All"; @@ -46,7 +46,7 @@ public abstract class OneVersusAllTrainerBase : MetaMulticlassTrainer /// Options passed to @@ -69,7 +69,7 @@ internal sealed class Options : OptionsBase internal OneVersusAllTrainerBase(IHostEnvironment env, Options options) : base(env, options, LoadNameValue) { - _options = options; + TrainerOptions = options; } /// @@ -98,23 +98,15 @@ internal OneVersusAllTrainerBase(IHostEnvironment env, LoadNameValue, labelColumnName, binaryEstimator, calibrator) { Host.CheckValue(labelColumnName, nameof(labelColumnName), "Label column should not be null."); - _options = (Options)Args; - _options.UseProbabilities = useProbabilities; + TrainerOptions = (Options)Args; + TrainerOptions.UseProbabilities = useProbabilities; } - private protected override OneVersusAllModelParameters TrainCore(IChannel ch, RoleMappedData data, int count) - { - // Train one-vs-all models. - var predictors = new T[count]; - for (int i = 0; i < predictors.Length; i++) - { - ch.Info($"Training learner {i}"); - predictors[i] = (T)TrainOne(ch, Trainer, data, i).Model; - } - return OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors); - } + private protected abstract ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch, + bool useProbabilities, IDataView view, string trainerLabel, + ISingleFeaturePredictionTransformer transformer); - private ISingleFeaturePredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls) + private protected ISingleFeaturePredictionTransformer TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappedData data, int cls) { var view = MapLabels(data, cls); @@ -124,13 +116,9 @@ private ISingleFeaturePredictionTransformer TrainOne(IChannel // this is currently unsupported. var transformer = trainer.Fit(view); - return TrainOneHelper(ch, _options.UseProbabilities, view, trainerLabel, transformer); + return TrainOneHelper(ch, TrainerOptions.UseProbabilities, view, trainerLabel, transformer); } - private protected abstract ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch, - bool useProbabilities, IDataView view, string trainerLabel, - ISingleFeaturePredictionTransformer transformer); - private IDataView MapLabels(RoleMappedData data, int cls) { var label = data.Schema.Label.Value; @@ -147,10 +135,12 @@ private IDataView MapLabels(RoleMappedData data, int cls) throw Host.ExceptNotSupp($"Label column type is not supported by OneVersusAllTrainer: {label.Type.RawType}"); } + private protected abstract MulticlassPredictionTransformer FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName); + /// Trains a model. /// The input data. /// A model./> - public override MulticlassPredictionTransformer> Fit(IDataView input) + public override MulticlassPredictionTransformer Fit(IDataView input) { var roles = new KeyValuePair[1]; roles[0] = new KeyValuePair(new CR(DefaultColumnNames.Label), LabelColumn.Name); @@ -158,7 +148,7 @@ public override MulticlassPredictionTransformer> td.CheckMulticlassLabel(out var numClasses); - var predictors = new T[numClasses]; + var predictors = new TScalarPredictor[numClasses]; string featureColumn = null; using (var ch = Host.Start("Fitting")) @@ -172,12 +162,62 @@ public override MulticlassPredictionTransformer> var transformer = TrainOne(ch, Trainer, td, i); featureColumn = transformer.FeatureColumnName; } - predictors[i] = (T)TrainOne(ch, Trainer, td, i).Model; + predictors[i] = TrainOne(ch, Trainer, td, i).Model; } } + return FitHelper(Host, TrainerOptions.UseProbabilities, predictors, input.Schema, featureColumn, LabelColumn.Name); + } + } - return new MulticlassPredictionTransformer>(Host, OneVersusAllModelParameters.Create(Host, _options.UseProbabilities, predictors), input.Schema, featureColumn, LabelColumn.Name); + public abstract class OneVersusAllTypedTrainerBase : OneVersusAllTrainerBase> where T : class + { + /// + /// Constructs a trainer supplying a . + /// + /// The private for this estimator. + /// The legacy + internal OneVersusAllTypedTrainerBase(IHostEnvironment env, Options options) + : base(env, options) + { + } + + /// + /// Initializes a new instance of . + /// + /// The instance. + /// An instance of a binary used as the base trainer. + /// /// The calibrator. If a calibrator is not provided, it will default to + /// The name of the label colum. + /// If true will treat missing labels as negative labels. + /// /// Number of instances to train the calibrator. + /// Use probabilities (vs. raw outputs) to identify top-score category. + internal OneVersusAllTypedTrainerBase(IHostEnvironment env, + TScalarTrainer binaryEstimator, + string labelColumnName = DefaultColumnNames.Label, + bool imputeMissingLabelsAsNegative = false, + ICalibratorTrainer calibrator = null, + int maximumCalibrationExampleCount = 1000000000, + bool useProbabilities = true) + : base(env, binaryEstimator, labelColumnName, imputeMissingLabelsAsNegative, calibrator, maximumCalibrationExampleCount, useProbabilities) + { + } + + private protected override OneVersusAllModelParametersBase TrainCore(IChannel ch, RoleMappedData data, int count) + { + // Train one-vs-all models. + var predictors = new T[count]; + for (int i = 0; i < predictors.Length; i++) + { + ch.Info($"Training learner {i}"); + predictors[i] = (T)TrainOne(ch, Trainer, data, i).Model; + } + return OneVersusAllModelParametersBase.Create(Host, TrainerOptions.UseProbabilities, predictors); + } + + private protected override MulticlassPredictionTransformer> FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName) + { + return new MulticlassPredictionTransformer>(Host, OneVersusAllModelParametersBase.Create(Host, useProbabilities, predictors.Cast().ToArray()), schema, featureColumn, LabelColumn.Name); } } @@ -226,7 +266,7 @@ public override MulticlassPredictionTransformer> /// /// /// - public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase + public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase { /// /// Constructs a trainer supplying a . @@ -259,6 +299,23 @@ internal OneVersusAllTrainer(IHostEnvironment env, { } + private protected override OneVersusAllModelParameters TrainCore(IChannel ch, RoleMappedData data, int count) + { + // Train one-vs-all models. + var predictors = new TScalarPredictor[count]; + for (int i = 0; i < predictors.Length; i++) + { + ch.Info($"Training learner {i}"); + predictors[i] = TrainOne(ch, Trainer, data, i).Model; + } + return OneVersusAllModelParameters.Create(Host, TrainerOptions.UseProbabilities, predictors) as OneVersusAllModelParameters; + } + + private protected override MulticlassPredictionTransformer FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName) + { + return new MulticlassPredictionTransformer(Host, OneVersusAllModelParameters.Create(Host, useProbabilities, predictors), schema, featureColumn, LabelColumn.Name); + } + private protected override ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch, bool useProbabilities, IDataView view, string trainerLabel, ISingleFeaturePredictionTransformer transformer) @@ -282,9 +339,9 @@ private protected override ISingleFeaturePredictionTransformer } } - public sealed class OneVersusAllTrainerTyped : OneVersusAllTrainerBase> - where TSubPredictor: class, IPredictorProducing - where TCalibrator: class, ICalibrator + public sealed class OneVersusAllTrainerTyped : OneVersusAllTypedTrainerBase>> + where TSubPredictor : class, IPredictorProducing + where TCalibrator : class, ICalibrator { /// /// Constructs a trainer supplying a . @@ -340,7 +397,7 @@ private protected override ISingleFeaturePredictionTransformer } } - public sealed class OneVersusAllTrainerTyped : OneVersusAllTrainerBase where T: class + public sealed class OneVersusAllTrainerTyped : OneVersusAllTypedTrainerBase> where T : class { /// /// Constructs a trainer supplying a . @@ -392,9 +449,9 @@ private protected override ISingleFeaturePredictionTransformer } /// - /// Model parameters for . + /// Model parameters for . /// - public class OneVersusAllModelParameters : + public class OneVersusAllModelParametersBase : ModelParametersBase>, IValueMapper, ICanSaveInSourceCode, @@ -413,7 +470,7 @@ private protected static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(OneVersusAllModelParameters).Assembly.FullName); + loaderAssemblyName: typeof(OneVersusAllModelParametersBase).Assembly.FullName); } private const string SubPredictorFmt = "SubPredictor_{0:000}"; @@ -448,7 +505,7 @@ internal enum OutputFormula { Raw = 0, ProbabilityNormalization = 1, Softmax = 2 bool ICanSavePfa.CanSavePfa => _impl.CanSavePfa; [BestFriend] - internal static OneVersusAllModelParameters Create(IHost host, OutputFormula outputFormula, T[] predictors) + internal static OneVersusAllModelParametersBase Create(IHost host, OutputFormula outputFormula, T[] predictors) { ImplBase impl; @@ -457,7 +514,7 @@ internal static OneVersusAllModelParameters Create(IHost host, OutputFormula if (outputFormula == OutputFormula.Softmax) { impl = new ImplSoftmax(predictors); - return new OneVersusAllModelParameters(host, impl); + return new OneVersusAllModelParametersBase(host, impl); } // Caller of this function asks for probability output. We check if input predictor can produce probability. @@ -481,11 +538,11 @@ internal static OneVersusAllModelParameters Create(IHost host, OutputFormula impl = new ImplRaw(predictors); } - return new OneVersusAllModelParameters(host, impl); + return new OneVersusAllModelParametersBase(host, impl); } [BestFriend] - internal static OneVersusAllModelParameters Create(IHost host, bool useProbability, T[] predictors) + internal static OneVersusAllModelParametersBase Create(IHost host, bool useProbability, T[] predictors) { var outputFormula = useProbability ? OutputFormula.ProbabilityNormalization : OutputFormula.Raw; @@ -493,17 +550,17 @@ internal static OneVersusAllModelParameters Create(IHost host, bool useProbab } /// - /// Create a from an array of predictors. + /// Create a from an array of predictors. /// [BestFriend] - internal static OneVersusAllModelParameters Create(IHost host, T[] predictors) + internal static OneVersusAllModelParametersBase Create(IHost host, T[] predictors) { Contracts.CheckValue(host, nameof(host)); host.CheckNonEmpty(predictors, nameof(predictors)); return Create(host, OutputFormula.ProbabilityNormalization, predictors); } - private protected OneVersusAllModelParameters(IHostEnvironment env, ImplBase impl) + private protected OneVersusAllModelParametersBase(IHostEnvironment env, ImplBase impl) : base(env, RegistrationName) { Host.AssertValue(impl, nameof(impl)); @@ -513,7 +570,7 @@ private protected OneVersusAllModelParameters(IHostEnvironment env, ImplBase imp DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length); } - private protected OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) + private protected OneVersusAllModelParametersBase(IHostEnvironment env, ModelLoadContext ctx) : base(env, RegistrationName, ctx) { // *** Binary format *** @@ -539,12 +596,12 @@ private protected OneVersusAllModelParameters(IHostEnvironment env, ModelLoadCon DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length); } - private static OneVersusAllModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) + private static OneVersusAllModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); - return new OneVersusAllModelParameters(env, ctx); + return new OneVersusAllModelParametersBase(env, ctx); } private static void LoadPredictors(IHostEnvironment env, TPredictor[] predictors, ModelLoadContext ctx) @@ -917,7 +974,7 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) /// Model parameters for . /// public sealed class OneVersusAllModelParameters : - OneVersusAllModelParameters + OneVersusAllModelParametersBase { private OneVersusAllModelParameters(IHostEnvironment env, ImplBase impl) : base(env, impl) @@ -930,7 +987,7 @@ private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) } /// - /// Create a from an array of predictors. + /// Create a from an array of predictors. /// [BestFriend] internal static new OneVersusAllModelParameters Create(IHost host, TScalarPredictor[] predictors) From 2f0fe035d30a2e62f9b73513e3abcb49856d3930 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Wed, 17 Jul 2019 11:27:23 -0700 Subject: [PATCH 11/19] final changes for testing --- .../ModelOperations.cs | 2 +- .../LightGbmMulticlassTrainer.cs | 4 +- .../OneVersusAllTrainer.cs | 297 ++++++++---------- .../StandardTrainersCatalog.cs | 2 +- test/Microsoft.ML.Tests/Scenarios/OvaTest.cs | 13 +- .../TrainerEstimators/MetalinearEstimators.cs | 18 +- 6 files changed, 139 insertions(+), 197 deletions(-) diff --git a/src/Microsoft.ML.EntryPoints/ModelOperations.cs b/src/Microsoft.ML.EntryPoints/ModelOperations.cs index 7b1b2afc8a..c44f86974f 100644 --- a/src/Microsoft.ML.EntryPoints/ModelOperations.cs +++ b/src/Microsoft.ML.EntryPoints/ModelOperations.cs @@ -155,7 +155,7 @@ public static PredictorModelOutput CombineOvaModels(IHostEnvironment env, Combin return new PredictorModelOutput { PredictorModel = new PredictorModelImpl(env, data, input.TrainingData, - OneVersusAllModelParameters.Create(host, input.UseProbabilities, + OneVersusAllModelParametersBuilder.Create(host, input.UseProbabilities, input.ModelArray.Select(p => p.Predictor as IPredictorProducing).ToArray())) }; } diff --git a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs index 353856e7bd..589be23714 100644 --- a/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs +++ b/src/Microsoft.ML.LightGbm/LightGbmMulticlassTrainer.cs @@ -198,9 +198,9 @@ private protected override OneVersusAllModelParameters CreatePredictor() } string obj = (string)GetGbmParameters()["objective"]; if (obj == "multiclass") - return OneVersusAllModelParameters.Create(Host, OneVersusAllModelParameters.OutputFormula.Softmax, predictors); + return OneVersusAllModelParametersBuilder.Create(Host, OneVersusAllModelParameters.OutputFormula.Softmax, predictors); else - return OneVersusAllModelParameters.Create(Host, predictors); + return OneVersusAllModelParametersBuilder.Create(Host, predictors); } private protected override void CheckDataValid(IChannel ch, RoleMappedData data) diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index 3a3ae54f7a..2fbf89aab4 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -170,57 +170,6 @@ public override MulticlassPredictionTransformer Fit(IDataView input) } } - public abstract class OneVersusAllTypedTrainerBase : OneVersusAllTrainerBase> where T : class - { - /// - /// Constructs a trainer supplying a . - /// - /// The private for this estimator. - /// The legacy - internal OneVersusAllTypedTrainerBase(IHostEnvironment env, Options options) - : base(env, options) - { - } - - /// - /// Initializes a new instance of . - /// - /// The instance. - /// An instance of a binary used as the base trainer. - /// /// The calibrator. If a calibrator is not provided, it will default to - /// The name of the label colum. - /// If true will treat missing labels as negative labels. - /// /// Number of instances to train the calibrator. - /// Use probabilities (vs. raw outputs) to identify top-score category. - internal OneVersusAllTypedTrainerBase(IHostEnvironment env, - TScalarTrainer binaryEstimator, - string labelColumnName = DefaultColumnNames.Label, - bool imputeMissingLabelsAsNegative = false, - ICalibratorTrainer calibrator = null, - int maximumCalibrationExampleCount = 1000000000, - bool useProbabilities = true) - : base(env, binaryEstimator, labelColumnName, imputeMissingLabelsAsNegative, calibrator, maximumCalibrationExampleCount, useProbabilities) - { - } - - private protected override OneVersusAllModelParametersBase TrainCore(IChannel ch, RoleMappedData data, int count) - { - // Train one-vs-all models. - var predictors = new T[count]; - for (int i = 0; i < predictors.Length; i++) - { - ch.Info($"Training learner {i}"); - predictors[i] = (T)TrainOne(ch, Trainer, data, i).Model; - } - return OneVersusAllModelParametersBase.Create(Host, TrainerOptions.UseProbabilities, predictors); - } - - private protected override MulticlassPredictionTransformer> FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName) - { - return new MulticlassPredictionTransformer>(Host, OneVersusAllModelParametersBase.Create(Host, useProbabilities, predictors.Cast().ToArray()), schema, featureColumn, LabelColumn.Name); - } - } - /// /// The for training a one-versus-all multi-class classifier that uses the specified binary classifier. /// @@ -308,12 +257,12 @@ private protected override OneVersusAllModelParameters TrainCore(IChannel ch, Ro ch.Info($"Training learner {i}"); predictors[i] = TrainOne(ch, Trainer, data, i).Model; } - return OneVersusAllModelParameters.Create(Host, TrainerOptions.UseProbabilities, predictors) as OneVersusAllModelParameters; + return OneVersusAllModelParametersBuilder.Create(Host, TrainerOptions.UseProbabilities, predictors) as OneVersusAllModelParameters; } private protected override MulticlassPredictionTransformer FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName) { - return new MulticlassPredictionTransformer(Host, OneVersusAllModelParameters.Create(Host, useProbabilities, predictors), schema, featureColumn, LabelColumn.Name); + return new MulticlassPredictionTransformer(Host, OneVersusAllModelParametersBuilder.Create(Host, useProbabilities, predictors), schema, featureColumn, LabelColumn.Name); } private protected override ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch, @@ -339,7 +288,7 @@ private protected override ISingleFeaturePredictionTransformer } } - public sealed class OneVersusAllTrainerTyped : OneVersusAllTypedTrainerBase>> + public sealed class OneVersusAllTrainerTyped : OneVersusAllTrainerBase>> where TSubPredictor : class, IPredictorProducing where TCalibrator : class, ICalibrator { @@ -395,9 +344,25 @@ private protected override ISingleFeaturePredictionTransformer return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName); } + private protected override MulticlassPredictionTransformer>> FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName) + { + return new MulticlassPredictionTransformer>>(Host, OneVersusAllModelParametersBuilder.Create(Host, useProbabilities, predictors.Cast>().ToArray()), schema, featureColumn, LabelColumn.Name); + } + + private protected override OneVersusAllModelParametersBase> TrainCore(IChannel ch, RoleMappedData data, int count) + { + // Train one-vs-all models. + var predictors = new CalibratedModelParametersBase[count]; + for (int i = 0; i < predictors.Length; i++) + { + ch.Info($"Training learner {i}"); + predictors[i] = (CalibratedModelParametersBase)TrainOne(ch, Trainer, data, i).Model; + } + return OneVersusAllModelParametersBuilder.Create(Host, TrainerOptions.UseProbabilities, predictors); + } } - public sealed class OneVersusAllTrainerTyped : OneVersusAllTypedTrainerBase> where T : class + public sealed class OneVersusAllTrainerTyped : OneVersusAllTrainerBase> where T : class { /// /// Constructs a trainer supplying a . @@ -446,12 +411,82 @@ private protected override ISingleFeaturePredictionTransformer return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName); } + + private protected override MulticlassPredictionTransformer> FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName) + { + return new MulticlassPredictionTransformer>(Host, OneVersusAllModelParametersBuilder.Create(Host, useProbabilities, predictors.Cast().ToArray()), schema, featureColumn, LabelColumn.Name); + } + + private protected override OneVersusAllModelParametersBase TrainCore(IChannel ch, RoleMappedData data, int count) + { + // Train one-vs-all models. + var predictors = new T[count]; + for (int i = 0; i < predictors.Length; i++) + { + ch.Info($"Training learner {i}"); + predictors[i] = (T)TrainOne(ch, Trainer, data, i).Model; + } + return OneVersusAllModelParametersBuilder.Create(Host, TrainerOptions.UseProbabilities, predictors); + } + } + + public class OneVersusAllModelParametersBuilder { + [BestFriend] + internal static OneVersusAllModelParameters Create(IHost host, OneVersusAllModelParametersBase.OutputFormula outputFormula, T[] predictors) where T : class + { + return new OneVersusAllModelParameters(host, outputFormula, predictors); + } + + [BestFriend] + internal static OneVersusAllModelParametersBase Create(IHost host, bool useProbability, T[] predictors) where T : class + { + var outputFormula = useProbability ? OneVersusAllModelParametersBase.OutputFormula.ProbabilityNormalization : OneVersusAllModelParametersBase.OutputFormula.Raw; + + return Create(host, outputFormula, predictors); + } + + /// + /// Create a from an array of predictors. + /// + [BestFriend] + internal static OneVersusAllModelParametersBase Create(IHost host, T[] predictors) where T : class + { + Contracts.CheckValue(host, nameof(host)); + host.CheckNonEmpty(predictors, nameof(predictors)); + return Create(host, OneVersusAllModelParametersBase.OutputFormula.ProbabilityNormalization, predictors); + } + + [BestFriend] + internal static OneVersusAllModelParameters Create(IHost host, OneVersusAllModelParameters.OutputFormula outputFormula, TScalarPredictor[] predictors) + { + return new OneVersusAllModelParameters(host, outputFormula, predictors); + } + + [BestFriend] + internal static OneVersusAllModelParameters Create(IHost host, bool useProbability, TScalarPredictor[] predictors) + { + var outputFormula = useProbability ? OneVersusAllModelParameters.OutputFormula.ProbabilityNormalization : OneVersusAllModelParameters.OutputFormula.Raw; + + return Create(host, outputFormula, predictors); + } + + /// + /// Create a from an array of predictors. + /// + [BestFriend] + internal static OneVersusAllModelParameters Create(IHost host, TScalarPredictor[] predictors) + { + Contracts.CheckValue(host, nameof(host)); + host.CheckNonEmpty(predictors, nameof(predictors)); + return Create(host, OneVersusAllModelParameters.OutputFormula.ProbabilityNormalization, predictors); + } + } /// /// Model parameters for . /// - public class OneVersusAllModelParametersBase : + public abstract class OneVersusAllModelParametersBase : ModelParametersBase>, IValueMapper, ICanSaveInSourceCode, @@ -504,69 +539,43 @@ internal enum OutputFormula { Raw = 0, ProbabilityNormalization = 1, Softmax = 2 bool ICanSavePfa.CanSavePfa => _impl.CanSavePfa; - [BestFriend] - internal static OneVersusAllModelParametersBase Create(IHost host, OutputFormula outputFormula, T[] predictors) + internal OneVersusAllModelParametersBase(IHostEnvironment env, OutputFormula outputFormula, T[] predictors) + : base(env, RegistrationName) { - ImplBase impl; - - using (var ch = host.Start("Creating OVA predictor")) + using (var ch = env.Start("Creating OVA predictor")) { if (outputFormula == OutputFormula.Softmax) { - impl = new ImplSoftmax(predictors); - return new OneVersusAllModelParametersBase(host, impl); + _impl = new ImplSoftmax(predictors); } // Caller of this function asks for probability output. We check if input predictor can produce probability. // If that predictor can't produce probability, ivmd will be null. - IValueMapperDist ivmd = null; - if (outputFormula == OutputFormula.ProbabilityNormalization && + else + { + IValueMapperDist ivmd = null; + if (outputFormula == OutputFormula.ProbabilityNormalization && ((ivmd = predictors[0] as IValueMapperDist) == null || ivmd.OutputType != NumberDataViewType.Single || ivmd.DistType != NumberDataViewType.Single)) - { - ch.Warning($"{nameof(OneVersusAllTrainerBase.Options.UseProbabilities)} specified with {nameof(OneVersusAllTrainerBase.Options.PredictorType)} that can't produce probabilities."); - ivmd = null; - } + { + ch.Warning($"{nameof(OneVersusAllTrainerBase.Options.UseProbabilities)} specified with {nameof(OneVersusAllTrainerBase.Options.PredictorType)} that can't produce probabilities."); + ivmd = null; + } - // If ivmd is null, either the user didn't ask for probability or the provided predictors can't produce probability. - if (ivmd != null) - { - impl = new ImplDist(predictors); + // If ivmd is null, either the user didn't ask for probability or the provided predictors can't produce probability. + if (ivmd != null) + { + _impl = new ImplDist(predictors); + } + else + _impl = new ImplRaw(predictors); } - else - impl = new ImplRaw(predictors); } - return new OneVersusAllModelParametersBase(host, impl); - } - - [BestFriend] - internal static OneVersusAllModelParametersBase Create(IHost host, bool useProbability, T[] predictors) - { - var outputFormula = useProbability ? OutputFormula.ProbabilityNormalization : OutputFormula.Raw; - - return Create(host, outputFormula, predictors); - } - - /// - /// Create a from an array of predictors. - /// - [BestFriend] - internal static OneVersusAllModelParametersBase Create(IHost host, T[] predictors) - { - Contracts.CheckValue(host, nameof(host)); - host.CheckNonEmpty(predictors, nameof(predictors)); - return Create(host, OutputFormula.ProbabilityNormalization, predictors); - } - - private protected OneVersusAllModelParametersBase(IHostEnvironment env, ImplBase impl) - : base(env, RegistrationName) - { - Host.AssertValue(impl, nameof(impl)); - Host.Assert(Utils.Size(impl.Predictors) > 0); + Host.AssertValue(_impl, nameof(_impl)); + Host.Assert(Utils.Size(_impl.Predictors) > 0); - _impl = impl; DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length); } @@ -596,14 +605,6 @@ private protected OneVersusAllModelParametersBase(IHostEnvironment env, ModelLoa DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length); } - private static OneVersusAllModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(GetVersionInfo()); - return new OneVersusAllModelParametersBase(env, ctx); - } - private static void LoadPredictors(IHostEnvironment env, TPredictor[] predictors, ModelLoadContext ctx) where TPredictor : class { @@ -970,76 +971,34 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) } } - /// - /// Model parameters for . - /// - public sealed class OneVersusAllModelParameters : - OneVersusAllModelParametersBase + public sealed class OneVersusAllModelParameters : + OneVersusAllModelParametersBase where T : class { - private OneVersusAllModelParameters(IHostEnvironment env, ImplBase impl) - : base(env, impl) + internal OneVersusAllModelParameters(IHostEnvironment env, OutputFormula outputFormula, T[] predictors) + : base(env, outputFormula, predictors) { } private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) - : base(env, ctx) - { - } - - /// - /// Create a from an array of predictors. - /// - [BestFriend] - internal static new OneVersusAllModelParameters Create(IHost host, TScalarPredictor[] predictors) + : base(env, ctx) { - Contracts.CheckValue(host, nameof(host)); - host.CheckNonEmpty(predictors, nameof(predictors)); - return Create(host, OutputFormula.ProbabilityNormalization, predictors); } + } - [BestFriend] - internal static new OneVersusAllModelParameters Create(IHost host, OutputFormula outputFormula, TScalarPredictor[] predictors) + /// + /// Model parameters for . + /// + public sealed class OneVersusAllModelParameters : + OneVersusAllModelParametersBase + { + internal OneVersusAllModelParameters(IHostEnvironment env, OutputFormula outputFormula, TScalarPredictor[] predictors) + : base(env, outputFormula, predictors) { - ImplBase impl; - - using (var ch = host.Start("Creating OVA predictor")) - { - if (outputFormula == OutputFormula.Softmax) - { - impl = new ImplSoftmax(predictors); - return new OneVersusAllModelParameters(host, impl); - } - - // Caller of this function asks for probability output. We check if input predictor can produce probability. - // If that predictor can't produce probability, ivmd will be null. - IValueMapperDist ivmd = null; - if (outputFormula == OutputFormula.ProbabilityNormalization && - ((ivmd = predictors[0] as IValueMapperDist) == null || - ivmd.OutputType != NumberDataViewType.Single || - ivmd.DistType != NumberDataViewType.Single)) - { - ch.Warning($"{nameof(OneVersusAllTrainer.Options.UseProbabilities)} specified with {nameof(OneVersusAllTrainer.Options.PredictorType)} that can't produce probabilities."); - ivmd = null; - } - - // If ivmd is null, either the user didn't ask for probability or the provided predictors can't produce probability. - if (ivmd != null) - { - impl = new ImplDist(predictors); - } - else - impl = new ImplRaw(predictors); - } - - return new OneVersusAllModelParameters(host, impl); } - private static OneVersusAllModelParameters Create(IHostEnvironment env, ModelLoadContext ctx) + private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) + : base(env, ctx) { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(GetVersionInfo()); - return new OneVersusAllModelParameters(env, ctx); } } } diff --git a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs index 4d2781f9d4..2133d4376b 100644 --- a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs @@ -827,7 +827,7 @@ public static OneVersusAllTrainerTyped OneVersusAllTyped(this Mu /// [!code-csharp[OneVersusAll](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/OneVersusAll.cs)] /// ]]> /// - public static OneVersusAllTrainerTyped OneVersusAllUnCalibratedToCalibratedTyped(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + public static OneVersusAllTrainerTyped OneVersusAllUnCalibratedToCalibrated(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, ITrainerEstimator, TModelIn> binaryEstimator, string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, diff --git a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs index 9322e7b54d..8ae7ded3a3 100644 --- a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs +++ b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs @@ -132,7 +132,7 @@ public void OvaCalibratedAveragedPerceptron() new AveragedPerceptronTrainer.Options { Shuffle = true }); var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll(ap); - var pipelineTyped = mlContext.MulticlassClassification.Trainers.OneVersusAllUnCalibratedToCalibratedTyped(apTyped); + var pipelineTyped = mlContext.MulticlassClassification.Trainers.OneVersusAllUnCalibratedToCalibrated(apTyped); var model = pipeline.Fit(data); var predictions = model.Transform(data); @@ -225,11 +225,10 @@ public void OvaLinearSvm() }); // Data var textData = reader.Load(GetDataPath(dataPath)); - var textDataTyped = reader.Load(GetDataPath(dataPath)); var data = mlContext.Data.Cache(mlContext.Transforms.Conversion.MapValueToKey("Label") .Fit(textData).Transform(textData)); var dataTyped = mlContextTyped.Data.Cache(mlContextTyped.Transforms.Conversion.MapValueToKey("Label") - .Fit(textDataTyped).Transform(textDataTyped)); + .Fit(textData).Transform(textData)); // Pipeline var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll( @@ -243,15 +242,15 @@ public void OvaLinearSvm() var model = pipeline.Fit(data); var predictions = model.Transform(data); - var modelTyped = pipelineTyped.Fit(dataTyped); - var predictionsTyped = modelTyped.Transform(dataTyped); + var modelTyped = pipelineTyped.Fit(data); + var predictionsTyped = modelTyped.Transform(data); // Metrics var metrics = mlContext.MulticlassClassification.Evaluate(predictions); - Assert.True(metrics.MicroAccuracy > 0.83); + Assert.True(metrics.MicroAccuracy > 0.95); var metricsTyped = mlContextTyped.MulticlassClassification.Evaluate(predictionsTyped); - Assert.True(metricsTyped.MicroAccuracy > 0.83); + Assert.True(metricsTyped.MicroAccuracy > 0.95); Assert.Equal(metrics.MicroAccuracy, metricsTyped.MicroAccuracy); } diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs index b3ba9877af..9f94bc2560 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs @@ -52,23 +52,6 @@ public void OVAUncalibrated() Done(); } - /// - /// OVA strongly typed un-calibrated - /// - [Fact] - public void OVATypedUncalibrated() - { - var (pipeline, data) = GetMulticlassPipeline(); - var sdcaTrainer = ML.BinaryClassification.Trainers.SdcaNonCalibrated( - new SdcaNonCalibratedBinaryTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1 }); - - pipeline = pipeline.Append(ML.MulticlassClassification.Trainers.OneVersusAllTyped(sdcaTrainer, useProbabilities: false)) - .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); - - TestEstimatorCore(pipeline, data); - Done(); - } - /// /// Pairwise Coupling trainer /// @@ -112,6 +95,7 @@ public void MetacomponentsFeaturesRenamed() var pipeline = new ColumnConcatenatingEstimator(Env, "Vars", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(new ValueToKeyMappingEstimator(Env, "Label"), TransformerScope.TrainTest) + .Append(ML.MulticlassClassification.Trainers.OneVersusAll(sdcaTrainer)) .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); var model = pipeline.Fit(data); From d488049bf3c10682608a43a9a6d0da958456dd45 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Wed, 17 Jul 2019 14:16:01 -0700 Subject: [PATCH 12/19] expanded comments --- .../OneVersusAllTrainer.cs | 141 ++++++++++++------ .../StandardTrainersCatalog.cs | 8 +- test/Microsoft.ML.Tests/Scenarios/OvaTest.cs | 3 + 3 files changed, 101 insertions(+), 51 deletions(-) diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index 2fbf89aab4..2b3eddff55 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -38,6 +38,51 @@ namespace Microsoft.ML.Trainers using TScalarPredictor = IPredictorProducing; using TScalarTrainer = ITrainerEstimator>, IPredictorProducing>; + /// + /// The for training a one-versus-all multi-class classifier that uses the specified binary classifier. + /// + /// + /// , + /// can be different from , which develops a multi-class classifier directly. + /// Note that even if the classifier indicates that it does not need caching, OneVersusAll will always + /// request caching, as it will be performing multiple passes over the data set. + /// This trainer will request normalization from the data pipeline if the classifier indicates it would benefit from it. + /// + /// This can allow you to exploit trainers that do not naturally have a + /// multiclass option, for example, using the + /// to solve a multiclass problem. + /// Alternately, it can allow ML.NET to solve a "simpler" problem even in the cases + /// where the trainer has a multiclass option, but using it directly is not + /// practical due to, usually, memory constraints. For example, while a multiclass + /// logistic regression is a more principled way to solve a multiclass problem, it + /// requires that the trainer store a lot more intermediate state in the form of + /// L-BFGS history for all classes *simultaneously*, rather than just one-by-one + /// as would be needed for a one-versus-all classification model. + /// + /// Check the See Also section for links to usage examples. + /// ]]> + /// + /// + /// public abstract class OneVersusAllTrainerBase : MetaMulticlassTrainer, T> where T : class { internal const string LoadNameValue = "OVA"; @@ -102,6 +147,16 @@ internal OneVersusAllTrainerBase(IHostEnvironment env, TrainerOptions.UseProbabilities = useProbabilities; } + /// + /// Training helper method that is called by . This allows the + /// classes that inherit from this class to do any custom training changes needed, such as casting. + /// + /// The instance. + /// Whether probabilities should be used or not. Is pulled from the trainer . + /// The that has the data. + /// The label for the trainer. + /// The used by the trainer + /// private protected abstract ISingleFeaturePredictionTransformer TrainOneHelper(IChannel ch, bool useProbabilities, IDataView view, string trainerLabel, ISingleFeaturePredictionTransformer transformer); @@ -135,6 +190,17 @@ private IDataView MapLabels(RoleMappedData data, int cls) throw Host.ExceptNotSupp($"Label column type is not supported by OneVersusAllTrainer: {label.Type.RawType}"); } + /// + /// Fit helper method that is called by . This allows the + /// classes that inherit from this class to do any custom fit changes needed, such as casting. + /// + /// The . + /// Whether probabilities should be used or not. Is pulled from the trainer . + /// The array of used. + /// The of the transformer. + /// The feature column. + /// The name of the label column. + /// private protected abstract MulticlassPredictionTransformer FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName); /// Trains a model. @@ -171,50 +237,9 @@ public override MulticlassPredictionTransformer Fit(IDataView input) } /// - /// The for training a one-versus-all multi-class classifier that uses the specified binary classifier. + /// Implementation of the where T is a + /// to maintain api compatability. /// - /// - /// , - /// can be different from , which develops a multi-class classifier directly. - /// Note that even if the classifier indicates that it does not need caching, OneVersusAll will always - /// request caching, as it will be performing multiple passes over the data set. - /// This trainer will request normalization from the data pipeline if the classifier indicates it would benefit from it. - /// - /// This can allow you to exploit trainers that do not naturally have a - /// multiclass option, for example, using the - /// to solve a multiclass problem. - /// Alternately, it can allow ML.NET to solve a "simpler" problem even in the cases - /// where the trainer has a multiclass option, but using it directly is not - /// practical due to, usually, memory constraints. For example, while a multiclass - /// logistic regression is a more principled way to solve a multiclass problem, it - /// requires that the trainer store a lot more intermediate state in the form of - /// L-BFGS history for all classes *simultaneously*, rather than just one-by-one - /// as would be needed for a one-versus-all classification model. - /// - /// Check the See Also section for links to usage examples. - /// ]]> - /// - /// - /// public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase { /// @@ -288,6 +313,12 @@ private protected override ISingleFeaturePredictionTransformer } } + /// + /// Strongly typed implementation of the where T is a of type + /// This is used to turn a non calibrated binary classification estimator into its calibrated version. + /// + /// + /// public sealed class OneVersusAllTrainerTyped : OneVersusAllTrainerBase>> where TSubPredictor : class, IPredictorProducing where TCalibrator : class, ICalibrator @@ -362,6 +393,12 @@ private protected override OneVersusAllModelParametersBase + /// Strongly typed implementation of the where T is a . T can either be + /// a calibrated binary estimator of type , or a non calibrated binary estimary. + /// This cannot be used to turn a non calibrated binary classification estimator into its calibrated version. If that is required, use instead. + /// + /// public sealed class OneVersusAllTrainerTyped : OneVersusAllTrainerBase> where T : class { /// @@ -430,6 +467,9 @@ private protected override OneVersusAllModelParametersBase TrainCore(IChannel } } + /// + /// Class that holds the static create methods for the classes. + /// public class OneVersusAllModelParametersBuilder { [BestFriend] internal static OneVersusAllModelParameters Create(IHost host, OneVersusAllModelParametersBase.OutputFormula outputFormula, T[] predictors) where T : class @@ -438,7 +478,7 @@ internal static OneVersusAllModelParameters Create(IHost host, OneVersusAl } [BestFriend] - internal static OneVersusAllModelParametersBase Create(IHost host, bool useProbability, T[] predictors) where T : class + internal static OneVersusAllModelParameters Create(IHost host, bool useProbability, T[] predictors) where T : class { var outputFormula = useProbability ? OneVersusAllModelParametersBase.OutputFormula.ProbabilityNormalization : OneVersusAllModelParametersBase.OutputFormula.Raw; @@ -446,14 +486,14 @@ internal static OneVersusAllModelParametersBase Create(IHost host, bool us } /// - /// Create a from an array of predictors. + /// Create a from an array of predictors. /// [BestFriend] - internal static OneVersusAllModelParametersBase Create(IHost host, T[] predictors) where T : class + internal static OneVersusAllModelParameters Create(IHost host, T[] predictors) where T : class { Contracts.CheckValue(host, nameof(host)); host.CheckNonEmpty(predictors, nameof(predictors)); - return Create(host, OneVersusAllModelParametersBase.OutputFormula.ProbabilityNormalization, predictors); + return Create(host, OneVersusAllModelParameters.OutputFormula.ProbabilityNormalization, predictors); } [BestFriend] @@ -471,7 +511,7 @@ internal static OneVersusAllModelParameters Create(IHost host, bool useProbabili } /// - /// Create a from an array of predictors. + /// Create a from an array of predictors. This is for backwards API compatability. /// [BestFriend] internal static OneVersusAllModelParameters Create(IHost host, TScalarPredictor[] predictors) @@ -484,7 +524,7 @@ internal static OneVersusAllModelParameters Create(IHost host, TScalarPredictor[ } /// - /// Model parameters for . + /// Base model parameters for . /// public abstract class OneVersusAllModelParametersBase : ModelParametersBase>, @@ -971,6 +1011,9 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) } } + /// + /// Model parameters for typed versions of . + /// public sealed class OneVersusAllModelParameters : OneVersusAllModelParametersBase where T : class { diff --git a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs index 2133d4376b..86ad260477 100644 --- a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs @@ -766,7 +766,9 @@ public static OneVersusAllTrainer OneVersusAll(this MulticlassClassifica /// /// Create a , which predicts a multiclass target using one-versus-all strategy with - /// the binary classification estimator specified by . + /// the binary classification estimator specified by . This method works with binary classifiers that + /// are either already calibrated, or non calibrated ones you don't want calibrated. If you need to have your classifier calibrated, use the + /// method instead. /// /// /// @@ -803,7 +805,9 @@ public static OneVersusAllTrainerTyped OneVersusAllTyped(this Mu /// /// Create a , which predicts a multiclass target using one-versus-all strategy with - /// the binary classification estimator specified by . + /// the binary classification estimator specified by .This method works with binary classifiers that + /// are not calibrated and need to be calibrated before use. If your classifier is already calibrated or it does not need to be, use the + /// method instead. /// /// /// diff --git a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs index 8ae7ded3a3..89f7e19f43 100644 --- a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs +++ b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs @@ -215,6 +215,9 @@ public void OvaLinearSvm() new TextLoader.Column("Features", DataKind.Single, new [] { new TextLoader.Range(1, 4) }), } }); + + // REVIEW: readerTyped and dataTyped aren't used anywhere in this test, but if I take them out + // the test will fail. It seems to me that something is changing state somewhere, maybe in the cache? var readerTyped = new TextLoader(mlContextTyped, new TextLoader.Options() { Columns = new[] From 46fab0410cdd74b51a5b5b1da024ca28aba95cc7 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Wed, 17 Jul 2019 14:30:37 -0700 Subject: [PATCH 13/19] small formatting fixes --- src/Microsoft.ML.Data/Prediction/Calibrator.cs | 2 -- .../StandardTrainersCatalog.cs | 9 +++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 1179a0f306..74a558d78f 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -989,9 +989,7 @@ public static IPredictorProducing CreateCalibratedPredictor; if (predWithFeatureScores != null && predictor is IParameterMixer && cali is IParameterMixer) - { return new ParameterMixingCalibratedModelParameters(env, predictor, cali); - } if (predictor is IValueMapper) return new ValueMapperCalibratedModelParameters(env, predictor, cali); diff --git a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs index 86ad260477..84a91dbcde 100644 --- a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs @@ -804,9 +804,10 @@ public static OneVersusAllTrainerTyped OneVersusAllTyped(this Mu } /// - /// Create a , which predicts a multiclass target using one-versus-all strategy with + /// Create a , which predicts a multiclass target using one-versus-all strategy with /// the binary classification estimator specified by .This method works with binary classifiers that - /// are not calibrated and need to be calibrated before use. If your classifier is already calibrated or it does not need to be, use the + /// are not calibrated and need to be calibrated before use. Due to the type of estimator changing (from uncalibrated to calibrated), you must manually + /// specify both the type of the model and the type of the calibrator. If your classifier is already calibrated or it does not need to be, use the /// method instead. /// /// @@ -823,8 +824,8 @@ public static OneVersusAllTrainerTyped OneVersusAllTyped(this Mu /// Whether to treat missing labels as having negative labels, instead of keeping them missing. /// Number of instances to train the calibrator. /// Use probabilities (vs. raw outputs) to identify top-score category. - /// The type of the model. This type parameter will usually be inferred automatically from . - /// The type of the model. This type parameter will usually be inferred automatically from . + /// The type of the model. This type parameter cannot be inferred and must be specified manually. It is usually a . + /// The calibrator for the model. This type parameter cannot be inferred automatically and must be specified manually and must be of type . /// /// /// Date: Thu, 18 Jul 2019 10:01:24 -0700 Subject: [PATCH 14/19] removed unused import. Changed type of class to static internal per PR comments --- .../Standard/MulticlassClassification/OneVersusAllTrainer.cs | 3 ++- test/Microsoft.ML.Tests/Scenarios/OvaTest.cs | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index 2b3eddff55..5f1c06f305 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -470,7 +470,8 @@ private protected override OneVersusAllModelParametersBase TrainCore(IChannel /// /// Class that holds the static create methods for the classes. /// - public class OneVersusAllModelParametersBuilder { + [BestFriend] + internal static class OneVersusAllModelParametersBuilder { [BestFriend] internal static OneVersusAllModelParameters Create(IHost host, OneVersusAllModelParametersBase.OutputFormula outputFormula, T[] predictors) where T : class { diff --git a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs index 89f7e19f43..06c30e2598 100644 --- a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs +++ b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.Reflection.Metadata; using Microsoft.ML.Calibrators; using Microsoft.ML.Data; using Microsoft.ML.Trainers; From 8f851fa6e02649fec590bd2b5894f54567a81eab Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Thu, 18 Jul 2019 10:45:28 -0700 Subject: [PATCH 15/19] class name changes from PR comments --- .../OneVersusAllTrainer.cs | 20 +++++++++---------- .../StandardTrainersCatalog.cs | 12 +++++------ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index 5f1c06f305..8daea2fb2d 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -319,22 +319,22 @@ private protected override ISingleFeaturePredictionTransformer /// /// /// - public sealed class OneVersusAllTrainerTyped : OneVersusAllTrainerBase>> + public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase>> where TSubPredictor : class, IPredictorProducing where TCalibrator : class, ICalibrator { /// - /// Constructs a trainer supplying a . + /// Constructs a trainer supplying a . /// /// The private for this estimator. /// The legacy - internal OneVersusAllTrainerTyped(IHostEnvironment env, Options options) + internal OneVersusAllTrainer(IHostEnvironment env, Options options) : base(env, options) { } /// - /// Initializes a new instance of . + /// Initializes a new instance of . /// /// The instance. /// An instance of a binary used as the base trainer. @@ -343,7 +343,7 @@ internal OneVersusAllTrainerTyped(IHostEnvironment env, Options options) /// If true will treat missing labels as negative labels. /// Number of instances to train the calibrator. /// Use probabilities (vs. raw outputs) to identify top-score category. - internal OneVersusAllTrainerTyped(IHostEnvironment env, + internal OneVersusAllTrainer(IHostEnvironment env, TScalarTrainer binaryEstimator, string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, @@ -396,30 +396,30 @@ private protected override OneVersusAllModelParametersBase /// Strongly typed implementation of the where T is a . T can either be /// a calibrated binary estimator of type , or a non calibrated binary estimary. - /// This cannot be used to turn a non calibrated binary classification estimator into its calibrated version. If that is required, use instead. + /// This cannot be used to turn a non calibrated binary classification estimator into its calibrated version. If that is required, use instead. /// /// - public sealed class OneVersusAllTrainerTyped : OneVersusAllTrainerBase> where T : class + public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase> where T : class { /// /// Constructs a trainer supplying a . /// /// The private for this estimator. /// The legacy - internal OneVersusAllTrainerTyped(IHostEnvironment env, Options options) + internal OneVersusAllTrainer(IHostEnvironment env, Options options) : base(env, options) { } /// - /// Initializes a new instance of . + /// Initializes a new instance of . /// /// The instance. /// An instance of a binary used as the base trainer. /// The name of the label colum. /// If true will treat missing labels as negative labels. /// Use probabilities (vs. raw outputs) to identify top-score category. - internal OneVersusAllTrainerTyped(IHostEnvironment env, + internal OneVersusAllTrainer(IHostEnvironment env, TScalarTrainer binaryEstimator, string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, diff --git a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs index 84a91dbcde..9ac45d8390 100644 --- a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs @@ -765,7 +765,7 @@ public static OneVersusAllTrainer OneVersusAll(this MulticlassClassifica } /// - /// Create a , which predicts a multiclass target using one-versus-all strategy with + /// Create a , which predicts a multiclass target using one-versus-all strategy with /// the binary classification estimator specified by . This method works with binary classifiers that /// are either already calibrated, or non calibrated ones you don't want calibrated. If you need to have your classifier calibrated, use the /// method instead. @@ -789,7 +789,7 @@ public static OneVersusAllTrainer OneVersusAll(this MulticlassClassifica /// [!code-csharp[OneVersusAll](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/OneVersusAll.cs)] /// ]]> /// - public static OneVersusAllTrainerTyped OneVersusAllTyped(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + public static OneVersusAllTrainer OneVersusAllTyped(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, ITrainerEstimator, TModel> binaryEstimator, string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, @@ -800,11 +800,11 @@ public static OneVersusAllTrainerTyped OneVersusAllTyped(this Mu var env = CatalogUtils.GetEnvironment(catalog); if (!(binaryEstimator is ITrainerEstimator>, IPredictorProducing> est)) throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); - return new OneVersusAllTrainerTyped(env, est, labelColumnName, imputeMissingLabelsAsNegative, useProbabilities); + return new OneVersusAllTrainer(env, est, labelColumnName, imputeMissingLabelsAsNegative, useProbabilities); } /// - /// Create a , which predicts a multiclass target using one-versus-all strategy with + /// Create a , which predicts a multiclass target using one-versus-all strategy with /// the binary classification estimator specified by .This method works with binary classifiers that /// are not calibrated and need to be calibrated before use. Due to the type of estimator changing (from uncalibrated to calibrated), you must manually /// specify both the type of the model and the type of the calibrator. If your classifier is already calibrated or it does not need to be, use the @@ -832,7 +832,7 @@ public static OneVersusAllTrainerTyped OneVersusAllTyped(this Mu /// [!code-csharp[OneVersusAll](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/OneVersusAll.cs)] /// ]]> /// - public static OneVersusAllTrainerTyped OneVersusAllUnCalibratedToCalibrated(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, + public static OneVersusAllTrainer OneVersusAllUnCalibratedToCalibrated(this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, ITrainerEstimator, TModelIn> binaryEstimator, string labelColumnName = DefaultColumnNames.Label, bool imputeMissingLabelsAsNegative = false, @@ -846,7 +846,7 @@ public static OneVersusAllTrainerTyped OneVersusAllUnCali var env = CatalogUtils.GetEnvironment(catalog); if (!(binaryEstimator is ITrainerEstimator>, IPredictorProducing> est)) throw env.ExceptParam(nameof(binaryEstimator), "Trainer estimator does not appear to produce the right kind of model."); - return new OneVersusAllTrainerTyped(env, est, labelColumnName, imputeMissingLabelsAsNegative, GetCalibratorTrainerOrThrow(env, calibrator), maximumCalibrationExampleCount, useProbabilities); + return new OneVersusAllTrainer(env, est, labelColumnName, imputeMissingLabelsAsNegative, GetCalibratorTrainerOrThrow(env, calibrator), maximumCalibrationExampleCount, useProbabilities); } /// From 32ddb544d57eaaef3726d5c64fdb36cbefcaec67 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Thu, 18 Jul 2019 13:20:24 -0700 Subject: [PATCH 16/19] restructured OVAModelParameters. changed interfaces back to internal --- .../Prediction/IPredictor.cs | 6 +- .../Prediction/Calibrator.cs | 8 +- .../OneVersusAllTrainer.cs | 113 ++++++++++-------- .../StandardTrainersCatalog.cs | 2 +- 4 files changed, 69 insertions(+), 60 deletions(-) diff --git a/src/Microsoft.ML.Core/Prediction/IPredictor.cs b/src/Microsoft.ML.Core/Prediction/IPredictor.cs index 2f5fae8db3..728afe7802 100644 --- a/src/Microsoft.ML.Core/Prediction/IPredictor.cs +++ b/src/Microsoft.ML.Core/Prediction/IPredictor.cs @@ -33,7 +33,8 @@ public enum PredictionKind /// /// Weakly typed version of IPredictor. /// - public interface IPredictor + [BestFriend] + internal interface IPredictor { /// /// Return the type of prediction task. @@ -45,7 +46,8 @@ public interface IPredictor /// A predictor the produces values of the indicated type. /// REVIEW: Determine whether this is just a temporary shim or long term solution. /// - public interface IPredictorProducing : IPredictor + [BestFriend] + internal interface IPredictorProducing : IPredictor { } diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index 74a558d78f..d3735e6460 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -881,10 +881,10 @@ public static IPredictor GetCalibratedPredictor(IHostEnvironment env, IChannel c public static CalibratedModelParametersBase GetCalibratedPredictor(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, TSubPredictor predictor, RoleMappedData data, int maxRows = _maxCalibrationExamples) - where TSubPredictor : class, IPredictorProducing + where TSubPredictor : class where TCalibrator : class, ICalibrator { - var trainedCalibrator = TrainCalibrator(env, ch, caliTrainer, predictor, data, maxRows) as TCalibrator; + var trainedCalibrator = TrainCalibrator(env, ch, caliTrainer, (IPredictor)predictor, data, maxRows) as TCalibrator; return (CalibratedModelParametersBase)CreateCalibratedPredictor(env, predictor, trainedCalibrator); } @@ -972,12 +972,12 @@ public static ICalibrator TrainCalibrator(IHostEnvironment env, IChannel ch, ICa } public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironment env, TSubPredictor predictor, TCalibrator cali) - where TSubPredictor : class, IPredictorProducing + where TSubPredictor : class where TCalibrator : class, ICalibrator { Contracts.Assert(predictor != null); if (cali == null) - return predictor; + return (IPredictorProducing)predictor; for (; ; ) { diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index 8daea2fb2d..1c4e2c04c3 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -237,7 +237,7 @@ public override MulticlassPredictionTransformer Fit(IDataView input) } /// - /// Implementation of the where T is a + /// Implementation of the where T is a /// to maintain api compatability. /// public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase @@ -319,8 +319,8 @@ private protected override ISingleFeaturePredictionTransformer /// /// /// - public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase>> - where TSubPredictor : class, IPredictorProducing + public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase>> + where TSubPredictor : class where TCalibrator : class, ICalibrator { /// @@ -375,12 +375,12 @@ private protected override ISingleFeaturePredictionTransformer return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName); } - private protected override MulticlassPredictionTransformer>> FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName) + private protected override MulticlassPredictionTransformer>> FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName) { - return new MulticlassPredictionTransformer>>(Host, OneVersusAllModelParametersBuilder.Create(Host, useProbabilities, predictors.Cast>().ToArray()), schema, featureColumn, LabelColumn.Name); + return new MulticlassPredictionTransformer>>(Host, OneVersusAllModelParametersBuilder.Create(Host, useProbabilities, predictors.Cast>().ToArray()), schema, featureColumn, LabelColumn.Name); } - private protected override OneVersusAllModelParametersBase> TrainCore(IChannel ch, RoleMappedData data, int count) + private protected override OneVersusAllModelParameters> TrainCore(IChannel ch, RoleMappedData data, int count) { // Train one-vs-all models. var predictors = new CalibratedModelParametersBase[count]; @@ -399,7 +399,7 @@ private protected override OneVersusAllModelParametersBase instead. /// /// - public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase> where T : class + public sealed class OneVersusAllTrainer : OneVersusAllTrainerBase> where T : class { /// /// Constructs a trainer supplying a . @@ -449,12 +449,12 @@ private protected override ISingleFeaturePredictionTransformer return new BinaryPredictionTransformer(Host, transformer.Model, view.Schema, transformer.FeatureColumnName); } - private protected override MulticlassPredictionTransformer> FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName) + private protected override MulticlassPredictionTransformer> FitHelper(IHost host, bool useProbabilities, TScalarPredictor[] predictors, DataViewSchema schema, string featureColumn, string labelColumnName) { - return new MulticlassPredictionTransformer>(Host, OneVersusAllModelParametersBuilder.Create(Host, useProbabilities, predictors.Cast().ToArray()), schema, featureColumn, LabelColumn.Name); + return new MulticlassPredictionTransformer>(Host, OneVersusAllModelParametersBuilder.Create(Host, useProbabilities, predictors.Cast().ToArray()), schema, featureColumn, LabelColumn.Name); } - private protected override OneVersusAllModelParametersBase TrainCore(IChannel ch, RoleMappedData data, int count) + private protected override OneVersusAllModelParameters TrainCore(IChannel ch, RoleMappedData data, int count) { // Train one-vs-all models. var predictors = new T[count]; @@ -473,7 +473,7 @@ private protected override OneVersusAllModelParametersBase TrainCore(IChannel [BestFriend] internal static class OneVersusAllModelParametersBuilder { [BestFriend] - internal static OneVersusAllModelParameters Create(IHost host, OneVersusAllModelParametersBase.OutputFormula outputFormula, T[] predictors) where T : class + internal static OneVersusAllModelParameters Create(IHost host, OneVersusAllModelParameters.OutputFormula outputFormula, T[] predictors) where T : class { return new OneVersusAllModelParameters(host, outputFormula, predictors); } @@ -481,7 +481,7 @@ internal static OneVersusAllModelParameters Create(IHost host, OneVersusAl [BestFriend] internal static OneVersusAllModelParameters Create(IHost host, bool useProbability, T[] predictors) where T : class { - var outputFormula = useProbability ? OneVersusAllModelParametersBase.OutputFormula.ProbabilityNormalization : OneVersusAllModelParametersBase.OutputFormula.Raw; + var outputFormula = useProbability ? OneVersusAllModelParameters.OutputFormula.ProbabilityNormalization : OneVersusAllModelParameters.OutputFormula.Raw; return Create(host, outputFormula, predictors); } @@ -527,18 +527,17 @@ internal static OneVersusAllModelParameters Create(IHost host, TScalarPredictor[ /// /// Base model parameters for . /// - public abstract class OneVersusAllModelParametersBase : + public abstract class OneVersusAllModelParametersBase : ModelParametersBase>, IValueMapper, ICanSaveInSourceCode, ICanSaveInTextFormat, ISingleCanSavePfa - where T : class { internal const string LoaderSignature = "OVAExec"; internal const string RegistrationName = "OVAPredictor"; - private protected static VersionInfo GetVersionInfo() + private static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "TLC OVA ", @@ -546,17 +545,12 @@ private protected static VersionInfo GetVersionInfo() verReadableCur: 0x00010001, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(OneVersusAllModelParametersBase).Assembly.FullName); + loaderAssemblyName: typeof(OneVersusAllModelParameters).Assembly.FullName); } private const string SubPredictorFmt = "SubPredictor_{0:000}"; - private readonly ImplBase _impl; - - /// - /// Retrieves the model parameters. - /// - internal ImmutableArray SubModelParameters => _impl.Predictors.Cast().ToImmutableArray(); + private protected readonly ImplBase Impl; /// /// The type of the prediction task. @@ -578,16 +572,16 @@ internal enum OutputFormula { Raw = 0, ProbabilityNormalization = 1, Softmax = 2 private DataViewType DistType { get; } - bool ICanSavePfa.CanSavePfa => _impl.CanSavePfa; + bool ICanSavePfa.CanSavePfa => Impl.CanSavePfa; - internal OneVersusAllModelParametersBase(IHostEnvironment env, OutputFormula outputFormula, T[] predictors) + internal OneVersusAllModelParametersBase(IHostEnvironment env, OutputFormula outputFormula, TScalarPredictor[] predictors) : base(env, RegistrationName) { using (var ch = env.Start("Creating OVA predictor")) { if (outputFormula == OutputFormula.Softmax) { - _impl = new ImplSoftmax(predictors); + Impl = new ImplSoftmax(predictors); } // Caller of this function asks for probability output. We check if input predictor can produce probability. @@ -600,24 +594,27 @@ internal OneVersusAllModelParametersBase(IHostEnvironment env, OutputFormula out ivmd.OutputType != NumberDataViewType.Single || ivmd.DistType != NumberDataViewType.Single)) { - ch.Warning($"{nameof(OneVersusAllTrainerBase.Options.UseProbabilities)} specified with {nameof(OneVersusAllTrainerBase.Options.PredictorType)} that can't produce probabilities."); + ch.Warning($"{nameof(OneVersusAllTrainer.Options.UseProbabilities)} specified with {nameof(OneVersusAllTrainer.Options.PredictorType)} that can't produce probabilities."); ivmd = null; } // If ivmd is null, either the user didn't ask for probability or the provided predictors can't produce probability. if (ivmd != null) { - _impl = new ImplDist(predictors); + var dists = new IValueMapperDist[predictors.Length]; + for (int i = 0; i < predictors.Length; ++i) + dists[i] = (IValueMapperDist)predictors[i]; + Impl = new ImplDist(dists); } else - _impl = new ImplRaw(predictors); + Impl = new ImplRaw(predictors); } } - Host.AssertValue(_impl, nameof(_impl)); - Host.Assert(Utils.Size(_impl.Predictors) > 0); + Host.AssertValue(Impl, nameof(Impl)); + Host.Assert(Utils.Size(Impl.Predictors) > 0); - DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length); + DistType = new VectorDataViewType(NumberDataViewType.Single, Impl.Predictors.Length); } private protected OneVersusAllModelParametersBase(IHostEnvironment env, ModelLoadContext ctx) @@ -632,18 +629,18 @@ private protected OneVersusAllModelParametersBase(IHostEnvironment env, ModelLoa if (useDist) { - var predictors = new T[len]; + var predictors = new IValueMapperDist[len]; LoadPredictors(Host, predictors, ctx); - _impl = new ImplDist(predictors); + Impl = new ImplDist(predictors); } else { - var predictors = new T[len]; + var predictors = new TScalarPredictor[len]; LoadPredictors(Host, predictors, ctx); - _impl = new ImplRaw(predictors); + Impl = new ImplRaw(predictors); } - DistType = new VectorDataViewType(NumberDataViewType.Single, _impl.Predictors.Length); + DistType = new VectorDataViewType(NumberDataViewType.Single, Impl.Predictors.Length); } private static void LoadPredictors(IHostEnvironment env, TPredictor[] predictors, ModelLoadContext ctx) @@ -658,12 +655,12 @@ private protected override void SaveCore(ModelSaveContext ctx) base.SaveCore(ctx); ctx.SetVersionInfo(GetVersionInfo()); - var preds = _impl.Predictors; + var preds = Impl.Predictors; // *** Binary format *** // bool: useDist // int: predictor count - ctx.Writer.WriteBoolByte(_impl is ImplDist); + ctx.Writer.WriteBoolByte(Impl is ImplDist); ctx.Writer.Write(preds.Length); // Save other streams. @@ -675,12 +672,12 @@ JToken ISingleCanSavePfa.SaveAsPfa(BoundPfaContext ctx, JToken input) { Host.CheckValue(ctx, nameof(ctx)); Host.CheckValue(input, nameof(input)); - return _impl.SaveAsPfa(ctx, input); + return Impl.SaveAsPfa(ctx, input); } DataViewType IValueMapper.InputType { - get { return _impl.InputType; } + get { return Impl.InputType; } } DataViewType IValueMapper.OutputType @@ -692,7 +689,7 @@ ValueMapper IValueMapper.GetMapper() Host.Check(typeof(TIn) == typeof(VBuffer)); Host.Check(typeof(TOut) == typeof(VBuffer)); - return (ValueMapper)(Delegate)_impl.GetMapper(); + return (ValueMapper)(Delegate)Impl.GetMapper(); } void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema) @@ -700,7 +697,7 @@ void ICanSaveInSourceCode.SaveAsCode(TextWriter writer, RoleMappedSchema schema) Host.CheckValue(writer, nameof(writer)); Host.CheckValue(schema, nameof(schema)); - var preds = _impl.Predictors; + var preds = Impl.Predictors; writer.WriteLine("double[] outputs = new double[{0}];", preds.Length); for (int i = 0; i < preds.Length; i++) @@ -720,7 +717,7 @@ void ICanSaveInTextFormat.SaveAsText(TextWriter writer, RoleMappedSchema schema) Host.CheckValue(writer, nameof(writer)); Host.CheckValue(schema, nameof(schema)); - var preds = _impl.Predictors; + var preds = Impl.Predictors; for (int i = 0; i < preds.Length; i++) { @@ -767,13 +764,13 @@ protected bool IsValid(IValueMapper mapper, ref VectorDataViewType inputType) } } - private protected sealed class ImplRaw : ImplBase + private sealed class ImplRaw : ImplBase { public override DataViewType InputType { get; } public override IValueMapper[] Predictors { get; } public override bool CanSavePfa { get; } - internal ImplRaw(T[] predictors) + internal ImplRaw(TScalarPredictor[] predictors) { Contracts.CheckNonEmpty(predictors, nameof(predictors)); @@ -831,14 +828,14 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) } } - private protected sealed class ImplDist : ImplBase + private sealed class ImplDist : ImplBase { private readonly IValueMapperDist[] _mappers; public override DataViewType InputType { get; } public override IValueMapper[] Predictors => _mappers; public override bool CanSavePfa { get; } - internal ImplDist(T[] predictors) + internal ImplDist(IValueMapperDist[] predictors) { Contracts.Check(Utils.Size(predictors) > 0); @@ -846,7 +843,7 @@ internal ImplDist(T[] predictors) VectorDataViewType inputType = null; for (int i = 0; i < predictors.Length; i++) { - var vm = predictors[i] as IValueMapperDist; + var vm = predictors[i]; Contracts.Check(IsValid(vm, ref inputType), "Predictor doesn't implement the expected interface"); _mappers[i] = vm; } @@ -943,13 +940,13 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) } } - private protected sealed class ImplSoftmax : ImplBase + private sealed class ImplSoftmax : ImplBase { public override DataViewType InputType { get; } public override IValueMapper[] Predictors { get; } public override bool CanSavePfa { get; } - internal ImplSoftmax(T[] predictors) + internal ImplSoftmax(TScalarPredictor[] predictors) { Contracts.CheckNonEmpty(predictors, nameof(predictors)); @@ -1016,10 +1013,10 @@ public override JToken SaveAsPfa(BoundPfaContext ctx, JToken input) /// Model parameters for typed versions of . /// public sealed class OneVersusAllModelParameters : - OneVersusAllModelParametersBase where T : class + OneVersusAllModelParametersBase where T : class { internal OneVersusAllModelParameters(IHostEnvironment env, OutputFormula outputFormula, T[] predictors) - : base(env, outputFormula, predictors) + : base(env, outputFormula, predictors.Cast>().ToArray()) { } @@ -1027,13 +1024,18 @@ private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, ctx) { } + + /// + /// Retrieves the model parameters. + /// + internal ImmutableArray SubModelParameters => Impl.Predictors.Cast().ToImmutableArray(); } /// /// Model parameters for . /// public sealed class OneVersusAllModelParameters : - OneVersusAllModelParametersBase + OneVersusAllModelParametersBase { internal OneVersusAllModelParameters(IHostEnvironment env, OutputFormula outputFormula, TScalarPredictor[] predictors) : base(env, outputFormula, predictors) @@ -1044,5 +1046,10 @@ private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) : base(env, ctx) { } + + /// + /// Retrieves the model parameters. + /// + internal ImmutableArray SubModelParameters => Impl.Predictors.Cast().ToImmutableArray(); } } diff --git a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs index 9ac45d8390..a7df1208d2 100644 --- a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs @@ -839,7 +839,7 @@ public static OneVersusAllTrainer OneVersusAllUnCalibrate IEstimator> calibrator = null, int maximumCalibrationExampleCount = 1000000000, bool useProbabilities = true) - where TModelIn : class, IPredictorProducing + where TModelIn : class where TCalibrator : class, ICalibrator { Contracts.CheckValue(catalog, nameof(catalog)); From cb3ef0f2c515584ef512ded8878a6b5d6631ec93 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Fri, 2 Aug 2019 09:20:23 -0700 Subject: [PATCH 17/19] added tests to test 'EstimatorCore' with the new typed classes --- .../TrainerEstimators/MetalinearEstimators.cs | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs index 9f94bc2560..7dc885a336 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MetalinearEstimators.cs @@ -52,6 +52,62 @@ public void OVAUncalibrated() Done(); } + /// + /// Tests passing in a non calibrated trainer + /// + [Fact] + public void OVATypedUncalibrated() + { + var (pipeline, data) = GetMulticlassPipeline(); + var sdcaTrainer = ML.BinaryClassification.Trainers.SdcaNonCalibrated( + new SdcaNonCalibratedBinaryTrainer.Options { MaximumNumberOfIterations = 100, Shuffle = true, NumberOfThreads = 1 }); + + pipeline = pipeline.Append(ML.MulticlassClassification.Trainers.OneVersusAllTyped(sdcaTrainer, useProbabilities: false)) + .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); + + TestEstimatorCore(pipeline, data); + Done(); + } + + /// + /// Test passing in a trainer that is already calibrated + /// + [Fact] + public void OVATypedCalibrated() + { + var (pipeline, data) = GetMulticlassPipeline(); + var sdcaTrainer = ML.BinaryClassification.Trainers.SgdCalibrated( + new SgdCalibratedTrainer.Options { Shuffle = true, NumberOfThreads = 1 }); + + pipeline = pipeline.Append(ML.MulticlassClassification.Trainers.OneVersusAllTyped(sdcaTrainer, useProbabilities: true)) + .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); + + TestEstimatorCore(pipeline, data); + Done(); + } + + /// + /// Tests passing an uncalibrated trainer with a calibrator and having it be auto calibrated + /// using the strongly typed API. + /// + [Fact] + public void OVATypedUncalibratedToCalibrated() + { + var (pipeline, data) = GetMulticlassPipeline(); + var calibrator = new PlattCalibratorEstimator(Env); + var averagePerceptron = ML.BinaryClassification.Trainers.AveragedPerceptron( + new AveragedPerceptronTrainer.Options { Shuffle = true }); + + var ova = ML.MulticlassClassification.Trainers.OneVersusAllUnCalibratedToCalibrated(averagePerceptron, imputeMissingLabelsAsNegative: true, + calibrator: calibrator, maximumCalibrationExampleCount: 10000, useProbabilities: true); + + pipeline = pipeline.Append(ova) + .Append(new KeyToValueMappingEstimator(Env, "PredictedLabel")); + + TestEstimatorCore(pipeline, data); + Done(); + } + /// /// Pairwise Coupling trainer /// From d676d461458ad287ad73ca382a68c5f83dbba7a4 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Fri, 2 Aug 2019 12:59:39 -0700 Subject: [PATCH 18/19] comment fixes --- .../StandardTrainersCatalog.cs | 9 ++++++--- test/Microsoft.ML.Tests/Scenarios/OvaTest.cs | 1 - 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs index a7df1208d2..a84b258bbd 100644 --- a/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs +++ b/src/Microsoft.ML.StandardTrainers/StandardTrainersCatalog.cs @@ -725,7 +725,8 @@ private static ICalibratorTrainer GetCalibratorTrainerOrThrow(IExceptionContext /// /// Create a , which predicts a multiclass target using one-versus-all strategy with - /// the binary classification estimator specified by . + /// the binary classification estimator specified by . If you want to retrieve strongly typed model parameters, + /// use either the or methods instead. /// /// /// @@ -768,7 +769,8 @@ public static OneVersusAllTrainer OneVersusAll(this MulticlassClassifica /// Create a , which predicts a multiclass target using one-versus-all strategy with /// the binary classification estimator specified by . This method works with binary classifiers that /// are either already calibrated, or non calibrated ones you don't want calibrated. If you need to have your classifier calibrated, use the - /// method instead. + /// method instead. If you want to retrieve strongly typed model parameters, + /// you must use either this method or method. /// /// /// @@ -808,7 +810,8 @@ public static OneVersusAllTrainer OneVersusAllTyped(this Multicl /// the binary classification estimator specified by .This method works with binary classifiers that /// are not calibrated and need to be calibrated before use. Due to the type of estimator changing (from uncalibrated to calibrated), you must manually /// specify both the type of the model and the type of the calibrator. If your classifier is already calibrated or it does not need to be, use the - /// method instead. + /// method instead. If you want to retrieve strongly typed model parameters, you must either use this method or + /// method. /// /// /// diff --git a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs index 06c30e2598..1f9bc269f6 100644 --- a/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs +++ b/test/Microsoft.ML.Tests/Scenarios/OvaTest.cs @@ -53,7 +53,6 @@ public void OvaLogisticRegression() Assert.Equal(metrics.MicroAccuracy, metricsTyped.MicroAccuracy); } - //.Append(ML.MulticlassClassification.Trainers.OneVersusAllUnCalibratedToCalibratedTyped(sdcaTrainer)) [Fact] public void OvaAveragedPerceptron() From cf9ef428a1c6bbb6f113b207798a31e6cca81286 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Mon, 5 Aug 2019 14:42:22 -0700 Subject: [PATCH 19/19] small fix from PR comment --- .../Standard/MulticlassClassification/OneVersusAllTrainer.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs index 1c4e2c04c3..741923c4a2 100644 --- a/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs +++ b/src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/OneVersusAllTrainer.cs @@ -1028,7 +1028,7 @@ private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) /// /// Retrieves the model parameters. /// - internal ImmutableArray SubModelParameters => Impl.Predictors.Cast().ToImmutableArray(); + public ImmutableArray SubModelParameters => Impl.Predictors.Cast().ToImmutableArray(); } /// @@ -1050,6 +1050,6 @@ private OneVersusAllModelParameters(IHostEnvironment env, ModelLoadContext ctx) /// /// Retrieves the model parameters. /// - internal ImmutableArray SubModelParameters => Impl.Predictors.Cast().ToImmutableArray(); + public ImmutableArray SubModelParameters => Impl.Predictors.Cast().ToImmutableArray(); } }