Skip to content

Commit 815a687

Browse files
committed
Offer suggestions for possibly mistyped label column names
1 parent 3d3d45c commit 815a687

File tree

3 files changed

+99
-0
lines changed

3 files changed

+99
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
using System;
2+
3+
namespace Microsoft.ML.AutoML.Utils
4+
{
5+
public static class StringEditDistance
6+
{
7+
public static int GetEditDistance(string first, string second)
8+
{
9+
if (first is null)
10+
{
11+
throw new ArgumentNullException(nameof(first));
12+
}
13+
14+
if (second is null)
15+
{
16+
throw new ArgumentNullException(nameof(second));
17+
}
18+
19+
if (first.Length == 0 || second.Length == 0)
20+
{
21+
return first.Length + second.Length;
22+
}
23+
24+
var currentRow = 0;
25+
var nextRow = 1;
26+
var rows = new int[second.Length + 1, second.Length + 1];
27+
28+
for (var j = 0; j <= second.Length; ++j)
29+
{
30+
rows[currentRow, j] = j;
31+
}
32+
33+
for (var i = 1; i <= first.Length; ++i)
34+
{
35+
rows[nextRow, 0] = i;
36+
for (var j = 1; j <= second.Length; ++j)
37+
{
38+
var deletion = rows[currentRow, j] + 1;
39+
var insertion = rows[nextRow, j - 1] + 1;
40+
var substitution = rows[currentRow, j - 1] + (first[i - 1].Equals(second[j - 1]) ? 0 : 1);
41+
42+
rows[nextRow, j] = Math.Min(deletion, Math.Min(insertion, substitution));
43+
}
44+
45+
if (currentRow == 0)
46+
{
47+
currentRow = 1;
48+
nextRow = 0;
49+
}
50+
else
51+
{
52+
currentRow = 0;
53+
nextRow = 1;
54+
}
55+
}
56+
57+
return rows[currentRow, second.Length];
58+
}
59+
}
60+
}

src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs

+23
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Collections.Generic;
77
using System.IO;
88
using System.Linq;
9+
using Microsoft.ML.AutoML.Utils;
910
using Microsoft.ML.Data;
1011

1112
namespace Microsoft.ML.AutoML
@@ -248,6 +249,11 @@ private static void ValidateTrainDataColumn(IDataView trainData, string columnNa
248249
var nullableColumn = trainData.Schema.GetColumnOrNull(columnName);
249250
if (nullableColumn == null)
250251
{
252+
var closestNamed = ClosestNamed(trainData, columnName, 2);
253+
if (closestNamed != string.Empty)
254+
{
255+
throw new ArgumentException($"Provided {columnPurpose} column '{columnName}' not found in training data. Did you mean '{closestNamed}'.");
256+
}
251257
throw new ArgumentException($"Provided {columnPurpose} column '{columnName}' not found in training data.");
252258
}
253259

@@ -272,6 +278,23 @@ private static void ValidateTrainDataColumn(IDataView trainData, string columnNa
272278
}
273279
}
274280

281+
private static string ClosestNamed(IDataView trainData, string columnName, int maxAllowableEditDistance = int.MaxValue)
282+
{
283+
var minEditDistance = int.MaxValue;
284+
var closestNamed = string.Empty;
285+
foreach (var column in trainData.Schema)
286+
{
287+
var editDistance = StringEditDistance.GetEditDistance(column.Name, columnName);
288+
if (editDistance < minEditDistance)
289+
{
290+
minEditDistance = editDistance;
291+
closestNamed = column.Name;
292+
}
293+
}
294+
295+
return minEditDistance <= maxAllowableEditDistance ? closestNamed : string.Empty;
296+
}
297+
275298
private static string FindFirstDuplicate(IEnumerable<string> values)
276299
{
277300
var groups = values.GroupBy(v => v);

test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs

+16
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System;
66
using System.Collections.Generic;
77
using System.IO;
8+
using System.Linq;
89
using System.Threading.Tasks;
910
using Microsoft.ML.Data;
1011
using Microsoft.ML.TestFramework;
@@ -50,6 +51,21 @@ public void ValidateExperimentExecuteLabelNotInTrain()
5051
}
5152
}
5253

54+
[Fact]
55+
public void ValidateExperimentExecuteLabelNotInTrainMistyped()
56+
{
57+
foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking })
58+
{
59+
var originalColumnName = _data.Schema.First().Name;
60+
var mistypedColumnName = originalColumnName + "a";
61+
var ex = Assert.Throws<ArgumentException>(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data,
62+
new ColumnInformation() { LabelColumnName = mistypedColumnName }, null, task));
63+
64+
Assert.Equal($"Provided label column '{mistypedColumnName}' not found in training data. Did you mean '{originalColumnName}'.",
65+
ex.Message);
66+
}
67+
}
68+
5369
[Fact]
5470
public void ValidateExperimentExecuteNumericColNotInTrain()
5571
{

0 commit comments

Comments
 (0)