Skip to content

Commit 05482eb

Browse files
Adding support for training metrics in PipelineSweeperMacro and needed support files. Also includes new output information in PipelineSweeperMacro output graph to make consumption of returned pipelines easier.
1 parent 3102180 commit 05482eb

File tree

8 files changed

+49190
-259
lines changed

8 files changed

+49190
-259
lines changed

src/Microsoft.ML.PipelineInference/AutoInference.cs

+10-7
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,14 @@ private bool GetDataVariableName(IExceptionContext ectx, string nameOfData, JTok
172172
public sealed class RunSummary
173173
{
174174
public double MetricValue { get; }
175+
public double TrainingMetricValue { get; }
175176
public int NumRowsInTraining { get; }
176177
public long RunTimeMilliseconds { get; }
177178

178-
public RunSummary(double metricValue, int numRows, long runTimeMilliseconds)
179+
public RunSummary(double metricValue, int numRows, long runTimeMilliseconds, double trainingMetricValue)
179180
{
180181
MetricValue = metricValue;
182+
TrainingMetricValue = trainingMetricValue;
181183
NumRowsInTraining = numRows;
182184
RunTimeMilliseconds = runTimeMilliseconds;
183185
}
@@ -303,7 +305,7 @@ private void MainLearningLoop(int batchSize, int numOfTrainingRows)
303305
var stopwatch = new Stopwatch();
304306
var probabilityUtils = new Sweeper.Algorithms.SweeperProbabilityUtils(_host);
305307

306-
while (!_terminator.ShouldTerminate(_history))
308+
while (!_terminator.ShouldTerminate(_history))
307309
{
308310
// Get next set of candidates
309311
var currentBatchSize = batchSize;
@@ -341,16 +343,17 @@ private void ProcessPipeline(Sweeper.Algorithms.SweeperProbabilityUtils utils, S
341343

342344
// Run pipeline, and time how long it takes
343345
stopwatch.Restart();
344-
double d = candidate.RunTrainTestExperiment(_trainData.Take(randomizedNumberOfRows),
345-
_testData, Metric, TrainerKind);
346+
candidate.RunTrainTestExperiment(_trainData.Take(randomizedNumberOfRows),
347+
_testData, Metric, TrainerKind, out var testMetricVal, out var trainMetricVal);
346348
stopwatch.Stop();
347349

348350
// Handle key collisions on sorted list
349-
while (_sortedSampledElements.ContainsKey(d))
350-
d += 1e-10;
351+
while (_sortedSampledElements.ContainsKey(testMetricVal))
352+
testMetricVal += 1e-10;
351353

352354
// Save performance score
353-
candidate.PerformanceSummary = new RunSummary(d, randomizedNumberOfRows, stopwatch.ElapsedMilliseconds);
355+
candidate.PerformanceSummary =
356+
new RunSummary(testMetricVal, randomizedNumberOfRows, stopwatch.ElapsedMilliseconds, trainMetricVal);
354357
_sortedSampledElements.Add(candidate.PerformanceSummary.MetricValue, candidate);
355358
_history.Add(candidate);
356359
}

src/Microsoft.ML.PipelineInference/AutoMlUtils.cs

+20-4
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,35 @@ namespace Microsoft.ML.Runtime.PipelineInference
1515
{
1616
public static class AutoMlUtils
1717
{
18-
public static AutoInference.RunSummary ExtractRunSummary(IHostEnvironment env, IDataView data, string metricColumnName)
18+
public static AutoInference.RunSummary ExtractRunSummary(IHostEnvironment env, IDataView result, string metricColumnName, IDataView trainResult = null)
1919
{
2020
double metricValue = 0;
21+
double trainingMetricValue = -1d;
2122
int numRows = 0;
22-
var schema = data.Schema;
23+
var schema = result.Schema;
2324
schema.TryGetColumnIndex(metricColumnName, out var metricCol);
2425

25-
using (var cursor = data.GetRowCursor(col => col == metricCol))
26+
using (var cursor = result.GetRowCursor(col => col == metricCol))
2627
{
2728
var getter = cursor.GetGetter<double>(metricCol);
2829
cursor.MoveNext();
2930
getter(ref metricValue);
3031
}
3132

32-
return new AutoInference.RunSummary(metricValue, numRows, 0);
33+
if (trainResult != null)
34+
{
35+
var trainSchema = trainResult.Schema;
36+
trainSchema.TryGetColumnIndex(metricColumnName, out var trainingMetricCol);
37+
38+
using (var cursor = trainResult.GetRowCursor(col => col == trainingMetricCol))
39+
{
40+
var getter = cursor.GetGetter<double>(trainingMetricCol);
41+
cursor.MoveNext();
42+
getter(ref trainingMetricValue);
43+
}
44+
}
45+
46+
return new AutoInference.RunSummary(metricValue, numRows, 0, trainingMetricValue);
3347
}
3448

3549
public static CommonInputs.IEvaluatorInput CloneEvaluatorInstance(CommonInputs.IEvaluatorInput evalInput) =>
@@ -618,5 +632,7 @@ public static Tuple<string, string[]>[] ConvertToSweepArgumentStrings(TlcModule.
618632
}
619633
return results;
620634
}
635+
636+
public static string GenerateOverallTrainingMetricVarName(Guid id) => $"Var_Training_OM_{id:N}";
621637
}
622638
}

src/Microsoft.ML.PipelineInference/Macros/PipelineSweeperMacro.cs

+14-5
Original file line numberDiff line numberDiff line change
@@ -65,18 +65,24 @@ public static Output ExtractSweepResult(IHostEnvironment env, ResultInput input)
6565
var col1 = new KeyValuePair<string, ColumnType>("Graph", TextType.Instance);
6666
var col2 = new KeyValuePair<string, ColumnType>("MetricValue", PrimitiveType.FromKind(DataKind.R8));
6767
var col3 = new KeyValuePair<string, ColumnType>("PipelineId", TextType.Instance);
68+
var col4 = new KeyValuePair<string, ColumnType>("TrainingMetricValue", PrimitiveType.FromKind(DataKind.R8));
69+
var col5 = new KeyValuePair<string, ColumnType>("FirstInput", TextType.Instance);
70+
var col6 = new KeyValuePair<string, ColumnType>("PredictorModel", TextType.Instance);
6871

6972
if (rows.Count == 0)
7073
{
7174
var host = env.Register("ExtractSweepResult");
72-
outputView = new EmptyDataView(host, new SimpleSchema(host, col1, col2, col3));
75+
outputView = new EmptyDataView(host, new SimpleSchema(host, col1, col2, col3, col4, col5, col6));
7376
}
7477
else
7578
{
7679
var builder = new ArrayDataViewBuilder(env);
7780
builder.AddColumn(col1.Key, (PrimitiveType)col1.Value, rows.Select(r => new DvText(r.GraphJson)).ToArray());
7881
builder.AddColumn(col2.Key, (PrimitiveType)col2.Value, rows.Select(r => r.MetricValue).ToArray());
7982
builder.AddColumn(col3.Key, (PrimitiveType)col3.Value, rows.Select(r => new DvText(r.PipelineId)).ToArray());
83+
builder.AddColumn(col4.Key, (PrimitiveType)col4.Value, rows.Select(r => r.TrainingMetricValue).ToArray());
84+
builder.AddColumn(col5.Key, (PrimitiveType)col5.Value, rows.Select(r => new DvText(r.FirstInput)).ToArray());
85+
builder.AddColumn(col6.Key, (PrimitiveType)col6.Value, rows.Select(r => new DvText(r.PredictorModel)).ToArray());
8086
outputView = builder.GetDataView();
8187
}
8288
return new Output { Results = outputView, State = autoMlState };
@@ -132,11 +138,11 @@ public static CommonOutputs.MacroOutput<Output> PipelineSweep(
132138
// Extract performance summaries and assign to previous candidate pipelines.
133139
foreach (var pipeline in autoMlState.BatchCandidates)
134140
{
135-
if (node.Context.TryGetVariable(ExperimentUtils.GenerateOverallMetricVarName(pipeline.UniqueId),
136-
out var v))
141+
if (node.Context.TryGetVariable(ExperimentUtils.GenerateOverallMetricVarName(pipeline.UniqueId), out var v) &&
142+
node.Context.TryGetVariable(AutoMlUtils.GenerateOverallTrainingMetricVarName(pipeline.UniqueId), out var v2))
137143
{
138144
pipeline.PerformanceSummary =
139-
AutoMlUtils.ExtractRunSummary(env, (IDataView)v.Value, autoMlState.Metric.Name);
145+
AutoMlUtils.ExtractRunSummary(env, (IDataView)v.Value, autoMlState.Metric.Name, (IDataView)v2.Value);
140146
autoMlState.AddEvaluated(pipeline);
141147
}
142148
}
@@ -168,14 +174,17 @@ public static CommonOutputs.MacroOutput<Output> PipelineSweep(
168174
{
169175
// Add train test experiments to current graph for candidate pipeline
170176
var subgraph = new Experiment(env);
171-
var trainTestOutput = p.AddAsTrainTest(training, testing, autoMlState.TrainerKind, subgraph);
177+
var trainTestOutput = p.AddAsTrainTest(training, testing, autoMlState.TrainerKind, subgraph, true);
172178

173179
// Change variable name to reference pipeline ID in output map, context and entrypoint output.
174180
var uniqueName = ExperimentUtils.GenerateOverallMetricVarName(p.UniqueId);
181+
var uniqueNameTraining = AutoMlUtils.GenerateOverallTrainingMetricVarName(p.UniqueId);
175182
var sgNode = EntryPointNode.ValidateNodes(env, node.Context,
176183
new JArray(subgraph.GetNodes().Last()), node.Catalog).Last();
177184
sgNode.RenameOutputVariable(trainTestOutput.OverallMetrics.VarName, uniqueName, cascadeChanges: true);
185+
sgNode.RenameOutputVariable(trainTestOutput.TrainingOverallMetrics.VarName, uniqueNameTraining, cascadeChanges: true);
178186
trainTestOutput.OverallMetrics.VarName = uniqueName;
187+
trainTestOutput.TrainingOverallMetrics.VarName = uniqueNameTraining;
179188
expNodes.Add(sgNode);
180189

181190
// Store indicators, to pass to next iteration of macro.

src/Microsoft.ML.PipelineInference/Microsoft.ML.PipelineInference.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
1818
<ProjectReference Include="..\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
1919
<ProjectReference Include="..\Microsoft.ML.Sweeper\Microsoft.ML.Sweeper.csproj" />
20+
<ProjectReference Include="..\Microsoft.ML\Microsoft.ML.csproj" />
2021
</ItemGroup>
2122

2223
</Project>

src/Microsoft.ML.PipelineInference/PipelinePattern.cs

+67-14
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,24 @@ public sealed class PipelineResultRow
2121
{
2222
public string GraphJson { get; }
2323
public double MetricValue { get; }
24+
public double TrainingMetricValue { get; }
2425
public string PipelineId { get; }
26+
public string FirstInput { get; }
27+
public string PredictorModel { get; }
2528

2629
public PipelineResultRow()
2730
{ }
2831

29-
public PipelineResultRow(string graphJson, double metricValue, string pipelineId)
32+
public PipelineResultRow(string graphJson, double metricValue,
33+
string pipelineId, double trainingMetricValue, string firstInput,
34+
string predictorModel)
3035
{
3136
GraphJson = graphJson;
3237
MetricValue = metricValue;
3338
PipelineId = pipelineId;
39+
TrainingMetricValue = trainingMetricValue;
40+
FirstInput = firstInput;
41+
PredictorModel = predictorModel;
3442
}
3543
}
3644

@@ -111,7 +119,8 @@ public AutoInference.EntryPointGraphDef ToEntryPointGraph(Experiment experiment
111119
public bool Equals(PipelinePattern obj) => obj != null && UniqueId == obj.UniqueId;
112120

113121
// REVIEW: We may want to allow for sweeping with CV in the future, so we will need to add new methods like this, or refactor these in that case.
114-
public Experiment CreateTrainTestExperiment(IDataView trainData, IDataView testData, MacroUtils.TrainerKinds trainerKind, out Models.TrainTestEvaluator.Output resultsOutput)
122+
public Experiment CreateTrainTestExperiment(IDataView trainData, IDataView testData, MacroUtils.TrainerKinds trainerKind,
123+
bool includeTrainingMetrics, out Models.TrainTestEvaluator.Output resultsOutput)
115124
{
116125
var graphDef = ToEntryPointGraph();
117126
var subGraph = graphDef.Graph;
@@ -136,7 +145,8 @@ public Experiment CreateTrainTestExperiment(IDataView trainData, IDataView testD
136145
Model = finalOutput
137146
},
138147
PipelineId = UniqueId.ToString("N"),
139-
Kind = MacroUtils.TrainerKindApiValue<Models.MacroUtilsTrainerKinds>(trainerKind)
148+
Kind = MacroUtils.TrainerKindApiValue<Models.MacroUtilsTrainerKinds>(trainerKind),
149+
IncludeTrainingMetrics = includeTrainingMetrics
140150
};
141151

142152
var experiment = _env.CreateExperiment();
@@ -150,7 +160,7 @@ public Experiment CreateTrainTestExperiment(IDataView trainData, IDataView testD
150160
}
151161

152162
public Models.TrainTestEvaluator.Output AddAsTrainTest(Var<IDataView> trainData, Var<IDataView> testData,
153-
MacroUtils.TrainerKinds trainerKind, Experiment experiment = null)
163+
MacroUtils.TrainerKinds trainerKind, Experiment experiment = null, bool includeTrainingMetrics = false)
154164
{
155165
experiment = experiment ?? _env.CreateExperiment();
156166
var graphDef = ToEntryPointGraph(experiment);
@@ -174,7 +184,8 @@ public Models.TrainTestEvaluator.Output AddAsTrainTest(Var<IDataView> trainData,
174184
TrainingData = trainData,
175185
TestingData = testData,
176186
Kind = MacroUtils.TrainerKindApiValue<Models.MacroUtilsTrainerKinds>(trainerKind),
177-
PipelineId = UniqueId.ToString("N")
187+
PipelineId = UniqueId.ToString("N"),
188+
IncludeTrainingMetrics = includeTrainingMetrics
178189
};
179190
var trainTestOutput = experiment.Add(trainTestInput);
180191
return trainTestOutput;
@@ -183,34 +194,58 @@ public Models.TrainTestEvaluator.Output AddAsTrainTest(Var<IDataView> trainData,
183194
/// <summary>
184195
/// Runs a train-test experiment on the current pipeline, through entrypoints.
185196
/// </summary>
186-
public double RunTrainTestExperiment(IDataView trainData, IDataView testData, AutoInference.SupportedMetric metric, MacroUtils.TrainerKinds trainerKind)
197+
public void RunTrainTestExperiment(IDataView trainData, IDataView testData,
198+
AutoInference.SupportedMetric metric, MacroUtils.TrainerKinds trainerKind, out double testMetricValue,
199+
out double trainMetricValue)
187200
{
188-
var experiment = CreateTrainTestExperiment(trainData, testData, trainerKind, out var trainTestOutput);
201+
var experiment = CreateTrainTestExperiment(trainData, testData, trainerKind, true, out var trainTestOutput);
189202
experiment.Run();
203+
190204
var dataOut = experiment.GetOutput(trainTestOutput.OverallMetrics);
191205
var schema = dataOut.Schema;
192206
schema.TryGetColumnIndex(metric.Name, out var metricCol);
207+
double metricValue = 0;
208+
double trainingMetricValue = 0;
193209

194210
using (var cursor = dataOut.GetRowCursor(col => col == metricCol))
195211
{
196212
var getter = cursor.GetGetter<double>(metricCol);
197-
double metricValue = 0;
198213
cursor.MoveNext();
199214
getter(ref metricValue);
200-
return metricValue;
215+
}
216+
217+
dataOut = experiment.GetOutput(trainTestOutput.TrainingOverallMetrics);
218+
schema = dataOut.Schema;
219+
schema.TryGetColumnIndex(metric.Name, out metricCol);
220+
221+
using (var cursor = dataOut.GetRowCursor(col => col == metricCol))
222+
{
223+
var getter = cursor.GetGetter<double>(metricCol);
224+
cursor.MoveNext();
225+
getter(ref trainingMetricValue);
226+
testMetricValue = metricValue;
227+
trainMetricValue = trainingMetricValue;
201228
}
202229
}
203230

204-
public static PipelineResultRow[] ExtractResults(IHostEnvironment env, IDataView data, string graphColName, string metricColName, string idColName)
231+
public static PipelineResultRow[] ExtractResults(IHostEnvironment env, IDataView data,
232+
string graphColName, string metricColName, string idColName, string trainingMetricColName,
233+
string firstInputColName, string predictorModelColName)
205234
{
206235
var results = new List<PipelineResultRow>();
207236
var schema = data.Schema;
208237
if (!schema.TryGetColumnIndex(graphColName, out var graphCol))
209238
throw env.ExceptNotSupp($"Column name {graphColName} not found");
210239
if (!schema.TryGetColumnIndex(metricColName, out var metricCol))
211240
throw env.ExceptNotSupp($"Column name {metricColName} not found");
241+
if (!schema.TryGetColumnIndex(trainingMetricColName, out var trainingMetricCol))
242+
throw env.ExceptNotSupp($"Column name {trainingMetricColName} not found");
212243
if (!schema.TryGetColumnIndex(idColName, out var pipelineIdCol))
213244
throw env.ExceptNotSupp($"Column name {idColName} not found");
245+
if (!schema.TryGetColumnIndex(firstInputColName, out var firstInputCol))
246+
throw env.ExceptNotSupp($"Column name {firstInputColName} not found");
247+
if (!schema.TryGetColumnIndex(predictorModelColName, out var predictorModelCol))
248+
throw env.ExceptNotSupp($"Column name {predictorModelColName} not found");
214249

215250
using (var cursor = data.GetRowCursor(col => true))
216251
{
@@ -225,15 +260,33 @@ public static PipelineResultRow[] ExtractResults(IHostEnvironment env, IDataView
225260
var getter3 = cursor.GetGetter<DvText>(pipelineIdCol);
226261
DvText pipelineId = new DvText();
227262
getter3(ref pipelineId);
228-
results.Add(new PipelineResultRow(graphJson.ToString(), metricValue, pipelineId.ToString()));
263+
var getter4 = cursor.GetGetter<double>(trainingMetricCol);
264+
double trainingMetricValue = 0;
265+
getter4(ref trainingMetricValue);
266+
var getter5 = cursor.GetGetter<DvText>(firstInputCol);
267+
DvText firstInput = new DvText();
268+
getter5(ref firstInput);
269+
var getter6 = cursor.GetGetter<DvText>(predictorModelCol);
270+
DvText predictorModel = new DvText();
271+
getter6(ref predictorModel);
272+
273+
results.Add(new PipelineResultRow(graphJson.ToString(),
274+
metricValue, pipelineId.ToString(), trainingMetricValue,
275+
firstInput.ToString(), predictorModel.ToString()));
229276
}
230277
}
231278

232279
return results.ToArray();
233280
}
234281

235-
public PipelineResultRow ToResultRow() =>
236-
new PipelineResultRow(ToEntryPointGraph().Graph.ToJsonString(),
237-
PerformanceSummary?.MetricValue ?? -1d, UniqueId.ToString("N"));
282+
public PipelineResultRow ToResultRow() {
283+
var graphDef = ToEntryPointGraph();
284+
285+
return new PipelineResultRow($"{{'Nodes' : [{graphDef.Graph.ToJsonString()}]}}",
286+
PerformanceSummary?.MetricValue ?? -1d, UniqueId.ToString("N"),
287+
PerformanceSummary?.TrainingMetricValue ?? -1d,
288+
graphDef.GetSubgraphFirstNodeDataVarName(_env),
289+
graphDef.ModelOutput.VarName);
290+
}
238291
}
239292
}

0 commit comments

Comments
 (0)