diff --git a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs index bba2f58256..458e83aaed 100644 --- a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs @@ -46,6 +46,32 @@ public sealed class Arguments : TransformInputBase internal const string Summary = "Runs a previously trained predictor on the data."; + /// + /// Convenience method for creating . + /// The allows for model stacking (i.e. to combine information from multiple predictive models to generate a new model) + /// in the pipeline by using the scores from an already trained model. + /// + /// Host Environment. + /// Input . + /// The model file. + /// Role name for the features. + /// Role name for the group column. + public static IDataTransform Create(IHostEnvironment env, + IDataView input, + string inputModelFile, + string featureColumn = DefaultColumnNames.Features, + string groupColumn = DefaultColumnNames.GroupId) + { + var args = new Arguments() + { + FeatureColumn = featureColumn, + GroupColumn = groupColumn, + InputModelFile = inputModelFile + }; + + return Create(env, args, input); + } + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); @@ -62,9 +88,9 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV } string feat = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema, - "featureColumn", args.FeatureColumn, DefaultColumnNames.Features); + nameof(args.FeatureColumn), args.FeatureColumn, DefaultColumnNames.Features); string group = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema, - "groupColumn", args.GroupColumn, DefaultColumnNames.GroupId); + nameof(args.GroupColumn), args.GroupColumn, DefaultColumnNames.GroupId); var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn); return ScoreUtils.GetScorer(args.Scorer, predictor, input, feat, group, customCols, env, trainSchema); @@ -131,20 +157,66 @@ public sealed class Arguments : ArgumentsBase internal const string Summary = "Trains a predictor, or loads it from a file, and runs it on the data."; + /// + /// Convenience method for creating . + /// The allows for model stacking (i.e. to combine information from multiple predictive models to generate a new model) + /// in the pipeline by training a model first and then using the scores from the trained model. + /// + /// Unlike , the trains the model on the fly as name indicates. + /// + /// Host Environment. + /// Input . + /// The object i.e. the learning algorithm that will be used for training the model. + /// Role name for features. + /// Role name for label. + /// Role name for the group column. + public static IDataTransform Create(IHostEnvironment env, + IDataView input, + ITrainer trainer, + string featureColumn = DefaultColumnNames.Features, + string labelColumn = DefaultColumnNames.Label, + string groupColumn = DefaultColumnNames.GroupId) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(input, nameof(input)); + env.CheckValue(trainer, nameof(trainer)); + env.CheckValue(featureColumn, nameof(featureColumn)); + env.CheckValue(labelColumn, nameof(labelColumn)); + env.CheckValue(groupColumn, nameof(groupColumn)); + + var args = new Arguments() + { + FeatureColumn = featureColumn, + LabelColumn = labelColumn, + GroupColumn = groupColumn + }; + + return Create(env, args, trainer, input); + } + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(args, nameof(args)); - env.CheckValue(input, nameof(input)); env.CheckUserArg(args.Trainer.IsGood(), nameof(args.Trainer), "Trainer cannot be null. If your model is already trained, please use ScoreTransform instead."); + env.CheckValue(input, nameof(input)); + + return Create(env, args, args.Trainer.CreateInstance(env), input); + } + + private static IDataTransform Create(IHostEnvironment env, Arguments args, ITrainer trainer, IDataView input) + { + Contracts.AssertValue(env, nameof(env)); + env.AssertValue(args, nameof(args)); + env.AssertValue(trainer, nameof(trainer)); + env.AssertValue(input, nameof(input)); var host = env.Register("TrainAndScoreTransform"); using (var ch = host.Start("Train")) { ch.Trace("Constructing trainer"); - ITrainer trainer = args.Trainer.CreateInstance(host); var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn); string feat; string group;