diff --git a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs
index bba2f58256..458e83aaed 100644
--- a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs
@@ -46,6 +46,32 @@ public sealed class Arguments : TransformInputBase
internal const string Summary = "Runs a previously trained predictor on the data.";
+ ///
+ /// Convenience method for creating .
+ /// The allows for model stacking (i.e. to combine information from multiple predictive models to generate a new model)
+ /// in the pipeline by using the scores from an already trained model.
+ ///
+ /// Host Environment.
+ /// Input .
+ /// The model file.
+ /// Role name for the features.
+ /// Role name for the group column.
+ public static IDataTransform Create(IHostEnvironment env,
+ IDataView input,
+ string inputModelFile,
+ string featureColumn = DefaultColumnNames.Features,
+ string groupColumn = DefaultColumnNames.GroupId)
+ {
+ var args = new Arguments()
+ {
+ FeatureColumn = featureColumn,
+ GroupColumn = groupColumn,
+ InputModelFile = inputModelFile
+ };
+
+ return Create(env, args, input);
+ }
+
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
@@ -62,9 +88,9 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
}
string feat = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema,
- "featureColumn", args.FeatureColumn, DefaultColumnNames.Features);
+ nameof(args.FeatureColumn), args.FeatureColumn, DefaultColumnNames.Features);
string group = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema,
- "groupColumn", args.GroupColumn, DefaultColumnNames.GroupId);
+ nameof(args.GroupColumn), args.GroupColumn, DefaultColumnNames.GroupId);
var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn);
return ScoreUtils.GetScorer(args.Scorer, predictor, input, feat, group, customCols, env, trainSchema);
@@ -131,20 +157,66 @@ public sealed class Arguments : ArgumentsBase
internal const string Summary = "Trains a predictor, or loads it from a file, and runs it on the data.";
+ ///
+ /// Convenience method for creating .
+ /// The allows for model stacking (i.e. to combine information from multiple predictive models to generate a new model)
+ /// in the pipeline by training a model first and then using the scores from the trained model.
+ ///
+ /// Unlike , the trains the model on the fly as name indicates.
+ ///
+ /// Host Environment.
+ /// Input .
+ /// The object i.e. the learning algorithm that will be used for training the model.
+ /// Role name for features.
+ /// Role name for label.
+ /// Role name for the group column.
+ public static IDataTransform Create(IHostEnvironment env,
+ IDataView input,
+ ITrainer trainer,
+ string featureColumn = DefaultColumnNames.Features,
+ string labelColumn = DefaultColumnNames.Label,
+ string groupColumn = DefaultColumnNames.GroupId)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(input, nameof(input));
+ env.CheckValue(trainer, nameof(trainer));
+ env.CheckValue(featureColumn, nameof(featureColumn));
+ env.CheckValue(labelColumn, nameof(labelColumn));
+ env.CheckValue(groupColumn, nameof(groupColumn));
+
+ var args = new Arguments()
+ {
+ FeatureColumn = featureColumn,
+ LabelColumn = labelColumn,
+ GroupColumn = groupColumn
+ };
+
+ return Create(env, args, trainer, input);
+ }
+
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(args, nameof(args));
- env.CheckValue(input, nameof(input));
env.CheckUserArg(args.Trainer.IsGood(), nameof(args.Trainer),
"Trainer cannot be null. If your model is already trained, please use ScoreTransform instead.");
+ env.CheckValue(input, nameof(input));
+
+ return Create(env, args, args.Trainer.CreateInstance(env), input);
+ }
+
+ private static IDataTransform Create(IHostEnvironment env, Arguments args, ITrainer trainer, IDataView input)
+ {
+ Contracts.AssertValue(env, nameof(env));
+ env.AssertValue(args, nameof(args));
+ env.AssertValue(trainer, nameof(trainer));
+ env.AssertValue(input, nameof(input));
var host = env.Register("TrainAndScoreTransform");
using (var ch = host.Start("Train"))
{
ch.Trace("Constructing trainer");
- ITrainer trainer = args.Trainer.CreateInstance(host);
var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn);
string feat;
string group;