Skip to content

Commit 2a4681b

Browse files
authored
Conversion of Multi Class Naive Bayes classifier to estimator (#1111)
* conversion of multiclass naive bayes classifier to estimator * added pigstension and related test * added public methods to access label and feature histograms in the predictor * fixed review comments on new access functions * moved test to main file
1 parent b770281 commit 2a4681b

File tree

4 files changed

+193
-7
lines changed

4 files changed

+193
-7
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using Microsoft.ML.Runtime;
7+
using Microsoft.ML.Runtime.Data;
8+
using Microsoft.ML.Runtime.Internal.Calibration;
9+
using Microsoft.ML.Runtime.Learners;
10+
using Microsoft.ML.StaticPipe.Runtime;
11+
12+
namespace Microsoft.ML.StaticPipe
13+
{
14+
/// <summary>
15+
/// MultiClass Classification trainer estimators.
16+
/// </summary>
17+
public static partial class MultiClassClassificationTrainers
18+
{
19+
/// <summary>
20+
/// Predict a target using a linear multiclass classification model trained with the <see cref="Microsoft.ML.Runtime.Learners.MultiClassNaiveBayesTrainer"/> trainer.
21+
/// </summary>
22+
/// <param name="ctx">The multiclass classification context trainer object.</param>
23+
/// <param name="label">The label, or dependent variable.</param>
24+
/// <param name="features">The features, or independent variables.</param>
25+
/// <param name="onFit">A delegate that is called every time the
26+
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}.Fit(DataView{TTupleInShape})"/> method is called on the
27+
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}"/> instance created out of this. This delegate will receive
28+
/// the linear model that was trained. Note that this action cannot change the
29+
/// result in any way; it is only a way for the caller to be informed about what was learnt.</param>
30+
/// <returns>The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label.</returns>
31+
public static (Vector<float> score, Key<uint, TVal> predictedLabel)
32+
MultiClassNaiveBayesTrainer<TVal>(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx,
33+
Key<uint, TVal> label,
34+
Vector<float> features,
35+
Action<MultiClassNaiveBayesPredictor> onFit = null)
36+
{
37+
Contracts.CheckValue(features, nameof(features));
38+
Contracts.CheckValue(label, nameof(label));
39+
Contracts.CheckValueOrNull(onFit);
40+
41+
var rec = new TrainerEstimatorReconciler.MulticlassClassifier<TVal>(
42+
(env, labelName, featuresName, weightsName) =>
43+
{
44+
var trainer = new MultiClassNaiveBayesTrainer(env, featuresName, labelName);
45+
46+
if (onFit != null)
47+
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
48+
return trainer;
49+
}, label, features, null);
50+
51+
return rec.Output;
52+
}
53+
}
54+
}

src/Microsoft.ML.StandardLearners/Standard/MultiClass/MultiClassNaiveBayesTrainer.cs

+82-7
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
6-
using System.Linq;
5+
using Microsoft.ML.Core.Data;
76
using Microsoft.ML.Runtime;
87
using Microsoft.ML.Runtime.Data;
98
using Microsoft.ML.Runtime.EntryPoints;
@@ -12,6 +11,9 @@
1211
using Microsoft.ML.Runtime.Model;
1312
using Microsoft.ML.Runtime.Training;
1413
using Microsoft.ML.Runtime.Internal.Internallearn;
14+
using System;
15+
using System.Collections.Generic;
16+
using System.Linq;
1517

