Skip to content

Conversion of Multi Class Naive Bayes classifier to estimator #1111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 4, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Calibration;
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.StaticPipe.Runtime;

Copy link
Contributor Author

@artidoro artidoro Oct 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change file name to static #Resolved

namespace Microsoft.ML.StaticPipe
{
/// <summary>
/// MultiClass Classification trainer estimators.
/// </summary>
public static partial class MultiClassClassificationTrainers
{
/// <summary>
/// Predict a target using a linear multiclass classification model trained with the <see cref="Microsoft.ML.Runtime.Learners.MultiClassNaiveBayesTrainer"/> trainer.
/// </summary>
/// <param name="ctx">The multiclass classification context trainer object.</param>
/// <param name="label">The label, or dependent variable.</param>
/// <param name="features">The features, or independent variables.</param>
/// <param name="onFit">A delegate that is called every time the
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}.Fit(DataView{TTupleInShape})"/> method is called on the
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}"/> instance created out of this. This delegate will receive
/// the linear model that was trained. Note that this action cannot change the
/// result in any way; it is only a way for the caller to be informed about what was learnt.</param>
/// <returns>The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label.</returns>
public static (Vector<float> score, Key<uint, TVal> predictedLabel)
MultiClassNaiveBayesTrainer<TVal>(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx,
Key<uint, TVal> label,
Vector<float> features,
Action<MultiClassNaiveBayesPredictor> onFit = null)
{
Contracts.CheckValue(features, nameof(features));
Contracts.CheckValue(label, nameof(label));
Contracts.CheckValueOrNull(onFit);

var rec = new TrainerEstimatorReconciler.MulticlassClassifier<TVal>(
(env, labelName, featuresName, weightsName) =>
{
var trainer = new MultiClassNaiveBayesTrainer(env, featuresName, labelName);

if (onFit != null)
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
return trainer;
}, label, features, null);

return rec.Output;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Linq;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
Expand All @@ -12,6 +11,9 @@
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.Training;
using Microsoft.ML.Runtime.Internal.Internallearn;
using System;
using System.Collections.Generic;
using System.Linq;

[assembly: LoadableClass(MultiClassNaiveBayesTrainer.Summary, typeof(MultiClassNaiveBayesTrainer), typeof(MultiClassNaiveBayesTrainer.Arguments),
new[] { typeof(SignatureMultiClassClassifierTrainer), typeof(SignatureTrainer) },
Expand All @@ -22,12 +24,12 @@
[assembly: LoadableClass(typeof(MultiClassNaiveBayesPredictor), null, typeof(SignatureLoadModel),
"Multi Class Naive Bayes predictor", MultiClassNaiveBayesPredictor.LoaderSignature)]

[assembly: LoadableClass(typeof(void), typeof(MultiClassNaiveBayesTrainer), null, typeof(SignatureEntryPointModule), "MultiClassNaiveBayes")]
[assembly: LoadableClass(typeof(void), typeof(MultiClassNaiveBayesTrainer), null, typeof(SignatureEntryPointModule), MultiClassNaiveBayesTrainer.LoadName)]

namespace Microsoft.ML.Runtime.Learners
{
/// <include file='doc.xml' path='doc/members/member[@name="MultiClassNaiveBayesTrainer"]' />
public sealed class MultiClassNaiveBayesTrainer : TrainerBase<MultiClassNaiveBayesPredictor>
public sealed class MultiClassNaiveBayesTrainer : TrainerEstimatorBase<MulticlassPredictionTransformer<MultiClassNaiveBayesPredictor>, MultiClassNaiveBayesPredictor>
{
public const string LoadName = "MultiClassNaiveBayes";
internal const string UserName = "Multiclass Naive Bayes";
Expand All @@ -43,13 +45,52 @@ public sealed class Arguments : LearnerInputBaseWithLabel
private static readonly TrainerInfo _info = new TrainerInfo(normalization: false, caching: false);
public override TrainerInfo Info => _info;

public MultiClassNaiveBayesTrainer(IHostEnvironment env, Arguments args)
: base(env, LoadName)
/// <summary>
/// Initializes a new instance of <see cref="MultiClassNaiveBayesTrainer"/>
/// </summary>
/// <param name="env">The environment to use.</param>
/// <param name="labelColumn">The name of the label column.</param>
/// <param name="featureColumn">The name of the feature column.</param>
public MultiClassNaiveBayesTrainer(IHostEnvironment env, string featureColumn, string labelColumn)
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), TrainerUtils.MakeR4VecFeature(featureColumn),
TrainerUtils.MakeU4ScalarLabel(labelColumn))
{
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
}

/// <summary>
/// Initializes a new instance of <see cref="MultiClassNaiveBayesTrainer"/>
/// </summary>
internal MultiClassNaiveBayesTrainer(IHostEnvironment env, Arguments args)
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadName), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
TrainerUtils.MakeU4ScalarLabel(args.LabelColumn))
{
Host.CheckValue(args, nameof(args));
}

public override MultiClassNaiveBayesPredictor Train(TrainContext context)
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
{
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
Contracts.Assert(success);

var scoreMetadata = new List<SchemaShape.Column>() { new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false) };
scoreMetadata.AddRange(MetadataUtils.GetTrainerOutputMetadata());

var predLabelMetadata = new SchemaShape(labelCol.Metadata.Columns.Where(x => x.Name == MetadataUtils.Kinds.KeyValues)
.Concat(MetadataUtils.GetTrainerOutputMetadata()));

return new[]
{
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(scoreMetadata)),
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, predLabelMetadata)
};
}

