Skip to content

Commit 8d575ba

Browse files
daholsteDmitry-A
authored andcommitted
propagate root MLContext thru AutoML (instead of creating our own) (dotnet#182)
1 parent 7b46ccc commit 8d575ba

19 files changed

+127
-97
lines changed

src/Microsoft.ML.Auto/API/Pipeline.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ internal Pipeline()
2121
{
2222
}
2323

24-
public IEstimator<ITransformer> ToEstimator()
24+
public IEstimator<ITransformer> ToEstimator(MLContext context)
2525
{
26-
var inferredPipeline = SuggestedPipeline.FromPipeline(this);
26+
var inferredPipeline = SuggestedPipeline.FromPipeline(context, this);
2727
return inferredPipeline.ToEstimator();
2828
}
2929
}

src/Microsoft.ML.Auto/AutoFitter/AutoFitter.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public AutoFitter(MLContext context,
4343
{
4444
if (validationData == null)
4545
{
46-
(trainData, validationData) = context.Regression.TestValidateSplit(trainData);
46+
(trainData, validationData) = context.Regression.TestValidateSplit(context, trainData);
4747
}
4848
_trainData = trainData;
4949
_validationData = validationData;
@@ -85,7 +85,7 @@ public List<RunResult<T>> Fit()
8585
var getPiplelineStopwatch = Stopwatch.StartNew();
8686

8787
// get next pipeline
88-
pipeline = PipelineSuggester.GetNextInferredPipeline(_history, columns, _task, _optimizingMetricInfo.IsMaximizing, _trainerWhitelist);
88+
pipeline = PipelineSuggester.GetNextInferredPipeline(_context, _history, columns, _task, _optimizingMetricInfo.IsMaximizing, _trainerWhitelist);
8989

9090
getPiplelineStopwatch.Stop();
9191

src/Microsoft.ML.Auto/AutoFitter/SuggestedPipeline.cs

+5-7
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ internal class SuggestedPipeline
2323

2424
public SuggestedPipeline(IEnumerable<SuggestedTransform> transforms,
2525
SuggestedTrainer trainer,
26-
MLContext context = null,
26+
MLContext context,
2727
bool autoNormalize = true)
2828
{
2929
Transforms = transforms.Select(t => t.Clone()).ToList();
3030
Trainer = trainer.Clone();
31-
_context = context ?? new MLContext();
31+
_context = context;
3232

3333
if(autoNormalize)
3434
{
@@ -64,10 +64,8 @@ public Pipeline ToPipeline()
6464
return new Pipeline(pipelineElements.ToArray());
6565
}
6666

67-
public static SuggestedPipeline FromPipeline(Pipeline pipeline)
67+
public static SuggestedPipeline FromPipeline(MLContext context, Pipeline pipeline)
6868
{
69-
var context = new MLContext();
70-
7169
var transforms = new List<SuggestedTransform>();
7270
SuggestedTrainer trainer = null;
7371

@@ -84,13 +82,13 @@ public static SuggestedPipeline FromPipeline(Pipeline pipeline)
8482
{
8583
var estimatorName = (EstimatorName)Enum.Parse(typeof(EstimatorName), pipelineNode.Name);
8684
var estimatorExtension = EstimatorExtensionCatalog.GetExtension(estimatorName);
87-
var estimator = estimatorExtension.CreateInstance(new MLContext(), pipelineNode);
85+
var estimator = estimatorExtension.CreateInstance(context, pipelineNode);
8886
var transform = new SuggestedTransform(pipelineNode, estimator);
8987
transforms.Add(transform);
9088
}
9189
}
9290

93-
return new SuggestedPipeline(transforms, trainer, null, false);
91+
return new SuggestedPipeline(transforms, trainer, context, false);
9492
}
9593

9694
public IEstimator<ITransformer> ToEstimator()

src/Microsoft.ML.Auto/AutoFitter/SuggestedPipelineResult.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ public SuggestedPipelineResult(SuggestedPipeline pipeline, double score, bool ru
2020
RunSucceded = runSucceeded;
2121
}
2222

23-
public static SuggestedPipelineResult FromPipelineRunResult(PipelineScore pipelineRunResult)
23+
public static SuggestedPipelineResult FromPipelineRunResult(MLContext context, PipelineScore pipelineRunResult)
2424
{
25-
return new SuggestedPipelineResult(SuggestedPipeline.FromPipeline(pipelineRunResult.Pipeline), pipelineRunResult.Score, pipelineRunResult.RunSucceded);
25+
return new SuggestedPipelineResult(SuggestedPipeline.FromPipeline(context, pipelineRunResult.Pipeline), pipelineRunResult.Score, pipelineRunResult.RunSucceded);
2626
}
2727

2828
public IRunResult ToRunResult(bool isMetricMaximizing)

src/Microsoft.ML.Auto/AutoMlUtils.cs

+9-10
Original file line numberDiff line numberDiff line change
@@ -23,37 +23,36 @@ public static void Assert(bool boolVal, string message = null)
2323
}
2424
}
2525

