Skip to content

Commit 2207a27

Browse files
Adding support for training metrics in PipelineSweeperMacro + new graph variable outputs (#152)
* Adding support for training metrics in PipelineSweeperMacro and needed support files. Also includes new output information in PipelineSweeperMacro output graph to make consumption of returned pipelines easier. * Changed where XML comment was placed. * Added more tests (uncommented and fixed) for auto inference. Changed magic number in AutoMlUtils to be mix double value, per review comment. * Added another test, TestPipelineSweeperMacroNoTransforms. * Updated tests to include warning disabling (following Zeeshan S's example) to get build working. * Changes to checks in AutoMlUtils (more correct usage of them). * Fixing issue with ExceptParam using value and not name of parameters. * Fixing errors on use of ExceptParam.
1 parent 7a5b303 commit 2207a27

File tree

7 files changed

+506
-385
lines changed

7 files changed

+506
-385
lines changed

src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs

+7-7
Original file line numberDiff line numberDiff line change
@@ -832,21 +832,21 @@ public static class SweepableDiscreteParam
832832
public static class PipelineSweeperSupportedMetrics
833833
{
834834
public new static string ToString() => "SupportedMetric";
835-
public const string Auc = "Auc";
835+
public const string Auc = "AUC";
836836
public const string AccuracyMicro = "AccuracyMicro";
837837
public const string AccuracyMacro = "AccuracyMacro";
838838
public const string F1 = "F1";
839-
public const string AuPrc = "AuPrc";
839+
public const string AuPrc = "AUPRC";
840840
public const string TopKAccuracy = "TopKAccuracy";
841841
public const string L1 = "L1";
842842
public const string L2 = "L2";
843-
public const string Rms = "Rms";
843+
public const string Rms = "RMS";
844844
public const string LossFn = "LossFn";
845845
public const string RSquared = "RSquared";
846846
public const string LogLoss = "LogLoss";
847847
public const string LogLossReduction = "LogLossReduction";
848-
public const string Ndcg = "Ndcg";
849-
public const string Dcg = "Dcg";
848+
public const string Ndcg = "NDCG";
849+
public const string Dcg = "DCG";
850850
public const string PositivePrecision = "PositivePrecision";
851851
public const string PositiveRecall = "PositiveRecall";
852852
public const string NegativePrecision = "NegativePrecision";
@@ -858,9 +858,9 @@ public static class PipelineSweeperSupportedMetrics
858858
public const string ThreshAtK = "ThreshAtK";
859859
public const string ThreshAtP = "ThreshAtP";
860860
public const string ThreshAtNumPos = "ThreshAtNumPos";
861-
public const string Nmi = "Nmi";
861+
public const string Nmi = "NMI";
862862
public const string AvgMinScore = "AvgMinScore";
863-
public const string Dbi = "Dbi";
863+
public const string Dbi = "DBI";
864864
}
865865
}
866866
}

src/Microsoft.ML.PipelineInference/AutoInference.cs

