@@ -19,7 +19,7 @@ public sealed partial class CrossValidator
19
19
/// <typeparam name="TOutput">Class type that represents prediction schema.</typeparam>
20
20
/// <param name="pipeline">Machine learning pipeline may contain loader, transforms and at least one trainer.</param>
21
21
/// <returns>List containing metrics and predictor model for each fold</returns>
22
- public CrossValidationOutput < TInput , TOutput > CrossValidate < TInput , TOutput > ( LearningPipeline pipeline )
22
+ public CrossValidationOutput < TInput , TOutput > CrossValidate < TInput , TOutput > ( LearningPipeline pipeline )
23
23
where TInput : class
24
24
where TOutput : class , new ( )
25
25
{
@@ -76,7 +76,7 @@ public CrossValidationOutput<TInput, TOutput> CrossValidate<TInput, TOutput>(Lea
76
76
{
77
77
PredictorModel = predictorModel
78
78
} ;
79
-
79
+
80
80
var scorerOutput = subGraph . Add ( scorer ) ;
81
81
lastTransformModel = scorerOutput . ScoringTransform ;
82
82
step = new ScorerPipelineStep ( scorerOutput . ScoredData , scorerOutput . ScoringTransform ) ;
@@ -129,7 +129,7 @@ public CrossValidationOutput<TInput, TOutput> CrossValidate<TInput, TOutput>(Lea
129
129
experiment . GetOutput ( crossValidateOutput . OverallMetrics ) ,
130
130
experiment . GetOutput ( crossValidateOutput . ConfusionMatrix ) , 2 ) ;
131
131
}
132
- else if ( Kind == MacroUtilsTrainerKinds . SignatureMultiClassClassifierTrainer )
132
+ else if ( Kind == MacroUtilsTrainerKinds . SignatureMultiClassClassifierTrainer )
133
133
{
134
134
cvOutput . ClassificationMetrics = ClassificationMetrics . FromMetrics (
135
135
environment ,
@@ -142,6 +142,12 @@ public CrossValidationOutput<TInput, TOutput> CrossValidate<TInput, TOutput>(Lea
142
142
environment ,
143
143
experiment . GetOutput ( crossValidateOutput . OverallMetrics ) ) ;
144
144
}
145
+ else if ( Kind == MacroUtilsTrainerKinds . SignatureClusteringTrainer )
146
+ {
147
+ cvOutput . ClusterMetrics = ClusterMetrics . FromOverallMetrics (
148
+ environment ,
149
+ experiment . GetOutput ( crossValidateOutput . OverallMetrics ) ) ;
150
+ }
145
151
else
146
152
{
147
153
//Implement metrics for ranking, clustering and anomaly detection.
@@ -174,6 +180,7 @@ public class CrossValidationOutput<TInput, TOutput>
174
180
public List < BinaryClassificationMetrics > BinaryClassificationMetrics ;
175
181
public List < ClassificationMetrics > ClassificationMetrics ;
176
182
public List < RegressionMetrics > RegressionMetrics ;
183
+ public List < ClusterMetrics > ClusterMetrics ;
177
184
public PredictionModel < TInput , TOutput > [ ] PredictorModels ;
178
185
179
186
//REVIEW: Add warnings and per instance results and implement
0 commit comments