26-
public static IDataView Take(this IDataView data, int count)
26+
public static IDataView Take(this IDataView data, MLContext context, int count)
2727
{
28-
var context = new MLContext();
2928
return TakeFilter.Create(context, data, count);
3029
}
3130

32-
public static IDataView DropLastColumn(this IDataView data)
31+
public static IDataView DropLastColumn(this IDataView data, MLContext context)
3332
{
34-
return new MLContext().Transforms.DropColumns(data.Schema[data.Schema.Count - 1].Name).Fit(data).Transform(data);
33+
return context.Transforms.DropColumns(data.Schema[data.Schema.Count - 1].Name).Fit(data).Transform(data);
3534
}
3635

37-
public static (IDataView testData, IDataView validationData) TestValidateSplit(this TrainCatalogBase catalog, IDataView trainData)
36+
public static (IDataView testData, IDataView validationData) TestValidateSplit(this TrainCatalogBase catalog,
37+
MLContext context, IDataView trainData)
3838
{
3939
IDataView validationData;
4040
(trainData, validationData) = catalog.TrainTestSplit(trainData);
41-
trainData = trainData.DropLastColumn();
42-
validationData = validationData.DropLastColumn();
41+
trainData = trainData.DropLastColumn(context);
42+
validationData = validationData.DropLastColumn(context);
4343
return (trainData, validationData);
4444
}
4545

46-
public static IDataView Skip(this IDataView data, int count)
46+
public static IDataView Skip(this IDataView data, MLContext context, int count)
4747
{
48-
var context = new MLContext();
4948
return SkipFilter.Create(context, data, count);
5049
}
5150

