@@ -739,5 +739,64 @@ public void TestCrossValidationMacroWithNonDefaultNames()
739
739
}
740
740
}
741
741
}
742
+
743
+ [ Fact ]
744
+ public void TestOvaMacro ( )
745
+ {
746
+ var dataPath = GetDataPath ( @"iris.txt" ) ;
747
+ using ( var env = new TlcEnvironment ( 42 ) )
748
+ {
749
+ // Specify subgraph for OVA
750
+ var subGraph = env . CreateExperiment ( ) ;
751
+ var learnerInput = new Trainers . StochasticDualCoordinateAscentBinaryClassifier { NumThreads = 1 } ;
752
+ var learnerOutput = subGraph . Add ( learnerInput ) ;
753
+ // Create pipeline with OVA and multiclass scoring.
754
+ var experiment = env . CreateExperiment ( ) ;
755
+ var importInput = new ML . Data . TextLoader ( dataPath ) ;
756
+ importInput . Arguments . Column = new TextLoaderColumn [ ]
757
+ {
758
+ new TextLoaderColumn { Name = "Label" , Source = new [ ] { new TextLoaderRange ( 0 ) } } ,
759
+ new TextLoaderColumn { Name = "Features" , Source = new [ ] { new TextLoaderRange ( 1 , 4 ) } }
760
+ } ;
761
+ var importOutput = experiment . Add ( importInput ) ;
762
+ var oneVersusAll = new Models . OneVersusAll
763
+ {
764
+ TrainingData = importOutput . Data ,
765
+ Nodes = subGraph ,
766
+ UseProbabilities = true ,
767
+ } ;
768
+ var ovaOutput = experiment . Add ( oneVersusAll ) ;
769
+ var scoreInput = new ML . Transforms . DatasetScorer
770
+ {
771
+ Data = importOutput . Data ,
772
+ PredictorModel = ovaOutput . PredictorModel
773
+ } ;
774
+ var scoreOutput = experiment . Add ( scoreInput ) ;
775
+ var evalInput = new ML . Models . ClassificationEvaluator
776
+ {
777
+ Data = scoreOutput . ScoredData
778
+ } ;
779
+ var evalOutput = experiment . Add ( evalInput ) ;
780
+ experiment . Compile ( ) ;
781
+ experiment . SetInput ( importInput . InputFile , new SimpleFileHandle ( env , dataPath , false , false ) ) ;
782
+ experiment . Run ( ) ;
783
+
784
+ var data = experiment . GetOutput ( evalOutput . OverallMetrics ) ;
785
+ var schema = data . Schema ;
786
+ var b = schema . TryGetColumnIndex ( MultiClassClassifierEvaluator . AccuracyMacro , out int accCol ) ;
787
+ Assert . True ( b ) ;
788
+ using ( var cursor = data . GetRowCursor ( col => col == accCol ) )
789
+ {
790
+ var getter = cursor . GetGetter < double > ( accCol ) ;
791
+ b = cursor . MoveNext ( ) ;
792
+ Assert . True ( b ) ;
793
+ double acc = 0 ;
794
+ getter ( ref acc ) ;
795
+ Assert . Equal ( 0.96 , acc , 2 ) ;
796
+ b = cursor . MoveNext ( ) ;
797
+ Assert . False ( b ) ;
798
+ }
799
+ }
800
+ }
742
801
}
743
802
}
0 commit comments