Skip to content

Commit ea07be8

Browse files
authored
Fixed exception: "InvalidOperationException: Source column 'Label' is required but not found." (#121)
* Checking for both ColumnAttribute and ColumnNameAttribute when creating schema in CreateBatchPredictionEngine. * Addressed reviewers' comments.
1 parent 2e732ea commit ea07be8

File tree

2 files changed

+139
-2
lines changed

2 files changed

+139
-2
lines changed

src/Microsoft.ML.Api/SchemaDefinition.cs

+3-2
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,9 @@ public static SchemaDefinition Create(Type userType)
332332

333333
if (fieldInfo.GetCustomAttribute<NoColumnAttribute>() != null)
334334
continue;
335-
var mappingAttr = fieldInfo.GetCustomAttribute<ColumnNameAttribute>();
336-
var name = mappingAttr == null ? fieldInfo.Name : (mappingAttr.Name ?? fieldInfo.Name);
335+
var mappingAttr = fieldInfo.GetCustomAttribute<ColumnAttribute>();
336+
var mappingNameAttr = fieldInfo.GetCustomAttribute<ColumnNameAttribute>();
337+
string name = mappingAttr?.Name ?? mappingNameAttr?.Name ?? fieldInfo.Name;
337338
// Disallow duplicate names, because the field enumeration order is not actually
338339
// well defined, so we are not gauranteed to have consistent "hiding" from run to
339340
// run, across different .NET versions.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Models;
6+
using Microsoft.ML.Runtime.Api;
7+
using Microsoft.ML.Trainers;
8+
using Microsoft.ML.Transforms;
9+
using Xunit;
10+
11+
namespace Microsoft.ML.Scenarios
12+
{
13+
public partial class ScenariosTests
14+
{
15+
[Fact]
16+
public void TrainAndPredictIrisModelWithStringLabelTest()
17+
{
18+
string dataPath = GetDataPath("iris.data");
19+
20+
var pipeline = new LearningPipeline();
21+
22+
pipeline.Add(new TextLoader<IrisDataWithStringLabel>(dataPath, useHeader: false, separator: ","));
23+
24+
pipeline.Add(new Dictionarizer("Label")); // "IrisPlantType" is used as "Label" because of column attribute name on the field.
25+
26+
pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
27+
"SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));
28+
29+
pipeline.Add(new StochasticDualCoordinateAscentClassifier());
30+
31+
PredictionModel<IrisDataWithStringLabel, IrisPrediction> model = pipeline.Train<IrisDataWithStringLabel, IrisPrediction>();
32+
33+
IrisPrediction prediction = model.Predict(new IrisDataWithStringLabel()
34+
{
35+
SepalLength = 3.3f,
36+
SepalWidth = 1.6f,
37+
PetalLength = 0.2f,
38+
PetalWidth = 5.1f,
39+
});
40+
41+
Assert.Equal(1, prediction.PredictedLabels[0], 2);
42+
Assert.Equal(0, prediction.PredictedLabels[1], 2);
43+
Assert.Equal(0, prediction.PredictedLabels[2], 2);
44+
45+
prediction = model.Predict(new IrisDataWithStringLabel()
46+
{
47+
SepalLength = 3.1f,
48+
SepalWidth = 5.5f,
49+
PetalLength = 2.2f,
50+
PetalWidth = 6.4f,
51+
});
52+
53+
Assert.Equal(0, prediction.PredictedLabels[0], 2);
54+
Assert.Equal(0, prediction.PredictedLabels[1], 2);
55+
Assert.Equal(1, prediction.PredictedLabels[2], 2);
56+
57+
prediction = model.Predict(new IrisDataWithStringLabel()
58+
{
59+
SepalLength = 3.1f,
60+
SepalWidth = 2.5f,
61+
PetalLength = 1.2f,
62+
PetalWidth = 4.4f,
63+
});
64+
65+
Assert.Equal(.2, prediction.PredictedLabels[0], 1);
66+
Assert.Equal(.8, prediction.PredictedLabels[1], 1);
67+
Assert.Equal(0, prediction.PredictedLabels[2], 2);
68+
69+
// Note: Testing against the same data set as a simple way to test evaluation.
70+
// This isn't appropriate in real-world scenarios.
71+
string testDataPath = GetDataPath("iris.data");
72+
var testData = new TextLoader<IrisDataWithStringLabel>(testDataPath, useHeader: false, separator: ",");
73+
74+
var evaluator = new ClassificationEvaluator();
75+
evaluator.OutputTopKAcc = 3;
76+
ClassificationMetrics metrics = evaluator.Evaluate(model, testData);
77+
78+
Assert.Equal(.98, metrics.AccuracyMacro);
79+
Assert.Equal(.98, metrics.AccuracyMicro, 2);
80+
Assert.Equal(.06, metrics.LogLoss, 2);
81+
Assert.InRange(metrics.LogLossReduction, 94, 96);
82+
Assert.Equal(1, metrics.TopKAccuracy);
83+
84+
Assert.Equal(3, metrics.PerClassLogLoss.Length);
85+
Assert.Equal(0, metrics.PerClassLogLoss[0], 1);
86+
Assert.Equal(.1, metrics.PerClassLogLoss[1], 1);
87+
Assert.Equal(.1, metrics.PerClassLogLoss[2], 1);
88+
89+
ConfusionMatrix matrix = metrics.ConfusionMatrix;
90+
Assert.Equal(3, matrix.Order);
91+
Assert.Equal(3, matrix.ClassNames.Count);
92+
Assert.Equal("Iris-setosa", matrix.ClassNames[0]);
93+
Assert.Equal("Iris-versicolor", matrix.ClassNames[1]);
94+
Assert.Equal("Iris-virginica", matrix.ClassNames[2]);
95+
96+
Assert.Equal(50, matrix[0, 0]);
97+
Assert.Equal(50, matrix["Iris-setosa", "Iris-setosa"]);
98+
Assert.Equal(0, matrix[0, 1]);
99+
Assert.Equal(0, matrix["Iris-setosa", "Iris-versicolor"]);
100+
Assert.Equal(0, matrix[0, 2]);
101+
Assert.Equal(0, matrix["Iris-setosa", "Iris-virginica"]);
102+
103+
Assert.Equal(0, matrix[1, 0]);
104+
Assert.Equal(0, matrix["Iris-versicolor", "Iris-setosa"]);
105+
Assert.Equal(48, matrix[1, 1]);
106+
Assert.Equal(48, matrix["Iris-versicolor", "Iris-versicolor"]);
107+
Assert.Equal(2, matrix[1, 2]);
108+
Assert.Equal(2, matrix["Iris-versicolor", "Iris-virginica"]);
109+
110+
Assert.Equal(0, matrix[2, 0]);
111+
Assert.Equal(0, matrix["Iris-virginica", "Iris-setosa"]);
112+
Assert.Equal(1, matrix[2, 1]);
113+
Assert.Equal(1, matrix["Iris-virginica", "Iris-versicolor"]);
114+
Assert.Equal(49, matrix[2, 2]);
115+
Assert.Equal(49, matrix["Iris-virginica", "Iris-virginica"]);
116+
}
117+
118+
public class IrisDataWithStringLabel
119+
{
120+
[Column("0")]
121+
public float PetalWidth;
122+
123+
[Column("1")]
124+
public float SepalLength;
125+
126+
[Column("2")]
127+
public float SepalWidth;
128+
129+
[Column("3")]
130+
public float PetalLength;
131+
132+
[Column("4", name: "Label")]
133+
public string IrisPlantType;
134+
}
135+
}
136+
}

0 commit comments

Comments
 (0)