diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesStatic.cs new file mode 100644 index 0000000000..14e0c00d9c --- /dev/null +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesStatic.cs @@ -0,0 +1,54 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Internal.Calibration; +using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.StaticPipe.Runtime; + +namespace Microsoft.ML.StaticPipe +{ + /// + /// MultiClass Classification trainer estimators. + /// + public static partial class MultiClassClassificationTrainers + { + /// + /// Predict a target using a linear multiclass classification model trained with the trainer. + /// + /// The multiclass classification context trainer object. + /// The label, or dependent variable. + /// The features, or independent variables. + /// A delegate that is called every time the + /// method is called on the + /// instance created out of this. This delegate will receive + /// the linear model that was trained. Note that this action cannot change the + /// result in any way; it is only a way for the caller to be informed about what was learnt. + /// The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label. + public static (Vector score, Key predictedLabel) + MultiClassNaiveBayesTrainer(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, + Key label, + Vector features, + Action onFit = null) + { + Contracts.CheckValue(features, nameof(features)); + Contracts.CheckValue(label, nameof(label)); + Contracts.CheckValueOrNull(onFit); + + var rec = new TrainerEstimatorReconciler.MulticlassClassifier( + (env, labelName, featuresName, weightsName) => + { + var trainer = new MultiClassNaiveBayesTrainer(env, featuresName, labelName); + + if (onFit != null) + return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); + return trainer; + }, label, features, null); + + return rec.Output; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs index 0817db47a5..908bf34c65 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs @@ -2,8 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Linq; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; @@ -12,6 +11,9 @@ using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Training; using Microsoft.ML.Runtime.Internal.Internallearn; +using System; +using System.Collections.Generic; +using System.Linq; [assembly: LoadableClass(MultiClassNaiveBayesTrainer.Summary, typeof(MultiClassNaiveBayesTrainer), typeof(MultiClassNaiveBayesTrainer.Arguments), new[] { typeof(SignatureMultiClassClassifierTrainer), typeof(SignatureTrainer) }, @@ -22,12 +24,12 @@ [assembly: LoadableClass(typeof(MultiClassNaiveBayesPredictor), null, typeof(SignatureLoadModel), "Multi Class Naive Bayes predictor", MultiClassNaiveBayesPredictor.LoaderSignature)] -[assembly: LoadableClass(typeof(void), typeof(MultiClassNaiveBayesTrainer), null, typeof(SignatureEntryPointModule), "MultiClassNaiveBayes")] +[assembly: LoadableClass(typeof(void), typeof(MultiClassNaiveBayesTrainer), null, typeof(SignatureEntryPointModule), MultiClassNaiveBayesTrainer.LoadName)] namespace Microsoft.ML.Runtime.Learners { /// - public sealed class MultiClassNaiveBayesTrainer : TrainerBase + public sealed class MultiClassNaiveBayesTrainer : TrainerEstimatorBase, MultiClassNaiveBayesPredictor> { public const string LoadName = "MultiClassNaiveBayes"; internal const string UserName = "Multiclass Naive Bayes"; @@ -43,13 +45,52 @@ public sealed class Arguments : LearnerInputBaseWithLabel private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false); public override TrainerInfo Info => _info; - public MultiClassNaiveBayesTrainer(IHostEnvironment env, Arguments args) - : base(env, LoadName) + /// + /// Initializes a new instance of + /// + /// The environment to use. + /// The name of the label column. + /// The name of the feature column. + public MultiClassNaiveBayesTrainer(IHostEnvironment env, string featureColumn, string labelColumn) + : base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), TrainerUtils.MakeR4VecFeature(featureColumn), + TrainerUtils.MakeU4ScalarLabel(labelColumn)) + { + Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + } + + /// + /// Initializes a new instance of + /// + internal MultiClassNaiveBayesTrainer(IHostEnvironment env, Arguments args) + : base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn), + TrainerUtils.MakeU4ScalarLabel(args.LabelColumn)) { Host.CheckValue(args, nameof(args)); } - public override MultiClassNaiveBayesPredictor Train(TrainContext context) + protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) + { + bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol); + Contracts.Assert(success); + + var scoreMetadata = new List() { new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false) }; + scoreMetadata.AddRange(MetadataUtils.GetTrainerOutputMetadata()); + + var predLabelMetadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues) + .Concat(MetadataUtils.GetTrainerOutputMetadata())); + + return new[] + { + new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(scoreMetadata)), + new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, predLabelMetadata) + }; + } + + protected override MulticlassPredictionTransformer MakeTransformer(MultiClassNaiveBayesPredictor model, ISchema trainSchema) + => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name); + + protected override MultiClassNaiveBayesPredictor TrainModelCore(TrainContext context) { Host.CheckValue(context, nameof(context)); var data = context.TrainingSet; @@ -170,6 +211,40 @@ private static VersionInfo GetVersionInfo() public ColumnType OutputType => _outputType; + /// + /// Copies the label histogram into a buffer. + /// + /// A possibly reusable array, which will + /// be expanded as necessary to accomodate the data. + /// Set to the length of the resized array, which is also the number of different labels. + public void GetLabelHistogram(ref int[] labelHistogram, out int labelCount) + { + labelCount = _labelCount; + Utils.EnsureSize(ref labelHistogram, _labelCount); + Array.Copy(_labelHistogram, labelHistogram, _labelCount); + } + + /// + /// Copies the feature histogram into a buffer. + /// + /// A possibly reusable array, which will + /// be expanded as necessary to accomodate the data. + /// Set to the first dimension of the resized array, + /// which is the number of different labels encountered in training. + /// Set to the second dimension of the resized array, + /// which is also the number of different feature combinations encountered in training. + public void GetFeatureHistogram(ref int[][] featureHistogram, out int labelCount, out int featureCount) + { + labelCount = _labelCount; + featureCount = _featureCount; + Utils.EnsureSize(ref featureHistogram, _labelCount); + for(int i = 0; i < _labelCount; i++) + { + Utils.EnsureSize(ref featureHistogram[i], _featureCount); + Array.Copy(_featureHistogram[i], featureHistogram[i], _featureCount); + } + } + internal MultiClassNaiveBayesPredictor(IHostEnvironment env, int[] labelHistogram, int[][] featureHistogram, int featureCount) : base(env, LoaderSignature) { diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index 92bfe1e2c1..c3cdb476e9 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -749,5 +749,50 @@ public void FastTreeRanking() Assert.InRange(metrics.Ndcg[1], 36.5, 37); Assert.InRange(metrics.Ndcg[2], 36.5, 37); } + + [Fact] + public void MultiClassNaiveBayesTrainer() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.iris.trainFilename); + var dataSource = new MultiFileSource(dataPath); + + var ctx = new MulticlassClassificationContext(env); + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadText(0), features: c.LoadFloat(1, 4))); + + MultiClassNaiveBayesPredictor pred = null; + + // With a custom loss function we no longer get calibrated predictions. + var est = reader.MakeNewEstimator() + .Append(r => (label: r.label.ToKey(), r.features)) + .Append(r => (r.label, preds: ctx.Trainers.MultiClassNaiveBayesTrainer( + r.label, + r.features, onFit: p => pred = p))); + + var pipe = reader.Append(est); + + Assert.Null(pred); + var model = pipe.Fit(dataSource); + Assert.NotNull(pred); + int[] labelHistogram = default; + int[][] featureHistogram = default; + pred.GetLabelHistogram(ref labelHistogram, out int labelCount1); + pred.GetFeatureHistogram(ref featureHistogram, out int labelCount2, out int featureCount); + Assert.True(labelCount1 == 3 && labelCount1 == labelCount2 && labelCount1 <= labelHistogram.Length); + for (int i = 0; i < labelCount1; i++) + Assert.True(featureCount == 4 && (featureCount <= featureHistogram[i].Length)); + + var data = model.Read(dataSource); + + // Just output some data on the schema for fun. + var schema = data.AsDynamic.Schema; + for (int c = 0; c < schema.ColumnCount; ++c) + Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}"); + + var metrics = ctx.Evaluate(data, r => r.label, r => r.preds, 2); + Assert.True(metrics.LogLoss > 0); + Assert.True(metrics.TopKAccuracy > 0); + } } } diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs index 49b74c7fa2..70a822a239 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs @@ -77,6 +77,18 @@ public void KMeansEstimator() Done(); } + /// + /// MultiClassNaiveBayes TrainerEstimator test + /// + [Fact] + public void TestEstimatorMultiClassNaiveBayesTrainer() + { + (IEstimator pipe, IDataView dataView) = GetMultiClassPipeline(); + pipe.Append(new MultiClassNaiveBayesTrainer(Env, "Features", "Label")); + TestEstimatorCore(pipe, dataView); + Done(); + } + private (IEstimator, IDataView) GetBinaryClassificationPipeline() { var data = new TextLoader(Env,