@@ -21,16 +21,24 @@ public sealed class PipelineResultRow
21
21
{
22
22
public string GraphJson { get ; }
23
23
public double MetricValue { get ; }
24
+ public double TrainingMetricValue { get ; }
24
25
public string PipelineId { get ; }
26
+ public string FirstInput { get ; }
27
+ public string PredictorModel { get ; }
25
28
26
29
public PipelineResultRow ( )
27
30
{ }
28
31
29
- public PipelineResultRow ( string graphJson , double metricValue , string pipelineId )
32
+ public PipelineResultRow ( string graphJson , double metricValue ,
33
+ string pipelineId , double trainingMetricValue , string firstInput ,
34
+ string predictorModel )
30
35
{
31
36
GraphJson = graphJson ;
32
37
MetricValue = metricValue ;
33
38
PipelineId = pipelineId ;
39
+ TrainingMetricValue = trainingMetricValue ;
40
+ FirstInput = firstInput ;
41
+ PredictorModel = predictorModel ;
34
42
}
35
43
}
36
44
@@ -111,7 +119,8 @@ public AutoInference.EntryPointGraphDef ToEntryPointGraph(Experiment experiment
111
119
public bool Equals ( PipelinePattern obj ) => obj != null && UniqueId == obj . UniqueId ;
112
120
113
121
// REVIEW: We may want to allow for sweeping with CV in the future, so we will need to add new methods like this, or refactor these in that case.
114
- public Experiment CreateTrainTestExperiment ( IDataView trainData , IDataView testData , MacroUtils . TrainerKinds trainerKind , out Models . TrainTestEvaluator . Output resultsOutput )
122
+ public Experiment CreateTrainTestExperiment ( IDataView trainData , IDataView testData , MacroUtils . TrainerKinds trainerKind ,
123
+ bool includeTrainingMetrics , out Models . TrainTestEvaluator . Output resultsOutput )
115
124
{
116
125
var graphDef = ToEntryPointGraph ( ) ;
117
126
var subGraph = graphDef . Graph ;
@@ -136,7 +145,8 @@ public Experiment CreateTrainTestExperiment(IDataView trainData, IDataView testD
136
145
Model = finalOutput
137
146
} ,
138
147
PipelineId = UniqueId . ToString ( "N" ) ,
139
- Kind = MacroUtils . TrainerKindApiValue < Models . MacroUtilsTrainerKinds > ( trainerKind )
148
+ Kind = MacroUtils . TrainerKindApiValue < Models . MacroUtilsTrainerKinds > ( trainerKind ) ,
149
+ IncludeTrainingMetrics = includeTrainingMetrics
140
150
} ;
141
151
142
152
var experiment = _env . CreateExperiment ( ) ;
@@ -150,7 +160,7 @@ public Experiment CreateTrainTestExperiment(IDataView trainData, IDataView testD
150
160
}
151
161
152
162
public Models . TrainTestEvaluator . Output AddAsTrainTest ( Var < IDataView > trainData , Var < IDataView > testData ,
153
- MacroUtils . TrainerKinds trainerKind , Experiment experiment = null )
163
+ MacroUtils . TrainerKinds trainerKind , Experiment experiment = null , bool includeTrainingMetrics = false )
154
164
{
155
165
experiment = experiment ?? _env . CreateExperiment ( ) ;
156
166
var graphDef = ToEntryPointGraph ( experiment ) ;
@@ -174,7 +184,8 @@ public Models.TrainTestEvaluator.Output AddAsTrainTest(Var<IDataView> trainData,
174
184
TrainingData = trainData ,
175
185
TestingData = testData ,
176
186
Kind = MacroUtils . TrainerKindApiValue < Models . MacroUtilsTrainerKinds > ( trainerKind ) ,
177
- PipelineId = UniqueId . ToString ( "N" )
187
+ PipelineId = UniqueId . ToString ( "N" ) ,
188
+ IncludeTrainingMetrics = includeTrainingMetrics
178
189
} ;
179
190
var trainTestOutput = experiment . Add ( trainTestInput ) ;
180
191
return trainTestOutput ;
@@ -183,34 +194,58 @@ public Models.TrainTestEvaluator.Output AddAsTrainTest(Var<IDataView> trainData,
183
194
/// <summary>
184
195
/// Runs a train-test experiment on the current pipeline, through entrypoints.
185
196
/// </summary>
186
- public double RunTrainTestExperiment ( IDataView trainData , IDataView testData , AutoInference . SupportedMetric metric , MacroUtils . TrainerKinds trainerKind )
197
+ public void RunTrainTestExperiment ( IDataView trainData , IDataView testData ,
198
+ AutoInference . SupportedMetric metric , MacroUtils . TrainerKinds trainerKind , out double testMetricValue ,
199
+ out double trainMetricValue )
187
200
{
188
- var experiment = CreateTrainTestExperiment ( trainData , testData , trainerKind , out var trainTestOutput ) ;
201
+ var experiment = CreateTrainTestExperiment ( trainData , testData , trainerKind , true , out var trainTestOutput ) ;
189
202
experiment . Run ( ) ;
203
+
190
204
var dataOut = experiment . GetOutput ( trainTestOutput . OverallMetrics ) ;
191
205
var schema = dataOut . Schema ;
192
206
schema . TryGetColumnIndex ( metric . Name , out var metricCol ) ;
207
+ double metricValue = 0 ;
208
+ double trainingMetricValue = 0 ;
193
209
194
210
using ( var cursor = dataOut . GetRowCursor ( col => col == metricCol ) )
195
211
{
196
212
var getter = cursor . GetGetter < double > ( metricCol ) ;
197
- double metricValue = 0 ;
198
213
cursor . MoveNext ( ) ;
199
214
getter ( ref metricValue ) ;
200
- return metricValue ;
215
+ }
216
+
217
+ dataOut = experiment . GetOutput ( trainTestOutput . TrainingOverallMetrics ) ;
218
+ schema = dataOut . Schema ;
219
+ schema . TryGetColumnIndex ( metric . Name , out metricCol ) ;
220
+
221
+ using ( var cursor = dataOut . GetRowCursor ( col => col == metricCol ) )
222
+ {
223
+ var getter = cursor . GetGetter < double > ( metricCol ) ;
224
+ cursor . MoveNext ( ) ;
225
+ getter ( ref trainingMetricValue ) ;
226
+ testMetricValue = metricValue ;
227
+ trainMetricValue = trainingMetricValue ;
201
228
}
202
229
}
203
230
204
- public static PipelineResultRow [ ] ExtractResults ( IHostEnvironment env , IDataView data , string graphColName , string metricColName , string idColName )
231
+ public static PipelineResultRow [ ] ExtractResults ( IHostEnvironment env , IDataView data ,
232
+ string graphColName , string metricColName , string idColName , string trainingMetricColName ,
233
+ string firstInputColName , string predictorModelColName )
205
234
{
206
235
var results = new List < PipelineResultRow > ( ) ;
207
236
var schema = data . Schema ;
208
237
if ( ! schema . TryGetColumnIndex ( graphColName , out var graphCol ) )
209
238
throw env . ExceptNotSupp ( $ "Column name { graphColName } not found") ;
210
239
if ( ! schema . TryGetColumnIndex ( metricColName , out var metricCol ) )
211
240
throw env . ExceptNotSupp ( $ "Column name { metricColName } not found") ;
241
+ if ( ! schema . TryGetColumnIndex ( trainingMetricColName , out var trainingMetricCol ) )
242
+ throw env . ExceptNotSupp ( $ "Column name { trainingMetricColName } not found") ;
212
243
if ( ! schema . TryGetColumnIndex ( idColName , out var pipelineIdCol ) )
213
244
throw env . ExceptNotSupp ( $ "Column name { idColName } not found") ;
245
+ if ( ! schema . TryGetColumnIndex ( firstInputColName , out var firstInputCol ) )
246
+ throw env . ExceptNotSupp ( $ "Column name { firstInputColName } not found") ;
247
+ if ( ! schema . TryGetColumnIndex ( predictorModelColName , out var predictorModelCol ) )
248
+ throw env . ExceptNotSupp ( $ "Column name { predictorModelColName } not found") ;
214
249
215
250
using ( var cursor = data . GetRowCursor ( col => true ) )
216
251
{
@@ -225,15 +260,33 @@ public static PipelineResultRow[] ExtractResults(IHostEnvironment env, IDataView
225
260
var getter3 = cursor . GetGetter < DvText > ( pipelineIdCol ) ;
226
261
DvText pipelineId = new DvText ( ) ;
227
262
getter3 ( ref pipelineId ) ;
228
- results . Add ( new PipelineResultRow ( graphJson . ToString ( ) , metricValue , pipelineId . ToString ( ) ) ) ;
263
+ var getter4 = cursor . GetGetter < double > ( trainingMetricCol ) ;
264
+ double trainingMetricValue = 0 ;
265
+ getter4 ( ref trainingMetricValue ) ;
266
+ var getter5 = cursor . GetGetter < DvText > ( firstInputCol ) ;
267
+ DvText firstInput = new DvText ( ) ;
268
+ getter5 ( ref firstInput ) ;
269
+ var getter6 = cursor . GetGetter < DvText > ( predictorModelCol ) ;
270
+ DvText predictorModel = new DvText ( ) ;
271
+ getter6 ( ref predictorModel ) ;
272
+
273
+ results . Add ( new PipelineResultRow ( graphJson . ToString ( ) ,
274
+ metricValue , pipelineId . ToString ( ) , trainingMetricValue ,
275
+ firstInput . ToString ( ) , predictorModel . ToString ( ) ) ) ;
229
276
}
230
277
}
231
278
232
279
return results . ToArray ( ) ;
233
280
}
234
281
235
- public PipelineResultRow ToResultRow ( ) =>
236
- new PipelineResultRow ( ToEntryPointGraph ( ) . Graph . ToJsonString ( ) ,
237
- PerformanceSummary ? . MetricValue ?? - 1d , UniqueId . ToString ( "N" ) ) ;
282
+ public PipelineResultRow ToResultRow ( ) {
283
+ var graphDef = ToEntryPointGraph ( ) ;
284
+
285
+ return new PipelineResultRow ( $ "{{'Nodes' : [{ graphDef . Graph . ToJsonString ( ) } ]}}",
286
+ PerformanceSummary ? . MetricValue ?? - 1d , UniqueId . ToString ( "N" ) ,
287
+ PerformanceSummary ? . TrainingMetricValue ?? - 1d ,
288
+ graphDef . GetSubgraphFirstNodeDataVarName ( _env ) ,
289
+ graphDef . ModelOutput . VarName ) ;
290
+ }
238
291
}
239
292
}
0 commit comments