Skip to content

Commit 8ec02d8

Browse files
Autofit overloads + cancellation + progress callbacks
1) Introduce AutoFit overloads (basic and advanced) 2) AutoFit Cancellation 3) AutoFit progress callbacks
1 parent b8b577a commit 8ec02d8

18 files changed

+549
-531
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Linq;
7+
using Microsoft.ML.Core.Data;
8+
9+
namespace Microsoft.ML.Auto
10+
{
11+
public class AutoFitRunResult<T>
12+
{
13+
public readonly T Metrics;
14+
public readonly ITransformer Model;
15+
public readonly Exception Exception;
16+
public readonly string TrainerName;
17+
public readonly int RuntimeInSeconds;
18+
19+
internal readonly Pipeline Pipeline;
20+
internal readonly int PipelineInferenceTimeInSeconds;
21+
22+
internal AutoFitRunResult(
23+
ITransformer model,
24+
T metrics,
25+
Pipeline pipeline,
26+
Exception exception,
27+
int runtimeInSeconds,
28+
int pipelineInferenceTimeInSeconds)
29+
{
30+
Model = model;
31+
Metrics = metrics;
32+
Pipeline = pipeline;
33+
Exception = exception;
34+
RuntimeInSeconds = runtimeInSeconds;
35+
PipelineInferenceTimeInSeconds = pipelineInferenceTimeInSeconds;
36+
37+
TrainerName = pipeline?.Nodes.Where(n => n.NodeType == PipelineNodeType.Trainer).Last().Name;
38+
}
39+
}
40+
}

src/Microsoft.ML.Auto/API/AutoFitSettings.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace Microsoft.ML.Auto
99
{
1010
internal static class AutoFitDefaults
1111
{
12-
public const uint TimeOutInMinutes = 24 * 60;
12+
public const uint TimeoutInSeconds = 60 * 60;
1313
public const uint MaxIterations = 1000;
1414
}
1515

@@ -32,12 +32,12 @@ internal class AutoFitSettings
3232
internal bool DisableSubSampling;
3333
internal bool DisableCaching;
3434
internal bool ExternalizeTraining;
35-
internal TraceLevel TraceLevel;
35+
internal TraceLevel TraceLevel;
3636
}
3737

3838
internal class ExperimentStoppingCriteria
3939
{
40-
public uint TimeOutInMinutes = AutoFitDefaults.TimeOutInMinutes;
40+
public uint TimeoutInSeconds = AutoFitDefaults.TimeoutInSeconds;
4141
public uint MaxIterations = AutoFitDefaults.MaxIterations;
4242
internal bool StopAfterConverging;
4343
internal double ExperimentExitScore;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Linq;
8+
using System.Threading;
9+
using Microsoft.Data.DataView;
10+
using Microsoft.ML.Core.Data;
11+
using Microsoft.ML.Data;
12+
13+
namespace Microsoft.ML.Auto
14+
{
15+
public class AutoFitBinaryClassificationOptions
16+
{
17+
public IDataView TrainData;
18+
public string LabelColumnName = DefaultColumnNames.Label;
19+
public IDataView ValidationData;
20+
public uint TimeoutInSeconds = AutoFitDefaults.TimeoutInSeconds;
21+
public CancellationToken CancellationToken = default;
22+
public IProgress<AutoFitRunResult<BinaryClassificationMetrics>> ProgressCallback;
23+
public IEstimator<ITransformer> PreFeaturizers;
24+
public IEnumerable<(string, ColumnPurpose)> ColumnPurposes;
25+
}
26+
27+
public static class BinaryClassificationExtensions
28+
{
29+
public static List<AutoFitRunResult<BinaryClassificationMetrics>> AutoFit(this BinaryClassificationCatalog catalog,
30+
IDataView trainData,
31+
string labelColumnName = DefaultColumnNames.Label,
32+
IDataView validationData = null,
33+
uint timeoutInSeconds = AutoFitDefaults.TimeoutInSeconds,
34+
CancellationToken cancellationToken = default,
35+
IProgress<AutoFitRunResult<BinaryClassificationMetrics>> progressCallback = null)
36+
{
37+
var settings = new AutoFitSettings();
38+
settings.StoppingCriteria.TimeoutInSeconds = timeoutInSeconds;
39+
40+
return AutoFit(catalog, trainData, labelColumnName, validationData, settings,
41+
null, null, cancellationToken, progressCallback, null);
42+
}
43+
44+
public static List<AutoFitRunResult<BinaryClassificationMetrics>> AutoFit(this BinaryClassificationCatalog catalog,
45+
AutoFitBinaryClassificationOptions options)
46+
{
47+
var settings = new AutoFitSettings();
48+
settings.StoppingCriteria.TimeoutInSeconds = options.TimeoutInSeconds;
49+
50+
return AutoFit(catalog, options.TrainData, options.LabelColumnName, options.ValidationData, settings,
51+
options.PreFeaturizers, options.ColumnPurposes, options.CancellationToken, options.ProgressCallback, null);
52+
}
53+
54+
internal static List<AutoFitRunResult<BinaryClassificationMetrics>> AutoFit(this BinaryClassificationCatalog catalog,
55+
IDataView trainData,
56+
string labelColumnName = DefaultColumnNames.Label,
57+
IDataView validationData = null,
58+
AutoFitSettings settings = null,
59+
IEstimator<ITransformer> preFeaturizers = null,
60+
IEnumerable<(string, ColumnPurpose)> columnPurposes = null,
61+
CancellationToken cancellationToken = default,
62+
IProgress<AutoFitRunResult<BinaryClassificationMetrics>> progressCallback = null,
63+
IDebugLogger debugLogger = null)
64+
{
65+
UserInputValidationUtil.ValidateAutoFitArgs(trainData, labelColumnName, validationData, settings, columnPurposes);
66+
67+
if (validationData == null)
68+
{
69+
(trainData, validationData) = catalog.TestValidateSplit(trainData);
70+
}
71+
72+
// run autofit & get all pipelines run in that process
73+
var autoFitter = new AutoFitter<BinaryClassificationMetrics>(TaskKind.BinaryClassification, trainData, labelColumnName, validationData,
74+
settings, preFeaturizers, columnPurposes,
75+
OptimizingMetric.RSquared, cancellationToken, progressCallback, debugLogger);
76+
77+
return autoFitter.Fit();
78+
}
79+
80+
public static AutoFitRunResult<BinaryClassificationMetrics> Best(this IEnumerable<AutoFitRunResult<BinaryClassificationMetrics>> results)
81+
{
82+
double maxScore = results.Select(r => r.Metrics.Accuracy).Max();
83+
return results.First(r => r.Metrics.Accuracy == maxScore);
84+
}
85+
}
86+
87+
}

src/Microsoft.ML.Auto/API/MLContextAutoFitExtensions.cs

-167
This file was deleted.

0 commit comments

Comments
 (0)