+12-8
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ private bool GetDataVariableName(IExceptionContext ectx, string nameOfData, JTok
158158
return false;
159159

160160
string dataVar = firstNodeInputs.Value<String>(nameOfData);
161-
ectx.Check(VariableBinding.IsValidVariableName(ectx, dataVar), $"Invalid variable name {dataVar}.");
161+
if (!VariableBinding.IsValidVariableName(ectx, dataVar))
162+
throw ectx.ExceptParam(nameof(nameOfData), $"Invalid variable name {dataVar}.");
162163

163164
variableName = dataVar.Substring(1);
164165
return true;
@@ -172,12 +173,14 @@ private bool GetDataVariableName(IExceptionContext ectx, string nameOfData, JTok
172173
public sealed class RunSummary
173174
{
174175
public double MetricValue { get; }
176+
public double TrainingMetricValue { get; }
175177
public int NumRowsInTraining { get; }
176178
public long RunTimeMilliseconds { get; }
177179

178-
public RunSummary(double metricValue, int numRows, long runTimeMilliseconds)
180+
public RunSummary(double metricValue, int numRows, long runTimeMilliseconds, double trainingMetricValue)
179181
{
180182
MetricValue = metricValue;
183+
TrainingMetricValue = trainingMetricValue;
181184
NumRowsInTraining = numRows;
182185
RunTimeMilliseconds = runTimeMilliseconds;
183186
}
@@ -303,7 +306,7 @@ private void MainLearningLoop(int batchSize, int numOfTrainingRows)
303306
var stopwatch = new Stopwatch();
304307
var probabilityUtils = new Sweeper.Algorithms.SweeperProbabilityUtils(_host);
305308

306-
while (!_terminator.ShouldTerminate(_history))
309+
while (!_terminator.ShouldTerminate(_history))
307310
{
308311
// Get next set of candidates
309312
var currentBatchSize = batchSize;
@@ -341,16 +344,17 @@ private void ProcessPipeline(Sweeper.Algorithms.SweeperProbabilityUtils utils, S
341344

342345
// Run pipeline, and time how long it takes
343346
stopwatch.Restart();
344-
double d = candidate.RunTrainTestExperiment(_trainData.Take(randomizedNumberOfRows),
345-
_testData, Metric, TrainerKind);
347+
candidate.RunTrainTestExperiment(_trainData.Take(randomizedNumberOfRows),
348+
_testData, Metric, TrainerKind, out var testMetricVal, out var trainMetricVal);
346349
stopwatch.Stop();
347350

348351
// Handle key collisions on sorted list
349-
while (_sortedSampledElements.ContainsKey(d))
350-
d += 1e-10;
352+
while (_sortedSampledElements.ContainsKey(testMetricVal))
353+
testMetricVal += 1e-10;
351354

352355
// Save performance score
353-
candidate.PerformanceSummary = new RunSummary(d, randomizedNumberOfRows, stopwatch.ElapsedMilliseconds);
356+
candidate.PerformanceSummary =
357+
new RunSummary(testMetricVal, randomizedNumberOfRows, stopwatch.ElapsedMilliseconds, trainMetricVal);
354358
_sortedSampledElements.Add(candidate.PerformanceSummary.MetricValue, candidate);
355359
_history.Add(candidate);
356360
}

src/Microsoft.ML.PipelineInference/AutoMlUtils.cs

+24-9
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,34 @@ namespace Microsoft.ML.Runtime.PipelineInference
1515
{
1616
public static class AutoMlUtils
1717
{
18-
public static AutoInference.RunSummary ExtractRunSummary(IHostEnvironment env, IDataView data, string metricColumnName)
18+
public static double ExtractValueFromIDV(IHostEnvironment env, IDataView result, string columnName)
1919
{
20-
double metricValue = 0;
21-
int numRows = 0;
22-
var schema = data.Schema;
23-
schema.TryGetColumnIndex(metricColumnName, out var metricCol);
20+
Contracts.CheckValue(env, nameof(env));
21+
env.CheckValue(result, nameof(result));
22+
env.CheckNonEmpty(columnName, nameof(columnName));
2423

25-
using (var cursor = data.GetRowCursor(col => col == metricCol))
24+
double outputValue = 0;
25+
var schema = result.Schema;
26+
if (!schema.TryGetColumnIndex(columnName, out var metricCol))
27+
throw env.ExceptParam(nameof(columnName), $"Schema does not contain column: {columnName}");
28+
29+
using (var cursor = result.GetRowCursor(col => col == metricCol))
2630
{
2731
var getter = cursor.GetGetter<double>(metricCol);
28-
cursor.MoveNext();
29-
getter(ref metricValue);
32+
bool moved = cursor.MoveNext();
33+
env.Check(moved, "Expected an IDataView with a single row. Results dataset has no rows to extract.");
34+
getter(ref outputValue);
35+
env.Check(!cursor.MoveNext(), "Expected an IDataView with a single row. Results dataset has too many rows.");
3036
}
3137

32-
return new AutoInference.RunSummary(metricValue, numRows, 0);
38+
return outputValue;
39+
}
40+
41+
public static AutoInference.RunSummary ExtractRunSummary(IHostEnvironment env, IDataView result, string metricColumnName, IDataView trainResult = null)
42+
{
43+
double testingMetricValue = ExtractValueFromIDV(env, result, metricColumnName);
44+
double trainingMetricValue = trainResult != null ? ExtractValueFromIDV(env, trainResult, metricColumnName) : double.MinValue;
45+
return new AutoInference.RunSummary(testingMetricValue, 0, 0, trainingMetricValue);
3346
}
3447

3548
public static CommonInputs.IEvaluatorInput CloneEvaluatorInstance(CommonInputs.IEvaluatorInput evalInput) =>
@@ -618,5 +631,7 @@ public static Tuple<string, string[]>[] ConvertToSweepArgumentStrings(TlcModule.
618631
}
619632
return results;
620633
}
634+
635+
public static string GenerateOverallTrainingMetricVarName(Guid id) => $"Var_Training_OM_{id:N}";
621636
}
622637
}

src/Microsoft.ML.PipelineInference/Macros/PipelineSweeperMacro.cs

+14-5
Original file line numberDiff line numberDiff line change
@@ -65,18 +65,24 @@ public static Output ExtractSweepResult(IHostEnvironment env, ResultInput input)
6565
var col1 = new KeyValuePair<string, ColumnType>("Graph", TextType.Instance);
6666
var col2 = new KeyValuePair<string, ColumnType>("MetricValue", PrimitiveType.FromKind(DataKind.R8));
6767
var col3 = new KeyValuePair<string, ColumnType>("PipelineId", TextType.Instance);
68+
var col4 = new KeyValuePair<string, ColumnType>("TrainingMetricValue", PrimitiveType.FromKind(DataKind.R8));
69+
var col5 = new KeyValuePair<string, ColumnType>("FirstInput", TextType.Instance);
70+
var col6 = new KeyValuePair<string, ColumnType>("PredictorModel", TextType.Instance);
6871

6972
if (rows.Count == 0)
7073
{
7174
var host = env.Register("ExtractSweepResult");
72-
outputView = new EmptyDataView(host, new SimpleSchema(host, col1, col2, col3));
75+
outputView = new EmptyDataView(host, new SimpleSchema(host, col1, col2, col3, col4, col5, col6));
7376
}
7477
else
7578
{
7679
var builder = new ArrayDataViewBuilder(env);
7780
builder.AddColumn(col1.Key, (PrimitiveType)col1.Value, rows.Select(r => new DvText(r.GraphJson)).ToArray());
7881
builder.AddColumn(col2.Key, (PrimitiveType)col2.Value, rows.Select(r => r.MetricValue).ToArray());
7982
builder.AddColumn(col3.Key, (PrimitiveType)col3.Value, rows.Select(r => new DvText(r.PipelineId)).ToArray());
83+
builder.AddColumn(col4.Key, (PrimitiveType)col4.Value, rows.Select(r => r.TrainingMetricValue).ToArray());
84+
builder.AddColumn(col5.Key, (PrimitiveType)col5.Value, rows.Select(r => new DvText(r.FirstInput)).ToArray());
85+
builder.AddColumn(col6.Key, (PrimitiveType)col6.Value, rows.Select(r => new DvText(r.PredictorModel)).ToArray());
8086
outputView = builder.GetDataView();
8187
}
8288
return new Output { Results = outputView, State = autoMlState };
@@ -132,11 +138,11 @@ public static CommonOutputs.MacroOutput<Output> PipelineSweep(
132138
// Extract performance summaries and assign to previous candidate pipelines.
133139
foreach (var pipeline in autoMlState.BatchCandidates)
134140
{
135-
if (node.Context.TryGetVariable(ExperimentUtils.GenerateOverallMetricVarName(pipeline.UniqueId),
136-
out var v))
141+
if (node.Context.TryGetVariable(ExperimentUtils.GenerateOverallMetricVarName(pipeline.UniqueId), out var v) &&
142+
node.Context.TryGetVariable(AutoMlUtils.GenerateOverallTrainingMetricVarName(pipeline.UniqueId), out var v2))
137143
{
138144
pipeline.PerformanceSummary =
139-
AutoMlUtils.ExtractRunSummary(env, (IDataView)v.Value, autoMlState.Metric.Name);
145+
AutoMlUtils.ExtractRunSummary(env, (IDataView)v.Value, autoMlState.Metric.Name, (IDataView)v2.Value);
140146
autoMlState.AddEvaluated(pipeline);
141147
}
142148
}
@@ -168,14 +174,17 @@ public static CommonOutputs.MacroOutput<Output> PipelineSweep(
168174
{
169175
// Add train test experiments to current graph for candidate pipeline
170176
var subgraph = new Experiment(env);
171-
var trainTestOutput = p.AddAsTrainTest(training, testing, autoMlState.TrainerKind, subgraph);
177+
var trainTestOutput = p.AddAsTrainTest(training, testing, autoMlState.TrainerKind, subgraph, true);
172178

173179
// Change variable name to reference pipeline ID in output map, context and entrypoint output.
174180
var uniqueName = ExperimentUtils.GenerateOverallMetricVarName(p.UniqueId);
181+
var uniqueNameTraining = AutoMlUtils.GenerateOverallTrainingMetricVarName(p.UniqueId);
175182
var sgNode = EntryPointNode.ValidateNodes(env, node.Context,
176183
new JArray(subgraph.GetNodes().Last()), node.Catalog).Last();
177184
sgNode.RenameOutputVariable(trainTestOutput.OverallMetrics.VarName, uniqueName, cascadeChanges: true);
185+
sgNode.RenameOutputVariable(trainTestOutput.TrainingOverallMetrics.VarName, uniqueNameTraining, cascadeChanges: true);
178186
trainTestOutput.OverallMetrics.VarName = uniqueName;
187+
trainTestOutput.TrainingOverallMetrics.VarName = uniqueNameTraining;
179188
expNodes.Add(sgNode);
180189

181190
// Store indicators, to pass to next iteration of macro.

src/Microsoft.ML.PipelineInference/Microsoft.ML.PipelineInference.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
1818
<ProjectReference Include="..\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
1919
<ProjectReference Include="..\Microsoft.ML.Sweeper\Microsoft.ML.Sweeper.csproj" />
20+
<ProjectReference Include="..\Microsoft.ML\Microsoft.ML.csproj" />
2021
</ItemGroup>
2122

2223
</Project>

0 commit comments

Comments
 (0)