Skip to content

Commit b1af7a6

Browse files
authored
add task agnostic wrappers for autofit calls (#3860)
1 parent 3846384 commit b1af7a6

File tree

2 files changed

+231
-0
lines changed

2 files changed

+231
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Data;
6+
using System;
7+
using System.Collections.Generic;
8+
using System.Linq;
9+
10+
namespace Microsoft.ML.AutoML.Test
11+
{
12+
public enum TaskType
13+
{
14+
Classification = 1,
15+
Regression
16+
}
17+
18+
/// <summary>
19+
/// make AutoFit and Score calls uniform across task types
20+
/// </summary>
21+
internal class TaskAgnosticAutoFit
22+
{
23+
private TaskType taskType;
24+
private MLContext context;
25+
26+
internal interface IUniversalProgressHandler : IProgress<RunDetail<RegressionMetrics>>, IProgress<RunDetail<MulticlassClassificationMetrics>>
27+
{
28+
}
29+
30+
internal TaskAgnosticAutoFit(TaskType taskType, MLContext context)
31+
{
32+
this.taskType = taskType;
33+
this.context = context;
34+
}
35+
36+
internal IEnumerable<TaskAgnosticIterationResult> AutoFit(
37+
IDataView trainData,
38+
string label,
39+
int maxModels,
40+
uint maxExperimentTimeInSeconds,
41+
IDataView validationData = null,
42+
IEstimator<ITransformer> preFeaturizers = null,
43+
IEnumerable<(string, ColumnPurpose)> columnPurposes = null,
44+
IUniversalProgressHandler progressHandler = null)
45+
{
46+
var columnInformation = new ColumnInformation() { LabelColumnName = label };
47+
48+
switch (this.taskType)
49+
{
50+
case TaskType.Classification:
51+
52+
var mcs = new MulticlassExperimentSettings
53+
{
54+
OptimizingMetric = MulticlassClassificationMetric.MicroAccuracy,
55+
56+
MaxExperimentTimeInSeconds = maxExperimentTimeInSeconds,
57+
MaxModels = maxModels
58+
};
59+
60+
var classificationResult = this.context.Auto()
61+
.CreateMulticlassClassificationExperiment(mcs)
62+
.Execute(
63+
trainData,
64+
validationData,
65+
columnInformation,
66+
progressHandler: progressHandler);
67+
68+
var iterationResults = classificationResult.RunDetails.Select(i => new TaskAgnosticIterationResult(i)).ToList();
69+
70+
return iterationResults;
71+
72+
case TaskType.Regression:
73+
74+
var rs = new RegressionExperimentSettings
75+
{
76+
OptimizingMetric = RegressionMetric.RSquared,
77+
78+
MaxExperimentTimeInSeconds = maxExperimentTimeInSeconds,
79+
MaxModels = maxModels
80+
};
81+
82+
var regressionResult = this.context.Auto()
83+
.CreateRegressionExperiment(rs)
84+
.Execute(
85+
trainData,
86+
validationData,
87+
columnInformation,
88+
progressHandler: progressHandler);
89+
90+
iterationResults = regressionResult.RunDetails.Select(i => new TaskAgnosticIterationResult(i)).ToList();
91+
92+
return iterationResults;
93+
94+
default:
95+
throw new ArgumentException($"Unknown task type {this.taskType}.", "TaskType");
96+
}
97+
}
98+
99+
internal struct ScoreResult
100+
{
101+
public IDataView ScoredTestData;
102+
public double PrimaryMetricResult;
103+
public Dictionary<string, double> Metrics;
104+
}
105+
106+
internal ScoreResult Score(
107+
IDataView testData,
108+
ITransformer model,
109+
string label)
110+
{
111+
var result = new ScoreResult();
112+
113+
result.ScoredTestData = model.Transform(testData);
114+
115+
switch (this.taskType)
116+
{
117+
case TaskType.Classification:
118+
119+
var classificationMetrics = context.MulticlassClassification.Evaluate(result.ScoredTestData, labelColumnName: label);
120+
121+
//var classificationMetrics = context.MulticlassClassification.(scoredTestData, labelColumnName: label);
122+
result.PrimaryMetricResult = classificationMetrics.MicroAccuracy; // TODO: don't hardcode metric
123+
result.Metrics = TaskAgnosticIterationResult.MetricValuesToDictionary(classificationMetrics);
124+
125+
break;
126+
127+
case TaskType.Regression:
128+
129+
var regressionMetrics = context.Regression.Evaluate(result.ScoredTestData, labelColumnName: label);
130+
131+
result.PrimaryMetricResult = regressionMetrics.RSquared; // TODO: don't hardcode metric
132+
result.Metrics = TaskAgnosticIterationResult.MetricValuesToDictionary(regressionMetrics);
133+
134+
break;
135+
136+
default:
137+
throw new ArgumentException($"Unknown task type {this.taskType}.", "TaskType");
138+
}
139+
140+
return result;
141+
}
142+
}
143+
}
144+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Data;
6+
using System;
7+
using System.Collections.Generic;
8+
using System.Linq;
9+
10+
namespace Microsoft.ML.AutoML.Test
11+
{
12+
internal class TaskAgnosticIterationResult
13+
{
14+
internal double PrimaryMetricValue;
15+
16+
internal Dictionary<string, double> MetricValues = new Dictionary<string, double>();
17+
18+
internal readonly ITransformer Model;
19+
internal readonly Exception Exception;
20+
internal string TrainerName;
21+
internal double RuntimeInSeconds;
22+
internal IEstimator<ITransformer> Estimator;
23+
internal Pipeline Pipeline;
24+
internal int PipelineInferenceTimeInSeconds;
25+
26+
private string primaryMetricName;
27+
28+
private TaskAgnosticIterationResult(RunDetail baseRunDetail, object validationMetrics, string primaryMetricName)
29+
{
30+
this.TrainerName = baseRunDetail.TrainerName;
31+
this.Estimator = baseRunDetail.Estimator;
32+
this.Pipeline = baseRunDetail.Pipeline;
33+
34+
this.PipelineInferenceTimeInSeconds = (int)baseRunDetail.PipelineInferenceTimeInSeconds;
35+
this.RuntimeInSeconds = (int)baseRunDetail.RuntimeInSeconds;
36+
37+
this.primaryMetricName = primaryMetricName;
38+
this.PrimaryMetricValue = -1; // default value in case of exception. TODO: won't work for minimizing metrics, use nullable?
39+
40+
if (validationMetrics == null)
41+
{
42+
return;
43+
}
44+
45+
this.MetricValues = MetricValuesToDictionary(validationMetrics);
46+
47+
this.PrimaryMetricValue = this.MetricValues[this.primaryMetricName];
48+
}
49+
50+
public TaskAgnosticIterationResult(RunDetail<RegressionMetrics> runDetail, string primaryMetricName = "RSquared")
51+
: this(runDetail, runDetail.ValidationMetrics, primaryMetricName)
52+
{
53+
if (runDetail.Exception == null)
54+
{
55+
this.Model = runDetail.Model;
56+
}
57+
58+
this.Exception = runDetail.Exception;
59+
}
60+
61+
public TaskAgnosticIterationResult(RunDetail<MulticlassClassificationMetrics> runDetail, string primaryMetricName = "MicroAccuracy")
62+
: this(runDetail, runDetail.ValidationMetrics, primaryMetricName)
63+
{
64+
if (runDetail.Exception == null)
65+
{
66+
this.Model = runDetail.Model;
67+
}
68+
69+
this.Exception = runDetail.Exception;
70+
}
71+
72+
public static Dictionary<string, double> MetricValuesToDictionary<T>(T metric)
73+
{
74+
var supportedTypes = new[] { typeof(MulticlassClassificationMetrics), typeof(RegressionMetrics) };
75+
76+
if (!supportedTypes.Contains(metric.GetType()))
77+
{
78+
throw new ArgumentException($"Unsupported metric type {typeof(T).Name}.");
79+
}
80+
81+
var propertiesToReport = metric.GetType().GetProperties().Where(p => p.PropertyType == typeof(double));
82+
83+
return propertiesToReport.ToDictionary(p => p.Name, p => (double)metric.GetType().GetProperty(p.Name).GetValue(metric));
84+
}
85+
}
86+
}
87+

0 commit comments

Comments
 (0)