From 815a687e50779b6997243a10453140a8a45a9f51 Mon Sep 17 00:00:00 2001 From: Ivan Agarsky Date: Tue, 16 Feb 2021 14:22:08 +0100 Subject: [PATCH 1/2] Offer suggestions for possibly mistyped label column names --- .../Utils/StringEditDistance.cs | 60 +++++++++++++++++++ .../Utils/UserInputValidationUtil.cs | 23 +++++++ .../UserInputValidationTests.cs | 16 +++++ 3 files changed, 99 insertions(+) create mode 100644 src/Microsoft.ML.AutoML/Utils/StringEditDistance.cs diff --git a/src/Microsoft.ML.AutoML/Utils/StringEditDistance.cs b/src/Microsoft.ML.AutoML/Utils/StringEditDistance.cs new file mode 100644 index 0000000000..40f22501ad --- /dev/null +++ b/src/Microsoft.ML.AutoML/Utils/StringEditDistance.cs @@ -0,0 +1,60 @@ +using System; + +namespace Microsoft.ML.AutoML.Utils +{ + public static class StringEditDistance + { + public static int GetEditDistance(string first, string second) + { + if (first is null) + { + throw new ArgumentNullException(nameof(first)); + } + + if (second is null) + { + throw new ArgumentNullException(nameof(second)); + } + + if (first.Length == 0 || second.Length == 0) + { + return first.Length + second.Length; + } + + var currentRow = 0; + var nextRow = 1; + var rows = new int[second.Length + 1, second.Length + 1]; + + for (var j = 0; j <= second.Length; ++j) + { + rows[currentRow, j] = j; + } + + for (var i = 1; i <= first.Length; ++i) + { + rows[nextRow, 0] = i; + for (var j = 1; j <= second.Length; ++j) + { + var deletion = rows[currentRow, j] + 1; + var insertion = rows[nextRow, j - 1] + 1; + var substitution = rows[currentRow, j - 1] + (first[i - 1].Equals(second[j - 1]) ? 0 : 1); + + rows[nextRow, j] = Math.Min(deletion, Math.Min(insertion, substitution)); + } + + if (currentRow == 0) + { + currentRow = 1; + nextRow = 0; + } + else + { + currentRow = 0; + nextRow = 1; + } + } + + return rows[currentRow, second.Length]; + } + } +} diff --git a/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs b/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs index cddbead9d4..240b03a94e 100644 --- a/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs +++ b/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using Microsoft.ML.AutoML.Utils; using Microsoft.ML.Data; namespace Microsoft.ML.AutoML @@ -248,6 +249,11 @@ private static void ValidateTrainDataColumn(IDataView trainData, string columnNa var nullableColumn = trainData.Schema.GetColumnOrNull(columnName); if (nullableColumn == null) { + var closestNamed = ClosestNamed(trainData, columnName, 2); + if (closestNamed != string.Empty) + { + throw new ArgumentException($"Provided {columnPurpose} column '{columnName}' not found in training data. Did you mean '{closestNamed}'."); + } throw new ArgumentException($"Provided {columnPurpose} column '{columnName}' not found in training data."); } @@ -272,6 +278,23 @@ private static void ValidateTrainDataColumn(IDataView trainData, string columnNa } } + private static string ClosestNamed(IDataView trainData, string columnName, int maxAllowableEditDistance = int.MaxValue) + { + var minEditDistance = int.MaxValue; + var closestNamed = string.Empty; + foreach (var column in trainData.Schema) + { + var editDistance = StringEditDistance.GetEditDistance(column.Name, columnName); + if (editDistance < minEditDistance) + { + minEditDistance = editDistance; + closestNamed = column.Name; + } + } + + return minEditDistance <= maxAllowableEditDistance ? closestNamed : string.Empty; + } + private static string FindFirstDuplicate(IEnumerable values) { var groups = values.GroupBy(v => v); diff --git a/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs b/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs index 6c007c8279..8759b30cb5 100644 --- a/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Threading.Tasks; using Microsoft.ML.Data; using Microsoft.ML.TestFramework; @@ -50,6 +51,21 @@ public void ValidateExperimentExecuteLabelNotInTrain() } } + [Fact] + public void ValidateExperimentExecuteLabelNotInTrainMistyped() + { + foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking }) + { + var originalColumnName = _data.Schema.First().Name; + var mistypedColumnName = originalColumnName + "a"; + var ex = Assert.Throws(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data, + new ColumnInformation() { LabelColumnName = mistypedColumnName }, null, task)); + + Assert.Equal($"Provided label column '{mistypedColumnName}' not found in training data. Did you mean '{originalColumnName}'.", + ex.Message); + } + } + [Fact] public void ValidateExperimentExecuteNumericColNotInTrain() { From 80c0760e91b0ea962b0963f7f460ad7d7b31d118 Mon Sep 17 00:00:00 2001 From: Ivan Agarsky Date: Tue, 16 Feb 2021 19:32:46 +0100 Subject: [PATCH 2/2] review changes --- src/Microsoft.ML.AutoML/Utils/StringEditDistance.cs | 4 ++-- .../Utils/UserInputValidationUtil.cs | 11 +++++++---- .../UserInputValidationTests.cs | 5 +++-- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/Microsoft.ML.AutoML/Utils/StringEditDistance.cs b/src/Microsoft.ML.AutoML/Utils/StringEditDistance.cs index 40f22501ad..63b56e815a 100644 --- a/src/Microsoft.ML.AutoML/Utils/StringEditDistance.cs +++ b/src/Microsoft.ML.AutoML/Utils/StringEditDistance.cs @@ -2,9 +2,9 @@ namespace Microsoft.ML.AutoML.Utils { - public static class StringEditDistance + internal static class StringEditDistance { - public static int GetEditDistance(string first, string second) + public static int GetLevenshteinDistance(string first, string second) { if (first is null) { diff --git a/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs b/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs index 240b03a94e..6255e526ee 100644 --- a/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs +++ b/src/Microsoft.ML.AutoML/Utils/UserInputValidationUtil.cs @@ -249,12 +249,15 @@ private static void ValidateTrainDataColumn(IDataView trainData, string columnNa var nullableColumn = trainData.Schema.GetColumnOrNull(columnName); if (nullableColumn == null) { - var closestNamed = ClosestNamed(trainData, columnName, 2); + var closestNamed = ClosestNamed(trainData, columnName, 7); + + var exceptionMessage = $"Provided {columnPurpose} column '{columnName}' not found in training data."; if (closestNamed != string.Empty) { - throw new ArgumentException($"Provided {columnPurpose} column '{columnName}' not found in training data. Did you mean '{closestNamed}'."); + exceptionMessage += $" Did you mean '{closestNamed}'."; } - throw new ArgumentException($"Provided {columnPurpose} column '{columnName}' not found in training data."); + + throw new ArgumentException(exceptionMessage); } if(allowedTypes == null) @@ -284,7 +287,7 @@ private static string ClosestNamed(IDataView trainData, string columnName, int m var closestNamed = string.Empty; foreach (var column in trainData.Schema) { - var editDistance = StringEditDistance.GetEditDistance(column.Name, columnName); + var editDistance = StringEditDistance.GetLevenshteinDistance(column.Name, columnName); if (editDistance < minEditDistance) { minEditDistance = editDistance; diff --git a/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs b/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs index 8759b30cb5..259acede05 100644 --- a/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs @@ -44,10 +44,11 @@ public void ValidateExperimentExecuteLabelNotInTrain() { foreach (var task in new[] { TaskKind.Recommendation, TaskKind.Regression, TaskKind.Ranking }) { + const string columnName = "ReallyLongNonExistingColumnName"; var ex = Assert.Throws(() => UserInputValidationUtil.ValidateExperimentExecuteArgs(_data, - new ColumnInformation() { LabelColumnName = "L" }, null, task)); + new ColumnInformation() { LabelColumnName = columnName }, null, task)); - Assert.Equal("Provided label column 'L' not found in training data.", ex.Message); + Assert.Equal($"Provided label column '{columnName}' not found in training data.", ex.Message); } }