Skip to content

Commit eac3695

Browse files
daholsteDmitry-A
authored andcommitted
get next pipeline API rev -- refactor API to consume column dimensions, purpose, type, and name instead of available trainers & transforms (dotnet#19)
1 parent 816e8e8 commit eac3695

14 files changed

+205
-293
lines changed

src/AutoML/API/MLContextAutoFitExtensions.cs

-15
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,6 @@ internal static RegressionResult AutoFit(this RegressionContext context,
4747
var bestResult = new RegressionIterationResult(bestPipeline.Model, (RegressionMetrics)bestPipeline.EvaluatedMetrics, bestPipeline.ScoredValidationData, bestPipeline.Pipeline.ToPipeline());
4848
return new RegressionResult(bestResult, results);
4949
}
50-
51-
public static Pipeline GetPipeline(this RegressionContext context, IDataView dataView, string label)
52-
{
53-
return PipelineSuggesterApi.GetPipeline(TaskKind.Regression, dataView, label);
54-
}
5550
}
5651

5752
public static class BinaryClassificationExtensions
@@ -96,11 +91,6 @@ internal static BinaryClassificationResult AutoFit(this BinaryClassificationCont
9691
var bestResult = new BinaryClassificationItertionResult(bestPipeline.Model, (BinaryClassificationMetrics)bestPipeline.EvaluatedMetrics, bestPipeline.ScoredValidationData, bestPipeline.Pipeline.ToPipeline());
9792
return new BinaryClassificationResult(bestResult, results);
9893
}
99-
100-
public static Pipeline GetPipeline(this BinaryClassificationContext context, IDataView dataView, string label)
101-
{
102-
return PipelineSuggesterApi.GetPipeline(TaskKind.BinaryClassification, dataView, label);
103-
}
10494
}
10595

10696
public static class MulticlassExtensions
@@ -144,11 +134,6 @@ internal static MulticlassClassificationResult AutoFit(this MulticlassClassifica
144134
var bestResult = new MulticlassClassificationIterationResult(bestPipeline.Model, (MultiClassClassifierMetrics)bestPipeline.EvaluatedMetrics, bestPipeline.ScoredValidationData, bestPipeline.Pipeline.ToPipeline());
145135
return new MulticlassClassificationResult(bestResult, results);
146136
}
147-
148-
public static Pipeline GetPipeline(this MulticlassClassificationContext context, IDataView dataView, string label)
149-
{
150-
return PipelineSuggesterApi.GetPipeline(TaskKind.MulticlassClassification, dataView, label);
151-
}
152137
}
153138

154139
public class BinaryClassificationResult

src/AutoML/API/Pipeline.cs

