@@ -528,6 +528,89 @@ public void TestCrossValidationMacroWithMultiClass()
528
528
}
529
529
Assert . Equal ( 0 , rowCount ) ;
530
530
}
531
+
532
+ var warnings = experiment . GetOutput ( crossValidateOutput . Warnings ) ;
533
+ using ( var cursor = warnings . GetRowCursor ( col => true ) )
534
+ Assert . False ( cursor . MoveNext ( ) ) ;
535
+ }
536
+ }
537
+
538
+ [ Fact ]
539
+ public void TestCrossValidationMacroMultiClassWithWarnings ( )
540
+ {
541
+ var dataPath = GetDataPath ( @"Train-Tiny-28x28.txt" ) ;
542
+ using ( var env = new TlcEnvironment ( 42 ) )
543
+ {
544
+ var subGraph = env . CreateExperiment ( ) ;
545
+
546
+ var nop = new ML . Transforms . NoOperation ( ) ;
547
+ var nopOutput = subGraph . Add ( nop ) ;
548
+
549
+ var learnerInput = new ML . Trainers . LogisticRegressionClassifier
550
+ {
551
+ TrainingData = nopOutput . OutputData ,
552
+ NumThreads = 1
553
+ } ;
554
+ var learnerOutput = subGraph . Add ( learnerInput ) ;
555
+
556
+ var experiment = env . CreateExperiment ( ) ;
557
+ var importInput = new ML . Data . TextLoader ( dataPath ) ;
558
+ var importOutput = experiment . Add ( importInput ) ;
559
+
560
+ var filter = new ML . Transforms . RowRangeFilter ( ) ;
561
+ filter . Data = importOutput . Data ;
562
+ filter . Column = "Label" ;
563
+ filter . Min = 0 ;
564
+ filter . Max = 5 ;
565
+ var filterOutput = experiment . Add ( filter ) ;
566
+
567
+ var term = new ML . Transforms . TextToKeyConverter ( ) ;
568
+ term . Column = new [ ]
569
+ {
570
+ new ML . Transforms . TermTransformColumn ( )
571
+ {
572
+ Source = "Label" , Name = "Strat" , Sort = ML . Transforms . TermTransformSortOrder . Value
573
+ }
574
+ } ;
575
+ term . Data = filterOutput . OutputData ;
576
+ var termOutput = experiment . Add ( term ) ;
577
+
578
+ var crossValidate = new ML . Models . CrossValidator
579
+ {
580
+ Data = termOutput . OutputData ,
581
+ Nodes = subGraph ,
582
+ Kind = ML . Models . MacroUtilsTrainerKinds . SignatureMultiClassClassifierTrainer ,
583
+ TransformModel = null ,
584
+ StratificationColumn = "Strat"
585
+ } ;
586
+ crossValidate . Inputs . Data = nop . Data ;
587
+ crossValidate . Outputs . PredictorModel = learnerOutput . PredictorModel ;
588
+ var crossValidateOutput = experiment . Add ( crossValidate ) ;
589
+
590
+ experiment . Compile ( ) ;
591
+ importInput . SetInput ( env , experiment ) ;
592
+ experiment . Run ( ) ;
593
+ var warnings = experiment . GetOutput ( crossValidateOutput . Warnings ) ;
594
+
595
+ var schema = warnings . Schema ;
596
+ var b = schema . TryGetColumnIndex ( "WarningText" , out int warningCol ) ;
597
+ Assert . True ( b ) ;
598
+ using ( var cursor = warnings . GetRowCursor ( col => col == warningCol ) )
599
+ {
600
+ var getter = cursor . GetGetter < DvText > ( warningCol ) ;
601
+
602
+ b = cursor . MoveNext ( ) ;
603
+ Assert . True ( b ) ;
604
+ var warning = default ( DvText ) ;
605
+ getter ( ref warning ) ;
606
+ Assert . Contains ( "test instances with class values not seen in the training set." , warning . ToString ( ) ) ;
607
+ b = cursor . MoveNext ( ) ;
608
+ Assert . True ( b ) ;
609
+ getter ( ref warning ) ;
610
+ Assert . Contains ( "Detected columns of variable length: SortedScores, SortedClasses" , warning . ToString ( ) ) ;
611
+ b = cursor . MoveNext ( ) ;
612
+ Assert . False ( b ) ;
613
+ }
531
614
}
532
615
}
533
616
0 commit comments