protected override MulticlassPredictionTransformer<MultiClassNaiveBayesPredictor> MakeTransformer(MultiClassNaiveBayesPredictor model, ISchema trainSchema)
=> new MulticlassPredictionTransformer<MultiClassNaiveBayesPredictor>(Host, model, trainSchema, FeatureColumn.Name, LabelColumn.Name);

protected override MultiClassNaiveBayesPredictor TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));
var data = context.TrainingSet;
Expand Down
37 changes: 37 additions & 0 deletions test/Microsoft.ML.StaticPipelineTesting/Training.cs
Original file line number Diff line number Diff line change
Expand Up @@ -749,5 +749,42 @@ public void FastTreeRanking()
Assert.InRange(metrics.Ndcg[1], 36.5, 37);
Assert.InRange(metrics.Ndcg[2], 36.5, 37);
}

[Fact]
public void MultiClassNaiveBayesTrainer()
{
var env = new ConsoleEnvironment(seed: 0);
var dataPath = GetDataPath(TestDatasets.iris.trainFilename);
var dataSource = new MultiFileSource(dataPath);

var ctx = new MulticlassClassificationContext(env);
var reader = TextLoader.CreateReader(env,
c => (label: c.LoadText(0), features: c.LoadFloat(1, 4)));

MultiClassNaiveBayesPredictor pred = null;
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Oct 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultiClassNaiveBayesPredictor [](start = 12, length = 29)

Should we expose Label and Feature histograms? Otherwise, I don't see any reason why we actually expose it. #Closed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created two public methods that will output a copy of the featurehistogram and labelhistogram


In reply to: 222122921 [](ancestors = 222122921)


// With a custom loss function we no longer get calibrated predictions.
var est = reader.MakeNewEstimator()
.Append(r => (label: r.label.ToKey(), r.features))
.Append(r => (r.label, preds: ctx.Trainers.MultiClassNaiveBayesTrainer(
r.label,
r.features, onFit: p => pred = p)));

var pipe = reader.Append(est);

Assert.Null(pred);
var model = pipe.Fit(dataSource);
Assert.NotNull(pred);
var data = model.Read(dataSource);

// Just output some data on the schema for fun.
var schema = data.AsDynamic.Schema;
for (int c = 0; c < schema.ColumnCount; ++c)
Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}");

var metrics = ctx.Evaluate(data, r => r.label, r => r.preds, 2);
Assert.True(metrics.LogLoss > 0);
Assert.True(metrics.TopKAccuracy > 0);
}
}
}
23 changes: 23 additions & 0 deletions test/Microsoft.ML.Tests/TrainerEstimators/NaiveBayesTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Learners;
using Xunit;

namespace Microsoft.ML.Tests.TrainerEstimators
{
public partial class TrainerEstimators
{
[Fact]
public void TestEstimatorMultiClassNaiveBayesTrainer()
{
(IEstimator<ITransformer> pipe, IDataView dataView) = GetMultiClassPipeline();
pipe.Append(new MultiClassNaiveBayesTrainer(Env, "Features", "Label"));
TestEstimatorCore(pipe, dataView);
Done();
}
}
}