+10-10
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ namespace Microsoft.ML.Auto
55
{
66
public class Pipeline
77
{
8-
public PipelineNode[] Elements { get; set; }
8+
public PipelineNode[] Nodes { get; set; }
99

10-
public Pipeline(PipelineNode[] elements)
10+
public Pipeline(PipelineNode[] nodes)
1111
{
12-
Elements = elements;
12+
Nodes = nodes;
1313
}
1414

1515
// (used by Newtonsoft)
@@ -27,31 +27,31 @@ public IEstimator<ITransformer> ToEstimator()
2727
public class PipelineNode
2828
{
2929
public string Name { get; set; }
30-
public PipelineNodeType ElementType { get; set; }
30+
public PipelineNodeType NodeType { get; set; }
3131
public string[] InColumns { get; set; }
3232
public string[] OutColumns { get; set; }
3333
public IDictionary<string, object> Properties { get; set; }
3434

35-
public PipelineNode(string name, PipelineNodeType elementType,
35+
public PipelineNode(string name, PipelineNodeType nodeType,
3636
string[] inColumns, string[] outColumns,
3737
IDictionary<string, object> properties = null)
3838
{
3939
Name = name;
40-
ElementType = elementType;
40+
NodeType = nodeType;
4141
InColumns = inColumns;
4242
OutColumns = outColumns;
4343
Properties = properties ?? new Dictionary<string, object>();
4444
}
4545

46-
public PipelineNode(string name, PipelineNodeType elementType,
46+
public PipelineNode(string name, PipelineNodeType nodeType,
4747
string inColumn, string outColumn, IDictionary<string, object> properties = null) :
48-
this(name, elementType, new string[] { inColumn }, new string[] { outColumn }, properties)
48+
this(name, nodeType, new string[] { inColumn }, new string[] { outColumn }, properties)
4949
{
5050
}
5151

52-
public PipelineNode(string name, PipelineNodeType elementType,
52+
public PipelineNode(string name, PipelineNodeType nodeType,
5353
string[] inColumns, string outColumn, IDictionary<string, object> properties = null) :
54-
this(name, elementType, inColumns, new string[] { outColumn }, properties)
54+
this(name, nodeType, inColumns, new string[] { outColumn }, properties)
5555
{
5656
}
5757

src/AutoML/AutoFitter/AutoFitter.cs

+9-9
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,22 @@ internal class AutoFitter
1616
private readonly IDebugLogger _debugLogger;
1717
private readonly IList<InferredPipelineRunResult> _history;
1818
private readonly string _label;
19-
private readonly MLContext _mlContext;
19+
private readonly MLContext _context;
2020
private readonly OptimizingMetricInfo _optimizingMetricInfo;
2121
private readonly IDictionary<string, ColumnPurpose> _purposeOverrides;
2222
private readonly AutoFitSettings _settings;
2323
private readonly IDataView _trainData;
2424
private readonly TaskKind _task;
2525
private readonly IDataView _validationData;
2626

27-
public AutoFitter(MLContext mlContext, OptimizingMetricInfo metricInfo, AutoFitSettings settings,
27+
public AutoFitter(MLContext context, OptimizingMetricInfo metricInfo, AutoFitSettings settings,
2828
TaskKind task, string label, IDataView trainData, IDataView validationData,
2929
IDictionary<string, ColumnPurpose> purposeOverrides, IDebugLogger debugLogger)
3030
{
3131
_debugLogger = debugLogger;
3232
_history = new List<InferredPipelineRunResult>();
3333
_label = label;
34-
_mlContext = mlContext;
34+
_context = context;
3535
_optimizingMetricInfo = metricInfo;
3636
_settings = settings ?? new AutoFitSettings();
3737
_purposeOverrides = purposeOverrides;
@@ -49,13 +49,13 @@ public InferredPipelineRunResult[] Fit()
4949
private void IteratePipelinesAndFit()
5050
{
5151
var stopwatch = Stopwatch.StartNew();
52-
var transforms = TransformInferenceApi.InferTransforms(_mlContext, _trainData, _label, _purposeOverrides);
53-
var availableTrainers = RecipeInference.AllowedTrainers(_mlContext, _task, _settings.StoppingCriteria.MaxIterations);
52+
var columns = AutoMlUtils.GetColumnInfoTuples(_context, _trainData, _label, _purposeOverrides);
5453

5554
do
5655
{
5756
// get next pipeline
58-
var pipeline = PipelineSuggester.GetNextInferredPipeline(_history, transforms, availableTrainers, _optimizingMetricInfo.IsMaximizing);
57+
var iterationsRemaining = _settings.StoppingCriteria.MaxIterations - _history.Count;
58+
var pipeline = PipelineSuggester.GetNextInferredPipeline(_history, columns, _task, iterationsRemaining, _optimizingMetricInfo.IsMaximizing);
5959

6060
// break if no candidates returned, means no valid pipeline available
6161
if (pipeline == null)
@@ -113,11 +113,11 @@ private object GetEvaluatedMetrics(IDataView scoredData)
113113
switch(_task)
114114
{
115115
case TaskKind.BinaryClassification:
116-
return _mlContext.BinaryClassification.EvaluateNonCalibrated(scoredData);
116+
return _context.BinaryClassification.EvaluateNonCalibrated(scoredData);
117117
case TaskKind.MulticlassClassification:
118-
return _mlContext.MulticlassClassification.Evaluate(scoredData);
118+
return _context.MulticlassClassification.Evaluate(scoredData);
119119
case TaskKind.Regression:
120-
return _mlContext.Regression.Evaluate(scoredData);
120+
return _context.Regression.Evaluate(scoredData);
121121
// should not be possible to reach here
122122
default:
123123
throw new InvalidOperationException($"unsupported machine learning task type {_task}");

src/AutoML/AutoFitter/InferredPipeline.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,17 @@ public static InferredPipeline FromPipeline(Pipeline pipeline)
7070
var transforms = new List<SuggestedTransform>();
7171
SuggestedTrainer trainer = null;
7272

73-
foreach(var pipelineNode in pipeline.Elements)
73+
foreach(var pipelineNode in pipeline.Nodes)
7474
{
75-
if(pipelineNode.ElementType == PipelineNodeType.Trainer)
75+
if(pipelineNode.NodeType == PipelineNodeType.Trainer)
7676
{
7777
var trainerName = (TrainerName)Enum.Parse(typeof(TrainerName), pipelineNode.Name);
7878
var trainerExtension = TrainerExtensionCatalog.GetTrainerExtension(trainerName);
7979
var stringParamVals = pipelineNode.Properties.Select(prop => new StringParameterValue(prop.Key, prop.Value.ToString()));
8080
var hyperParamSet = new ParameterSet(stringParamVals);
8181
trainer = new SuggestedTrainer(context, trainerExtension, hyperParamSet);
8282
}
83-
else if (pipelineNode.ElementType == PipelineNodeType.Transform)
83+
else if (pipelineNode.NodeType == PipelineNodeType.Transform)
8484
{
8585
var estimatorName = (EstimatorName)Enum.Parse(typeof(EstimatorName), pipelineNode.Name);
8686
var estimatorExtension = EstimatorExtensionCatalog.GetExtension(estimatorName);

src/AutoML/AutoMlUtils.cs

+16
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections.Generic;
67
using System.Linq;
78
using Microsoft.ML.Data;
89
using Microsoft.ML.Transforms;
@@ -29,5 +30,20 @@ public static IDataView Take(this IDataView data, int count)
2930
var take = SkipTakeFilter.Create(env, new SkipTakeFilter.TakeArguments { Count = count }, data);
3031
return new CacheDataView(env, data, Enumerable.Range(0, data.Schema.Count).ToArray());
3132
}
33+
34+
public static (string, ColumnType, ColumnPurpose, ColumnDimensions)[] GetColumnInfoTuples(MLContext context,
35+
IDataView data, string label, IDictionary<string, ColumnPurpose> purposeOverrides)
36+
{
37+
var purposes = PurposeInference.InferPurposes(context, data, label, purposeOverrides);
38+
var colDimensions = DatasetDimensionsApi.CalcColumnDimensions(data, purposes);
39+
var cols = new (string, ColumnType, ColumnPurpose, ColumnDimensions)[data.Schema.Count];
40+
for (var i = 0; i < cols.Length; i++)
41+
{
42+
var schemaCol = data.Schema[i];
43+
var col = (schemaCol.Name, schemaCol.Type, purposes[i].Purpose, colDimensions[i]);
44+
cols[i] = col;
45+
}
46+
return cols;
47+
}
3248
}
3349
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
namespace Microsoft.ML.Auto
2+
{
3+
internal class ColumnDimensions
4+
{
5+
public int? Cardinality;
6+
public bool? HasMissing;
7+
8+
public ColumnDimensions(int? cardinality, bool? hasMissing)
9+
{
10+
Cardinality = cardinality;
11+
HasMissing = hasMissing;
12+
}
13+
}
14+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
using Microsoft.ML.Data;
2+
3+
namespace Microsoft.ML.Auto
4+
{
5+
internal class DatasetDimensionsApi
6+
{
7+
private const int MaxRowsToRead = 1000;
8+
9+
public static ColumnDimensions[] CalcColumnDimensions(IDataView data, PurposeInference.Column[] purposes)
10+
{
11+
data = data.Take(MaxRowsToRead);
12+
13+
var colDimensions = new ColumnDimensions[data.Schema.Count];
14+
15+
for (var i = 0; i < data.Schema.Count; i++)
16+
{
17+
var column = data.Schema[i];
18+
var purpose = purposes[i];
19+
20+
// default column dimensions
21+
int? cardinality = null;
22+
bool? hasMissing = null;
23+
24+
// if categorical text feature, calc cardinality
25+
if(column.Type.ItemType().IsText() && purpose.Purpose == ColumnPurpose.CategoricalFeature)
26+
{
27+
cardinality = DatasetDimensionsUtil.GetTextColumnCardinality(data, i);
28+
}
29+
30+
// if numeric feature, discover missing values
31+
// todo: upgrade logic to consider R8?
32+
if (column.Type.ItemType() == NumberType.R4)
33+
{
34+
hasMissing = column.Type.IsVector() ?
35+
DatasetDimensionsUtil.HasMissingNumericVector(data, i) :
36+
DatasetDimensionsUtil.HasMissingNumericSingleValue(data, i);
37+
}
38+
39+
colDimensions[i] = new ColumnDimensions(cardinality, hasMissing);
40+
}
41+
42+
return colDimensions;
43+
}
44+
}
45+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using Microsoft.ML.Data;
4+
5+
namespace Microsoft.ML.Auto
6+
{
7+
internal static class DatasetDimensionsUtil
8+
{
9+
public static int GetTextColumnCardinality(IDataView data, int colIndex)
10+
{
11+
var seen = new HashSet<string>();
12+
using (var cursor = data.GetRowCursor(x => x == colIndex))
13+
{
14+
var getter = cursor.GetGetter<ReadOnlyMemory<char>>(colIndex);
15+
while (cursor.MoveNext())
16+
{
17+
var value = default(ReadOnlyMemory<char>);
18+
getter(ref value);
19+
var valueStr = value.ToString();
20+
seen.Add(valueStr);
21+
}
22+
}
23+
return seen.Count;
24+
}
25+
26+
public static bool HasMissingNumericSingleValue(IDataView data, int colIndex)
27+
{
28+
using (var cursor = data.GetRowCursor(x => x == colIndex))
29+
{
30+
var getter = cursor.GetGetter<Single>(colIndex);
31+
var value = default(Single);
32+
while (cursor.MoveNext())
33+
{
34+
getter(ref value);
35+
if (Single.IsNaN(value))
36+
{
37+
return true;
38+
}
39+
}
40+
return false;
41+
}
42+
}
43+
44+
public static bool HasMissingNumericVector(IDataView data, int colIndex)
45+
{
46+
using (var cursor = data.GetRowCursor(x => x == colIndex))
47+
{
48+
var getter = cursor.GetGetter<VBuffer<Single>>(colIndex);
49+
var value = default(VBuffer<Single>);
50+
while (cursor.MoveNext())
51+
{
52+
getter(ref value);
53+
if (VBufferUtils.HasNaNs(value))
54+
{
55+
return true;
56+
}
57+
}
58+
return false;
59+
}
60+
}
61+
}
62+
}

src/AutoML/PipelineSuggesters/PipelineSuggester.cs

+14-7
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System;
66
using System.Collections.Generic;
77
using System.Linq;
8+
using Microsoft.ML.Data;
89

910
namespace Microsoft.ML.Auto
1011
{
@@ -13,23 +14,29 @@ internal static class PipelineSuggester
1314
private const int TopKTrainers = 3;
1415

1516
public static Pipeline GetNextPipeline(IEnumerable<PipelineRunResult> history,
16-
IEnumerable<SuggestedTransform> transforms,
17-
IEnumerable<SuggestedTrainer> availableTrainers,
17+
(string, ColumnType, ColumnPurpose, ColumnDimensions)[] columns,
18+
TaskKind task,
19+
int iterationsRemaining,
1820
bool isMaximizingMetric = true)
1921
{
2022
var inferredHistory = history.Select(r => InferredPipelineRunResult.FromPipelineRunResult(r));
21-
var nextInferredPipeline = GetNextInferredPipeline(inferredHistory,
22-
transforms, availableTrainers, isMaximizingMetric);
23+
var nextInferredPipeline = GetNextInferredPipeline(inferredHistory, columns, task, iterationsRemaining, isMaximizingMetric);
2324
return nextInferredPipeline.ToPipeline();
2425
}
2526

2627
public static InferredPipeline GetNextInferredPipeline(IEnumerable<InferredPipelineRunResult> history,
27-
IEnumerable<SuggestedTransform> transforms,
28-
IEnumerable<SuggestedTrainer> availableTrainers,
28+
(string, ColumnType, ColumnPurpose, ColumnDimensions)[] columns,
29+
TaskKind task,
30+
int iterationsRemaining,
2931
bool isMaximizingMetric = true)
3032
{
33+
var context = new MLContext();
34+
35+
var availableTrainers = RecipeInference.AllowedTrainers(context, TaskKind.BinaryClassification, history.Count() + iterationsRemaining);
36+
var transforms = TransformInferenceApi.InferTransforms(context, columns);
37+
3138
// if we haven't run all pipelines once
32-
if(history.Count() < availableTrainers.Count())
39+
if (history.Count() < availableTrainers.Count())
3340
{
3441
return GetNextFirstStagePipeline(history, availableTrainers, transforms);
3542
}

src/AutoML/PipelineSuggesters/PipelineSuggesterApi.cs

-18
This file was deleted.

0 commit comments

Comments
 (0)