1618
[assembly: LoadableClass(MultiClassNaiveBayesTrainer.Summary, typeof(MultiClassNaiveBayesTrainer), typeof(MultiClassNaiveBayesTrainer.Arguments),
1719
new[] { typeof(SignatureMultiClassClassifierTrainer), typeof(SignatureTrainer) },
@@ -22,12 +24,12 @@
2224
[assembly: LoadableClass(typeof(MultiClassNaiveBayesPredictor), null, typeof(SignatureLoadModel),
2325
"Multi Class Naive Bayes predictor", MultiClassNaiveBayesPredictor.LoaderSignature)]
2426

25-
[assembly: LoadableClass(typeof(void), typeof(MultiClassNaiveBayesTrainer), null, typeof(SignatureEntryPointModule), "MultiClassNaiveBayes")]
27+
[assembly: LoadableClass(typeof(void), typeof(MultiClassNaiveBayesTrainer), null, typeof(SignatureEntryPointModule), MultiClassNaiveBayesTrainer.LoadName)]
2628

2729
namespace Microsoft.ML.Runtime.Learners
2830
{
2931
/// <include file='doc.xml' path='doc/members/member[@name="MultiClassNaiveBayesTrainer"]' />
30-
public sealed class MultiClassNaiveBayesTrainer : TrainerBase<MultiClassNaiveBayesPredictor>
32+
public sealed class MultiClassNaiveBayesTrainer : TrainerEstimatorBase<MulticlassPredictionTransformer<MultiClassNaiveBayesPredictor>, MultiClassNaiveBayesPredictor>
3133
{
3234
public const string LoadName = "MultiClassNaiveBayes";
3335
internal const string UserName = "Multiclass Naive Bayes";
@@ -43,13 +45,52 @@ public sealed class Arguments : LearnerInputBaseWithLabel
4345
private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false);
4446
public override TrainerInfo Info => _info;
4547

46-
public MultiClassNaiveBayesTrainer(IHostEnvironment env, Arguments args)
47-
: base(env, LoadName)
48+
/// <summary>
49+
/// Initializes a new instance of <see cref="MultiClassNaiveBayesTrainer"/>
50+
/// </summary>
51+
/// <param name="env">The environment to use.</param>
52+
/// <param name="labelColumn">The name of the label column.</param>
53+
/// <param name="featureColumn">The name of the feature column.</param>
54+
public MultiClassNaiveBayesTrainer(IHostEnvironment env, string featureColumn, string labelColumn)
55+
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), TrainerUtils.MakeR4VecFeature(featureColumn),
56+
TrainerUtils.MakeU4ScalarLabel(labelColumn))
57+
{
58+
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
59+
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
60+
}
61+
62+
/// <summary>
63+
/// Initializes a new instance of <see cref="MultiClassNaiveBayesTrainer"/>
64+
/// </summary>
65+
internal MultiClassNaiveBayesTrainer(IHostEnvironment env, Arguments args)
66+
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
67+
TrainerUtils.MakeU4ScalarLabel(args.LabelColumn))
4868
{
4969
Host.CheckValue(args, nameof(args));
5070
}
5171

