Skip to content

Commit 904cfd7

Browse files
daholsteDmitry-A
authored andcommitted
[AutoML] AutoML SDK API: validate schema types of input IDataView (dotnet#3597)
1 parent 37c08ae commit 904cfd7

File tree

3 files changed

+124
-23
lines changed

3 files changed

+124
-23
lines changed

src/Microsoft.ML.Auto/API/ExperimentBase.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ private ExperimentResult<TMetrics> ExecuteTrainValidate(IDataView trainData,
239239
IProgress<RunDetail<TMetrics>> progressHandler)
240240
{
241241
columnInfo = columnInfo ?? new ColumnInformation();
242-
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, columnInfo, validationData);
242+
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, columnInfo, validationData, _task);
243243

244244
// Apply pre-featurizer
245245
ITransformer preprocessorTransform = null;
@@ -263,7 +263,7 @@ private CrossValidationExperimentResult<TMetrics> ExecuteCrossVal(IDataView[] tr
263263
IProgress<CrossValidationRunDetail<TMetrics>> progressHandler)
264264
{
265265
columnInfo = columnInfo ?? new ColumnInformation();
266-
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0]);
266+
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0], _task);
267267

268268
// Apply pre-featurizer
269269
ITransformer[] preprocessorTransforms = null;
@@ -290,7 +290,7 @@ private ExperimentResult<TMetrics> ExecuteCrossValSummary(IDataView[] trainDatas
290290
IProgress<RunDetail<TMetrics>> progressHandler)
291291
{
292292
columnInfo = columnInfo ?? new ColumnInformation();
293-
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0]);
293+
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0], _task);
294294

295295
// Apply pre-featurizer
296296
ITransformer[] preprocessorTransforms = null;

src/Microsoft.ML.Auto/Utils/UserInputValidationUtil.cs

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ internal static class UserInputValidationUtil
2222
private const string SamplingKeyColumnPurposeName = "sampling key";
2323

2424
public static void ValidateExperimentExecuteArgs(IDataView trainData, ColumnInformation columnInformation,
25-
IDataView validationData)
25+
IDataView validationData, TaskKind task)
2626
{
27-
ValidateTrainData(trainData);
28-
ValidateColumnInformation(trainData, columnInformation);
27+
ValidateTrainData(trainData, columnInformation);
28+
ValidateColumnInformation(trainData, columnInformation, task);
2929
ValidateValidationData(trainData, validationData);
3030
}
3131

@@ -54,24 +54,37 @@ public static void ValidateNumberOfCVFoldsArg(uint numberOfCVFolds)
5454
}
5555
}
5656

57-
private static void ValidateTrainData(IDataView trainData)
57+
private static void ValidateTrainData(IDataView trainData, ColumnInformation columnInformation)
5858
{
5959
if (trainData == null)
6060
{
6161
throw new ArgumentNullException(nameof(trainData), "Training data cannot be null");
6262
}
6363

64-
var type = trainData.Schema.GetColumnOrNull(DefaultColumnNames.Features)?.Type.GetItemType();
65-
if (type != null && type != NumberDataViewType.Single)
64+
foreach (var column in trainData.Schema)
6665
{
67-
throw new ArgumentException($"{DefaultColumnNames.Features} column must be of data type Single", nameof(trainData));
66+
if (column.Name == DefaultColumnNames.Features && column.Type.GetItemType() != NumberDataViewType.Single)
67+
{
68+
throw new ArgumentException($"{DefaultColumnNames.Features} column must be of data type {NumberDataViewType.Single}", nameof(trainData));
69+
}
70+
71+
if (column.Name != columnInformation.LabelColumnName &&
72+
column.Type.GetItemType() != BooleanDataViewType.Instance &&
73+
column.Type.GetItemType() != NumberDataViewType.Single &&
74+
column.Type.GetItemType() != TextDataViewType.Instance)
75+
{
76+
throw new ArgumentException($"Only supported feature column types are " +
77+
$"{BooleanDataViewType.Instance}, {NumberDataViewType.Single}, and {TextDataViewType.Instance}. " +
78+
$"Please change the feature column {column.Name} of type {column.Type} to one of " +
79+
$"the supported types.", nameof(trainData));
80+
}
6881
}
6982
}
7083

