diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesStatic.cs
new file mode 100644
index 0000000000..14e0c00d9c
--- /dev/null
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesStatic.cs
@@ -0,0 +1,54 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Internal.Calibration;
+using Microsoft.ML.Runtime.Learners;
+using Microsoft.ML.StaticPipe.Runtime;
+
+namespace Microsoft.ML.StaticPipe
+{
+ ///
+ /// MultiClass Classification trainer estimators.
+ ///
+ public static partial class MultiClassClassificationTrainers
+ {
+ ///
+ /// Predict a target using a linear multiclass classification model trained with the trainer.
+ ///
+ /// The multiclass classification context trainer object.
+ /// The label, or dependent variable.
+ /// The features, or independent variables.
+ /// A delegate that is called every time the
+ /// method is called on the
+ /// instance created out of this. This delegate will receive
+ /// the linear model that was trained. Note that this action cannot change the
+ /// result in any way; it is only a way for the caller to be informed about what was learnt.
+ /// The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label.
+ public static (Vector score, Key predictedLabel)
+ MultiClassNaiveBayesTrainer(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx,
+ Key label,
+ Vector features,
+ Action onFit = null)
+ {
+ Contracts.CheckValue(features, nameof(features));
+ Contracts.CheckValue(label, nameof(label));
+ Contracts.CheckValueOrNull(onFit);
+
+ var rec = new TrainerEstimatorReconciler.MulticlassClassifier(
+ (env, labelName, featuresName, weightsName) =>
+ {
+ var trainer = new MultiClassNaiveBayesTrainer(env, featuresName, labelName);
+
+ if (onFit != null)
+ return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
+ return trainer;
+ }, label, features, null);
+
+ return rec.Output;
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
index 0817db47a5..908bf34c65 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs
@@ -2,8 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
-using System;
-using System.Linq;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
@@ -12,6 +11,9 @@
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.Training;
using Microsoft.ML.Runtime.Internal.Internallearn;
+using System;
+using System.Collections.Generic;
+using System.Linq;
[assembly: LoadableClass(MultiClassNaiveBayesTrainer.Summary, typeof(MultiClassNaiveBayesTrainer), typeof(MultiClassNaiveBayesTrainer.Arguments),
new[] { typeof(SignatureMultiClassClassifierTrainer), typeof(SignatureTrainer) },
@@ -22,12 +24,12 @@
[assembly: LoadableClass(typeof(MultiClassNaiveBayesPredictor), null, typeof(SignatureLoadModel),
"Multi Class Naive Bayes predictor", MultiClassNaiveBayesPredictor.LoaderSignature)]
-[assembly: LoadableClass(typeof(void), typeof(MultiClassNaiveBayesTrainer), null, typeof(SignatureEntryPointModule), "MultiClassNaiveBayes")]
+[assembly: LoadableClass(typeof(void), typeof(MultiClassNaiveBayesTrainer), null, typeof(SignatureEntryPointModule), MultiClassNaiveBayesTrainer.LoadName)]
namespace Microsoft.ML.Runtime.Learners
{
///
- public sealed class MultiClassNaiveBayesTrainer : TrainerBase
+ public sealed class MultiClassNaiveBayesTrainer : TrainerEstimatorBase, MultiClassNaiveBayesPredictor>
{
public const string LoadName = "MultiClassNaiveBayes";
internal const string UserName = "Multiclass Naive Bayes";
@@ -43,13 +45,52 @@ public sealed class Arguments : LearnerInputBaseWithLabel
private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false);
public override TrainerInfo Info => _info;
- public MultiClassNaiveBayesTrainer(IHostEnvironment env, Arguments args)
- : base(env, LoadName)
+ ///
+ /// Initializes a new instance of
+ ///
+ /// The environment to use.
+ /// The name of the label column.
+ /// The name of the feature column.
+ public MultiClassNaiveBayesTrainer(IHostEnvironment env, string featureColumn, string labelColumn)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), TrainerUtils.MakeR4VecFeature(featureColumn),
+ TrainerUtils.MakeU4ScalarLabel(labelColumn))
+ {
+ Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
+ Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
+ }
+
+ ///
+ /// Initializes a new instance of
+ ///
+ internal MultiClassNaiveBayesTrainer(IHostEnvironment env, Arguments args)
+ : base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
+ TrainerUtils.MakeU4ScalarLabel(args.LabelColumn))
{
Host.CheckValue(args, nameof(args));
}
- public override MultiClassNaiveBayesPredictor Train(TrainContext context)
+ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
+ {
+ bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
+ Contracts.Assert(success);
+
+ var scoreMetadata = new List() { new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false) };
+ scoreMetadata.AddRange(MetadataUtils.GetTrainerOutputMetadata());
+
+ var predLabelMetadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
+ .Concat(MetadataUtils.GetTrainerOutputMetadata()));
+
+ return new[]
+ {
+ new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(scoreMetadata)),
+ new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, predLabelMetadata)
+ };
+ }
+
+ protected override MulticlassPredictionTransformer MakeTransformer(MultiClassNaiveBayesPredictor model, ISchema trainSchema)
+ => new MulticlassPredictionTransformer(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
+
+ protected override MultiClassNaiveBayesPredictor TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var data = context.TrainingSet;
@@ -170,6 +211,40 @@ private static VersionInfo GetVersionInfo()
public ColumnType OutputType => _outputType;
+ ///
+ /// Copies the label histogram into a buffer.
+ ///
+ /// A possibly reusable array, which will
+ /// be expanded as necessary to accomodate the data.
+ /// Set to the length of the resized array, which is also the number of different labels.
+ public void GetLabelHistogram(ref int[] labelHistogram, out int labelCount)
+ {
+ labelCount = _labelCount;
+ Utils.EnsureSize(ref labelHistogram, _labelCount);
+ Array.Copy(_labelHistogram, labelHistogram, _labelCount);
+ }
+
+ ///
+ /// Copies the feature histogram into a buffer.
+ ///
+ /// A possibly reusable array, which will
+ /// be expanded as necessary to accomodate the data.
+ /// Set to the first dimension of the resized array,
+ /// which is the number of different labels encountered in training.
+ /// Set to the second dimension of the resized array,
+ /// which is also the number of different feature combinations encountered in training.
+ public void GetFeatureHistogram(ref int[][] featureHistogram, out int labelCount, out int featureCount)
+ {
+ labelCount = _labelCount;
+ featureCount = _featureCount;
+ Utils.EnsureSize(ref featureHistogram, _labelCount);
+ for(int i = 0; i < _labelCount; i++)
+ {
+ Utils.EnsureSize(ref featureHistogram[i], _featureCount);
+ Array.Copy(_featureHistogram[i], featureHistogram[i], _featureCount);
+ }
+ }
+
internal MultiClassNaiveBayesPredictor(IHostEnvironment env, int[] labelHistogram, int[][] featureHistogram, int featureCount)
: base(env, LoaderSignature)
{
diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs
index 92bfe1e2c1..c3cdb476e9 100644
--- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs
+++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs
@@ -749,5 +749,50 @@ public void FastTreeRanking()
Assert.InRange(metrics.Ndcg[1], 36.5, 37);
Assert.InRange(metrics.Ndcg[2], 36.5, 37);
}
+
+ [Fact]
+ public void MultiClassNaiveBayesTrainer()
+ {
+ var env = new ConsoleEnvironment(seed: 0);
+ var dataPath = GetDataPath(TestDatasets.iris.trainFilename);
+ var dataSource = new MultiFileSource(dataPath);
+
+ var ctx = new MulticlassClassificationContext(env);
+ var reader = TextLoader.CreateReader(env,
+ c => (label: c.LoadText(0), features: c.LoadFloat(1, 4)));
+
+ MultiClassNaiveBayesPredictor pred = null;
+
+ // With a custom loss function we no longer get calibrated predictions.
+ var est = reader.MakeNewEstimator()
+ .Append(r => (label: r.label.ToKey(), r.features))
+ .Append(r => (r.label, preds: ctx.Trainers.MultiClassNaiveBayesTrainer(
+ r.label,
+ r.features, onFit: p => pred = p)));
+
+ var pipe = reader.Append(est);
+
+ Assert.Null(pred);
+ var model = pipe.Fit(dataSource);
+ Assert.NotNull(pred);
+ int[] labelHistogram = default;
+ int[][] featureHistogram = default;
+ pred.GetLabelHistogram(ref labelHistogram, out int labelCount1);
+ pred.GetFeatureHistogram(ref featureHistogram, out int labelCount2, out int featureCount);
+ Assert.True(labelCount1 == 3 && labelCount1 == labelCount2 && labelCount1 <= labelHistogram.Length);
+ for (int i = 0; i < labelCount1; i++)
+ Assert.True(featureCount == 4 && (featureCount <= featureHistogram[i].Length));
+
+ var data = model.Read(dataSource);
+
+ // Just output some data on the schema for fun.
+ var schema = data.AsDynamic.Schema;
+ for (int c = 0; c < schema.ColumnCount; ++c)
+ Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}");
+
+ var metrics = ctx.Evaluate(data, r => r.label, r => r.preds, 2);
+ Assert.True(metrics.LogLoss > 0);
+ Assert.True(metrics.TopKAccuracy > 0);
+ }
}
}
diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs
index 49b74c7fa2..70a822a239 100644
--- a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs
+++ b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs
@@ -77,6 +77,18 @@ public void KMeansEstimator()
Done();
}
+ ///
+ /// MultiClassNaiveBayes TrainerEstimator test
+ ///
+ [Fact]
+ public void TestEstimatorMultiClassNaiveBayesTrainer()
+ {
+ (IEstimator pipe, IDataView dataView) = GetMultiClassPipeline();
+ pipe.Append(new MultiClassNaiveBayesTrainer(Env, "Features", "Label"));
+ TestEstimatorCore(pipe, dataView);
+ Done();
+ }
+
private (IEstimator, IDataView) GetBinaryClassificationPipeline()
{
var data = new TextLoader(Env,