@@ -46,6 +46,32 @@ public sealed class Arguments : TransformInputBase
46
46
47
47
internal const string Summary = "Runs a previously trained predictor on the data." ;
48
48
49
+ /// <summary>
50
+ /// Convenience method for creating <see cref="ScoreTransform"/>.
51
+ /// The <see cref="ScoreTransform"/> allows for model stacking (i.e. to combine information from multiple predictive models to generate a new model)
52
+ /// in the pipeline by using the scores from an already trained model.
53
+ /// </summary>
54
+ /// <param name="env">Host Environment.</param>
55
+ /// <param name="input">Input <see cref="IDataView"/>.</param>
56
+ /// <param name="inputModelFile">The model file.</param>
57
+ /// <param name="featureColumn">Role name for the features.</param>
58
+ /// <param name="groupColumn">Role name for the group column.</param>
59
+ public static IDataTransform Create ( IHostEnvironment env ,
60
+ IDataView input ,
61
+ string inputModelFile ,
62
+ string featureColumn = DefaultColumnNames . Features ,
63
+ string groupColumn = DefaultColumnNames . GroupId )
64
+ {
65
+ var args = new Arguments ( )
66
+ {
67
+ FeatureColumn = featureColumn ,
68
+ GroupColumn = groupColumn ,
69
+ InputModelFile = inputModelFile
70
+ } ;
71
+
72
+ return Create ( env , args , input ) ;
73
+ }
74
+
49
75
public static IDataTransform Create ( IHostEnvironment env , Arguments args , IDataView input )
50
76
{
51
77
Contracts . CheckValue ( env , nameof ( env ) ) ;
@@ -62,9 +88,9 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
62
88
}
63
89
64
90
string feat = TrainUtils . MatchNameOrDefaultOrNull ( env , input . Schema ,
65
- "featureColumn" , args . FeatureColumn , DefaultColumnNames . Features ) ;
91
+ nameof ( args . FeatureColumn ) , args . FeatureColumn , DefaultColumnNames . Features ) ;
66
92
string group = TrainUtils . MatchNameOrDefaultOrNull ( env , input . Schema ,
67
- "groupColumn" , args . GroupColumn , DefaultColumnNames . GroupId ) ;
93
+ nameof ( args . GroupColumn ) , args . GroupColumn , DefaultColumnNames . GroupId ) ;
68
94
var customCols = TrainUtils . CheckAndGenerateCustomColumns ( env , args . CustomColumn ) ;
69
95
70
96
return ScoreUtils . GetScorer ( args . Scorer , predictor , input , feat , group , customCols , env , trainSchema ) ;
@@ -131,20 +157,66 @@ public sealed class Arguments : ArgumentsBase<SignatureTrainer>
131
157
132
158
internal const string Summary = "Trains a predictor, or loads it from a file, and runs it on the data." ;
133
159
160
+ /// <summary>
161
+ /// Convenience method for creating <see cref="TrainAndScoreTransform"/>.
162
+ /// The <see cref="TrainAndScoreTransform"/> allows for model stacking (i.e. to combine information from multiple predictive models to generate a new model)
163
+ /// in the pipeline by training a model first and then using the scores from the trained model.
164
+ ///
165
+ /// Unlike <see cref="ScoreTransform"/>, the <see cref="TrainAndScoreTransform"/> trains the model on the fly as name indicates.
166
+ /// </summary>
167
+ /// <param name="env">Host Environment.</param>
168
+ /// <param name="input">Input <see cref="IDataView"/>.</param>
169
+ /// <param name="trainer">The <see cref="ITrainer"/> object i.e. the learning algorithm that will be used for training the model.</param>
170
+ /// <param name="featureColumn">Role name for features.</param>
171
+ /// <param name="labelColumn">Role name for label.</param>
172
+ /// <param name="groupColumn">Role name for the group column.</param>
173
+ public static IDataTransform Create ( IHostEnvironment env ,
174
+ IDataView input ,
175
+ ITrainer trainer ,
176
+ string featureColumn = DefaultColumnNames . Features ,
177
+ string labelColumn = DefaultColumnNames . Label ,
178
+ string groupColumn = DefaultColumnNames . GroupId )
179
+ {
180
+ Contracts . CheckValue ( env , nameof ( env ) ) ;
181
+ env . CheckValue ( input , nameof ( input ) ) ;
182
+ env . CheckValue ( trainer , nameof ( trainer ) ) ;
183
+ env . CheckValue ( featureColumn , nameof ( featureColumn ) ) ;
184
+ env . CheckValue ( labelColumn , nameof ( labelColumn ) ) ;
185
+ env . CheckValue ( groupColumn , nameof ( groupColumn ) ) ;
186
+
187
+ var args = new Arguments ( )
188
+ {
189
+ FeatureColumn = featureColumn ,
190
+ LabelColumn = labelColumn ,
191
+ GroupColumn = groupColumn
192
+ } ;
193
+
194
+ return Create ( env , args , trainer , input ) ;
195
+ }
196
+
134
197
public static IDataTransform Create ( IHostEnvironment env , Arguments args , IDataView input )
135
198
{
136
199
Contracts . CheckValue ( env , nameof ( env ) ) ;
137
200
env . CheckValue ( args , nameof ( args ) ) ;
138
- env . CheckValue ( input , nameof ( input ) ) ;
139
201
env . CheckUserArg ( args . Trainer . IsGood ( ) , nameof ( args . Trainer ) ,
140
202
"Trainer cannot be null. If your model is already trained, please use ScoreTransform instead." ) ;
203
+ env . CheckValue ( input , nameof ( input ) ) ;
204
+
205
+ return Create ( env , args , args . Trainer . CreateInstance ( env ) , input ) ;
206
+ }
207
+
208
+ private static IDataTransform Create ( IHostEnvironment env , Arguments args , ITrainer trainer , IDataView input )
209
+ {
210
+ Contracts . AssertValue ( env , nameof ( env ) ) ;
211
+ env . AssertValue ( args , nameof ( args ) ) ;
212
+ env . AssertValue ( trainer , nameof ( trainer ) ) ;
213
+ env . AssertValue ( input , nameof ( input ) ) ;
141
214
142
215
var host = env . Register ( "TrainAndScoreTransform" ) ;
143
216
144
217
using ( var ch = host . Start ( "Train" ) )
145
218
{
146
219
ch . Trace ( "Constructing trainer" ) ;
147
- ITrainer trainer = args . Trainer . CreateInstance ( host ) ;
148
220
var customCols = TrainUtils . CheckAndGenerateCustomColumns ( env , args . CustomColumn ) ;
149
221
string feat ;
150
222
string group ;
0 commit comments