71-
private static void ValidateColumnInformation(IDataView trainData, ColumnInformation columnInformation)
84+
private static void ValidateColumnInformation(IDataView trainData, ColumnInformation columnInformation, TaskKind task)
7285
{
7386
ValidateColumnInformation(columnInformation);
74-
ValidateTrainDataColumn(trainData, columnInformation.LabelColumnName, LabelColumnPurposeName);
87+
ValidateTrainDataColumn(trainData, columnInformation.LabelColumnName, LabelColumnPurposeName, GetAllowedLabelTypes(task));
7588
ValidateTrainDataColumn(trainData, columnInformation.ExampleWeightColumnName, WeightColumnPurposeName);
7689
ValidateTrainDataColumn(trainData, columnInformation.SamplingKeyColumnName, SamplingKeyColumnPurposeName);
7790
ValidateTrainDataColumns(trainData, columnInformation.CategoricalColumnNames, CategoricalColumnPurposeName,
@@ -228,5 +241,22 @@ private static string FindFirstDuplicate(IEnumerable<string> values)
228241
var groups = values.GroupBy(v => v);
229242
return groups.FirstOrDefault(g => g.Count() > 1)?.Key;
230243
}
244+
245+
private static IEnumerable<DataViewType> GetAllowedLabelTypes(TaskKind task)
246+
{
247+
switch (task)
248+
{
249+
case TaskKind.BinaryClassification:
250+
return new DataViewType[] { BooleanDataViewType.Instance };
251+
// Multiclass label types are flexible, as we convert the label to a key type
252+
// (if input label is not already a key) before invoking the trainer.
253+
case TaskKind.MulticlassClassification:
254+
return null;
255+
case TaskKind.Regression:
256+
return new DataViewType[] { NumberDataViewType.Single };
257+
default:
258+
throw new NotSupportedException($"Unsupported task type: {task}");
259+
}
260+
}
231261
}
232262
}

test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,23 @@ public class UserInputValidationTests
1818
[ExpectedException(typeof(ArgumentNullException))]
1919
public void ValidateExperimentExecuteNullTrainData()
2020
{
21-
UserInputValidationUtil.ValidateExperimentExecuteArgs(null, new ColumnInformation(), null);
21+
UserInputValidationUtil.ValidateExperimentExecuteArgs(null, new ColumnInformation(), null, TaskKind.Regression);
2222
}
2323

2424
[TestMethod]
2525
[ExpectedException(typeof(ArgumentException))]
2626
public void ValidateExperimentExecuteNullLabel()
2727
{
2828
UserInputValidationUtil.ValidateExperimentExecuteArgs(Data,
29-
new ColumnInformation() { LabelColumnName = null }, null);
29+
new ColumnInformation() { LabelColumnName = null }, null, TaskKind.Regression);
3030
}
3131

3232
[TestMethod]
3333
[ExpectedException(typeof(ArgumentException))]
3434
public void ValidateExperimentExecuteLabelNotInTrain()
3535
{
3636
UserInputValidationUtil.ValidateExperimentExecuteArgs(Data,
37-
new ColumnInformation() { LabelColumnName = "L" }, null);
37+
new ColumnInformation() { LabelColumnName = "L" }, null, TaskKind.Regression);
3838
}
3939

4040
[TestMethod]
@@ -44,7 +44,7 @@ public void ValidateExperimentExecuteNumericColNotInTrain()
4444
var columnInfo = new ColumnInformation();
4545
columnInfo.NumericColumnNames.Add("N");
4646

47-
UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null);
47+
UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null, TaskKind.Regression);
4848
}
4949

5050
[TestMethod]
@@ -53,7 +53,7 @@ public void ValidateExperimentExecuteNullNumericCol()
5353
{
5454
var columnInfo = new ColumnInformation();
5555
columnInfo.NumericColumnNames.Add(null);
56-
UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null);
56+
UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null, TaskKind.Regression);
5757
}
5858

5959
[TestMethod]
@@ -63,7 +63,7 @@ public void ValidateExperimentExecuteDuplicateCol()
6363
var columnInfo = new ColumnInformation();
6464
columnInfo.NumericColumnNames.Add(DefaultColumnNames.Label);
6565

66-
UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null);
66+
UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null, TaskKind.Regression);
6767
}
6868

6969
[TestMethod]
@@ -82,7 +82,7 @@ public void ValidateExperimentExecuteArgsTrainValidColCountMismatch()
8282
var validData = validDataBuilder.GetDataView();
8383

8484
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData,
85-
new ColumnInformation() { LabelColumnName = "0" }, validData);
85+
new ColumnInformation() { LabelColumnName = "0" }, validData, TaskKind.Regression);
8686
}
8787

8888
[TestMethod]
@@ -102,7 +102,7 @@ public void ValidateExperimentExecuteArgsTrainValidColNamesMismatch()
102102
var validData = validDataBuilder.GetDataView();
103103

