Skip to content

Commit 2b6b973

Browse files
daholsteDmitry-A
authored andcommitted
Rev user input validation for new API (dotnet#210)
1 parent 6957d2f commit 2b6b973

6 files changed

+99
-142
lines changed

src/Microsoft.ML.Auto/API/AutoInferenceCatalog.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -57,21 +57,22 @@ public MulticlassClassificationExperiment CreateMulticlassClassificationExperime
5757
public ColumnInferenceResults InferColumns(string path, string labelColumn = DefaultColumnNames.Label, char? separatorChar = null, bool? allowQuotedStrings = null,
5858
bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true)
5959
{
60-
//UserInputValidationUtil.ValidateInferColumnsArgs(path, label);
60+
UserInputValidationUtil.ValidateInferColumnsArgs(path, labelColumn);
6161
return ColumnInferenceApi.InferColumns(_context, path, labelColumn, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
6262
}
6363

6464
public ColumnInferenceResults InferColumns(string path, ColumnInformation columnInformation, char? separatorChar = null, bool? allowQuotedStrings = null,
6565
bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true)
6666
{
67-
//UserInputValidationUtil.ValidateInferColumnsArgs(path, label);
67+
columnInformation = columnInformation ?? new ColumnInformation();
68+
UserInputValidationUtil.ValidateInferColumnsArgs(path, columnInformation);
6869
return ColumnInferenceApi.InferColumns(_context, path, columnInformation, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
6970
}
7071

7172
public ColumnInferenceResults InferColumns(string path, uint labelColumnIndex, bool hasHeader = false, char? separatorChar = null,
7273
bool? allowQuotedStrings = null, bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true)
7374
{
74-
//UserInputValidationUtil.ValidateInferColumnsArgs(path);
75+
UserInputValidationUtil.ValidateInferColumnsArgs(path);
7576
return ColumnInferenceApi.InferColumns(_context, path, labelColumnIndex, hasHeader, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
7677
}
7778
}

src/Microsoft.ML.Auto/API/BinaryClassificationExperiment.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ internal IEnumerable<RunResult<BinaryClassificationMetrics>> Execute(MLContext c
8181
IEstimator<ITransformer> preFeaturizers = null)
8282
{
8383
columnInfo = columnInfo ?? new ColumnInformation();
84-
//UserInputValidationUtil.ValidateAutoFitArgs(trainData, labelColunName, validationData, settings, columnPurposes)
84+
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, columnInfo, validationData);
8585

8686
// run autofit & get all pipelines run in that process
8787
var experiment = new Experiment<BinaryClassificationMetrics>(context, TaskKind.BinaryClassification, trainData, columnInfo,

src/Microsoft.ML.Auto/API/MulticlassClassificationExperiment.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ internal IEnumerable<RunResult<MultiClassClassifierMetrics>> Execute(MLContext c
7979
IEstimator<ITransformer> preFeaturizers = null)
8080
{
8181
columnInfo = columnInfo ?? new ColumnInformation();
82-
//UserInputValidationUtil.ValidateAutoFitArgs(trainData, labelColunName, validationData, settings, columnPurposes)
82+
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, columnInfo, validationData);
8383

8484
// run autofit & get all pipelines run in that process
8585
var experiment = new Experiment<MultiClassClassifierMetrics>(context, TaskKind.MulticlassClassification, trainData,

src/Microsoft.ML.Auto/API/RegressionExperiment.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ internal IEnumerable<RunResult<RegressionMetrics>> Execute(MLContext context,
7676
IEstimator<ITransformer> preFeaturizers = null)
7777
{
7878
columnInfo = columnInfo ?? new ColumnInformation();
79-
//UserInputValidationUtil.ValidateAutoFitArgs(trainData, labelColunName, validationData, settings, columnPurposes);
79+
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, columnInfo, validationData);
8080

8181
// run autofit & get all pipelines run in that process
8282
var experiment = new Experiment<RegressionMetrics>(context, TaskKind.Regression, trainData, columnInfo,

src/Microsoft.ML.Auto/Utils/UserInputValidationUtil.cs

+61-57
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,27 @@
99
using Microsoft.Data.DataView;
1010
using Microsoft.ML.Data;
1111

12-
// todo: re-write & test user input validation once final API nailed down.
13-
// Tracked by Github issue: https://github.com/dotnet/machinelearning-automl/issues/159
14-
1512
namespace Microsoft.ML.Auto
1613
{
17-
/*internal static class UserInputValidationUtil
14+
internal static class UserInputValidationUtil
1815
{
19-
public static void ValidateAutoFitArgs(IDataView trainData, string label, IDataView validationData,
20-
AutoFitSettings settings, IEnumerable<(string, ColumnPurpose)> purposeOverrides)
16+
public static void ValidateExperimentExecuteArgs(IDataView trainData, ColumnInformation columnInformation,
17+
IDataView validationData)
2118
{
2219
ValidateTrainData(trainData);
20+
ValidateColumnInformation(trainData, columnInformation);
2321
ValidateValidationData(trainData, validationData);
24-
ValidateLabel(trainData, label);
25-
ValidateSettings(settings);
26-
ValidatePurposeOverrides(trainData, validationData, label, purposeOverrides);
2722
}
2823

29-
public static void ValidateInferColumnsArgs(string path, string label)
24+
public static void ValidateInferColumnsArgs(string path, ColumnInformation columnInformation)
25+
{
26+
ValidateColumnInformation(columnInformation);
27+
ValidatePath(path);
28+
}
29+
30+
public static void ValidateInferColumnsArgs(string path, string labelColumn)
3031
{
31-
ValidateLabel(label);
32+
ValidateLabelColumn(labelColumn);
3233
ValidatePath(path);
3334
}
3435

@@ -51,21 +52,55 @@ private static void ValidateTrainData(IDataView trainData)
5152
}
5253
}
5354

54-
private static void ValidateLabel(IDataView trainData, string label)
55+
private static void ValidateColumnInformation(IDataView trainData, ColumnInformation columnInformation)
5556
{
56-
ValidateLabel(label);
57+
ValidateColumnInformation(columnInformation);
58+
ValidateTrainDataColumnExists(trainData, columnInformation.LabelColumn);
59+
ValidateTrainDataColumnExists(trainData, columnInformation.WeightColumn);
60+
ValidateTrainDataColumnsExist(trainData, columnInformation.CategoricalColumns);
61+
ValidateTrainDataColumnsExist(trainData, columnInformation.NumericColumns);
62+
ValidateTrainDataColumnsExist(trainData, columnInformation.TextColumns);
63+
ValidateTrainDataColumnsExist(trainData, columnInformation.IgnoredColumns);
64+
}
5765

58-
if (trainData.Schema.GetColumnOrNull(label) == null)
66+
private static void ValidateColumnInformation(ColumnInformation columnInformation)
67+
{
68+
ValidateLabelColumn(columnInformation.LabelColumn);
69+
70+
ValidateColumnInfoEnumerationProperty(columnInformation.CategoricalColumns, "categorical");
71+
ValidateColumnInfoEnumerationProperty(columnInformation.NumericColumns, "numeric");
72+
ValidateColumnInfoEnumerationProperty(columnInformation.TextColumns, "text");
73+
ValidateColumnInfoEnumerationProperty(columnInformation.IgnoredColumns, "ignored");
74+
75+
// keep a list of all columns, to detect duplicates
76+
var allColumns = new List<string>();
77+
allColumns.Add(columnInformation.LabelColumn);
78+
if (columnInformation.WeightColumn != null) { allColumns.Add(columnInformation.WeightColumn); }
79+
if (columnInformation.CategoricalColumns != null) { allColumns.AddRange(columnInformation.CategoricalColumns); }
80+
if (columnInformation.NumericColumns != null) { allColumns.AddRange(columnInformation.NumericColumns); }
81+
if (columnInformation.TextColumns != null) { allColumns.AddRange(columnInformation.TextColumns); }
82+
if (columnInformation.IgnoredColumns != null) { allColumns.AddRange(columnInformation.IgnoredColumns); }
83+
84+
var duplicateColName = FindFirstDuplicate(allColumns);
85+
if (duplicateColName != null)
5986
{
60-
throw new ArgumentException($"Provided label column '{label}' not found in training data.", nameof(label));
87+
throw new ArgumentException($"Duplicate column name {duplicateColName} is present in two or more distinct properties of provided column information", nameof(columnInformation));
6188
}
6289
}
6390

64-
private static void ValidateLabel(string label)
91+
private static void ValidateColumnInfoEnumerationProperty(IEnumerable<string> columns, string propertyName)
6592
{
66-
if (label == null)
93+
if (columns?.Contains(null) == true)
6794
{
68-
throw new ArgumentNullException(nameof(label), "Provided label cannot be null");
95+
throw new ArgumentException($"Null column string was specified as {propertyName} in column information");
96+
}
97+
}
98+
99+
private static void ValidateLabelColumn(string labelColumn)
100+
{
101+
if (labelColumn == null)
102+
{
103+
throw new ArgumentException("Provided label column cannot be null");
69104
}
70105
}
71106

@@ -120,55 +155,24 @@ private static void ValidateValidationData(IDataView trainData, IDataView valida
120155
}
121156
}
122157

123-
private static void ValidateSettings(AutoFitSettings settings)
158+
private static void ValidateTrainDataColumnsExist(IDataView trainData, IEnumerable<string> columnNames)
124159
{
125-
if (settings?.StoppingCriteria == null)
160+
if (columnNames == null)
126161
{
127162
return;
128163
}
129164

130-
if (settings.StoppingCriteria.MaxIterations <= 0)
165+
foreach (var columnName in columnNames)
131166
{
132-
throw new ArgumentOutOfRangeException(nameof(settings), "Max iterations must be > 0");
167+
ValidateTrainDataColumnExists(trainData, columnName);
133168
}
134169
}
135170

136-
private static void ValidatePurposeOverrides(IDataView trainData, IDataView validationData,
137-
string label, IEnumerable<(string, ColumnPurpose)> purposeOverrides)
171+
private static void ValidateTrainDataColumnExists(IDataView trainData, string columnName)
138172
{
139-
if (purposeOverrides == null)
140-
{
141-
return;
142-
}
143-
144-
foreach (var purposeOverride in purposeOverrides)
145-
{
146-
var colName = purposeOverride.Item1;
147-
var colPurpose = purposeOverride.Item2;
148-
149-
if (colName == null)
150-
{
151-
throw new ArgumentException("Purpose override column name cannot be null.", nameof(purposeOverrides));
152-
}
153-
154-
if (trainData.Schema.GetColumnOrNull(colName) == null)
155-
{
156-
throw new ArgumentException($"Purpose override column name '{colName}' not found in training data.", nameof(purposeOverride));
157-
}
158-
159-
// if column w/ purpose = 'Label' found, ensure it matches the passed-in label
160-
if (colPurpose == ColumnPurpose.Label && colName != label)
161-
{
162-
throw new ArgumentException($"Label column name in provided list of purposes '{colName}' must match " +
163-
$"the label column name '{label}'", nameof(purposeOverrides));
164-
}
165-
}
166-
167-
// ensure all column names unique
168-
var duplicateColName = FindFirstDuplicate(purposeOverrides.Select(p => p.Item1));
169-
if (duplicateColName != null)
173+
if (columnName != null && trainData.Schema.GetColumnOrNull(columnName) == null)
170174
{
171-
throw new ArgumentException($"Duplicate column name '{duplicateColName}' in purpose overrides.", nameof(purposeOverrides));
175+
throw new ArgumentException($"Provided column '{columnName}' not found in training data.");
172176
}
173177
}
174178

@@ -177,5 +181,5 @@ private static string FindFirstDuplicate(IEnumerable<string> values)
177181
var groups = values.GroupBy(v => v);
178182
return groups.FirstOrDefault(g => g.Count() > 1)?.Key;
179183
}
180-
}*/
184+
}
181185
}

0 commit comments

Comments
 (0)