@@ -218,11 +218,32 @@ public void TransformInferenceNumericCols()
218
218
}
219
219
220
220
[ TestMethod ]
221
- public void TransformInferenceFeatCol ( )
221
+ public void TransformInferenceFeatColScalar ( )
222
222
{
223
223
TransformInferenceTestCore ( new ( string , ColumnType , ColumnPurpose , ColumnDimensions ) [ ]
224
224
{
225
225
( DefaultColumnNames . Features , NumberType . R4 , ColumnPurpose . NumericFeature , new ColumnDimensions ( null , null ) ) ,
226
+ } , @"[
227
+ {
228
+ ""Name"": ""ColumnConcatenating"",
229
+ ""NodeType"": ""Transform"",
230
+ ""InColumns"": [
231
+ ""Features""
232
+ ],
233
+ ""OutColumns"": [
234
+ ""Features""
235
+ ],
236
+ ""Properties"": {}
237
+ }
238
+ ]" ) ;
239
+ }
240
+
241
+ [ TestMethod ]
242
+ public void TransformInferenceFeatColVector ( )
243
+ {
244
+ TransformInferenceTestCore ( new ( string , ColumnType , ColumnPurpose , ColumnDimensions ) [ ]
245
+ {
246
+ ( DefaultColumnNames . Features , new VectorType ( NumberType . R4 ) , ColumnPurpose . NumericFeature , new ColumnDimensions ( null , null ) ) ,
226
247
} , @"[]" ) ;
227
248
}
228
249
@@ -249,6 +270,48 @@ public void NumericAndFeatCol()
249
270
]" ) ;
250
271
}
251
272
273
+ [ TestMethod ]
274
+ public void NumericScalarCol ( )
275
+ {
276
+ TransformInferenceTestCore ( new ( string , ColumnType , ColumnPurpose , ColumnDimensions ) [ ]
277
+ {
278
+ ( "Numeric" , NumberType . R4 , ColumnPurpose . NumericFeature , new ColumnDimensions ( null , null ) ) ,
279
+ } , @"[
280
+ {
281
+ ""Name"": ""ColumnConcatenating"",
282
+ ""NodeType"": ""Transform"",
283
+ ""InColumns"": [
284
+ ""Numeric""
285
+ ],
286
+ ""OutColumns"": [
287
+ ""Features""
288
+ ],
289
+ ""Properties"": {}
290
+ }
291
+ ]" ) ;
292
+ }
293
+
294
+ [ TestMethod ]
295
+ public void NumericVectorCol ( )
296
+ {
297
+ TransformInferenceTestCore ( new ( string , ColumnType , ColumnPurpose , ColumnDimensions ) [ ]
298
+ {
299
+ ( "Numeric" , new VectorType ( NumberType . R4 ) , ColumnPurpose . NumericFeature , new ColumnDimensions ( null , null ) ) ,
300
+ } , @"[
301
+ {
302
+ ""Name"": ""ColumnCopying"",
303
+ ""NodeType"": ""Transform"",
304
+ ""InColumns"": [
305
+ ""Numeric""
306
+ ],
307
+ ""OutColumns"": [
308
+ ""Features""
309
+ ],
310
+ ""Properties"": {}
311
+ }
312
+ ]" ) ;
313
+ }
314
+
252
315
[ TestMethod ]
253
316
public void TransformInferenceTextCol ( )
254
317
{
@@ -268,7 +331,7 @@ public void TransformInferenceTextCol()
268
331
""Properties"": {}
269
332
},
270
333
{
271
- ""Name"": ""ColumnConcatenating "",
334
+ ""Name"": ""ColumnCopying "",
272
335
""NodeType"": ""Transform"",
273
336
""InColumns"": [
274
337
""Text_tf""
@@ -566,7 +629,7 @@ public void TransformInferenceDefaultLabelCol()
566
629
{
567
630
TransformInferenceTestCore ( new ( string , ColumnType , ColumnPurpose , ColumnDimensions ) [ ]
568
631
{
569
- ( DefaultColumnNames . Features , NumberType . R4 , ColumnPurpose . NumericFeature , new ColumnDimensions ( null , null ) ) ,
632
+ ( DefaultColumnNames . Features , new VectorType ( NumberType . R4 ) , ColumnPurpose . NumericFeature , new ColumnDimensions ( null , null ) ) ,
570
633
( DefaultColumnNames . Label , NumberType . R4 , ColumnPurpose . Label , new ColumnDimensions ( null , null ) ) ,
571
634
} , @"[]" ) ;
572
635
}
@@ -576,7 +639,7 @@ public void TransformInferenceCustomLabelCol()
576
639
{
577
640
TransformInferenceTestCore ( new ( string , ColumnType , ColumnPurpose , ColumnDimensions ) [ ]
578
641
{
579
- ( DefaultColumnNames . Features , NumberType . R4 , ColumnPurpose . NumericFeature , new ColumnDimensions ( null , null ) ) ,
642
+ ( DefaultColumnNames . Features , new VectorType ( NumberType . R4 ) , ColumnPurpose . NumericFeature , new ColumnDimensions ( null , null ) ) ,
580
643
( "CustomLabel" , NumberType . R4 , ColumnPurpose . Label , new ColumnDimensions ( null , null ) ) ,
581
644
} , @"[
582
645
{
@@ -598,7 +661,7 @@ public void TransformInferenceDefaultGroupIdCol()
598
661
{
599
662
TransformInferenceTestCore ( new ( string , ColumnType , ColumnPurpose , ColumnDimensions ) [ ]
600
663
{
601
- ( DefaultColumnNames . Features , NumberType . R4 , ColumnPurpose . NumericFeature , new ColumnDimensions ( null , null ) ) ,
664
+ ( DefaultColumnNames . Features , new VectorType ( NumberType . R4 ) , ColumnPurpose . NumericFeature , new ColumnDimensions ( null , null ) ) ,
602
665
( DefaultColumnNames . GroupId , NumberType . R4 , ColumnPurpose . Group , new ColumnDimensions ( null , null ) ) ,
603
666
} , @"[]" ) ;
604
667
}
@@ -608,7 +671,7 @@ public void TransformInferenceCustomGroupIdCol()
608
671
{
609
672
TransformInferenceTestCore ( new ( string , ColumnType , ColumnPurpose , ColumnDimensions ) [ ]
610
673
{
611
- ( DefaultColumnNames . Features , NumberType . R4 , ColumnPurpose . NumericFeature , new ColumnDimensions ( null , null ) ) ,
674
+ ( DefaultColumnNames . Features , new VectorType ( NumberType . R4 ) , ColumnPurpose . NumericFeature , new ColumnDimensions ( null , null ) ) ,
612
675
( "CustomGroupId" , NumberType . R4 , ColumnPurpose . Group , new ColumnDimensions ( null , null ) ) ,
613
676
} , @"[
614
677
{
@@ -709,6 +772,7 @@ private static void TestApplyTransformsToRealDataView(IEnumerable<SuggestedTrans
709
772
// assert Features column of type 'R4' exists
710
773
var featuresCol = data . Schema . GetColumnOrNull ( DefaultColumnNames . Features ) ;
711
774
Assert . IsNotNull ( featuresCol ) ;
775
+ Assert . AreEqual ( true , featuresCol . Value . Type . IsVector ( ) ) ;
712
776
Assert . AreEqual ( NumberType . R4 , featuresCol . Value . Type . GetItemType ( ) ) ;
713
777
}
714
778
@@ -735,6 +799,11 @@ private static IDataView BuildDummyDataView(IEnumerable<(string name, ColumnType
735
799
{
736
800
dataBuilder . AddColumn ( column . name , new string [ ] { "a" } ) ;
737
801
}
802
+ else if ( column . type . IsVector ( ) && column . type . GetItemType ( ) == NumberType . R4 )
803
+ {
804
+ dataBuilder . AddColumn ( column . name , Util . GetKeyValueGetter ( new [ ] { "1" , "2" } ) ,
805
+ NumberType . R4 , new float [ ] { 0 , 0 } ) ;
806
+ }
738
807
}
739
808
return dataBuilder . GetDataView ( ) ;
740
809
}
0 commit comments