Skip to content

Commit 816e8e8

Browse files
daholsteDmitry-A
authored andcommitted
make GetNextPipeline API w/ public Pipeline method on PipelineSuggester; write GetNextPipeline API test; fix public Pipeline object serialization; fix header inferencing bug; write test utils for fetching datasets (dotnet#18)
1 parent 40e8e58 commit 816e8e8

10 files changed

+179
-38
lines changed

src/AutoML/API/Pipeline.cs

+42-6
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,36 @@
11
using System.Collections.Generic;
2+
using Microsoft.ML.Core.Data;
23

34
namespace Microsoft.ML.Auto
45
{
56
public class Pipeline
67
{
7-
public readonly PipelineNode[] Elements;
8+
public PipelineNode[] Elements { get; set; }
89

910
public Pipeline(PipelineNode[] elements)
1011
{
1112
Elements = elements;
1213
}
14+
15+
// (used by Newtonsoft)
16+
internal Pipeline()
17+
{
18+
}
19+
20+
public IEstimator<ITransformer> ToEstimator()
21+
{
22+
var inferredPipeline = InferredPipeline.FromPipeline(this);
23+
return inferredPipeline.ToEstimator();
24+
}
1325
}
1426

1527
public class PipelineNode
1628
{
17-
public readonly string Name;
18-
public readonly PipelineNodeType ElementType;
19-
public readonly string[] InColumns;
20-
public readonly string[] OutColumns;
21-
public readonly IDictionary<string, object> Properties;
29+
public string Name { get; set; }
30+
public PipelineNodeType ElementType { get; set; }
31+
public string[] InColumns { get; set; }
32+
public string[] OutColumns { get; set; }
33+
public IDictionary<string, object> Properties { get; set; }
2234

2335
public PipelineNode(string name, PipelineNodeType elementType,
2436
string[] inColumns, string[] outColumns,
@@ -42,6 +54,11 @@ public PipelineNode(string name, PipelineNodeType elementType,
4254
this(name, elementType, inColumns, new string[] { outColumn }, properties)
4355
{
4456
}
57+
58+
// (used by Newtonsoft)
59+
internal PipelineNode()
60+
{
61+
}
4562
}
4663

4764
public enum PipelineNodeType
@@ -55,4 +72,23 @@ public class CustomProperty
5572
public readonly string Name;
5673
public readonly IDictionary<string, object> Properties;
5774
}
75+
76+
public class PipelineRunResult
77+
{
78+
public readonly Pipeline Pipeline;
79+
public readonly double Score;
80+
81+
/// <summary>
82+
/// This setting is true if the pipeline run succeeded & ran to completion.
83+
/// Else, it is false if some exception was thrown before the run could complete.
84+
/// </summary>
85+
public readonly bool RunSucceded;
86+
87+
public PipelineRunResult(Pipeline pipeline, double score, bool runSucceeded)
88+
{
89+
Pipeline = pipeline;
90+
Score = score;
91+
RunSucceded = runSucceeded;
92+
}
93+
}
5894
}

src/AutoML/AutoFitter/AutoFitApi.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace Microsoft.ML.Auto
66
{
77
internal static class AutoFitApi
88
{
9-
public static (PipelineRunResult[] allPipelines, PipelineRunResult bestPipeline) Fit(IDataView trainData,
9+
public static (InferredPipelineRunResult[] allPipelines, InferredPipelineRunResult bestPipeline) Fit(IDataView trainData,
1010
IDataView validationData, string label, AutoFitSettings settings, TaskKind task, OptimizingMetric metric,
1111
IEnumerable<(string, ColumnPurpose)> purposeOverrides, IDebugLogger debugLogger)
1212
{

src/AutoML/AutoFitter/AutoFitter.cs

+7-7
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace Microsoft.ML.Auto
1414
internal class AutoFitter
1515
{
1616
private readonly IDebugLogger _debugLogger;
17-
private readonly IList<PipelineRunResult> _history;
17+
private readonly IList<InferredPipelineRunResult> _history;
1818
private readonly string _label;
1919
private readonly MLContext _mlContext;
2020
private readonly OptimizingMetricInfo _optimizingMetricInfo;
@@ -29,7 +29,7 @@ public AutoFitter(MLContext mlContext, OptimizingMetricInfo metricInfo, AutoFitS
2929
IDictionary<string, ColumnPurpose> purposeOverrides, IDebugLogger debugLogger)
3030
{
3131
_debugLogger = debugLogger;
32-
_history = new List<PipelineRunResult>();
32+
_history = new List<InferredPipelineRunResult>();
3333
_label = label;
3434
_mlContext = mlContext;
3535
_optimizingMetricInfo = metricInfo;
@@ -40,7 +40,7 @@ public AutoFitter(MLContext mlContext, OptimizingMetricInfo metricInfo, AutoFitS
4040
_validationData = validationData;
4141
}
4242

43-
public PipelineRunResult[] Fit()
43+
public InferredPipelineRunResult[] Fit()
4444
{
4545
IteratePipelinesAndFit();
4646
return _history.ToArray();
@@ -55,7 +55,7 @@ private void IteratePipelinesAndFit()
5555
do
5656
{
5757
// get next pipeline
58-
var pipeline = PipelineSuggester.GetNextPipeline(_history, transforms, availableTrainers, _optimizingMetricInfo.IsMaximizing);
58+
var pipeline = PipelineSuggester.GetNextInferredPipeline(_history, transforms, availableTrainers, _optimizingMetricInfo.IsMaximizing);
5959

6060
// break if no candidates returned, means no valid pipeline available
6161
if (pipeline == null)
@@ -75,19 +75,19 @@ private void ProcessPipeline(InferredPipeline pipeline)
7575
// run pipeline
7676
var stopwatch = Stopwatch.StartNew();
7777

78-
PipelineRunResult runResult;
78+
InferredPipelineRunResult runResult;
7979
try
8080
{
8181
var pipelineModel = pipeline.TrainTransformer(_trainData);
8282
var scoredValidationData = pipelineModel.Transform(_validationData);
8383
var evaluatedMetrics = GetEvaluatedMetrics(scoredValidationData);
8484
var score = GetPipelineScore(evaluatedMetrics);
85-
runResult = new PipelineRunResult(evaluatedMetrics, pipelineModel, pipeline, score, scoredValidationData);
85+
runResult = new InferredPipelineRunResult(evaluatedMetrics, pipelineModel, pipeline, score, scoredValidationData);
8686
}
8787
catch(Exception ex)
8888
{
8989
WriteDebugLog(DebugStream.Exception, $"{pipeline.Trainer} Crashed {ex}");
90-
runResult = new PipelineRunResult(pipeline, false);
90+
runResult = new InferredPipelineRunResult(pipeline, false);
9191
}
9292

9393
// save pipeline run

src/AutoML/AutoFitter/InferredPipeline.cs

+8-2
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ public static InferredPipeline FromPipeline(Pipeline pipeline)
9393
return new InferredPipeline(transforms, trainer, null, false);
9494
}
9595

96-
public ITransformer TrainTransformer(IDataView trainData)
96+
public IEstimator<ITransformer> ToEstimator()
9797
{
9898
IEstimator<ITransformer> pipeline = new EstimatorChain<ITransformer>();
9999

@@ -112,7 +112,13 @@ public ITransformer TrainTransformer(IDataView trainData)
112112
// append learner to pipeline
113113
pipeline = pipeline.Append(learner);
114114

115-
return pipeline.Fit(trainData);
115+
return pipeline;
116+
}
117+
118+
public ITransformer TrainTransformer(IDataView trainData)
119+
{
120+
var estimator = ToEstimator();
121+
return estimator.Fit(trainData);
116122
}
117123

118124
private void AddNormalizationTransforms()

src/AutoML/AutoFitter/PipelineRunResult.cs renamed to src/AutoML/AutoFitter/InferredPipelineRunResult.cs

+10-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
namespace Microsoft.ML.Auto
99
{
10-
internal class PipelineRunResult
10+
internal class InferredPipelineRunResult
1111
{
1212
public readonly object EvaluatedMetrics;
1313
public readonly InferredPipeline Pipeline;
@@ -22,7 +22,7 @@ internal class PipelineRunResult
2222

2323
public ITransformer Model { get; set; }
2424

25-
public PipelineRunResult(object evaluatedMetrics, ITransformer model, InferredPipeline pipeline, double score, IDataView scoredValidationData,
25+
public InferredPipelineRunResult(object evaluatedMetrics, ITransformer model, InferredPipeline pipeline, double score, IDataView scoredValidationData,
2626
bool runSucceeded = true)
2727
{
2828
EvaluatedMetrics = evaluatedMetrics;
@@ -33,12 +33,19 @@ public PipelineRunResult(object evaluatedMetrics, ITransformer model, InferredPi
3333
RunSucceded = runSucceeded;
3434
}
3535

36-
public PipelineRunResult(InferredPipeline pipeline, bool runSucceeded)
36+
public InferredPipelineRunResult(InferredPipeline pipeline, bool runSucceeded)
3737
{
3838
Pipeline = pipeline;
3939
RunSucceded = runSucceeded;
4040
}
4141

42+
public static InferredPipelineRunResult FromPipelineRunResult(PipelineRunResult pipelineRunResult)
43+
{
44+
return new InferredPipelineRunResult(null, null,
45+
InferredPipeline.FromPipeline(pipelineRunResult.Pipeline),
46+
pipelineRunResult.Score, null, pipelineRunResult.RunSucceded);
47+
}
48+
4249
public IRunResult ToRunResult(bool isMetricMaximizing)
4350
{
4451
return new RunResult(Pipeline.Trainer.HyperParamSet, Score, isMetricMaximizing);

src/AutoML/ColumnInference/ColumnInferenceApi.cs

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public static ColumnInferenceResult InferColumns(MLContext context, string path,
1111
{
1212
var sample = TextFileSample.CreateFromFullFile(path);
1313
var splitInference = InferSplit(sample, separatorChar, allowQuotedStrings, supportSparse);
14-
var typeInference = InferColumnTypes(context, sample, splitInference);
14+
var typeInference = InferColumnTypes(context, sample, splitInference, hasHeader);
1515
var typedLoaderArgs = new TextLoader.Arguments
1616
{
1717
Column = ColumnTypeInference.GenerateLoaderColumns(typeInference.Columns),
@@ -59,7 +59,7 @@ private static TextFileContents.ColumnSplitResult InferSplit(TextFileSample samp
5959
}
6060

6161
private static ColumnTypeInference.InferenceResult InferColumnTypes(MLContext context, TextFileSample sample,
62-
TextFileContents.ColumnSplitResult splitInference)
62+
TextFileContents.ColumnSplitResult splitInference, bool hasHeader)
6363
{
6464
// infer column types
6565
var typeInferenceResult = ColumnTypeInference.InferTextFileColumnTypes(context, sample,
@@ -69,6 +69,7 @@ private static ColumnTypeInference.InferenceResult InferColumnTypes(MLContext co
6969
Separator = splitInference.Separator,
7070
AllowSparse = splitInference.AllowSparse,
7171
AllowQuote = splitInference.AllowQuote,
72+
HasHeader = hasHeader
7273
});
7374

7475
if (!typeInferenceResult.IsSuccess)

src/AutoML/ColumnInference/ColumnTypeInference.cs

+3-11
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ internal sealed class Arguments
2727
public bool AllowSparse;
2828
public bool AllowQuote;
2929
public int ColumnCount;
30+
public bool HasHeader;
3031
public int MaxRowsToRead;
3132

3233
public Arguments()
@@ -325,23 +326,14 @@ private static InferenceResult InferTextFileColumnTypesCore(MLContext env, IMult
325326
suspect--;
326327
}
327328

328-
// REVIEW: Why not use this for column names as well?
329-
TextLoader.Arguments fileArgs;
330-
bool hasHeader;
331-
if (TextLoader.FileContainsValidSchema(env, fileSource, out fileArgs))
332-
hasHeader = fileArgs.HasHeader;
333-
else
334-
hasHeader = suspect > 0;
335-
hasHeader = true;
336-
337329
// suggest names
338330
var names = new List<string>();
339331
usedNames.Clear();
340332
foreach (var col in cols)
341333
{
342334
string name0;
343335
string name;
344-
name0 = name = SuggestName(col, hasHeader);
336+
name0 = name = SuggestName(col, args.HasHeader);
345337
int i = 0;
346338
while (!usedNames.Add(name))
347339
name = string.Format("{0}_{1:00}", name0, i++);
@@ -352,7 +344,7 @@ private static InferenceResult InferTextFileColumnTypesCore(MLContext env, IMult
352344

353345
var numerics = outCols.Count(x => x.ItemType.IsNumber());
354346

355-
return InferenceResult.Success(outCols, hasHeader, cols.Select(col => col.RawData).ToArray());
347+
return InferenceResult.Success(outCols, args.HasHeader, cols.Select(col => col.RawData).ToArray());
356348
}
357349

358350
private static string SuggestName(IntermediateColumn column, bool hasHeader)

src/AutoML/PipelineSuggesters/PipelineSuggester.cs

+17-6
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,18 @@ internal static class PipelineSuggester
1212
{
1313
private const int TopKTrainers = 3;
1414

15-
public static InferredPipeline GetNextPipeline(IEnumerable<PipelineRunResult> history,
15+
public static Pipeline GetNextPipeline(IEnumerable<PipelineRunResult> history,
16+
IEnumerable<SuggestedTransform> transforms,
17+
IEnumerable<SuggestedTrainer> availableTrainers,
18+
bool isMaximizingMetric = true)
19+
{
20+
var inferredHistory = history.Select(r => InferredPipelineRunResult.FromPipelineRunResult(r));
21+
var nextInferredPipeline = GetNextInferredPipeline(inferredHistory,
22+
transforms, availableTrainers, isMaximizingMetric);
23+
return nextInferredPipeline.ToPipeline();
24+
}
25+
26+
public static InferredPipeline GetNextInferredPipeline(IEnumerable<InferredPipelineRunResult> history,
1627
IEnumerable<SuggestedTransform> transforms,
1728
IEnumerable<SuggestedTrainer> availableTrainers,
1829
bool isMaximizingMetric = true)
@@ -49,15 +60,15 @@ public static InferredPipeline GetNextPipeline(IEnumerable<PipelineRunResult> hi
4960
/// <summary>
5061
/// Get top trainers from first stage
5162
/// </summary>
52-
private static IEnumerable<SuggestedTrainer> GetTopTrainers(IEnumerable<PipelineRunResult> history,
63+
private static IEnumerable<SuggestedTrainer> GetTopTrainers(IEnumerable<InferredPipelineRunResult> history,
5364
IEnumerable<SuggestedTrainer> availableTrainers,
5465
bool isMaximizingMetric)
5566
{
5667
// narrow history to first stage runs
5768
history = history.Take(availableTrainers.Count());
5869

5970
history = history.GroupBy(r => r.Pipeline.Trainer.TrainerName).Select(g => g.First());
60-
IEnumerable<PipelineRunResult> sortedHistory = history.OrderBy(r => r.Score);
71+
IEnumerable<InferredPipelineRunResult> sortedHistory = history.OrderBy(r => r.Score);
6172
if(isMaximizingMetric)
6273
{
6374
sortedHistory = sortedHistory.Reverse();
@@ -66,7 +77,7 @@ private static IEnumerable<SuggestedTrainer> GetTopTrainers(IEnumerable<Pipeline
6677
return topTrainers;
6778
}
6879

69-
private static InferredPipeline GetNextFirstStagePipeline(IEnumerable<PipelineRunResult> history,
80+
private static InferredPipeline GetNextFirstStagePipeline(IEnumerable<InferredPipelineRunResult> history,
7081
IEnumerable<SuggestedTrainer> availableTrainers,
7182
IEnumerable<SuggestedTransform> transforms)
7283
{
@@ -133,7 +144,7 @@ private static IValueGenerator[] ConvertToValueGenerators(IEnumerable<SweepableP
133144
return results;
134145
}
135146

136-
private static void SampleHyperparameters(SuggestedTrainer trainer, IEnumerable<PipelineRunResult> history, bool isMaximizingMetric)
147+
private static void SampleHyperparameters(SuggestedTrainer trainer, IEnumerable<InferredPipelineRunResult> history, bool isMaximizingMetric)
137148
{
138149
var sps = ConvertToValueGenerators(trainer.SweepParams);
139150
var sweeper = new SmacSweeper(
@@ -142,7 +153,7 @@ private static void SampleHyperparameters(SuggestedTrainer trainer, IEnumerable<
142153
SweptParameters = sps
143154
});
144155

145-
IEnumerable<PipelineRunResult> historyToUse = history
156+
IEnumerable<InferredPipelineRunResult> historyToUse = history
146157
.Where(r => r.RunSucceded && r.Pipeline.Trainer.TrainerName == trainer.TrainerName && r.Pipeline.Trainer.HyperParamSet != null);
147158

148159
// get new set of hyperparameter values

src/Test/DatasetUtil.cs

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using System;
2+
using System.IO;
3+
using System.Net;
4+
using Microsoft.ML.Data;
5+
6+
namespace Microsoft.ML.Auto.Test
7+
{
8+
internal static class DatasetUtil
9+
{
10+
public const string UciAdultLabel = DefaultColumnNames.Label;
11+
12+
private static IDataView _uciAdultDataView;
13+
14+
public static IDataView GetUciAdultDataView()
15+
{
16+
if(_uciAdultDataView == null)
17+
{
18+
var uciAdultDataFile = DownloadUciAdultDataset();
19+
_uciAdultDataView = (new MLContext()).Data.AutoRead(uciAdultDataFile, UciAdultLabel, true);
20+
}
21+
return _uciAdultDataView;
22+
}
23+
24+
// downloads the UCI Adult dataset from the ML.Net repo
25+
private static string DownloadUciAdultDataset() =>
26+
DownloadIfNotExists("https://raw.githubusercontent.com/dotnet/machinelearning/f0e639af5ffdc839aae8e65d19b5a9a1f0db634a/test/data/adult.tiny.with-schema.txt", "uciadult.dataset");
27+
28+
private static string DownloadIfNotExists(string baseGitPath, string dataFile)
29+
{
30+
// if file doesn't already exist, download it
31+
if(!File.Exists(dataFile))
32+
{
33+
using (var client = new WebClient())
34+
{
35+
client.DownloadFile(new Uri($"{baseGitPath}"), dataFile);
36+
}
37+
}
38+
39+
return dataFile;
40+
}
41+
}
42+
}

0 commit comments

Comments
 (0)