Skip to content

Commit f5776b0

Browse files
Adding more metrics to BinaryClassification Experiment (#6571)
* fix #6570 * fix build error * fix build error
1 parent db19715 commit f5776b0

File tree

3 files changed

+50
-56
lines changed

3 files changed

+50
-56
lines changed

src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs

+19-18
Original file line numberDiff line numberDiff line change
@@ -381,14 +381,8 @@ public TrialResult Run(TrialSettings settings)
381381
// now we just randomly pick a model, but a better way is to provide option to pick a model which score is the cloest to average or the best.
382382
var res = metrics[_rnd.Next(fold)];
383383
var model = res.Model;
384-
var metric = metricManager.Metric switch
385-
{
386-
BinaryClassificationMetric.PositivePrecision => res.Metrics.PositivePrecision,
387-
BinaryClassificationMetric.Accuracy => res.Metrics.Accuracy,
388-
BinaryClassificationMetric.AreaUnderRocCurve => res.Metrics.AreaUnderRocCurve,
389-
BinaryClassificationMetric.AreaUnderPrecisionRecallCurve => res.Metrics.AreaUnderPrecisionRecallCurve,
390-
_ => throw new NotImplementedException($"{metricManager.MetricName} is not supported!"),
391-
};
384+
var metric = GetMetric(metricManager.Metric, res.Metrics);
385+
392386
var loss = metricManager.IsMaximize ? -metric : metric;
393387
stopWatch.Stop();
394388

@@ -413,16 +407,7 @@ public TrialResult Run(TrialSettings settings)
413407
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
414408
var eval = model.Transform(trainTestDatasetManager.TestDataset);
415409
var metrics = _context.BinaryClassification.EvaluateNonCalibrated(eval, metricManager.LabelColumn, predictedLabelColumnName: metricManager.PredictedColumn);
416-
417-
// now we just randomly pick a model, but a better way is to provide option to pick a model which score is the cloest to average or the best.
418-
var metric = Enum.Parse(typeof(BinaryClassificationMetric), metricManager.MetricName) switch
419-
{
420-
BinaryClassificationMetric.PositivePrecision => metrics.PositivePrecision,
421-
BinaryClassificationMetric.Accuracy => metrics.Accuracy,
422-
BinaryClassificationMetric.AreaUnderRocCurve => metrics.AreaUnderRocCurve,
423-
BinaryClassificationMetric.AreaUnderPrecisionRecallCurve => metrics.AreaUnderPrecisionRecallCurve,
424-
_ => throw new NotImplementedException($"{metricManager.Metric} is not supported!"),
425-
};
410+
var metric = GetMetric(metricManager.Metric, metrics);
426411
var loss = metricManager.IsMaximize ? -metric : metric;
427412

428413
stopWatch.Stop();
@@ -465,5 +450,21 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
465450
throw;
466451
}
467452
}
453+
454+
private double GetMetric(BinaryClassificationMetric metric, BinaryClassificationMetrics metrics)
455+
{
456+
return metric switch
457+
{
458+
BinaryClassificationMetric.PositivePrecision => metrics.PositivePrecision,
459+
BinaryClassificationMetric.Accuracy => metrics.Accuracy,
460+
BinaryClassificationMetric.AreaUnderRocCurve => metrics.AreaUnderRocCurve,
461+
BinaryClassificationMetric.AreaUnderPrecisionRecallCurve => metrics.AreaUnderPrecisionRecallCurve,
462+
BinaryClassificationMetric.PositiveRecall => metrics.PositiveRecall,
463+
BinaryClassificationMetric.NegativePrecision => metrics.NegativePrecision,
464+
BinaryClassificationMetric.NegativeRecall => metrics.NegativeRecall,
465+
BinaryClassificationMetric.F1Score => metrics.F1Score,
466+
_ => throw new NotImplementedException($"{metric} is not supported!"),
467+
};
468+
}
468469
}
469470
}

src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs

+15-19
Original file line numberDiff line numberDiff line change
@@ -375,15 +375,7 @@ public TrialResult Run(TrialSettings settings)
375375
// now we just randomly pick a model, but a better way is to provide option to pick a model which score is the cloest to average or the best.
376376
var res = metrics[_rnd.Next(fold)];
377377
var model = res.Model;
378-
var metric = metricManager.Metric switch
379-
{
380-
MulticlassClassificationMetric.MacroAccuracy => res.Metrics.MacroAccuracy,
381-
MulticlassClassificationMetric.MicroAccuracy => res.Metrics.MicroAccuracy,
382-
MulticlassClassificationMetric.LogLoss => res.Metrics.LogLoss,
383-
MulticlassClassificationMetric.LogLossReduction => res.Metrics.LogLossReduction,
384-
MulticlassClassificationMetric.TopKAccuracy => res.Metrics.TopKAccuracy,
385-
_ => throw new NotImplementedException($"{metricManager.MetricName} is not supported!"),
386-
};
378+
var metric = GetMetric(metricManager.Metric, res.Metrics);
387379
var loss = metricManager.IsMaximize ? -metric : metric;
388380

