Skip to content

[AutoML] AutoML SDK API: validate schema types of input IDataView #3597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Auto/API/ExperimentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ private ExperimentResult<TMetrics> ExecuteTrainValidate(IDataView trainData,
IProgress<RunDetail<TMetrics>> progressHandler)
{
columnInfo = columnInfo ?? new ColumnInformation();
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, columnInfo, validationData);
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, columnInfo, validationData, _task);

// Apply pre-featurizer
ITransformer preprocessorTransform = null;
Expand All @@ -263,7 +263,7 @@ private CrossValidationExperimentResult<TMetrics> ExecuteCrossVal(IDataView[] tr
IProgress<CrossValidationRunDetail<TMetrics>> progressHandler)
{
columnInfo = columnInfo ?? new ColumnInformation();
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0]);
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0], _task);

// Apply pre-featurizer
ITransformer[] preprocessorTransforms = null;
Expand All @@ -290,7 +290,7 @@ private ExperimentResult<TMetrics> ExecuteCrossValSummary(IDataView[] trainDatas
IProgress<RunDetail<TMetrics>> progressHandler)
{
columnInfo = columnInfo ?? new ColumnInformation();
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0]);
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0], _task);

// Apply pre-featurizer
ITransformer[] preprocessorTransforms = null;
Expand Down
48 changes: 39 additions & 9 deletions src/Microsoft.ML.Auto/Utils/UserInputValidationUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ internal static class UserInputValidationUtil
private const string SamplingKeyColumnPurposeName = "sampling key";

public static void ValidateExperimentExecuteArgs(IDataView trainData, ColumnInformation columnInformation,
IDataView validationData)
IDataView validationData, TaskKind task)
{
ValidateTrainData(trainData);
ValidateColumnInformation(trainData, columnInformation);
ValidateTrainData(trainData, columnInformation);
ValidateColumnInformation(trainData, columnInformation, task);
ValidateValidationData(trainData, validationData);
}

Expand Down Expand Up @@ -54,24 +54,37 @@ public static void ValidateNumberOfCVFoldsArg(uint numberOfCVFolds)
}
}

private static void ValidateTrainData(IDataView trainData)
private static void ValidateTrainData(IDataView trainData, ColumnInformation columnInformation)
{
if (trainData == null)
{
throw new ArgumentNullException(nameof(trainData), "Training data cannot be null");
}

var type = trainData.Schema.GetColumnOrNull(DefaultColumnNames.Features)?.Type.GetItemType();
if (type != null && type != NumberDataViewType.Single)
foreach (var column in trainData.Schema)
{
throw new ArgumentException($"{DefaultColumnNames.Features} column must be of data type Single", nameof(trainData));
if (column.Name == DefaultColumnNames.Features && column.Type.GetItemType() != NumberDataViewType.Single)
{
throw new ArgumentException($"{DefaultColumnNames.Features} column must be of data type {NumberDataViewType.Single}", nameof(trainData));
}

if (column.Name != columnInformation.LabelColumnName &&
column.Type.GetItemType() != BooleanDataViewType.Instance &&
column.Type.GetItemType() != NumberDataViewType.Single &&
column.Type.GetItemType() != TextDataViewType.Instance)
{
throw new ArgumentException($"Only supported feature column types are " +
$"{BooleanDataViewType.Instance}, {NumberDataViewType.Single}, and {TextDataViewType.Instance}. " +
$"Please change the feature column {column.Name} of type {column.Type} to one of " +
$"the supported types.", nameof(trainData));
}
}
}

