@@ -18,23 +18,23 @@ public class UserInputValidationTests
18
18
[ ExpectedException ( typeof ( ArgumentNullException ) ) ]
19
19
public void ValidateExperimentExecuteNullTrainData ( )
20
20
{
21
- UserInputValidationUtil . ValidateExperimentExecuteArgs ( null , new ColumnInformation ( ) , null ) ;
21
+ UserInputValidationUtil . ValidateExperimentExecuteArgs ( null , new ColumnInformation ( ) , null , TaskKind . Regression ) ;
22
22
}
23
23
24
24
[ TestMethod ]
25
25
[ ExpectedException ( typeof ( ArgumentException ) ) ]
26
26
public void ValidateExperimentExecuteNullLabel ( )
27
27
{
28
28
UserInputValidationUtil . ValidateExperimentExecuteArgs ( Data ,
29
- new ColumnInformation ( ) { LabelColumnName = null } , null ) ;
29
+ new ColumnInformation ( ) { LabelColumnName = null } , null , TaskKind . Regression ) ;
30
30
}
31
31
32
32
[ TestMethod ]
33
33
[ ExpectedException ( typeof ( ArgumentException ) ) ]
34
34
public void ValidateExperimentExecuteLabelNotInTrain ( )
35
35
{
36
36
UserInputValidationUtil . ValidateExperimentExecuteArgs ( Data ,
37
- new ColumnInformation ( ) { LabelColumnName = "L" } , null ) ;
37
+ new ColumnInformation ( ) { LabelColumnName = "L" } , null , TaskKind . Regression ) ;
38
38
}
39
39
40
40
[ TestMethod ]
@@ -44,7 +44,7 @@ public void ValidateExperimentExecuteNumericColNotInTrain()
44
44
var columnInfo = new ColumnInformation ( ) ;
45
45
columnInfo . NumericColumnNames . Add ( "N" ) ;
46
46
47
- UserInputValidationUtil . ValidateExperimentExecuteArgs ( Data , columnInfo , null ) ;
47
+ UserInputValidationUtil . ValidateExperimentExecuteArgs ( Data , columnInfo , null , TaskKind . Regression ) ;
48
48
}
49
49
50
50
[ TestMethod ]
@@ -53,7 +53,7 @@ public void ValidateExperimentExecuteNullNumericCol()
53
53
{
54
54
var columnInfo = new ColumnInformation ( ) ;
55
55
columnInfo . NumericColumnNames . Add ( null ) ;
56
- UserInputValidationUtil . ValidateExperimentExecuteArgs ( Data , columnInfo , null ) ;
56
+ UserInputValidationUtil . ValidateExperimentExecuteArgs ( Data , columnInfo , null , TaskKind . Regression ) ;
57
57
}
58
58
59
59
[ TestMethod ]
@@ -63,7 +63,7 @@ public void ValidateExperimentExecuteDuplicateCol()
63
63
var columnInfo = new ColumnInformation ( ) ;
64
64
columnInfo . NumericColumnNames . Add ( DefaultColumnNames . Label ) ;
65
65
66
- UserInputValidationUtil . ValidateExperimentExecuteArgs ( Data , columnInfo , null ) ;
66
+ UserInputValidationUtil . ValidateExperimentExecuteArgs ( Data , columnInfo , null , TaskKind . Regression ) ;
67
67
}
68
68
69
69
[ TestMethod ]
@@ -82,7 +82,7 @@ public void ValidateExperimentExecuteArgsTrainValidColCountMismatch()
82
82
var validData = validDataBuilder . GetDataView ( ) ;
83
83
84
84
UserInputValidationUtil . ValidateExperimentExecuteArgs ( trainData ,
85
- new ColumnInformation ( ) { LabelColumnName = "0" } , validData ) ;
85
+ new ColumnInformation ( ) { LabelColumnName = "0" } , validData , TaskKind . Regression ) ;
86
86
}
87
87
88
88
[ TestMethod ]
@@ -102,7 +102,7 @@ public void ValidateExperimentExecuteArgsTrainValidColNamesMismatch()
102
102
var validData = validDataBuilder . GetDataView ( ) ;
103
103
104
104
UserInputValidationUtil . ValidateExperimentExecuteArgs ( trainData ,
105
- new ColumnInformation ( ) { LabelColumnName = "0" } , validData ) ;
105
+ new ColumnInformation ( ) { LabelColumnName = "0" } , validData , TaskKind . Regression ) ;
106
106
}
107
107
108
108
[ TestMethod ]
@@ -122,7 +122,7 @@ public void ValidateExperimentExecuteArgsTrainValidColTypeMismatch()
122
122
var validData = validDataBuilder . GetDataView ( ) ;
123
123
124
124
UserInputValidationUtil . ValidateExperimentExecuteArgs ( trainData ,
125
- new ColumnInformation ( ) { LabelColumnName = "0" } , validData ) ;
125
+ new ColumnInformation ( ) { LabelColumnName = "0" } , validData , TaskKind . Regression ) ;
126
126
}
127
127
128
128
[ TestMethod ]
@@ -163,7 +163,7 @@ public void ValidateFeaturesColInvalidType()
163
163
schemaBuilder . AddColumn ( DefaultColumnNames . Label , NumberDataViewType . Single ) ;
164
164
var schema = schemaBuilder . ToSchema ( ) ;
165
165
var dataView = new EmptyDataView ( new MLContext ( ) , schema ) ;
166
- UserInputValidationUtil . ValidateExperimentExecuteArgs ( dataView , new ColumnInformation ( ) , null ) ;
166
+ UserInputValidationUtil . ValidateExperimentExecuteArgs ( dataView , new ColumnInformation ( ) , null , TaskKind . Regression ) ;
167
167
}
168
168
169
169
[ TestMethod ]
@@ -181,7 +181,78 @@ public void ValidateTextColumnNotText()
181
181
var columnInfo = new ColumnInformation ( ) ;
182
182
columnInfo . NumericColumnNames . Add ( TextPurposeColName ) ;
183
183
184
- UserInputValidationUtil . ValidateExperimentExecuteArgs ( dataView , columnInfo , null ) ;
184
+ UserInputValidationUtil . ValidateExperimentExecuteArgs ( dataView , columnInfo , null , TaskKind . Regression ) ;
185
+ }
186
+
187
+ [ TestMethod ]
188
+ public void ValidateRegressionLabelTypes ( )
189
+ {
190
+ ValidateLabelTypeTestCore ( TaskKind . Regression , NumberDataViewType . Single , true ) ;
191
+ ValidateLabelTypeTestCore ( TaskKind . Regression , BooleanDataViewType . Instance , false ) ;
192
+ ValidateLabelTypeTestCore ( TaskKind . Regression , NumberDataViewType . Double , false ) ;
193
+ ValidateLabelTypeTestCore ( TaskKind . Regression , TextDataViewType . Instance , false ) ;
194
+ }
195
+
196
+ [ TestMethod ]
197
+ public void ValidateBinaryClassificationLabelTypes ( )
198
+ {
199
+ ValidateLabelTypeTestCore ( TaskKind . BinaryClassification , NumberDataViewType . Single , false ) ;
200
+ ValidateLabelTypeTestCore ( TaskKind . BinaryClassification , BooleanDataViewType . Instance , true ) ;
201
+ }
202
+
203
+ [ TestMethod ]
204
+ public void ValidateMulticlassLabelTypes ( )
205
+ {
206
+ ValidateLabelTypeTestCore ( TaskKind . MulticlassClassification , NumberDataViewType . Single , true ) ;
207
+ ValidateLabelTypeTestCore ( TaskKind . MulticlassClassification , BooleanDataViewType . Instance , true ) ;
208
+ ValidateLabelTypeTestCore ( TaskKind . MulticlassClassification , NumberDataViewType . Double , true ) ;
209
+ ValidateLabelTypeTestCore ( TaskKind . MulticlassClassification , TextDataViewType . Instance , true ) ;
210
+ }
211
+
212
+ [ TestMethod ]
213
+ public void ValidateAllowedFeatureColumnTypes ( )
214
+ {
215
+ var schemaBuilder = new DataViewSchema . Builder ( ) ;
216
+ schemaBuilder . AddColumn ( "Boolean" , BooleanDataViewType . Instance ) ;
217
+ schemaBuilder . AddColumn ( "Number" , NumberDataViewType . Single ) ;
218
+ schemaBuilder . AddColumn ( "Text" , TextDataViewType . Instance ) ;
219
+ schemaBuilder . AddColumn ( DefaultColumnNames . Label , NumberDataViewType . Single ) ;
220
+ var schema = schemaBuilder . ToSchema ( ) ;
221
+ var dataView = new EmptyDataView ( new MLContext ( ) , schema ) ;
222
+ UserInputValidationUtil . ValidateExperimentExecuteArgs ( dataView , new ColumnInformation ( ) ,
223
+ null , TaskKind . Regression ) ;
224
+ }
225
+
226
+ [ TestMethod ]
227
+ [ ExpectedException ( typeof ( ArgumentException ) ) ]
228
+ public void ValidateProhibitedFeatureColumnType ( )
229
+ {
230
+ var schemaBuilder = new DataViewSchema . Builder ( ) ;
231
+ schemaBuilder . AddColumn ( "UInt64" , NumberDataViewType . UInt64 ) ;
232
+ schemaBuilder . AddColumn ( DefaultColumnNames . Label , NumberDataViewType . Single ) ;
233
+ var schema = schemaBuilder . ToSchema ( ) ;
234
+ var dataView = new EmptyDataView ( new MLContext ( ) , schema ) ;
235
+ UserInputValidationUtil . ValidateExperimentExecuteArgs ( dataView , new ColumnInformation ( ) ,
236
+ null , TaskKind . Regression ) ;
237
+ }
238
+
239
+ private static void ValidateLabelTypeTestCore ( TaskKind task , DataViewType labelType , bool labelTypeShouldBeValid )
240
+ {
241
+ var schemaBuilder = new DataViewSchema . Builder ( ) ;
242
+ schemaBuilder . AddColumn ( DefaultColumnNames . Features , NumberDataViewType . Single ) ;
243
+ schemaBuilder . AddColumn ( DefaultColumnNames . Label , labelType ) ;
244
+ var schema = schemaBuilder . ToSchema ( ) ;
245
+ var dataView = new EmptyDataView ( new MLContext ( ) , schema ) ;
246
+ var validationExceptionThrown = false ;
247
+ try
248
+ {
249
+ UserInputValidationUtil . ValidateExperimentExecuteArgs ( dataView , new ColumnInformation ( ) , null , task ) ;
250
+ }
251
+ catch
252
+ {
253
+ validationExceptionThrown = true ;
254
+ }
255
+ Assert . AreEqual ( labelTypeShouldBeValid , ! validationExceptionThrown ) ;
185
256
}
186
257
}
187
258
}
0 commit comments