5251
public static (string, ColumnType, ColumnPurpose, ColumnDimensions)[] GetColumnInfoTuples(MLContext context,
5352
IDataView data, ColumnInformation columnInfo)
5453
{
5554
var purposes = PurposeInference.InferPurposes(context, data, columnInfo);
56-
var colDimensions = DatasetDimensionsApi.CalcColumnDimensions(data, purposes);
55+
var colDimensions = DatasetDimensionsApi.CalcColumnDimensions(context, data, purposes);
5756
var cols = new (string, ColumnType, ColumnPurpose, ColumnDimensions)[data.Schema.Count];
5857
for (var i = 0; i < cols.Length; i++)
5958
{

src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public static ColumnInferenceResults InferColumns(MLContext context, string path
1414
bool hasHeader, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns)
1515
{
1616
var sample = TextFileSample.CreateFromFullFile(path);
17-
var splitInference = InferSplit(sample, separatorChar, allowQuotedStrings, supportSparse);
17+
var splitInference = InferSplit(context, sample, separatorChar, allowQuotedStrings, supportSparse);
1818
var typeInference = InferColumnTypes(context, sample, splitInference, hasHeader, labelColumnIndex, null);
1919

2020
// if no column is named label,
@@ -32,7 +32,7 @@ public static ColumnInferenceResults InferColumns(MLContext context, string path
3232
char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns)
3333
{
3434
var sample = TextFileSample.CreateFromFullFile(path);
35-
var splitInference = InferSplit(sample, separatorChar, allowQuotedStrings, supportSparse);
35+
var splitInference = InferSplit(context, sample, separatorChar, allowQuotedStrings, supportSparse);
3636
var typeInference = InferColumnTypes(context, sample, splitInference, true, null, label);
3737
return InferColumns(context, path, label, true, splitInference, typeInference, trimWhitespace, groupColumns);
3838
}
@@ -93,10 +93,10 @@ public static ColumnInferenceResults InferColumns(MLContext context, string path
9393
};
9494
}
9595

96-
private static TextFileContents.ColumnSplitResult InferSplit(TextFileSample sample, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse)
96+
private static TextFileContents.ColumnSplitResult InferSplit(MLContext context, TextFileSample sample, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse)
9797
{
9898
var separatorCandidates = separatorChar == null ? TextFileContents.DefaultSeparators : new char[] { separatorChar.Value };
99-
var splitInference = TextFileContents.TrySplitColumns(sample, separatorCandidates);
99+
var splitInference = TextFileContents.TrySplitColumns(context, sample, separatorCandidates);
100100

101101
// respect passed-in overrides
102102
if (allowQuotedStrings != null)

src/Microsoft.ML.Auto/ColumnInference/ColumnTypeInference.cs

+5-5
Original file line numberDiff line numberDiff line change
@@ -236,12 +236,12 @@ private static IEnumerable<ITypeInferenceExpert> GetExperts()
236236
/// <summary>
237237
/// Auto-detect column types of the file.
238238
/// </summary>
239-
public static InferenceResult InferTextFileColumnTypes(MLContext env, IMultiStreamSource fileSource, Arguments args)
239+
public static InferenceResult InferTextFileColumnTypes(MLContext context, IMultiStreamSource fileSource, Arguments args)
240240
{
241-
return InferTextFileColumnTypesCore(env, fileSource, args);
241+
return InferTextFileColumnTypesCore(context, fileSource, args);
242242
}
243243

244-
private static InferenceResult InferTextFileColumnTypesCore(MLContext env, IMultiStreamSource fileSource, Arguments args)
244+
private static InferenceResult InferTextFileColumnTypesCore(MLContext context, IMultiStreamSource fileSource, Arguments args)
245245
{
246246
if (args.ColumnCount == 0)
247247
{
@@ -263,9 +263,9 @@ private static InferenceResult InferTextFileColumnTypesCore(MLContext env, IMult
263263
AllowSparse = args.AllowSparse,
264264
AllowQuoting = args.AllowQuote,
265265
};
266-
var textLoader = new TextLoader(env, textLoaderArgs);
266+
var textLoader = new TextLoader(context, textLoaderArgs);
267267
var idv = textLoader.Read(fileSource);
268-
idv = idv.Take(args.MaxRowsToRead);
268+
idv = idv.Take(context, args.MaxRowsToRead);
269269

270270
// read all the data into memory.
271271
// list items are rows of the dataset.

src/Microsoft.ML.Auto/ColumnInference/PurposeInference.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ private static IEnumerable<IPurposeInferenceExpert> GetExperts()
266266
public static PurposeInference.Column[] InferPurposes(MLContext context, IDataView data,
267267
ColumnInformation columnInfo)
268268
{
269-
data = data.Take(MaxRowsToRead);
269+
data = data.Take(context, MaxRowsToRead);
270270

271271
var allColumns = new List<IntermediateColumn>();
272272
var columnsToInfer = new List<IntermediateColumn>();

src/Microsoft.ML.Auto/ColumnInference/TextFileContents.cs

+6-5
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public ColumnSplitResult(bool isSuccess, char? separator, bool allowQuote, bool
4646
/// and this number of columns is more than 1.
4747
/// We sweep on separator, allow sparse and allow quote parameter.
4848
/// </summary>
49-
public static ColumnSplitResult TrySplitColumns(IMultiStreamSource source, char[] separatorCandidates)
49+
public static ColumnSplitResult TrySplitColumns(MLContext context, IMultiStreamSource source, char[] separatorCandidates)
5050
{
5151
var sparse = new[] { true, false };
5252
var quote = new[] { true, false };
@@ -69,7 +69,7 @@ from _sep in separatorCandidates
6969
AllowSparse = perm._allowSparse
7070
};
7171

72-
if (TryParseFile(args, source, out result))
72+
if (TryParseFile(context, args, source, out result))
7373
{
7474
foundAny = true;
7575
break;
@@ -78,15 +78,16 @@ from _sep in separatorCandidates
7878
return foundAny ? result : new ColumnSplitResult(false, null, true, true, 0);
7979
}
8080

81-
private static bool TryParseFile(TextLoader.Arguments args, IMultiStreamSource source, out ColumnSplitResult result)
81+
private static bool TryParseFile(MLContext context, TextLoader.Arguments args, IMultiStreamSource source,
82+
out ColumnSplitResult result)
8283
{
8384
result = null;
8485
// try to instantiate data view with swept arguments
8586
try
8687
{
8788

88-
var textLoader = new TextLoader(new MLContext(), args, source);
89-
var idv = textLoader.Read(source).Take(1000);
89+
var textLoader = new TextLoader(context, args, source);
90+
var idv = textLoader.Read(source).Take(context, 1000);
9091
var columnCounts = new List<int>();
9192
var column = idv.Schema["C"];
9293
var columnIndex = column.Index;

src/Microsoft.ML.Auto/DatasetDimensions/DatasetDimensionsApi.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ internal class DatasetDimensionsApi
1010
{
1111
private const int MaxRowsToRead = 1000;
1212

13-
public static ColumnDimensions[] CalcColumnDimensions(IDataView data, PurposeInference.Column[] purposes)
13+
public static ColumnDimensions[] CalcColumnDimensions(MLContext context, IDataView data, PurposeInference.Column[] purposes)
1414
{
15-
data = data.Take(MaxRowsToRead);
15+
data = data.Take(context, MaxRowsToRead);
1616

1717
var colDimensions = new ColumnDimensions[data.Schema.Count];
1818

src/Microsoft.ML.Auto/PipelineSuggesters/PipelineSuggester.cs

+14-13
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,32 @@ internal static class PipelineSuggester
1414
{
1515
private const int TopKTrainers = 3;
1616

17-
public static Pipeline GetNextPipeline(IEnumerable<PipelineScore> history,
17+
public static Pipeline GetNextPipeline(MLContext context,
18+
IEnumerable<PipelineScore> history,
1819
(string, ColumnType, ColumnPurpose, ColumnDimensions)[] columns,
1920
TaskKind task,
2021
bool isMaximizingMetric = true)
2122
{
22-
var inferredHistory = history.Select(r => SuggestedPipelineResult.FromPipelineRunResult(r));
23-
var nextInferredPipeline = GetNextInferredPipeline(inferredHistory, columns, task, isMaximizingMetric);
23+
var inferredHistory = history.Select(r => SuggestedPipelineResult.FromPipelineRunResult(context, r));
24+
var nextInferredPipeline = GetNextInferredPipeline(context, inferredHistory, columns, task, isMaximizingMetric);
2425
return nextInferredPipeline?.ToPipeline();
2526
}
2627

27-
public static SuggestedPipeline GetNextInferredPipeline(IEnumerable<SuggestedPipelineResult> history,
28+
public static SuggestedPipeline GetNextInferredPipeline(MLContext context,
29+
IEnumerable<SuggestedPipelineResult> history,
2830
(string, ColumnType, ColumnPurpose, ColumnDimensions)[] columns,
2931
TaskKind task,
3032
bool isMaximizingMetric,
3133
IEnumerable<TrainerName> trainerWhitelist = null)
3234
{
33-
var context = new MLContext();
34-
3535
var availableTrainers = RecipeInference.AllowedTrainers(context, task, trainerWhitelist);
3636
var transforms = CalculateTransforms(context, columns, task);
3737
//var transforms = TransformInferenceApi.InferTransforms(context, columns, task);
3838

3939
// if we haven't run all pipelines once
4040
if (history.Count() < availableTrainers.Count())
4141
{
42-
return GetNextFirstStagePipeline(history, availableTrainers, transforms);
42+
return GetNextFirstStagePipeline(context, history, availableTrainers, transforms);
4343
}
4444

4545
// get top trainers from stage 1 runs
@@ -63,14 +63,14 @@ public static SuggestedPipeline GetNextInferredPipeline(IEnumerable<SuggestedPip
6363
do
6464
{
6565
// sample new hyperparameters for the learner
66-
if (!SampleHyperparameters(newTrainer, history, isMaximizingMetric))
66+
if (!SampleHyperparameters(context, newTrainer, history, isMaximizingMetric))
6767
{
6868
// if unable to sample new hyperparameters for the learner
6969
// (ie SMAC returned 0 suggestions), break
7070
break;
7171
}
7272

73-
var suggestedPipeline = new SuggestedPipeline(transforms, newTrainer);
73+
var suggestedPipeline = new SuggestedPipeline(transforms, newTrainer, context);
7474

7575
// make sure we have not seen pipeline before
7676
if (!visitedPipelines.Contains(suggestedPipeline))
@@ -113,12 +113,13 @@ private static IEnumerable<SuggestedTrainer> OrderTrainersByNumTrials(IEnumerabl
113113
.Select(x => x.First().Pipeline.Trainer);
114114
}
115115

116-
private static SuggestedPipeline GetNextFirstStagePipeline(IEnumerable<SuggestedPipelineResult> history,
116+
private static SuggestedPipeline GetNextFirstStagePipeline(MLContext context,
117+
IEnumerable<SuggestedPipelineResult> history,
117118
IEnumerable<SuggestedTrainer> availableTrainers,
118119
IEnumerable<SuggestedTransform> transforms)
119120
{
120121
var trainer = availableTrainers.ElementAt(history.Count());
121-
return new SuggestedPipeline(transforms, trainer);
122+
return new SuggestedPipeline(transforms, trainer, context);
122123
}
123124

124125
private static IValueGenerator[] ConvertToValueGenerators(IEnumerable<SweepableParam> hps)
@@ -184,10 +185,10 @@ private static IValueGenerator[] ConvertToValueGenerators(IEnumerable<SweepableP
184185
/// Samples new hyperparameters for the trainer, and sets them.
185186
/// Returns true if success (new hyperparams were suggested and set). Else, returns false.
186187
/// </summary>
187-
private static bool SampleHyperparameters(SuggestedTrainer trainer, IEnumerable<SuggestedPipelineResult> history, bool isMaximizingMetric)
188+
private static bool SampleHyperparameters(MLContext context, SuggestedTrainer trainer, IEnumerable<SuggestedPipelineResult> history, bool isMaximizingMetric)
188189
{
189190
var sps = ConvertToValueGenerators(trainer.SweepParams);
190-
var sweeper = new SmacSweeper(
191+
var sweeper = new SmacSweeper(context,
191192
new SmacSweeper.Arguments
192193
{
193194
SweptParameters = sps

src/Microsoft.ML.Auto/Sweepers/SmacSweeper.cs

+3-2
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,13 @@ public sealed class Arguments
5555

5656
private readonly ISweeper _randomSweeper;
5757
private readonly Arguments _args;
58-
private readonly MLContext _context = new MLContext();
58+
private readonly MLContext _context;
5959

6060
private readonly IValueGenerator[] _sweepParameters;
6161

62-
public SmacSweeper(Arguments args)
62+
public SmacSweeper(MLContext context, Arguments args)
6363
{
64+
_context = context;
6465
_args = args;
6566
_sweepParameters = args.SweptParameters;
6667
_randomSweeper = new UniformRandomSweeper(new SweeperBase.ArgumentsBase(), _sweepParameters);

0 commit comments

Comments
 (0)