private static void ValidateColumnInformation(IDataView trainData, ColumnInformation columnInformation)
private static void ValidateColumnInformation(IDataView trainData, ColumnInformation columnInformation, TaskKind task)
{
ValidateColumnInformation(columnInformation);
ValidateTrainDataColumn(trainData, columnInformation.LabelColumnName, LabelColumnPurposeName);
ValidateTrainDataColumn(trainData, columnInformation.LabelColumnName, LabelColumnPurposeName, GetAllowedLabelTypes(task));
ValidateTrainDataColumn(trainData, columnInformation.ExampleWeightColumnName, WeightColumnPurposeName);
ValidateTrainDataColumn(trainData, columnInformation.SamplingKeyColumnName, SamplingKeyColumnPurposeName);
ValidateTrainDataColumns(trainData, columnInformation.CategoricalColumnNames, CategoricalColumnPurposeName,
Expand Down Expand Up @@ -228,5 +241,22 @@ private static string FindFirstDuplicate(IEnumerable<string> values)
var groups = values.GroupBy(v => v);
return groups.FirstOrDefault(g => g.Count() > 1)?.Key;
}

private static IEnumerable<DataViewType> GetAllowedLabelTypes(TaskKind task)
{
switch (task)
{
case TaskKind.BinaryClassification:
return new DataViewType[] { BooleanDataViewType.Instance };
// Multiclass label types are flexible, as we convert the label to a key type
// (if input label is not already a key) before invoking the trainer.
case TaskKind.MulticlassClassification:
return null;
case TaskKind.Regression:
return new DataViewType[] { NumberDataViewType.Single };
default:
throw new NotSupportedException($"Unsupported task type: {task}");
}
}
}
}
93 changes: 82 additions & 11 deletions test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,23 @@ public class UserInputValidationTests
[ExpectedException(typeof(ArgumentNullException))]
public void ValidateExperimentExecuteNullTrainData()
{
UserInputValidationUtil.ValidateExperimentExecuteArgs(null, new ColumnInformation(), null);
UserInputValidationUtil.ValidateExperimentExecuteArgs(null, new ColumnInformation(), null, TaskKind.Regression);
}

[TestMethod]
[ExpectedException(typeof(ArgumentException))]
public void ValidateExperimentExecuteNullLabel()
{
UserInputValidationUtil.ValidateExperimentExecuteArgs(Data,
new ColumnInformation() { LabelColumnName = null }, null);
new ColumnInformation() { LabelColumnName = null }, null, TaskKind.Regression);
}

[TestMethod]
[ExpectedException(typeof(ArgumentException))]
public void ValidateExperimentExecuteLabelNotInTrain()
{
UserInputValidationUtil.ValidateExperimentExecuteArgs(Data,
new ColumnInformation() { LabelColumnName = "L" }, null);
new ColumnInformation() { LabelColumnName = "L" }, null, TaskKind.Regression);
}

[TestMethod]
Expand All @@ -44,7 +44,7 @@ public void ValidateExperimentExecuteNumericColNotInTrain()
var columnInfo = new ColumnInformation();
columnInfo.NumericColumnNames.Add("N");

UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null);
UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null, TaskKind.Regression);
}

[TestMethod]
Expand All @@ -53,7 +53,7 @@ public void ValidateExperimentExecuteNullNumericCol()
{
var columnInfo = new ColumnInformation();
columnInfo.NumericColumnNames.Add(null);
UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null);
UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null, TaskKind.Regression);
}

[TestMethod]
Expand All @@ -63,7 +63,7 @@ public void ValidateExperimentExecuteDuplicateCol()
var columnInfo = new ColumnInformation();
columnInfo.NumericColumnNames.Add(DefaultColumnNames.Label);

UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null);
UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null, TaskKind.Regression);
}

[TestMethod]
Expand All @@ -82,7 +82,7 @@ public void ValidateExperimentExecuteArgsTrainValidColCountMismatch()
var validData = validDataBuilder.GetDataView();

UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData,
new ColumnInformation() { LabelColumnName = "0" }, validData);
new ColumnInformation() { LabelColumnName = "0" }, validData, TaskKind.Regression);
}

[TestMethod]
Expand All @@ -102,7 +102,7 @@ public void ValidateExperimentExecuteArgsTrainValidColNamesMismatch()
var validData = validDataBuilder.GetDataView();

UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData,
new ColumnInformation() { LabelColumnName = "0" }, validData);
new ColumnInformation() { LabelColumnName = "0" }, validData, TaskKind.Regression);
}

[TestMethod]
Expand All @@ -122,7 +122,7 @@ public void ValidateExperimentExecuteArgsTrainValidColTypeMismatch()
var validData = validDataBuilder.GetDataView();

UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData,
new ColumnInformation() { LabelColumnName = "0" }, validData);
new ColumnInformation() { LabelColumnName = "0" }, validData, TaskKind.Regression);
}

[TestMethod]
Expand Down Expand Up @@ -163,7 +163,7 @@ public void ValidateFeaturesColInvalidType()
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
var schema = schemaBuilder.ToSchema();
var dataView = new EmptyDataView(new MLContext(), schema);
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(), null);
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(), null, TaskKind.Regression);
}

[TestMethod]
Expand All @@ -181,7 +181,78 @@ public void ValidateTextColumnNotText()
var columnInfo = new ColumnInformation();
columnInfo.NumericColumnNames.Add(TextPurposeColName);

UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, columnInfo, null);
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, columnInfo, null, TaskKind.Regression);
}

[TestMethod]
public void ValidateRegressionLabelTypes()
{
ValidateLabelTypeTestCore(TaskKind.Regression, NumberDataViewType.Single, true);
ValidateLabelTypeTestCore(TaskKind.Regression, BooleanDataViewType.Instance, false);
ValidateLabelTypeTestCore(TaskKind.Regression, NumberDataViewType.Double, false);
ValidateLabelTypeTestCore(TaskKind.Regression, TextDataViewType.Instance, false);
}

[TestMethod]
public void ValidateBinaryClassificationLabelTypes()
{
ValidateLabelTypeTestCore(TaskKind.BinaryClassification, NumberDataViewType.Single, false);
ValidateLabelTypeTestCore(TaskKind.BinaryClassification, BooleanDataViewType.Instance, true);
}

[TestMethod]
public void ValidateMulticlassLabelTypes()
{
ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, NumberDataViewType.Single, true);
ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, BooleanDataViewType.Instance, true);
ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, NumberDataViewType.Double, true);
ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, TextDataViewType.Instance, true);
}

[TestMethod]
public void ValidateAllowedFeatureColumnTypes()
{
var schemaBuilder = new DataViewSchema.Builder();
schemaBuilder.AddColumn("Boolean", BooleanDataViewType.Instance);
schemaBuilder.AddColumn("Number", NumberDataViewType.Single);
schemaBuilder.AddColumn("Text", TextDataViewType.Instance);
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
var schema = schemaBuilder.ToSchema();
var dataView = new EmptyDataView(new MLContext(), schema);
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(),
null, TaskKind.Regression);
}

[TestMethod]
[ExpectedException(typeof(ArgumentException))]
public void ValidateProhibitedFeatureColumnType()
{
var schemaBuilder = new DataViewSchema.Builder();
schemaBuilder.AddColumn("UInt64", NumberDataViewType.UInt64);
schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single);
var schema = schemaBuilder.ToSchema();
var dataView = new EmptyDataView(new MLContext(), schema);
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(),
null, TaskKind.Regression);
}

private static void ValidateLabelTypeTestCore(TaskKind task, DataViewType labelType, bool labelTypeShouldBeValid)
{
var schemaBuilder = new DataViewSchema.Builder();
schemaBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single);
schemaBuilder.AddColumn(DefaultColumnNames.Label, labelType);
var schema = schemaBuilder.ToSchema();
var dataView = new EmptyDataView(new MLContext(), schema);
var validationExceptionThrown = false;
try
{
UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(), null, task);
}
catch
{
validationExceptionThrown = true;
}
Assert.AreEqual(labelTypeShouldBeValid, !validationExceptionThrown);
}
}
}