@@ -402,14 +402,7 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
402
402
// now we just randomly pick a model, but a better way is to provide option to pick a model which score is the cloest to average or the best.
403
403
var res = metrics [ _rnd . Next ( fold ) ] ;
404
404
var model = res . Model ;
405
- var metric = metricManager . Metric switch
406
- {
407
- RegressionMetric . RootMeanSquaredError => res . Metrics . RootMeanSquaredError ,
408
- RegressionMetric . RSquared => res . Metrics . RSquared ,
409
- RegressionMetric . MeanSquaredError => res . Metrics . MeanSquaredError ,
410
- RegressionMetric . MeanAbsoluteError => res . Metrics . MeanAbsoluteError ,
411
- _ => throw new NotImplementedException ( $ "{ metricManager . MetricName } is not supported!") ,
412
- } ;
405
+ var metric = GetMetric ( metricManager . Metric , res . Metrics ) ;
413
406
var loss = metricManager . IsMaximize ? - metric : metric ;
414
407
415
408
stopWatch . Stop ( ) ;
@@ -434,16 +427,8 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
434
427
stopWatch . Start ( ) ;
435
428
var model = pipeline . Fit ( trainTestDatasetManager . TrainDataset ) ;
436
429
var eval = model . Transform ( trainTestDatasetManager . TestDataset ) ;
437
- var res = _context . Regression . Evaluate ( eval , metricManager . LabelColumn , scoreColumnName : metricManager . ScoreColumn ) ;
438
-
439
- var metric = metricManager . Metric switch
440
- {
441
- RegressionMetric . RootMeanSquaredError => res . RootMeanSquaredError ,
442
- RegressionMetric . RSquared => res . RSquared ,
443
- RegressionMetric . MeanSquaredError => res . MeanSquaredError ,
444
- RegressionMetric . MeanAbsoluteError => res . MeanAbsoluteError ,
445
- _ => throw new NotImplementedException ( $ "{ metricManager . Metric } is not supported!") ,
446
- } ;
430
+ var metrics = _context . Regression . Evaluate ( eval , metricManager . LabelColumn , scoreColumnName : metricManager . ScoreColumn ) ;
431
+ var metric = GetMetric ( metricManager . Metric , metrics ) ;
447
432
var loss = metricManager . IsMaximize ? - metric : metric ;
448
433
449
434
stopWatch . Stop ( ) ;
@@ -453,10 +438,10 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
453
438
{
454
439
Loss = loss ,
455
440
Metric = metric ,
441
+ Metrics = metrics ,
456
442
Model = model ,
457
443
TrialSettings = settings ,
458
444
DurationInMilliseconds = stopWatch . ElapsedMilliseconds ,
459
- Metrics = res ,
460
445
Pipeline = refitPipeline ,
461
446
} as TrialResult ) ;
462
447
}
@@ -480,5 +465,17 @@ public void Dispose()
480
465
_context . CancelExecution ( ) ;
481
466
_context = null ;
482
467
}
468
+
469
+ private double GetMetric ( RegressionMetric metric , RegressionMetrics metrics )
470
+ {
471
+ return metric switch
472
+ {
473
+ RegressionMetric . RootMeanSquaredError => metrics . RootMeanSquaredError ,
474
+ RegressionMetric . RSquared => metrics . RSquared ,
475
+ RegressionMetric . MeanSquaredError => metrics . MeanSquaredError ,
476
+ RegressionMetric . MeanAbsoluteError => metrics . MeanAbsoluteError ,
477
+ _ => throw new NotImplementedException ( $ "{ metric } is not supported!") ,
478
+ } ;
479
+ }
483
480
}
484
481
}
0 commit comments