Skip to content

Commit 4bd866e

Browse files
zeahmedshauheen
authored andcommitted
Added convenience constructors for ScoreTransform and TrainAndScoreTransform. (#614)
1 parent 697ea3d commit 4bd866e

File tree

1 file changed

+76
-4
lines changed

1 file changed

+76
-4
lines changed

src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs

+76-4
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,32 @@ public sealed class Arguments : TransformInputBase
4646

4747
internal const string Summary = "Runs a previously trained predictor on the data.";
4848

49+
/// <summary>
50+
/// Convenience method for creating <see cref="ScoreTransform"/>.
51+
/// The <see cref="ScoreTransform"/> allows for model stacking (i.e. to combine information from multiple predictive models to generate a new model)
52+
/// in the pipeline by using the scores from an already trained model.
53+
/// </summary>
54+
/// <param name="env">Host Environment.</param>
55+
/// <param name="input">Input <see cref="IDataView"/>.</param>
56+
/// <param name="inputModelFile">The model file.</param>
57+
/// <param name="featureColumn">Role name for the features.</param>
58+
/// <param name="groupColumn">Role name for the group column.</param>
59+
public static IDataTransform Create(IHostEnvironment env,
60+
IDataView input,
61+
string inputModelFile,
62+
string featureColumn = DefaultColumnNames.Features,
63+
string groupColumn = DefaultColumnNames.GroupId)
64+
{
65+
var args = new Arguments()
66+
{
67+
FeatureColumn = featureColumn,
68+
GroupColumn = groupColumn,
69+
InputModelFile = inputModelFile
70+
};
71+
72+
return Create(env, args, input);
73+
}
74+
4975
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
5076
{
5177
Contracts.CheckValue(env, nameof(env));
@@ -62,9 +88,9 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
6288
}
6389

6490
string feat = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema,
65-
"featureColumn", args.FeatureColumn, DefaultColumnNames.Features);
91+
nameof(args.FeatureColumn), args.FeatureColumn, DefaultColumnNames.Features);
6692
string group = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema,
67-
"groupColumn", args.GroupColumn, DefaultColumnNames.GroupId);
93+
nameof(args.GroupColumn), args.GroupColumn, DefaultColumnNames.GroupId);
6894
var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn);
6995

7096
return ScoreUtils.GetScorer(args.Scorer, predictor, input, feat, group, customCols, env, trainSchema);
@@ -131,20 +157,66 @@ public sealed class Arguments : ArgumentsBase<SignatureTrainer>
131157

132158
internal const string Summary = "Trains a predictor, or loads it from a file, and runs it on the data.";
133159

160+
/// <summary>
161+
/// Convenience method for creating <see cref="TrainAndScoreTransform"/>.
162+
/// The <see cref="TrainAndScoreTransform"/> allows for model stacking (i.e. to combine information from multiple predictive models to generate a new model)
163+
/// in the pipeline by training a model first and then using the scores from the trained model.
164+
///
165+
/// Unlike <see cref="ScoreTransform"/>, the <see cref="TrainAndScoreTransform"/> trains the model on the fly as name indicates.
166+
/// </summary>
167+
/// <param name="env">Host Environment.</param>
168+
/// <param name="input">Input <see cref="IDataView"/>.</param>
169+
/// <param name="trainer">The <see cref="ITrainer"/> object i.e. the learning algorithm that will be used for training the model.</param>
170+
/// <param name="featureColumn">Role name for features.</param>
171+
/// <param name="labelColumn">Role name for label.</param>
172+
/// <param name="groupColumn">Role name for the group column.</param>
173+
public static IDataTransform Create(IHostEnvironment env,
174+
IDataView input,
175+
ITrainer trainer,
176+
string featureColumn = DefaultColumnNames.Features,
177+
string labelColumn = DefaultColumnNames.Label,
178+
string groupColumn = DefaultColumnNames.GroupId)
179+
{
180+
Contracts.CheckValue(env, nameof(env));
181+
env.CheckValue(input, nameof(input));
182+
env.CheckValue(trainer, nameof(trainer));
183+
env.CheckValue(featureColumn, nameof(featureColumn));
184+
env.CheckValue(labelColumn, nameof(labelColumn));
185+
env.CheckValue(groupColumn, nameof(groupColumn));
186+
187+
var args = new Arguments()
188+
{
189+
FeatureColumn = featureColumn,
190+
LabelColumn = labelColumn,
191+
GroupColumn = groupColumn
192+
};
193+
194+
return Create(env, args, trainer, input);
195+
}
196+
134197
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
135198
{
136199
Contracts.CheckValue(env, nameof(env));
137200
env.CheckValue(args, nameof(args));
138-
env.CheckValue(input, nameof(input));
139201
env.CheckUserArg(args.Trainer.IsGood(), nameof(args.Trainer),
140202
"Trainer cannot be null. If your model is already trained, please use ScoreTransform instead.");
203+
env.CheckValue(input, nameof(input));
204+
205+
return Create(env, args, args.Trainer.CreateInstance(env), input);
206+
}
207+
208+
private static IDataTransform Create(IHostEnvironment env, Arguments args, ITrainer trainer, IDataView input)
209+
{
210+
Contracts.AssertValue(env, nameof(env));
211+
env.AssertValue(args, nameof(args));
212+
env.AssertValue(trainer, nameof(trainer));
213+
env.AssertValue(input, nameof(input));
141214

142215
var host = env.Register("TrainAndScoreTransform");
143216

144217
using (var ch = host.Start("Train"))
145218
{
146219
ch.Trace("Constructing trainer");
147-
ITrainer trainer = args.Trainer.CreateInstance(host);
148220
var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn);
149221
string feat;
150222
string group;

0 commit comments

Comments
 (0)