Skip to content

Commit 633654c

Browse files
codemzseerhardt
authored andcommitted
Scores to Label mapping (dotnet#239)
* Scores to label mapping for multi-class classification problem.
1 parent b5a60a4 commit 633654c

File tree

4 files changed

+61
-8
lines changed

4 files changed

+61
-8
lines changed

src/Microsoft.ML.Core/Data/ITransformModel.cs

+11-4
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,20 @@ public interface ITransformModel
1818
/// Note that the schema may have columns that aren't needed by this transform model.
1919
/// If an IDataView exists with this schema, then applying this transform model to it
2020
/// shouldn't fail because of column type issues.
21-
/// REVIEW: Would be nice to be able to trim this to the minimum needed somehow. Note
22-
/// however that doing so may cause issues for composing transform models. For example,
23-
/// if transform model A needs column X and model B needs Y, that is NOT produced by A,
24-
/// then trimming A's input schema would cause composition to fail.
2521
/// </summary>
22+
// REVIEW: Would be nice to be able to trim this to the minimum needed somehow. Note
23+
// however that doing so may cause issues for composing transform models. For example,
24+
// if transform model A needs column X and model B needs Y, that is NOT produced by A,
25+
// then trimming A's input schema would cause composition to fail.
2626
ISchema InputSchema { get; }
2727

28+
/// <summary>
29+
/// The output schema that this transform model was originally instantiated on. The schema resulting
30+
/// from <see cref="Apply(IHostEnvironment, ITransformModel)"/> may differ from this, similarly to how
31+
/// <see cref="InputSchema"/> may differ from the schema of dataviews we apply this transform model to.
32+
/// </summary>
33+
ISchema OutputSchema { get; }
34+
2835
/// <summary>
2936
/// Apply the transform(s) in the model to the given input data.
3037
/// </summary>

src/Microsoft.ML.Data/EntryPoints/TransformModel.cs

+8-4
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,14 @@ public sealed class TransformModel : ITransformModel
3939
/// if transform model A needs column X and model B needs Y, that is NOT produced by A,
4040
/// then trimming A's input schema would cause composition to fail.
4141
/// </summary>
42-
public ISchema InputSchema
43-
{
44-
get { return _schemaRoot; }
45-
}
42+
public ISchema InputSchema => _schemaRoot;
43+
44+
/// <summary>
45+
/// The resulting schema once applied to this model. The <see cref="InputSchema"/> might have
46+
/// columns that are not needed by this transform and these columns will be seen in the
47+
/// <see cref="OutputSchema"/> produced by this transform.
48+
/// </summary>
49+
public ISchema OutputSchema => _chain.Schema;
4650

4751
/// <summary>
4852
/// Create a TransformModel containing the transforms from "result" back to "input".

src/Microsoft.ML/PredictionModel.cs

+34
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using Microsoft.ML.Runtime.Api;
77
using Microsoft.ML.Runtime.Data;
88
using Microsoft.ML.Runtime.EntryPoints;
9+
using Microsoft.ML.Runtime.Internal.Utilities;
910
using System;
1011
using System.Collections.Generic;
1112
using System.IO;
@@ -29,6 +30,39 @@ internal Runtime.EntryPoints.TransformModel PredictorModel
2930
get { return _predictorModel; }
3031
}
3132

33+
/// <summary>
34+
/// Returns labels that correspond to indices of the score array in the case of
35+
/// multi-class classification problem.
36+
/// </summary>
37+
/// <param name="names">Label to score mapping</param>
38+
/// <param name="scoreColumnName">Name of the score column</param>
39+
/// <returns></returns>
40+
public bool TryGetScoreLabelNames(out string[] names, string scoreColumnName = DefaultColumnNames.Score)
41+
{
42+
names = null;
43+
ISchema schema = _predictorModel.OutputSchema;
44+
int colIndex = -1;
45+
if (!schema.TryGetColumnIndex(scoreColumnName, out colIndex))
46+
return false;
47+
48+
int expectedLabelCount = schema.GetColumnType(colIndex).ValueCount;
49+
if (!schema.HasSlotNames(colIndex, expectedLabelCount))
50+
return false;
51+
52+
VBuffer<DvText> labels = default;
53+
schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colIndex, ref labels);
54+
55+
if (labels.Length != expectedLabelCount)
56+
return false;
57+
58+
names = new string[expectedLabelCount];
59+
int index = 0;
60+
foreach(var label in labels.DenseValues())
61+
names[index++] = label.ToString();
62+
63+
return true;
64+
}
65+
3266
/// <summary>
3367
/// Read model from file asynchronously.
3468
/// </summary>

test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs

+8
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ public void TrainAndPredictIrisModelWithStringLabelTest()
3030
pipeline.Add(new StochasticDualCoordinateAscentClassifier());
3131

3232
PredictionModel<IrisDataWithStringLabel, IrisPrediction> model = pipeline.Train<IrisDataWithStringLabel, IrisPrediction>();
33+
string[] scoreLabels;
34+
model.TryGetScoreLabelNames(out scoreLabels);
35+
36+
Assert.NotNull(scoreLabels);
37+
Assert.Equal(3, scoreLabels.Length);
38+
Assert.Equal("Iris-setosa", scoreLabels[0]);
39+
Assert.Equal("Iris-versicolor", scoreLabels[1]);
40+
Assert.Equal("Iris-virginica", scoreLabels[2]);
3341

3442
IrisPrediction prediction = model.Predict(new IrisDataWithStringLabel()
3543
{

0 commit comments

Comments
 (0)