9
9
using Microsoft . Data . DataView ;
10
10
using Microsoft . ML . Data ;
11
11
12
- // todo: re-write & test user input validation once final API nailed down.
13
- // Tracked by Github issue: https://github.com/dotnet/machinelearning-automl/issues/159
14
-
15
12
namespace Microsoft . ML . Auto
16
13
{
17
- /* internal static class UserInputValidationUtil
14
+ internal static class UserInputValidationUtil
18
15
{
19
- public static void ValidateAutoFitArgs (IDataView trainData, string label, IDataView validationData ,
20
- AutoFitSettings settings, IEnumerable<(string, ColumnPurpose)> purposeOverrides )
16
+ public static void ValidateExperimentExecuteArgs ( IDataView trainData , ColumnInformation columnInformation ,
17
+ IDataView validationData )
21
18
{
22
19
ValidateTrainData ( trainData ) ;
20
+ ValidateColumnInformation ( trainData , columnInformation ) ;
23
21
ValidateValidationData ( trainData , validationData ) ;
24
- ValidateLabel(trainData, label);
25
- ValidateSettings(settings);
26
- ValidatePurposeOverrides(trainData, validationData, label, purposeOverrides);
27
22
}
28
23
29
- public static void ValidateInferColumnsArgs(string path, string label)
24
+ public static void ValidateInferColumnsArgs ( string path , ColumnInformation columnInformation )
25
+ {
26
+ ValidateColumnInformation ( columnInformation ) ;
27
+ ValidatePath ( path ) ;
28
+ }
29
+
30
+ public static void ValidateInferColumnsArgs ( string path , string labelColumn )
30
31
{
31
- ValidateLabel(label );
32
+ ValidateLabelColumn ( labelColumn ) ;
32
33
ValidatePath ( path ) ;
33
34
}
34
35
@@ -51,21 +52,55 @@ private static void ValidateTrainData(IDataView trainData)
51
52
}
52
53
}
53
54
54
- private static void ValidateLabel (IDataView trainData, string label )
55
+ private static void ValidateColumnInformation ( IDataView trainData , ColumnInformation columnInformation )
55
56
{
56
- ValidateLabel(label);
57
+ ValidateColumnInformation ( columnInformation ) ;
58
+ ValidateTrainDataColumnExists ( trainData , columnInformation . LabelColumn ) ;
59
+ ValidateTrainDataColumnExists ( trainData , columnInformation . WeightColumn ) ;
60
+ ValidateTrainDataColumnsExist ( trainData , columnInformation . CategoricalColumns ) ;
61
+ ValidateTrainDataColumnsExist ( trainData , columnInformation . NumericColumns ) ;
62
+ ValidateTrainDataColumnsExist ( trainData , columnInformation . TextColumns ) ;
63
+ ValidateTrainDataColumnsExist ( trainData , columnInformation . IgnoredColumns ) ;
64
+ }
57
65
58
- if (trainData.Schema.GetColumnOrNull(label) == null)
66
+ private static void ValidateColumnInformation ( ColumnInformation columnInformation )
67
+ {
68
+ ValidateLabelColumn ( columnInformation . LabelColumn ) ;
69
+
70
+ ValidateColumnInfoEnumerationProperty ( columnInformation . CategoricalColumns , "categorical" ) ;
71
+ ValidateColumnInfoEnumerationProperty ( columnInformation . NumericColumns , "numeric" ) ;
72
+ ValidateColumnInfoEnumerationProperty ( columnInformation . TextColumns , "text" ) ;
73
+ ValidateColumnInfoEnumerationProperty ( columnInformation . IgnoredColumns , "ignored" ) ;
74
+
75
+ // keep a list of all columns, to detect duplicates
76
+ var allColumns = new List < string > ( ) ;
77
+ allColumns . Add ( columnInformation . LabelColumn ) ;
78
+ if ( columnInformation . WeightColumn != null ) { allColumns . Add ( columnInformation . WeightColumn ) ; }
79
+ if ( columnInformation . CategoricalColumns != null ) { allColumns . AddRange ( columnInformation . CategoricalColumns ) ; }
80
+ if ( columnInformation . NumericColumns != null ) { allColumns . AddRange ( columnInformation . NumericColumns ) ; }
81
+ if ( columnInformation . TextColumns != null ) { allColumns . AddRange ( columnInformation . TextColumns ) ; }
82
+ if ( columnInformation . IgnoredColumns != null ) { allColumns . AddRange ( columnInformation . IgnoredColumns ) ; }
83
+
84
+ var duplicateColName = FindFirstDuplicate ( allColumns ) ;
85
+ if ( duplicateColName != null )
59
86
{
60
- throw new ArgumentException($"Provided label column '{label}' not found in training data. ", nameof(label ));
87
+ throw new ArgumentException ( $ "Duplicate column name { duplicateColName } is present in two or more distinct properties of provided column information ", nameof ( columnInformation ) ) ;
61
88
}
62
89
}
63
90
64
- private static void ValidateLabel( string label )
91
+ private static void ValidateColumnInfoEnumerationProperty ( IEnumerable < string > columns , string propertyName )
65
92
{
66
- if (label == null )
93
+ if ( columns ? . Contains ( null ) == true )
67
94
{
68
- throw new ArgumentNullException(nameof(label), "Provided label cannot be null");
95
+ throw new ArgumentException ( $ "Null column string was specified as { propertyName } in column information") ;
96
+ }
97
+ }
98
+
99
+ private static void ValidateLabelColumn ( string labelColumn )
100
+ {
101
+ if ( labelColumn == null )
102
+ {
103
+ throw new ArgumentException ( "Provided label column cannot be null" ) ;
69
104
}
70
105
}
71
106
@@ -120,55 +155,24 @@ private static void ValidateValidationData(IDataView trainData, IDataView valida
120
155
}
121
156
}
122
157
123
- private static void ValidateSettings(AutoFitSettings settings )
158
+ private static void ValidateTrainDataColumnsExist ( IDataView trainData , IEnumerable < string > columnNames )
124
159
{
125
- if (settings?.StoppingCriteria == null)
160
+ if ( columnNames == null )
126
161
{
127
162
return ;
128
163
}
129
164
130
- if (settings.StoppingCriteria.MaxIterations <= 0 )
165
+ foreach ( var columnName in columnNames )
131
166
{
132
- throw new ArgumentOutOfRangeException(nameof(settings), "Max iterations must be > 0" );
167
+ ValidateTrainDataColumnExists ( trainData , columnName ) ;
133
168
}
134
169
}
135
170
136
- private static void ValidatePurposeOverrides(IDataView trainData, IDataView validationData,
137
- string label, IEnumerable<(string, ColumnPurpose)> purposeOverrides)
171
+ private static void ValidateTrainDataColumnExists ( IDataView trainData , string columnName )
138
172
{
139
- if (purposeOverrides == null)
140
- {
141
- return;
142
- }
143
-
144
- foreach (var purposeOverride in purposeOverrides)
145
- {
146
- var colName = purposeOverride.Item1;
147
- var colPurpose = purposeOverride.Item2;
148
-
149
- if (colName == null)
150
- {
151
- throw new ArgumentException("Purpose override column name cannot be null.", nameof(purposeOverrides));
152
- }
153
-
154
- if (trainData.Schema.GetColumnOrNull(colName) == null)
155
- {
156
- throw new ArgumentException($"Purpose override column name '{colName}' not found in training data.", nameof(purposeOverride));
157
- }
158
-
159
- // if column w/ purpose = 'Label' found, ensure it matches the passed-in label
160
- if (colPurpose == ColumnPurpose.Label && colName != label)
161
- {
162
- throw new ArgumentException($"Label column name in provided list of purposes '{colName}' must match " +
163
- $"the label column name '{label}'", nameof(purposeOverrides));
164
- }
165
- }
166
-
167
- // ensure all column names unique
168
- var duplicateColName = FindFirstDuplicate(purposeOverrides.Select(p => p.Item1));
169
- if (duplicateColName != null)
173
+ if ( columnName != null && trainData . Schema . GetColumnOrNull ( columnName ) == null )
170
174
{
171
- throw new ArgumentException($"Duplicate column name '{duplicateColName }' in purpose overrides.", nameof(purposeOverrides) );
175
+ throw new ArgumentException ( $ "Provided column ' { columnName } ' not found in training data." ) ;
172
176
}
173
177
}
174
178
@@ -177,5 +181,5 @@ private static string FindFirstDuplicate(IEnumerable<string> values)
177
181
var groups = values . GroupBy ( v => v ) ;
178
182
return groups . FirstOrDefault ( g => g . Count ( ) > 1 ) ? . Key ;
179
183
}
180
- }*/
184
+ }
181
185
}
0 commit comments