389381
stopWatch.Stop();
@@ -409,16 +401,7 @@ public TrialResult Run(TrialSettings settings)
409401
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
410402
var eval = model.Transform(trainTestDatasetManager.TestDataset);
411403
var metrics = _context.MulticlassClassification.Evaluate(eval, metricManager.LabelColumn, predictedLabelColumnName: metricManager.PredictedColumn);
412-
413-
var metric = metricManager.Metric switch
414-
{
415-
MulticlassClassificationMetric.MacroAccuracy => metrics.MacroAccuracy,
416-
MulticlassClassificationMetric.MicroAccuracy => metrics.MicroAccuracy,
417-
MulticlassClassificationMetric.LogLoss => metrics.LogLoss,
418-
MulticlassClassificationMetric.LogLossReduction => metrics.LogLossReduction,
419-
MulticlassClassificationMetric.TopKAccuracy => metrics.TopKAccuracy,
420-
_ => throw new NotImplementedException($"{metricManager.Metric} is not supported!"),
421-
};
404+
var metric = GetMetric(metricManager.Metric, metrics);
422405
var loss = metricManager.IsMaximize ? -metric : metric;
423406

424407
stopWatch.Stop();
@@ -462,6 +445,19 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
462445
}
463446
}
464447

448+
private double GetMetric(MulticlassClassificationMetric metric, MulticlassClassificationMetrics metrics)
449+
{
450+
return metric switch
451+
{
452+
MulticlassClassificationMetric.MacroAccuracy => metrics.MacroAccuracy,
453+
MulticlassClassificationMetric.MicroAccuracy => metrics.MicroAccuracy,
454+
MulticlassClassificationMetric.LogLoss => metrics.LogLoss,
455+
MulticlassClassificationMetric.LogLossReduction => metrics.LogLossReduction,
456+
MulticlassClassificationMetric.TopKAccuracy => metrics.TopKAccuracy,
457+
_ => throw new NotImplementedException($"{metric} is not supported!"),
458+
};
459+
}
460+
465461
public void Dispose()
466462
{
467463
_context.CancelExecution();

src/Microsoft.ML.AutoML/API/RegressionExperiment.cs

+16-19
Original file line numberDiff line numberDiff line change
@@ -402,14 +402,7 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
402402
// now we just randomly pick a model, but a better way is to provide option to pick a model which score is the cloest to average or the best.
403403
var res = metrics[_rnd.Next(fold)];
404404
var model = res.Model;
405-
var metric = metricManager.Metric switch
406-
{
407-
RegressionMetric.RootMeanSquaredError => res.Metrics.RootMeanSquaredError,
408-
RegressionMetric.RSquared => res.Metrics.RSquared,
409-
RegressionMetric.MeanSquaredError => res.Metrics.MeanSquaredError,
410-
RegressionMetric.MeanAbsoluteError => res.Metrics.MeanAbsoluteError,
411-
_ => throw new NotImplementedException($"{metricManager.MetricName} is not supported!"),
412-
};
405+
var metric = GetMetric(metricManager.Metric, res.Metrics);
413406
var loss = metricManager.IsMaximize ? -metric : metric;
414407

415408
stopWatch.Stop();
@@ -434,16 +427,8 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
434427
stopWatch.Start();
435428
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
436429
var eval = model.Transform(trainTestDatasetManager.TestDataset);
437-
var res = _context.Regression.Evaluate(eval, metricManager.LabelColumn, scoreColumnName: metricManager.ScoreColumn);
438-
439-
var metric = metricManager.Metric switch
440-
{
441-
RegressionMetric.RootMeanSquaredError => res.RootMeanSquaredError,
442-
RegressionMetric.RSquared => res.RSquared,
443-
RegressionMetric.MeanSquaredError => res.MeanSquaredError,
444-
RegressionMetric.MeanAbsoluteError => res.MeanAbsoluteError,
445-
_ => throw new NotImplementedException($"{metricManager.Metric} is not supported!"),
446-
};
430+
var metrics = _context.Regression.Evaluate(eval, metricManager.LabelColumn, scoreColumnName: metricManager.ScoreColumn);
431+
var metric = GetMetric(metricManager.Metric, metrics);
447432
var loss = metricManager.IsMaximize ? -metric : metric;
448433

449434
stopWatch.Stop();
@@ -453,10 +438,10 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
453438
{
454439
Loss = loss,
455440
Metric = metric,
441+
Metrics = metrics,
456442
Model = model,
457443
TrialSettings = settings,
458444
DurationInMilliseconds = stopWatch.ElapsedMilliseconds,
459-
Metrics = res,
460445
Pipeline = refitPipeline,
461446
} as TrialResult);
462447
}
@@ -480,5 +465,17 @@ public void Dispose()
480465
_context.CancelExecution();
481466
_context = null;
482467
}
468+
469+
private double GetMetric(RegressionMetric metric, RegressionMetrics metrics)
470+
{
471+
return metric switch
472+
{
473+
RegressionMetric.RootMeanSquaredError => metrics.RootMeanSquaredError,
474+
RegressionMetric.RSquared => metrics.RSquared,
475+
RegressionMetric.MeanSquaredError => metrics.MeanSquaredError,
476+
RegressionMetric.MeanAbsoluteError => metrics.MeanAbsoluteError,
477+
_ => throw new NotImplementedException($"{metric} is not supported!"),
478+
};
479+
}
483480
}
484481
}

0 commit comments

Comments
 (0)