Skip to content

Commit 76bab69

Browse files
committed
Added convenience constructors for ScoreTransform and TrainAndScoreTransform.
1 parent 2107b82 commit 76bab69

File tree

1 file changed

+50
-1
lines changed

1 file changed

+50
-1
lines changed

src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs

+50-1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,27 @@ public sealed class Arguments : TransformInputBase
4646

4747
internal const string Summary = "Runs a previously trained predictor on the data.";
4848

49+
/// <summary>
50+
/// Convenience method for creating <see cref="ScoreTransform"/>.
51+
/// The <see cref="ScoreTransform"/> allows for model stacking (i.e. to combine information from multiple predictive models to generate a new model)
52+
/// in the pipeline by using the scores from an already trained model.
53+
/// </summary>
54+
/// <param name="env">Host Environment.</param>
55+
/// <param name="input">Input <see cref="IDataView"/>.</param>
56+
/// <param name="inputModelFile">The model file.</param>
57+
/// <param name="featureColumn">Role name for the features.</param>
58+
/// <returns></returns>
59+
public static IDataTransform Create(IHostEnvironment env, IDataView input, string inputModelFile, string featureColumn = DefaultColumnNames.Features)
60+
{
61+
var args = new Arguments()
62+
{
63+
FeatureColumn = featureColumn,
64+
InputModelFile = inputModelFile
65+
};
66+
67+
return Create(env, args, input);
68+
}
69+
4970
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
5071
{
5172
Contracts.CheckValue(env, nameof(env));
@@ -131,7 +152,36 @@ public sealed class Arguments : ArgumentsBase<SignatureTrainer>
131152

132153
internal const string Summary = "Trains a predictor, or loads it from a file, and runs it on the data.";
133154

155+
/// <summary>
156+
/// Convenience method for creating <see cref="TrainAndScoreTransform"/>.
157+
/// The <see cref="TrainAndScoreTransform"/> allows for model stacking (i.e. to combine information from multiple predictive models to generate a new model)
158+
/// in the pipeline by training a model first and then using the scores from the trained model.
159+
///
160+
/// Unlike <see cref="ScoreTransform"/>, the <see cref="TrainAndScoreTransform"/> trains the model on the fly as name indicates.
161+
/// </summary>
162+
/// <param name="env">Host Environment.</param>
163+
/// <param name="input">Input <see cref="IDataView"/>.</param>
164+
/// <param name="trainer">The <see cref="ITrainer"/> object i.e. the learning algorithm that will be used for training the model.</param>
165+
/// <param name="featureColumn">Role name for features.</param>
166+
/// <param name="labelColumn">Role name for label.</param>
167+
/// <returns></returns>
168+
public static IDataTransform Create(IHostEnvironment env, IDataView input, ITrainer trainer, string featureColumn = DefaultColumnNames.Features, string labelColumn = DefaultColumnNames.Label)
169+
{
170+
var args = new Arguments()
171+
{
172+
FeatureColumn = featureColumn,
173+
LabelColumn = labelColumn
174+
};
175+
176+
return Create(env, args, trainer, input);
177+
}
178+
134179
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
180+
{
181+
return Create(env, args, args.Trainer.CreateInstance(env), input);
182+
}
183+
184+
private static IDataTransform Create(IHostEnvironment env, Arguments args, ITrainer trainer, IDataView input)
135185
{
136186
Contracts.CheckValue(env, nameof(env));
137187
env.CheckValue(args, nameof(args));
@@ -144,7 +194,6 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
144194
using (var ch = host.Start("Train"))
145195
{
146196
ch.Trace("Constructing trainer");
147-
ITrainer trainer = args.Trainer.CreateInstance(host);
148197
var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn);
149198
string feat;
150199
string group;

0 commit comments

Comments
 (0)