diff --git a/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs b/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs index ea7e12c769..bbf4ef6531 100644 --- a/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs +++ b/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs @@ -381,14 +381,8 @@ public TrialResult Run(TrialSettings settings) // now we just randomly pick a model, but a better way is to provide option to pick a model which score is the cloest to average or the best. var res = metrics[_rnd.Next(fold)]; var model = res.Model; - var metric = metricManager.Metric switch - { - BinaryClassificationMetric.PositivePrecision => res.Metrics.PositivePrecision, - BinaryClassificationMetric.Accuracy => res.Metrics.Accuracy, - BinaryClassificationMetric.AreaUnderRocCurve => res.Metrics.AreaUnderRocCurve, - BinaryClassificationMetric.AreaUnderPrecisionRecallCurve => res.Metrics.AreaUnderPrecisionRecallCurve, - _ => throw new NotImplementedException($"{metricManager.MetricName} is not supported!"), - }; + var metric = GetMetric(metricManager.Metric, res.Metrics); + var loss = metricManager.IsMaximize ? -metric : metric; stopWatch.Stop(); @@ -413,16 +407,7 @@ public TrialResult Run(TrialSettings settings) var model = pipeline.Fit(trainTestDatasetManager.TrainDataset); var eval = model.Transform(trainTestDatasetManager.TestDataset); var metrics = _context.BinaryClassification.EvaluateNonCalibrated(eval, metricManager.LabelColumn, predictedLabelColumnName: metricManager.PredictedColumn); - - // now we just randomly pick a model, but a better way is to provide option to pick a model which score is the cloest to average or the best. - var metric = Enum.Parse(typeof(BinaryClassificationMetric), metricManager.MetricName) switch - { - BinaryClassificationMetric.PositivePrecision => metrics.PositivePrecision, - BinaryClassificationMetric.Accuracy => metrics.Accuracy, - BinaryClassificationMetric.AreaUnderRocCurve => metrics.AreaUnderRocCurve, - BinaryClassificationMetric.AreaUnderPrecisionRecallCurve => metrics.AreaUnderPrecisionRecallCurve, - _ => throw new NotImplementedException($"{metricManager.Metric} is not supported!"), - }; + var metric = GetMetric(metricManager.Metric, metrics); var loss = metricManager.IsMaximize ? -metric : metric; stopWatch.Stop(); @@ -465,5 +450,21 @@ public Task RunAsync(TrialSettings settings, CancellationToken ct) throw; } } + + private double GetMetric(BinaryClassificationMetric metric, BinaryClassificationMetrics metrics) + { + return metric switch + { + BinaryClassificationMetric.PositivePrecision => metrics.PositivePrecision, + BinaryClassificationMetric.Accuracy => metrics.Accuracy, + BinaryClassificationMetric.AreaUnderRocCurve => metrics.AreaUnderRocCurve, + BinaryClassificationMetric.AreaUnderPrecisionRecallCurve => metrics.AreaUnderPrecisionRecallCurve, + BinaryClassificationMetric.PositiveRecall => metrics.PositiveRecall, + BinaryClassificationMetric.NegativePrecision => metrics.NegativePrecision, + BinaryClassificationMetric.NegativeRecall => metrics.NegativeRecall, + BinaryClassificationMetric.F1Score => metrics.F1Score, + _ => throw new NotImplementedException($"{metric} is not supported!"), + }; + } } } diff --git a/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs b/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs index dfddaaa78e..975381f2c6 100644 --- a/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs +++ b/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs @@ -375,15 +375,7 @@ public TrialResult Run(TrialSettings settings) // now we just randomly pick a model, but a better way is to provide option to pick a model which score is the cloest to average or the best. var res = metrics[_rnd.Next(fold)]; var model = res.Model; - var metric = metricManager.Metric switch - { - MulticlassClassificationMetric.MacroAccuracy => res.Metrics.MacroAccuracy, - MulticlassClassificationMetric.MicroAccuracy => res.Metrics.MicroAccuracy, - MulticlassClassificationMetric.LogLoss => res.Metrics.LogLoss, - MulticlassClassificationMetric.LogLossReduction => res.Metrics.LogLossReduction, - MulticlassClassificationMetric.TopKAccuracy => res.Metrics.TopKAccuracy, - _ => throw new NotImplementedException($"{metricManager.MetricName} is not supported!"), - }; + var metric = GetMetric(metricManager.Metric, res.Metrics); var loss = metricManager.IsMaximize ? -metric : metric; stopWatch.Stop(); @@ -409,16 +401,7 @@ public TrialResult Run(TrialSettings settings) var model = pipeline.Fit(trainTestDatasetManager.TrainDataset); var eval = model.Transform(trainTestDatasetManager.TestDataset); var metrics = _context.MulticlassClassification.Evaluate(eval, metricManager.LabelColumn, predictedLabelColumnName: metricManager.PredictedColumn); - - var metric = metricManager.Metric switch - { - MulticlassClassificationMetric.MacroAccuracy => metrics.MacroAccuracy, - MulticlassClassificationMetric.MicroAccuracy => metrics.MicroAccuracy, - MulticlassClassificationMetric.LogLoss => metrics.LogLoss, - MulticlassClassificationMetric.LogLossReduction => metrics.LogLossReduction, - MulticlassClassificationMetric.TopKAccuracy => metrics.TopKAccuracy, - _ => throw new NotImplementedException($"{metricManager.Metric} is not supported!"), - }; + var metric = GetMetric(metricManager.Metric, metrics); var loss = metricManager.IsMaximize ? -metric : metric; stopWatch.Stop(); @@ -462,6 +445,19 @@ public Task RunAsync(TrialSettings settings, CancellationToken ct) } } + private double GetMetric(MulticlassClassificationMetric metric, MulticlassClassificationMetrics metrics) + { + return metric switch + { + MulticlassClassificationMetric.MacroAccuracy => metrics.MacroAccuracy, + MulticlassClassificationMetric.MicroAccuracy => metrics.MicroAccuracy, + MulticlassClassificationMetric.LogLoss => metrics.LogLoss, + MulticlassClassificationMetric.LogLossReduction => metrics.LogLossReduction, + MulticlassClassificationMetric.TopKAccuracy => metrics.TopKAccuracy, + _ => throw new NotImplementedException($"{metric} is not supported!"), + }; + } + public void Dispose() { _context.CancelExecution(); diff --git a/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs b/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs index b044ad664c..3adad3fddc 100644 --- a/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs +++ b/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs @@ -402,14 +402,7 @@ public Task RunAsync(TrialSettings settings, CancellationToken ct) // now we just randomly pick a model, but a better way is to provide option to pick a model which score is the cloest to average or the best. var res = metrics[_rnd.Next(fold)]; var model = res.Model; - var metric = metricManager.Metric switch - { - RegressionMetric.RootMeanSquaredError => res.Metrics.RootMeanSquaredError, - RegressionMetric.RSquared => res.Metrics.RSquared, - RegressionMetric.MeanSquaredError => res.Metrics.MeanSquaredError, - RegressionMetric.MeanAbsoluteError => res.Metrics.MeanAbsoluteError, - _ => throw new NotImplementedException($"{metricManager.MetricName} is not supported!"), - }; + var metric = GetMetric(metricManager.Metric, res.Metrics); var loss = metricManager.IsMaximize ? -metric : metric; stopWatch.Stop(); @@ -434,16 +427,8 @@ public Task RunAsync(TrialSettings settings, CancellationToken ct) stopWatch.Start(); var model = pipeline.Fit(trainTestDatasetManager.TrainDataset); var eval = model.Transform(trainTestDatasetManager.TestDataset); - var res = _context.Regression.Evaluate(eval, metricManager.LabelColumn, scoreColumnName: metricManager.ScoreColumn); - - var metric = metricManager.Metric switch - { - RegressionMetric.RootMeanSquaredError => res.RootMeanSquaredError, - RegressionMetric.RSquared => res.RSquared, - RegressionMetric.MeanSquaredError => res.MeanSquaredError, - RegressionMetric.MeanAbsoluteError => res.MeanAbsoluteError, - _ => throw new NotImplementedException($"{metricManager.Metric} is not supported!"), - }; + var metrics = _context.Regression.Evaluate(eval, metricManager.LabelColumn, scoreColumnName: metricManager.ScoreColumn); + var metric = GetMetric(metricManager.Metric, metrics); var loss = metricManager.IsMaximize ? -metric : metric; stopWatch.Stop(); @@ -453,10 +438,10 @@ public Task RunAsync(TrialSettings settings, CancellationToken ct) { Loss = loss, Metric = metric, + Metrics = metrics, Model = model, TrialSettings = settings, DurationInMilliseconds = stopWatch.ElapsedMilliseconds, - Metrics = res, Pipeline = refitPipeline, } as TrialResult); } @@ -480,5 +465,17 @@ public void Dispose() _context.CancelExecution(); _context = null; } + + private double GetMetric(RegressionMetric metric, RegressionMetrics metrics) + { + return metric switch + { + RegressionMetric.RootMeanSquaredError => metrics.RootMeanSquaredError, + RegressionMetric.RSquared => metrics.RSquared, + RegressionMetric.MeanSquaredError => metrics.MeanSquaredError, + RegressionMetric.MeanAbsoluteError => metrics.MeanAbsoluteError, + _ => throw new NotImplementedException($"{metric} is not supported!"), + }; + } } }