From a5fba6c65a80854554ccbb7485e5c280f49fac23 Mon Sep 17 00:00:00 2001 From: zewditu Date: Wed, 15 Mar 2023 18:41:49 -0700 Subject: [PATCH] Change test to validate --- src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs | 4 ++-- .../API/BinaryClassificationExperiment.cs | 6 +++--- .../API/MulticlassClassificationExperiment.cs | 6 +++--- src/Microsoft.ML.AutoML/API/RegressionExperiment.cs | 6 +++--- .../AutoMLExperiment/IDatasetManager.cs | 8 ++++---- .../AutoMLExperiment/Runner/SweepablePipelineRunner.cs | 4 ++-- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs b/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs index 35d0a679ad..727e443852 100644 --- a/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs +++ b/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs @@ -27,10 +27,10 @@ public static class AutoMLExperimentExtension /// public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, IDataView train, IDataView validation) { - var datasetManager = new TrainTestDatasetManager() + var datasetManager = new TrainValidateDatasetManager() { TrainDataset = train, - TestDataset = validation + ValidateDataset = validation }; experiment.ServiceCollection.AddSingleton(datasetManager); diff --git a/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs b/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs index bbf4ef6531..6db1da1dec 100644 --- a/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs +++ b/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs @@ -400,12 +400,12 @@ public TrialResult Run(TrialSettings settings) }; } - if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager) + if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager) { var stopWatch = new Stopwatch(); stopWatch.Start(); var model = pipeline.Fit(trainTestDatasetManager.TrainDataset); - var eval = model.Transform(trainTestDatasetManager.TestDataset); + var eval = model.Transform(trainTestDatasetManager.ValidateDataset); var metrics = _context.BinaryClassification.EvaluateNonCalibrated(eval, metricManager.LabelColumn, predictedLabelColumnName: metricManager.PredictedColumn); var metric = GetMetric(metricManager.Metric, metrics); var loss = metricManager.IsMaximize ? -metric : metric; @@ -426,7 +426,7 @@ public TrialResult Run(TrialSettings settings) } } - throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainTestDatasetManager)} or {typeof(ICrossValidateDatasetManager)}"); + throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainValidateDatasetManager)} or {typeof(ICrossValidateDatasetManager)}"); } public Task RunAsync(TrialSettings settings, CancellationToken ct) diff --git a/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs b/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs index 975381f2c6..b855c9e71a 100644 --- a/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs +++ b/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs @@ -394,12 +394,12 @@ public TrialResult Run(TrialSettings settings) }; } - if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager) + if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager) { var stopWatch = new Stopwatch(); stopWatch.Start(); var model = pipeline.Fit(trainTestDatasetManager.TrainDataset); - var eval = model.Transform(trainTestDatasetManager.TestDataset); + var eval = model.Transform(trainTestDatasetManager.ValidateDataset); var metrics = _context.MulticlassClassification.Evaluate(eval, metricManager.LabelColumn, predictedLabelColumnName: metricManager.PredictedColumn); var metric = GetMetric(metricManager.Metric, metrics); var loss = metricManager.IsMaximize ? -metric : metric; @@ -420,7 +420,7 @@ public TrialResult Run(TrialSettings settings) } } - throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainTestDatasetManager)} or {typeof(ICrossValidateDatasetManager)}"); + throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainValidateDatasetManager)} or {typeof(ICrossValidateDatasetManager)}"); } public Task RunAsync(TrialSettings settings, CancellationToken ct) diff --git a/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs b/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs index 3adad3fddc..99f9e9800f 100644 --- a/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs +++ b/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs @@ -421,12 +421,12 @@ public Task RunAsync(TrialSettings settings, CancellationToken ct) } as TrialResult); } - if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager) + if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager) { var stopWatch = new Stopwatch(); stopWatch.Start(); var model = pipeline.Fit(trainTestDatasetManager.TrainDataset); - var eval = model.Transform(trainTestDatasetManager.TestDataset); + var eval = model.Transform(trainTestDatasetManager.ValidateDataset); var metrics = _context.Regression.Evaluate(eval, metricManager.LabelColumn, scoreColumnName: metricManager.ScoreColumn); var metric = GetMetric(metricManager.Metric, metrics); var loss = metricManager.IsMaximize ? -metric : metric; @@ -447,7 +447,7 @@ public Task RunAsync(TrialSettings settings, CancellationToken ct) } } - throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainTestDatasetManager)} or {typeof(ICrossValidateDatasetManager)}"); + throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainValidateDatasetManager)} or {typeof(ICrossValidateDatasetManager)}"); } } catch (Exception ex) when (ct.IsCancellationRequested) diff --git a/src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs b/src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs index ee752f3e07..e1cac220bc 100644 --- a/src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs +++ b/src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs @@ -19,18 +19,18 @@ internal interface ICrossValidateDatasetManager IDataView Dataset { get; set; } } - internal interface ITrainTestDatasetManager + internal interface ITrainValidateDatasetManager { IDataView TrainDataset { get; set; } - IDataView TestDataset { get; set; } + IDataView ValidateDataset { get; set; } } - internal class TrainTestDatasetManager : IDatasetManager, ITrainTestDatasetManager + internal class TrainValidateDatasetManager : IDatasetManager, ITrainValidateDatasetManager { public IDataView TrainDataset { get; set; } - public IDataView TestDataset { get; set; } + public IDataView ValidateDataset { get; set; } } internal class CrossValidateDatasetManager : IDatasetManager, ICrossValidateDatasetManager diff --git a/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs b/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs index ff3298c291..c0b2b7d6a8 100644 --- a/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs +++ b/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs @@ -66,10 +66,10 @@ public TrialResult Run(TrialSettings settings) }; } - if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager) + if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager) { var model = mlnetPipeline.Fit(trainTestDatasetManager.TrainDataset); - var eval = model.Transform(trainTestDatasetManager.TestDataset); + var eval = model.Transform(trainTestDatasetManager.ValidateDataset); var metric = _metricManager.Evaluate(_mLContext, eval); stopWatch.Stop(); var loss = _metricManager.IsMaximize ? -metric : metric;