104104
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData,
105-
new ColumnInformation() { LabelColumnName = "0" }, validData);
105+
new ColumnInformation() { LabelColumnName = "0" }, validData, TaskKind.Regression);
106106
}
107107

108108
[TestMethod]
@@ -122,7 +122,7 @@ public void ValidateExperimentExecuteArgsTrainValidColTypeMismatch()
122122
var validData = validDataBuilder.GetDataView();
123123

124124
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData,
125-
new ColumnInformation() { LabelColumnName = "0" }, validData);
125+
new ColumnInformation() { LabelColumnName = "0" }, validData, TaskKind.Regression);
126126
}
127127

128128
[TestMethod]
@@ -163,7 +163,7 @@ public void ValidateFeaturesColInvalidType()
163163
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
164164
var schema = schemaBuilder.ToSchema();
165165
var dataView = new EmptyDataView(new MLContext(), schema);
166-
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(), null);
166+
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(), null, TaskKind.Regression);
167167
}
168168

169169
[TestMethod]
@@ -181,7 +181,78 @@ public void ValidateTextColumnNotText()
181181
var columnInfo = new ColumnInformation();
182182
columnInfo.NumericColumnNames.Add(TextPurposeColName);
183183

184-
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, columnInfo, null);
184+
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, columnInfo, null, TaskKind.Regression);
185+
}
186+
187+
[TestMethod]
188+
public void ValidateRegressionLabelTypes()
189+
{
190+
ValidateLabelTypeTestCore(TaskKind.Regression, NumberDataViewType.Single, true);
191+
ValidateLabelTypeTestCore(TaskKind.Regression, BooleanDataViewType.Instance, false);
192+
ValidateLabelTypeTestCore(TaskKind.Regression, NumberDataViewType.Double, false);
193+
ValidateLabelTypeTestCore(TaskKind.Regression, TextDataViewType.Instance, false);
194+
}
195+
196+
[TestMethod]
197+
public void ValidateBinaryClassificationLabelTypes()
198+
{
199+
ValidateLabelTypeTestCore(TaskKind.BinaryClassification, NumberDataViewType.Single, false);
200+
ValidateLabelTypeTestCore(TaskKind.BinaryClassification, BooleanDataViewType.Instance, true);
201+
}
202+
203+
[TestMethod]
204+
public void ValidateMulticlassLabelTypes()
205+
{
206+
ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, NumberDataViewType.Single, true);
207+
ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, BooleanDataViewType.Instance, true);
208+
ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, NumberDataViewType.Double, true);
209+
ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, TextDataViewType.Instance, true);
210+
}
211+
212+
[TestMethod]
213+
public void ValidateAllowedFeatureColumnTypes()
214+
{
215+
var schemaBuilder = new DataViewSchema.Builder();
216+
schemaBuilder.AddColumn("Boolean", BooleanDataViewType.Instance);
217+
schemaBuilder.AddColumn("Number", NumberDataViewType.Single);
218+
schemaBuilder.AddColumn("Text", TextDataViewType.Instance);
219+
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
220+
var schema = schemaBuilder.ToSchema();
221+
var dataView = new EmptyDataView(new MLContext(), schema);
222+
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(),
223+
null, TaskKind.Regression);
224+
}
225+
226+
[TestMethod]
227+
[ExpectedException(typeof(ArgumentException))]
228+
public void ValidateProhibitedFeatureColumnType()
229+
{
230+
var schemaBuilder = new DataViewSchema.Builder();
231+
schemaBuilder.AddColumn("UInt64", NumberDataViewType.UInt64);
232+
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
233+
var schema = schemaBuilder.ToSchema();
234+
var dataView = new EmptyDataView(new MLContext(), schema);
235+
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(),
236+
null, TaskKind.Regression);
237+
}
238+
239+
private static void ValidateLabelTypeTestCore(TaskKind task, DataViewType labelType, bool labelTypeShouldBeValid)
240+
{
241+
var schemaBuilder = new DataViewSchema.Builder();
242+
schemaBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single);
243+
schemaBuilder.AddColumn(DefaultColumnNames.Label, labelType);
244+
var schema = schemaBuilder.ToSchema();
245+
var dataView = new EmptyDataView(new MLContext(), schema);
246+
var validationExceptionThrown = false;
247+
try
248+
{
249+
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(), null, task);
250+
}
251+
catch
252+
{
253+
validationExceptionThrown = true;
254+
}
255+
Assert.AreEqual(labelTypeShouldBeValid, !validationExceptionThrown);
185256
}
186257
}
187258
}

0 commit comments

Comments
 (0)