-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Added convenience constructors for ScoreTransform and TrainAndScoreTransform. #614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
76bab69
a3a0302
1828860
ac4e76a
6b1616c
aca4a2e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,6 +46,26 @@ public sealed class Arguments : TransformInputBase | |
|
||
internal const string Summary = "Runs a previously trained predictor on the data."; | ||
|
||
/// <summary> | ||
/// Convenience method for creating <see cref="ScoreTransform"/>. | ||
/// The <see cref="ScoreTransform"/> allows for model stacking (i.e. to combine information from multiple predictive models to generate a new model) | ||
/// in the pipeline by using the scores from an already trained model. | ||
/// </summary> | ||
/// <param name="env">Host Environment.</param> | ||
/// <param name="input">Input <see cref="IDataView"/>.</param> | ||
/// <param name="inputModelFile">The model file.</param> | ||
/// <param name="featureColumn">Role name for the features.</param> | ||
public static IDataTransform Create(IHostEnvironment env, IDataView input, string inputModelFile, string featureColumn = DefaultColumnNames.Features) | ||
{ | ||
var args = new Arguments() | ||
{ | ||
FeatureColumn = featureColumn, | ||
InputModelFile = inputModelFile | ||
}; | ||
|
||
return Create(env, args, input); | ||
} | ||
|
||
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
|
@@ -62,9 +82,9 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV | |
} | ||
|
||
string feat = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema, | ||
"featureColumn", args.FeatureColumn, DefaultColumnNames.Features); | ||
nameof(args.FeatureColumn), args.FeatureColumn, DefaultColumnNames.Features); | ||
string group = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema, | ||
"groupColumn", args.GroupColumn, DefaultColumnNames.GroupId); | ||
nameof(args.GroupColumn), args.GroupColumn, DefaultColumnNames.GroupId); | ||
var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn); | ||
|
||
return ScoreUtils.GetScorer(args.Scorer, predictor, input, feat, group, customCols, env, trainSchema); | ||
|
@@ -131,20 +151,62 @@ public sealed class Arguments : ArgumentsBase<SignatureTrainer> | |
|
||
internal const string Summary = "Trains a predictor, or loads it from a file, and runs it on the data."; | ||
|
||
/// <summary> | ||
/// Convenience method for creating <see cref="TrainAndScoreTransform"/>. | ||
/// The <see cref="TrainAndScoreTransform"/> allows for model stacking (i.e. to combine information from multiple predictive models to generate a new model) | ||
/// in the pipeline by training a model first and then using the scores from the trained model. | ||
/// | ||
/// Unlike <see cref="ScoreTransform"/>, the <see cref="TrainAndScoreTransform"/> trains the model on the fly as name indicates. | ||
/// </summary> | ||
/// <param name="env">Host Environment.</param> | ||
/// <param name="input">Input <see cref="IDataView"/>.</param> | ||
/// <param name="trainer">The <see cref="ITrainer"/> object i.e. the learning algorithm that will be used for training the model.</param> | ||
/// <param name="featureColumn">Role name for features.</param> | ||
/// <param name="labelColumn">Role name for label.</param> | ||
public static IDataTransform Create(IHostEnvironment env, | ||
IDataView input, | ||
ITrainer trainer, | ||
string featureColumn = DefaultColumnNames.Features, | ||
string labelColumn = DefaultColumnNames.Label) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
env.CheckValue(input, nameof(input)); | ||
env.CheckValue(trainer, nameof(trainer)); | ||
env.CheckValue(featureColumn, nameof(featureColumn)); | ||
env.CheckValue(labelColumn, nameof(labelColumn)); | ||
|
||
var args = new Arguments() | ||
{ | ||
FeatureColumn = featureColumn, | ||
LabelColumn = labelColumn | ||
}; | ||
|
||
return Create(env, args, trainer, input); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Can you check trainer and input? #Resolved |
||
} | ||
|
||
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
can you check input in this function? #Resolved |
||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
env.CheckValue(args, nameof(args)); | ||
env.CheckValue(input, nameof(input)); | ||
env.CheckUserArg(args.Trainer.IsGood(), nameof(args.Trainer), | ||
"Trainer cannot be null. If your model is already trained, please use ScoreTransform instead."); | ||
env.CheckValue(input, nameof(input)); | ||
|
||
return Create(env, args, args.Trainer.CreateInstance(env), input); | ||
} | ||
|
||
private static IDataTransform Create(IHostEnvironment env, Arguments args, ITrainer trainer, IDataView input) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moving the common code in private create method. #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, a more common pattern is to use AssertValue in private methods, and CheckValue in public methods. The purpose of the Check* is twofold: first, we want to enforce the calling contract, and second, we want to actually EXPLICITLY SPECIFY the calling contract. In reply to: 206362456 [](ancestors = 206362456) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved all the checks in the public methods instead. In reply to: 206367797 [](ancestors = 206367797,206362456) |
||
{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, AssertValue of all the things you CheckValue'd in the public side. This is again in the interest of self-documenting methods #Resolved |
||
Contracts.AssertValue(env, nameof(env)); | ||
env.AssertValue(args, nameof(args)); | ||
env.AssertValue(trainer, nameof(trainer)); | ||
env.AssertValue(input, nameof(input)); | ||
|
||
var host = env.Register("TrainAndScoreTransform"); | ||
|
||
using (var ch = host.Start("Train")) | ||
{ | ||
ch.Trace("Constructing trainer"); | ||
ITrainer trainer = args.Trainer.CreateInstance(host); | ||
var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn); | ||
string feat; | ||
string group; | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the interest of self-documenting functions, add CheckValue (or CheckValueOrNull) on featureColumn and labelColumn #Resolved