Skip to content

Commit ebb5789

Browse files
authored
Change test to validate (#6599)
1 parent c696e09 commit ebb5789

6 files changed

+17
-17
lines changed

src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ public static class AutoMLExperimentExtension
2727
/// <returns><see cref="AutoMLExperiment"/></returns>
2828
public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, IDataView train, IDataView validation)
2929
{
30-
var datasetManager = new TrainTestDatasetManager()
30+
var datasetManager = new TrainValidateDatasetManager()
3131
{
3232
TrainDataset = train,
33-
TestDataset = validation
33+
ValidateDataset = validation
3434
};
3535

3636
experiment.ServiceCollection.AddSingleton<IDatasetManager>(datasetManager);

src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -400,12 +400,12 @@ public TrialResult Run(TrialSettings settings)
400400
};
401401
}
402402

403-
if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager)
403+
if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
404404
{
405405
var stopWatch = new Stopwatch();
406406
stopWatch.Start();
407407
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
408-
var eval = model.Transform(trainTestDatasetManager.TestDataset);
408+
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
409409
var metrics = _context.BinaryClassification.EvaluateNonCalibrated(eval, metricManager.LabelColumn, predictedLabelColumnName: metricManager.PredictedColumn);
410410
var metric = GetMetric(metricManager.Metric, metrics);
411411
var loss = metricManager.IsMaximize ? -metric : metric;
@@ -426,7 +426,7 @@ public TrialResult Run(TrialSettings settings)
426426
}
427427
}
428428

429-
throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainTestDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
429+
throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainValidateDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
430430
}
431431

432432
public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)

src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -394,12 +394,12 @@ public TrialResult Run(TrialSettings settings)
394394
};
395395
}
396396

397-
if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager)
397+
if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
398398
{
399399
var stopWatch = new Stopwatch();
400400
stopWatch.Start();
401401
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
402-
var eval = model.Transform(trainTestDatasetManager.TestDataset);
402+
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
403403
var metrics = _context.MulticlassClassification.Evaluate(eval, metricManager.LabelColumn, predictedLabelColumnName: metricManager.PredictedColumn);
404404
var metric = GetMetric(metricManager.Metric, metrics);
405405
var loss = metricManager.IsMaximize ? -metric : metric;
@@ -420,7 +420,7 @@ public TrialResult Run(TrialSettings settings)
420420
}
421421
}
422422

423-
throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainTestDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
423+
throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainValidateDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
424424
}
425425

426426
public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)

src/Microsoft.ML.AutoML/API/RegressionExperiment.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -421,12 +421,12 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
421421
} as TrialResult);
422422
}
423423

424-
if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager)
424+
if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
425425
{
426426
var stopWatch = new Stopwatch();
427427
stopWatch.Start();
428428
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
429-
var eval = model.Transform(trainTestDatasetManager.TestDataset);
429+
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
430430
var metrics = _context.Regression.Evaluate(eval, metricManager.LabelColumn, scoreColumnName: metricManager.ScoreColumn);
431431
var metric = GetMetric(metricManager.Metric, metrics);
432432
var loss = metricManager.IsMaximize ? -metric : metric;
@@ -447,7 +447,7 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
447447
}
448448
}
449449

450-
throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainTestDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
450+
throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainValidateDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
451451
}
452452
}
453453
catch (Exception ex) when (ct.IsCancellationRequested)

src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,18 @@ internal interface ICrossValidateDatasetManager
1919
IDataView Dataset { get; set; }
2020
}
2121

22-
internal interface ITrainTestDatasetManager
22+
internal interface ITrainValidateDatasetManager
2323
{
2424
IDataView TrainDataset { get; set; }
2525

26-
IDataView TestDataset { get; set; }
26+
IDataView ValidateDataset { get; set; }
2727
}
2828

29-
internal class TrainTestDatasetManager : IDatasetManager, ITrainTestDatasetManager
29+
internal class TrainValidateDatasetManager : IDatasetManager, ITrainValidateDatasetManager
3030
{
3131
public IDataView TrainDataset { get; set; }
3232

33-
public IDataView TestDataset { get; set; }
33+
public IDataView ValidateDataset { get; set; }
3434
}
3535

3636
internal class CrossValidateDatasetManager : IDatasetManager, ICrossValidateDatasetManager

src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ public TrialResult Run(TrialSettings settings)
6666
};
6767
}
6868

69-
if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager)
69+
if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
7070
{
7171
var model = mlnetPipeline.Fit(trainTestDatasetManager.TrainDataset);
72-
var eval = model.Transform(trainTestDatasetManager.TestDataset);
72+
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
7373
var metric = _metricManager.Evaluate(_mLContext, eval);
7474
stopWatch.Stop();
7575
var loss = _metricManager.IsMaximize ? -metric : metric;

0 commit comments

Comments
 (0)