Skip to content

Added convenience constructors for ScoreTransform and TrainAndScoreTransform. #614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 1, 2018
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 66 additions & 4 deletions src/Microsoft.ML.Data/Transforms/TrainAndScoreTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,26 @@ public sealed class Arguments : TransformInputBase

internal const string Summary = "Runs a previously trained predictor on the data.";

/// <summary>
/// Convenience method for creating <see cref="ScoreTransform"/>.
/// The <see cref="ScoreTransform"/> allows for model stacking (i.e. to combine information from multiple predictive models to generate a new model)
/// in the pipeline by using the scores from an already trained model.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>.</param>
/// <param name="inputModelFile">The model file.</param>
/// <param name="featureColumn">Role name for the features.</param>
public static IDataTransform Create(IHostEnvironment env, IDataView input, string inputModelFile, string featureColumn = DefaultColumnNames.Features)
{
var args = new Arguments()
{
FeatureColumn = featureColumn,
InputModelFile = inputModelFile
};

return Create(env, args, input);
}

public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
Expand All @@ -62,9 +82,9 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
}

string feat = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema,
"featureColumn", args.FeatureColumn, DefaultColumnNames.Features);
nameof(args.FeatureColumn), args.FeatureColumn, DefaultColumnNames.Features);
string group = TrainUtils.MatchNameOrDefaultOrNull(env, input.Schema,
"groupColumn", args.GroupColumn, DefaultColumnNames.GroupId);
nameof(args.GroupColumn), args.GroupColumn, DefaultColumnNames.GroupId);
var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn);

return ScoreUtils.GetScorer(args.Scorer, predictor, input, feat, group, customCols, env, trainSchema);
Expand Down Expand Up @@ -131,20 +151,62 @@ public sealed class Arguments : ArgumentsBase<SignatureTrainer>

internal const string Summary = "Trains a predictor, or loads it from a file, and runs it on the data.";

/// <summary>
/// Convenience method for creating <see cref="TrainAndScoreTransform"/>.
/// The <see cref="TrainAndScoreTransform"/> allows for model stacking (i.e. to combine information from multiple predictive models to generate a new model)
/// in the pipeline by training a model first and then using the scores from the trained model.
///
/// Unlike <see cref="ScoreTransform"/>, the <see cref="TrainAndScoreTransform"/> trains the model on the fly as name indicates.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>.</param>
/// <param name="trainer">The <see cref="ITrainer"/> object i.e. the learning algorithm that will be used for training the model.</param>
/// <param name="featureColumn">Role name for features.</param>
/// <param name="labelColumn">Role name for label.</param>
public static IDataTransform Create(IHostEnvironment env,
IDataView input,
ITrainer trainer,
string featureColumn = DefaultColumnNames.Features,
string labelColumn = DefaultColumnNames.Label)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(input, nameof(input));
env.CheckValue(trainer, nameof(trainer));
Copy link
Contributor

@Zruty0 Zruty0 Jul 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

env.CheckValue(trainer, nameof(trainer)); [](start = 12, length = 41)

in the interest of self-documenting functions, add CheckValue (or CheckValueOrNull) on featureColumn and labelColumn #Resolved

env.CheckValue(featureColumn, nameof(featureColumn));
env.CheckValue(labelColumn, nameof(labelColumn));

var args = new Arguments()
{
FeatureColumn = featureColumn,
LabelColumn = labelColumn
};

return Create(env, args, trainer, input);
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Jul 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trainer, input [](start = 37, length = 14)

Can you check trainer and input? #Resolved

}

public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Jul 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input [](start = 92, length = 5)

can you check input in this function? #Resolved

{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(args, nameof(args));
env.CheckValue(input, nameof(input));
env.CheckUserArg(args.Trainer.IsGood(), nameof(args.Trainer),
"Trainer cannot be null. If your model is already trained, please use ScoreTransform instead.");
env.CheckValue(input, nameof(input));

return Create(env, args, args.Trainer.CreateInstance(env), input);
}

private static IDataTransform Create(IHostEnvironment env, Arguments args, ITrainer trainer, IDataView input)
Copy link
Contributor Author

@zeahmed zeahmed Jul 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving the common code in private create method. #Closed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, a more common pattern is to use AssertValue in private methods, and CheckValue in public methods.

The purpose of the Check* is twofold: first, we want to enforce the calling contract, and second, we want to actually EXPLICITLY SPECIFY the calling contract.
So that you can look at any argument and immediately, within 3 lines of code, tell whether it can be null or not etc.


In reply to: 206362456 [](ancestors = 206362456)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved all the checks in the public methods instead.


In reply to: 206367797 [](ancestors = 206367797,206362456)

{
Copy link
Contributor

@Zruty0 Zruty0 Jul 31, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, AssertValue of all the things you CheckValue'd in the public side.

This is again in the interest of self-documenting methods #Resolved

Contracts.AssertValue(env, nameof(env));
env.AssertValue(args, nameof(args));
env.AssertValue(trainer, nameof(trainer));
env.AssertValue(input, nameof(input));

var host = env.Register("TrainAndScoreTransform");

using (var ch = host.Start("Train"))
{
ch.Trace("Constructing trainer");
ITrainer trainer = args.Trainer.CreateInstance(host);
var customCols = TrainUtils.CheckAndGenerateCustomColumns(env, args.CustomColumn);
string feat;
string group;
Expand Down