Skip to content

Commit d254f4e

Browse files
authored
Throw error on incorrect Label name in InferColumns API (dotnet#47)
* Added sequential grouping of columns * reverted the file * addded infer columns label name checking * added column detection error * removed unsed usings * added quotes * replace Where with Any clause * replace Where with Any clause
1 parent bd900ff commit d254f4e

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

src/AutoML/ColumnInference/ColumnInferenceApi.cs

+4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ public static ColumnInferenceResult InferColumns(MLContext context, string path,
1616
var splitInference = InferSplit(sample, separatorChar, allowQuotedStrings, supportSparse);
1717
var typeInference = InferColumnTypes(context, sample, splitInference, hasHeader);
1818
var loaderColumns = ColumnTypeInference.GenerateLoaderColumns(typeInference.Columns);
19+
if (!loaderColumns.Any(t => label.Equals(t.Name)))
20+
{
21+
throw new InferenceException(InferenceType.Label, $"Specified Label Column '{label}' was not found.");
22+
}
1923
var typedLoaderArgs = new TextLoader.Arguments
2024
{
2125
Column = loaderColumns,

src/Test/ColumnInferenceTests.cs

+8
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,13 @@ public void UnGroupColumnsTest()
2020
var columnInferenceWithGrouping = context.Data.InferColumns(dataPath, DatasetUtil.UciAdultLabel, true, groupColumns: true);
2121
Assert.IsTrue(columnInferenceWithGrouping.Columns.Count() < columnInferenceWithoutGrouping.Columns.Count());
2222
}
23+
24+
[TestMethod]
25+
public void IncorrectLabelColumnTest()
26+
{
27+
var dataPath = DatasetUtil.DownloadUciAdultDataset();
28+
var context = new MLContext();
29+
Assert.ThrowsException<InferenceException>(new System.Action(() => context.Data.InferColumns(dataPath, "Junk", true, groupColumns: false)));
30+
}
2331
}
2432
}

0 commit comments

Comments
 (0)