diff --git a/src/Microsoft.ML.Auto/API/ExperimentBase.cs b/src/Microsoft.ML.Auto/API/ExperimentBase.cs index c9a028e926..8a7ffcb86e 100644 --- a/src/Microsoft.ML.Auto/API/ExperimentBase.cs +++ b/src/Microsoft.ML.Auto/API/ExperimentBase.cs @@ -239,7 +239,7 @@ private ExperimentResult ExecuteTrainValidate(IDataView trainData, IProgress> progressHandler) { columnInfo = columnInfo ?? new ColumnInformation(); - UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, columnInfo, validationData); + UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, columnInfo, validationData, _task); // Apply pre-featurizer ITransformer preprocessorTransform = null; @@ -263,7 +263,7 @@ private CrossValidationExperimentResult ExecuteCrossVal(IDataView[] tr IProgress> progressHandler) { columnInfo = columnInfo ?? new ColumnInformation(); - UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0]); + UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0], _task); // Apply pre-featurizer ITransformer[] preprocessorTransforms = null; @@ -290,7 +290,7 @@ private ExperimentResult ExecuteCrossValSummary(IDataView[] trainDatas IProgress> progressHandler) { columnInfo = columnInfo ?? new ColumnInformation(); - UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0]); + UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0], _task); // Apply pre-featurizer ITransformer[] preprocessorTransforms = null; diff --git a/src/Microsoft.ML.Auto/Utils/UserInputValidationUtil.cs b/src/Microsoft.ML.Auto/Utils/UserInputValidationUtil.cs index eddd143b37..11c493989b 100644 --- a/src/Microsoft.ML.Auto/Utils/UserInputValidationUtil.cs +++ b/src/Microsoft.ML.Auto/Utils/UserInputValidationUtil.cs @@ -22,10 +22,10 @@ internal static class UserInputValidationUtil private const string SamplingKeyColumnPurposeName = "sampling key"; public static void ValidateExperimentExecuteArgs(IDataView trainData, ColumnInformation columnInformation, - IDataView validationData) + IDataView validationData, TaskKind task) { - ValidateTrainData(trainData); - ValidateColumnInformation(trainData, columnInformation); + ValidateTrainData(trainData, columnInformation); + ValidateColumnInformation(trainData, columnInformation, task); ValidateValidationData(trainData, validationData); } @@ -54,24 +54,37 @@ public static void ValidateNumberOfCVFoldsArg(uint numberOfCVFolds) } } - private static void ValidateTrainData(IDataView trainData) + private static void ValidateTrainData(IDataView trainData, ColumnInformation columnInformation) { if (trainData == null) { throw new ArgumentNullException(nameof(trainData), "Training data cannot be null"); } - var type = trainData.Schema.GetColumnOrNull(DefaultColumnNames.Features)?.Type.GetItemType(); - if (type != null && type != NumberDataViewType.Single) + foreach (var column in trainData.Schema) { - throw new ArgumentException($"{DefaultColumnNames.Features} column must be of data type Single", nameof(trainData)); + if (column.Name == DefaultColumnNames.Features && column.Type.GetItemType() != NumberDataViewType.Single) + { + throw new ArgumentException($"{DefaultColumnNames.Features} column must be of data type {NumberDataViewType.Single}", nameof(trainData)); + } + + if (column.Name != columnInformation.LabelColumnName && + column.Type.GetItemType() != BooleanDataViewType.Instance && + column.Type.GetItemType() != NumberDataViewType.Single && + column.Type.GetItemType() != TextDataViewType.Instance) + { + throw new ArgumentException($"Only supported feature column types are " + + $"{BooleanDataViewType.Instance}, {NumberDataViewType.Single}, and {TextDataViewType.Instance}. " + + $"Please change the feature column {column.Name} of type {column.Type} to one of " + + $"the supported types.", nameof(trainData)); + } } } - private static void ValidateColumnInformation(IDataView trainData, ColumnInformation columnInformation) + private static void ValidateColumnInformation(IDataView trainData, ColumnInformation columnInformation, TaskKind task) { ValidateColumnInformation(columnInformation); - ValidateTrainDataColumn(trainData, columnInformation.LabelColumnName, LabelColumnPurposeName); + ValidateTrainDataColumn(trainData, columnInformation.LabelColumnName, LabelColumnPurposeName, GetAllowedLabelTypes(task)); ValidateTrainDataColumn(trainData, columnInformation.ExampleWeightColumnName, WeightColumnPurposeName); ValidateTrainDataColumn(trainData, columnInformation.SamplingKeyColumnName, SamplingKeyColumnPurposeName); ValidateTrainDataColumns(trainData, columnInformation.CategoricalColumnNames, CategoricalColumnPurposeName, @@ -228,5 +241,22 @@ private static string FindFirstDuplicate(IEnumerable values) var groups = values.GroupBy(v => v); return groups.FirstOrDefault(g => g.Count() > 1)?.Key; } + + private static IEnumerable GetAllowedLabelTypes(TaskKind task) + { + switch (task) + { + case TaskKind.BinaryClassification: + return new DataViewType[] { BooleanDataViewType.Instance }; + // Multiclass label types are flexible, as we convert the label to a key type + // (if input label is not already a key) before invoking the trainer. + case TaskKind.MulticlassClassification: + return null; + case TaskKind.Regression: + return new DataViewType[] { NumberDataViewType.Single }; + default: + throw new NotSupportedException($"Unsupported task type: {task}"); + } + } } } diff --git a/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs b/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs index e8962c484c..187af7c796 100644 --- a/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs @@ -18,7 +18,7 @@ public class UserInputValidationTests [ExpectedException(typeof(ArgumentNullException))] public void ValidateExperimentExecuteNullTrainData() { - UserInputValidationUtil.ValidateExperimentExecuteArgs(null, new ColumnInformation(), null); + UserInputValidationUtil.ValidateExperimentExecuteArgs(null, new ColumnInformation(), null, TaskKind.Regression); } [TestMethod] @@ -26,7 +26,7 @@ public void ValidateExperimentExecuteNullTrainData() public void ValidateExperimentExecuteNullLabel() { UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, - new ColumnInformation() { LabelColumnName = null }, null); + new ColumnInformation() { LabelColumnName = null }, null, TaskKind.Regression); } [TestMethod] @@ -34,7 +34,7 @@ public void ValidateExperimentExecuteNullLabel() public void ValidateExperimentExecuteLabelNotInTrain() { UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, - new ColumnInformation() { LabelColumnName = "L" }, null); + new ColumnInformation() { LabelColumnName = "L" }, null, TaskKind.Regression); } [TestMethod] @@ -44,7 +44,7 @@ public void ValidateExperimentExecuteNumericColNotInTrain() var columnInfo = new ColumnInformation(); columnInfo.NumericColumnNames.Add("N"); - UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null); + UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null, TaskKind.Regression); } [TestMethod] @@ -53,7 +53,7 @@ public void ValidateExperimentExecuteNullNumericCol() { var columnInfo = new ColumnInformation(); columnInfo.NumericColumnNames.Add(null); - UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null); + UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null, TaskKind.Regression); } [TestMethod] @@ -63,7 +63,7 @@ public void ValidateExperimentExecuteDuplicateCol() var columnInfo = new ColumnInformation(); columnInfo.NumericColumnNames.Add(DefaultColumnNames.Label); - UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null); + UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null, TaskKind.Regression); } [TestMethod] @@ -82,7 +82,7 @@ public void ValidateExperimentExecuteArgsTrainValidColCountMismatch() var validData = validDataBuilder.GetDataView(); UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, - new ColumnInformation() { LabelColumnName = "0" }, validData); + new ColumnInformation() { LabelColumnName = "0" }, validData, TaskKind.Regression); } [TestMethod] @@ -102,7 +102,7 @@ public void ValidateExperimentExecuteArgsTrainValidColNamesMismatch() var validData = validDataBuilder.GetDataView(); UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, - new ColumnInformation() { LabelColumnName = "0" }, validData); + new ColumnInformation() { LabelColumnName = "0" }, validData, TaskKind.Regression); } [TestMethod] @@ -122,7 +122,7 @@ public void ValidateExperimentExecuteArgsTrainValidColTypeMismatch() var validData = validDataBuilder.GetDataView(); UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, - new ColumnInformation() { LabelColumnName = "0" }, validData); + new ColumnInformation() { LabelColumnName = "0" }, validData, TaskKind.Regression); } [TestMethod] @@ -163,7 +163,7 @@ public void ValidateFeaturesColInvalidType() schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single); var schema = schemaBuilder.ToSchema(); var dataView = new EmptyDataView(new MLContext(), schema); - UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(), null); + UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(), null, TaskKind.Regression); } [TestMethod] @@ -181,7 +181,78 @@ public void ValidateTextColumnNotText() var columnInfo = new ColumnInformation(); columnInfo.NumericColumnNames.Add(TextPurposeColName); - UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, columnInfo, null); + UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, columnInfo, null, TaskKind.Regression); + } + + [TestMethod] + public void ValidateRegressionLabelTypes() + { + ValidateLabelTypeTestCore(TaskKind.Regression, NumberDataViewType.Single, true); + ValidateLabelTypeTestCore(TaskKind.Regression, BooleanDataViewType.Instance, false); + ValidateLabelTypeTestCore(TaskKind.Regression, NumberDataViewType.Double, false); + ValidateLabelTypeTestCore(TaskKind.Regression, TextDataViewType.Instance, false); + } + + [TestMethod] + public void ValidateBinaryClassificationLabelTypes() + { + ValidateLabelTypeTestCore(TaskKind.BinaryClassification, NumberDataViewType.Single, false); + ValidateLabelTypeTestCore(TaskKind.BinaryClassification, BooleanDataViewType.Instance, true); + } + + [TestMethod] + public void ValidateMulticlassLabelTypes() + { + ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, NumberDataViewType.Single, true); + ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, BooleanDataViewType.Instance, true); + ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, NumberDataViewType.Double, true); + ValidateLabelTypeTestCore(TaskKind.MulticlassClassification, TextDataViewType.Instance, true); + } + + [TestMethod] + public void ValidateAllowedFeatureColumnTypes() + { + var schemaBuilder = new DataViewSchema.Builder(); + schemaBuilder.AddColumn("Boolean", BooleanDataViewType.Instance); + schemaBuilder.AddColumn("Number", NumberDataViewType.Single); + schemaBuilder.AddColumn("Text", TextDataViewType.Instance); + schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single); + var schema = schemaBuilder.ToSchema(); + var dataView = new EmptyDataView(new MLContext(), schema); + UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(), + null, TaskKind.Regression); + } + + [TestMethod] + [ExpectedException(typeof(ArgumentException))] + public void ValidateProhibitedFeatureColumnType() + { + var schemaBuilder = new DataViewSchema.Builder(); + schemaBuilder.AddColumn("UInt64", NumberDataViewType.UInt64); + schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single); + var schema = schemaBuilder.ToSchema(); + var dataView = new EmptyDataView(new MLContext(), schema); + UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(), + null, TaskKind.Regression); + } + + private static void ValidateLabelTypeTestCore(TaskKind task, DataViewType labelType, bool labelTypeShouldBeValid) + { + var schemaBuilder = new DataViewSchema.Builder(); + schemaBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single); + schemaBuilder.AddColumn(DefaultColumnNames.Label, labelType); + var schema = schemaBuilder.ToSchema(); + var dataView = new EmptyDataView(new MLContext(), schema); + var validationExceptionThrown = false; + try + { + UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(), null, task); + } + catch + { + validationExceptionThrown = true; + } + Assert.AreEqual(labelTypeShouldBeValid, !validationExceptionThrown); } } }