From 76bab69d24e2bb65a9c6fc2beeec56491752614e Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed Date: Mon, 30 Jul 2018 17:12:50 -0700 Subject: [PATCH 1/6] Added convenience constructors for ScoreTransform and TrainAndScoreTransform. --- .../Transforms/TrainAndScoreTransform.cs | 51 ++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs index bba2f58256..78c4c28d22 100644 --- a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs @@ -46,6 +46,27 @@ public sealed class Arguments : TransformInputBase internal const string Summary = "Runs a previously trained predictor on the data."; + /// + /// Convenience method for creating . + /// The allows for model stacking (i.e. to combine information from multiple predictive models to generate a new model) + /// in the pipeline by using the scores from an already trained model. + /// + /// Host Environment. + /// Input . + /// The model file. + /// Role name for the features. + /// + public static IDataTransform Create(IHostEnvironment env, IDataView input, string inputModelFile, string featureColumn = DefaultColumnNames.Features) + { + var args = new Arguments() + { + FeatureColumn = featureColumn, + InputModelFile = inputModelFile + }; + + return Create(env, args, input); + } + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { Contracts.CheckValue(env, nameof(env)); @@ -131,7 +152,36 @@ public sealed class Arguments : ArgumentsBase internal const string Summary = "Trains a predictor, or loads it from a file, and runs it on the data."; + /// + /// Convenience method for creating . + /// The allows for model stacking (i.e. to combine information from multiple predictive models to generate a new model) + /// in the pipeline by training a model first and then using the scores from the trained model. + /// + /// Unlike , the trains the model on the fly as name indicates. + /// + /// Host Environment. + /// Input . + /// The object i.e. the learning algorithm that will be used for training the model. + /// Role name for features. + /// Role name for label. + /// + public static IDataTransform Create(IHostEnvironment env, IDataView input, ITrainer trainer, string featureColumn = DefaultColumnNames.Features, string labelColumn = DefaultColumnNames.Label) + { + var args = new Arguments() + { + FeatureColumn = featureColumn, + LabelColumn = labelColumn + }; + + return Create(env, args, trainer, input); + } + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + { + return Create(env, args, args.Trainer.CreateInstance(env), input); + } + + private static IDataTransform Create(IHostEnvironment env, Arguments args, ITrainer trainer, IDataView input) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(args, nameof(args)); @@ -144,7 +194,6 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV using (var ch = host.Start("Train")) { ch.Trace("Constructing trainer"); - ITrainer trainer = args.Trainer.CreateInstance(host); var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn); string feat; string group; From a3a0302d0433ac148b0cbd57dcb555a6af5a0eb0 Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed Date: Mon, 30 Jul 2018 17:47:10 -0700 Subject: [PATCH 2/6] Moving validation/checking of arguments to relevant blocks. --- .../Transforms/TrainAndScoreTransform.cs | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs index 78c4c28d22..a3aec0ab14 100644 --- a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs @@ -165,8 +165,13 @@ public sealed class Arguments : ArgumentsBase /// Role name for features. /// Role name for label. /// - public static IDataTransform Create(IHostEnvironment env, IDataView input, ITrainer trainer, string featureColumn = DefaultColumnNames.Features, string labelColumn = DefaultColumnNames.Label) + public static IDataTransform Create(IHostEnvironment env, + IDataView input, + ITrainer trainer, + string featureColumn = DefaultColumnNames.Features, + string labelColumn = DefaultColumnNames.Label) { + Contracts.CheckValue(env, nameof(env)); var args = new Arguments() { FeatureColumn = featureColumn, @@ -178,16 +183,17 @@ public static IDataTransform Create(IHostEnvironment env, IDataView input, ITrai public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckUserArg(args.Trainer.IsGood(), nameof(args.Trainer), + "Trainer cannot be null. If your model is already trained, please use ScoreTransform instead."); + return Create(env, args, args.Trainer.CreateInstance(env), input); } private static IDataTransform Create(IHostEnvironment env, Arguments args, ITrainer trainer, IDataView input) { - Contracts.CheckValue(env, nameof(env)); - env.CheckValue(args, nameof(args)); env.CheckValue(input, nameof(input)); - env.CheckUserArg(args.Trainer.IsGood(), nameof(args.Trainer), - "Trainer cannot be null. If your model is already trained, please use ScoreTransform instead."); var host = env.Register("TrainAndScoreTransform"); From 182886005802f6bad89513270b9adba4242f806c Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed Date: Mon, 30 Jul 2018 17:56:41 -0700 Subject: [PATCH 3/6] Removed tag from doc section. --- src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs index a3aec0ab14..efa782eef1 100644 --- a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs @@ -55,7 +55,6 @@ public sealed class Arguments : TransformInputBase /// Input . /// The model file. /// Role name for the features. - /// public static IDataTransform Create(IHostEnvironment env, IDataView input, string inputModelFile, string featureColumn = DefaultColumnNames.Features) { var args = new Arguments() @@ -164,7 +163,6 @@ public sealed class Arguments : ArgumentsBase /// The object i.e. the learning algorithm that will be used for training the model. /// Role name for features. /// Role name for label. - /// public static IDataTransform Create(IHostEnvironment env, IDataView input, ITrainer trainer, From ac4e76ae698520ea818754f9f0f831a43e5d9ea7 Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed Date: Mon, 30 Jul 2018 18:34:14 -0700 Subject: [PATCH 4/6] Addressed reviewers' comments. --- .../Transforms/TrainAndScoreTransform.cs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs index efa782eef1..b21e1d5097 100644 --- a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs @@ -82,9 +82,9 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV } string feat = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema, - "featureColumn", args.FeatureColumn, DefaultColumnNames.Features); + nameof(args.FeatureColumn), args.FeatureColumn, DefaultColumnNames.Features); string group = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema, - "groupColumn", args.GroupColumn, DefaultColumnNames.GroupId); + nameof(args.GroupColumn), args.GroupColumn, DefaultColumnNames.GroupId); var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn); return ScoreUtils.GetScorer(args.Scorer, predictor, input, feat, group, customCols, env, trainSchema); @@ -170,6 +170,9 @@ public static IDataTransform Create(IHostEnvironment env, string labelColumn = DefaultColumnNames.Label) { Contracts.CheckValue(env, nameof(env)); + env.CheckValue(input, nameof(input)); + env.CheckValue(trainer, nameof(trainer)); + var args = new Arguments() { FeatureColumn = featureColumn, @@ -185,14 +188,13 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV env.CheckValue(args, nameof(args)); env.CheckUserArg(args.Trainer.IsGood(), nameof(args.Trainer), "Trainer cannot be null. If your model is already trained, please use ScoreTransform instead."); + env.CheckValue(input, nameof(input)); return Create(env, args, args.Trainer.CreateInstance(env), input); } private static IDataTransform Create(IHostEnvironment env, Arguments args, ITrainer trainer, IDataView input) { - env.CheckValue(input, nameof(input)); - var host = env.Register("TrainAndScoreTransform"); using (var ch = host.Start("Train")) From 6b1616cabd03c701149212ce933fd18552694ae4 Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed Date: Tue, 31 Jul 2018 12:01:33 -0700 Subject: [PATCH 5/6] Addressed reviewers' comments. --- src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs index b21e1d5097..2abe600665 100644 --- a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs @@ -172,6 +172,8 @@ public static IDataTransform Create(IHostEnvironment env, Contracts.CheckValue(env, nameof(env)); env.CheckValue(input, nameof(input)); env.CheckValue(trainer, nameof(trainer)); + env.CheckValue(featureColumn, nameof(featureColumn)); + env.CheckValue(labelColumn, nameof(labelColumn)); var args = new Arguments() { @@ -195,6 +197,11 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV private static IDataTransform Create(IHostEnvironment env, Arguments args, ITrainer trainer, IDataView input) { + Contracts.AssertValue(env, nameof(env)); + env.AssertValue(args, nameof(args)); + env.AssertValue(trainer, nameof(trainer)); + env.AssertValue(input, nameof(input)); + var host = env.Register("TrainAndScoreTransform"); using (var ch = host.Start("Train")) From aca4a2edd5f7ca4338f377e48451f2c4eb8f0df7 Mon Sep 17 00:00:00 2001 From: Zeeshan Ahmed Date: Tue, 31 Jul 2018 17:01:29 -0700 Subject: [PATCH 6/6] Addressed reviewers' comments. --- .../Transforms/TrainAndScoreTransform.cs | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs index 2abe600665..458e83aaed 100644 --- a/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs @@ -55,11 +55,17 @@ public sealed class Arguments : TransformInputBase /// Input . /// The model file. /// Role name for the features. - public static IDataTransform Create(IHostEnvironment env, IDataView input, string inputModelFile, string featureColumn = DefaultColumnNames.Features) + /// Role name for the group column. + public static IDataTransform Create(IHostEnvironment env, + IDataView input, + string inputModelFile, + string featureColumn = DefaultColumnNames.Features, + string groupColumn = DefaultColumnNames.GroupId) { var args = new Arguments() { FeatureColumn = featureColumn, + GroupColumn = groupColumn, InputModelFile = inputModelFile }; @@ -163,22 +169,26 @@ public sealed class Arguments : ArgumentsBase /// The object i.e. the learning algorithm that will be used for training the model. /// Role name for features. /// Role name for label. + /// Role name for the group column. public static IDataTransform Create(IHostEnvironment env, IDataView input, ITrainer trainer, string featureColumn = DefaultColumnNames.Features, - string labelColumn = DefaultColumnNames.Label) + string labelColumn = DefaultColumnNames.Label, + string groupColumn = DefaultColumnNames.GroupId) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(input, nameof(input)); env.CheckValue(trainer, nameof(trainer)); env.CheckValue(featureColumn, nameof(featureColumn)); env.CheckValue(labelColumn, nameof(labelColumn)); + env.CheckValue(groupColumn, nameof(groupColumn)); var args = new Arguments() { FeatureColumn = featureColumn, - LabelColumn = labelColumn + LabelColumn = labelColumn, + GroupColumn = groupColumn }; return Create(env, args, trainer, input);