-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Conversion of Multi Class Naive Bayes classifier to estimator #1111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
ad24a46
95c74b8
739df0c
915b1e6
0ad46ea
c89aa2d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Internal.Calibration; | ||
using Microsoft.ML.Runtime.Learners; | ||
using Microsoft.ML.StaticPipe.Runtime; | ||
|
||
namespace Microsoft.ML.StaticPipe | ||
{ | ||
/// <summary> | ||
/// MultiClass Classification trainer estimators. | ||
/// </summary> | ||
public static partial class MultiClassClassificationTrainers | ||
{ | ||
/// <summary> | ||
/// Predict a target using a linear multiclass classification model trained with the <see cref="Microsoft.ML.Runtime.Learners.MultiClassNaiveBayesTrainer"/> trainer. | ||
/// </summary> | ||
/// <param name="ctx">The multiclass classification context trainer object.</param> | ||
/// <param name="label">The label, or dependent variable.</param> | ||
/// <param name="features">The features, or independent variables.</param> | ||
/// <param name="onFit">A delegate that is called every time the | ||
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}.Fit(DataView{TTupleInShape})"/> method is called on the | ||
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}"/> instance created out of this. This delegate will receive | ||
/// the linear model that was trained. Note that this action cannot change the | ||
/// result in any way; it is only a way for the caller to be informed about what was learnt.</param> | ||
/// <returns>The set of output columns including in order the predicted per-class likelihoods (between 0 and 1, and summing up to 1), and the predicted label.</returns> | ||
public static (Vector<float> score, Key<uint, TVal> predictedLabel) | ||
MultiClassNaiveBayesTrainer<TVal>(this MulticlassClassificationContext.MulticlassClassificationTrainers ctx, | ||
Key<uint, TVal> label, | ||
Vector<float> features, | ||
Action<MultiClassNaiveBayesPredictor> onFit = null) | ||
{ | ||
Contracts.CheckValue(features, nameof(features)); | ||
Contracts.CheckValue(label, nameof(label)); | ||
Contracts.CheckValueOrNull(onFit); | ||
|
||
var rec = new TrainerEstimatorReconciler.MulticlassClassifier<TVal>( | ||
(env, labelName, featuresName, weightsName) => | ||
{ | ||
var trainer = new MultiClassNaiveBayesTrainer(env, featuresName, labelName); | ||
|
||
if (onFit != null) | ||
return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); | ||
return trainer; | ||
}, label, features, null); | ||
|
||
return rec.Output; | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -749,5 +749,42 @@ public void FastTreeRanking() | |
Assert.InRange(metrics.Ndcg[1], 36.5, 37); | ||
Assert.InRange(metrics.Ndcg[2], 36.5, 37); | ||
} | ||
|
||
[Fact] | ||
public void MultiClassNaiveBayesTrainer() | ||
{ | ||
var env = new ConsoleEnvironment(seed: 0); | ||
var dataPath = GetDataPath(TestDatasets.iris.trainFilename); | ||
var dataSource = new MultiFileSource(dataPath); | ||
|
||
var ctx = new MulticlassClassificationContext(env); | ||
var reader = TextLoader.CreateReader(env, | ||
c => (label: c.LoadText(0), features: c.LoadFloat(1, 4))); | ||
|
||
MultiClassNaiveBayesPredictor pred = null; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Should we expose Label and Feature histograms? Otherwise, I don't see any reason why we actually expose it. #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Created two public methods that will output a copy of the featurehistogram and labelhistogram In reply to: 222122921 [](ancestors = 222122921) |
||
|
||
// With a custom loss function we no longer get calibrated predictions. | ||
var est = reader.MakeNewEstimator() | ||
.Append(r => (label: r.label.ToKey(), r.features)) | ||
.Append(r => (r.label, preds: ctx.Trainers.MultiClassNaiveBayesTrainer( | ||
r.label, | ||
r.features, onFit: p => pred = p))); | ||
|
||
var pipe = reader.Append(est); | ||
|
||
Assert.Null(pred); | ||
var model = pipe.Fit(dataSource); | ||
Assert.NotNull(pred); | ||
var data = model.Read(dataSource); | ||
|
||
// Just output some data on the schema for fun. | ||
var schema = data.AsDynamic.Schema; | ||
for (int c = 0; c < schema.ColumnCount; ++c) | ||
Console.WriteLine($"{schema.GetColumnName(c)}, {schema.GetColumnType(c)}"); | ||
|
||
var metrics = ctx.Evaluate(data, r => r.label, r => r.preds, 2); | ||
Assert.True(metrics.LogLoss > 0); | ||
Assert.True(metrics.TopKAccuracy > 0); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Core.Data; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Learners; | ||
using Xunit; | ||
|
||
namespace Microsoft.ML.Tests.TrainerEstimators | ||
{ | ||
public partial class TrainerEstimators | ||
{ | ||
[Fact] | ||
public void TestEstimatorMultiClassNaiveBayesTrainer() | ||
{ | ||
(IEstimator<ITransformer> pipe, IDataView dataView) = GetMultiClassPipeline(); | ||
pipe.Append(new MultiClassNaiveBayesTrainer(Env, "Features", "Label")); | ||
TestEstimatorCore(pipe, dataView); | ||
Done(); | ||
} | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change file name to static #Resolved