52-
public override MultiClassNaiveBayesPredictor Train(TrainContext context)
72+
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
73+
{
74+
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
75+
Contracts.Assert(success);
76+
77+
var scoreMetadata = new List<SchemaShape.Column>() { new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false) };
78+
scoreMetadata.AddRange(MetadataUtils.GetTrainerOutputMetadata());
79+
80+
var predLabelMetadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
81+
.Concat(MetadataUtils.GetTrainerOutputMetadata()));
82+
83+
return new[]
84+
{
85+
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(scoreMetadata)),
86+
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, predLabelMetadata)
87+
};
88+
}
89+
90+
protected override MulticlassPredictionTransformer<MultiClassNaiveBayesPredictor> MakeTransformer(MultiClassNaiveBayesPredictor model, ISchema trainSchema)
91+
=> new MulticlassPredictionTransformer<MultiClassNaiveBayesPredictor>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);
92+
93+
protected override MultiClassNaiveBayesPredictor TrainModelCore(TrainContext context)
5394
{
5495
Host.CheckValue(context, nameof(context));
5596
var data = context.TrainingSet;
@@ -170,6 +211,40 @@ private static VersionInfo GetVersionInfo()
170211

171212
public ColumnType OutputType => _outputType;
172213

214+
/// <summary>
215+
/// Copies the label histogram into a buffer.
216+
/// </summary>
217+
/// <param name="labelHistogram">A possibly reusable array, which will
218+
/// be expanded as necessary to accomodate the data.</param>
219+
/// <param name="labelCount">Set to the length of the resized array, which is also the number of different labels.</param>
220+
public void GetLabelHistogram(ref int[] labelHistogram, out int labelCount)
221+
{
222+
labelCount = _labelCount;
223+
Utils.EnsureSize(ref labelHistogram, _labelCount);
224+
Array.Copy(_labelHistogram, labelHistogram, _labelCount);
225+
}
226+
227+
/// <summary>
228+
/// Copies the feature histogram into a buffer.
229+
/// </summary>
230+
/// <param name="featureHistogram">A possibly reusable array, which will
231+
/// be expanded as necessary to accomodate the data.</param>
232+
/// <param name="labelCount">Set to the first dimension of the resized array,
233+
/// which is the number of different labels encountered in training.</param>
234+
/// <param name="featureCount">Set to the second dimension of the resized array,
235+
/// which is also the number of different feature combinations encountered in training.</param>
236+
public void GetFeatureHistogram(ref int[][] featureHistogram, out int labelCount, out int featureCount)
237+
{
238+
labelCount = _labelCount;
239+
featureCount = _featureCount;
240+
Utils.EnsureSize(ref featureHistogram, _labelCount);
241+
for(int i = 0; i < _labelCount; i++)
242+
{
243+
Utils.EnsureSize(ref featureHistogram[i], _featureCount);
244+
Array.Copy(_featureHistogram[i], featureHistogram[i], _featureCount);
245+
}
246+
}
247+
173248
internal MultiClassNaiveBayesPredictor(IHostEnvironment env, int[] labelHistogram, int[][] featureHistogram, int featureCount)
174249
: base(env, LoaderSignature)
175250
{

test/Microsoft.ML.StaticPipelineTesting/Training.cs

+45
Original file line numberDiff line numberDiff line change
@@ -749,5 +749,50 @@ public void FastTreeRanking()
749749
Assert.InRange(metrics.Ndcg[1], 36.5, 37);
750750
Assert.InRange(metrics.Ndcg[2], 36.5, 37);
751751
}
752+
753+
[Fact]
754+
public void MultiClassNaiveBayesTrainer()
755+
{
756+
var env = new ConsoleEnvironment(seed: 0);
757+
var dataPath = GetDataPath(TestDatasets.iris.trainFilename);
758+
var dataSource = new MultiFileSource(dataPath);
759+
760+
var ctx = new MulticlassClassificationContext(env);
761+
var reader = TextLoader.CreateReader(env,
762+
c => (label: c.LoadText(0), features: c.LoadFloat(1, 4)));
763+
764+
MultiClassNaiveBayesPredictor pred = null;
765+
766+
// With a custom loss function we no longer get calibrated predictions.
767+
var est = reader.MakeNewEstimator()
768+
.Append(r => (label: r.label.ToKey(), r.features))
769+
.Append(r => (r.label, preds: ctx.Trainers.MultiClassNaiveBayesTrainer(
770+
r.label,
771+
r.features, onFit: p => pred = p)));
772+
773+
var pipe = reader.Append(est);
774+
775+
Assert.Null(pred);
776+
var model = pipe.Fit(dataSource);
777+
Assert.NotNull(pred);
778+
int[] labelHistogram = default;
779+
int[][] featureHistogram = default;
780+
pred.GetLabelHistogram(ref labelHistogram, out int labelCount1);
781+
pred.GetFeatureHistogram(ref featureHistogram, out int labelCount2, out int featureCount);
782+
Assert.True(labelCount1 == 3 && labelCount1 == labelCount2 && labelCount1 <= labelHistogram.Length);
783+
for (int i = 0; i < labelCount1; i++)
784+
Assert.True(featureCount == 4 && (featureCount <= featureHistogram[i].Length));
785+
786+
var data = model.Read(dataSource);
787+
788+
// Just output some data on the schema for fun.
789+
var schema = data.AsDynamic.Schema;
790+
for (int c = 0; c < schema.ColumnCount; ++c)
791+
Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}");
792+
793+
var metrics = ctx.Evaluate(data, r => r.label, r => r.preds, 2);
794+
Assert.True(metrics.LogLoss > 0);
795+
Assert.True(metrics.TopKAccuracy > 0);
796+
}
752797
}
753798
}

test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs

+12
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,18 @@ public void KMeansEstimator()
7777
Done();
7878
}
7979

80+
/// <summary>
81+
/// MultiClassNaiveBayes TrainerEstimator test
82+
/// </summary>
83+
[Fact]
84+
public void TestEstimatorMultiClassNaiveBayesTrainer()
85+
{
86+
(IEstimator<ITransformer> pipe, IDataView dataView) = GetMultiClassPipeline();
87+
pipe.Append(new MultiClassNaiveBayesTrainer(Env, "Features", "Label"));
88+
TestEstimatorCore(pipe, dataView);
89+
Done();
90+
}
91+
8092
private (IEstimator<ITransformer>, IDataView) GetBinaryClassificationPipeline()
8193
{
8294
var data = new TextLoader(Env,

0 commit comments

Comments
 (0)