Skip to content

Move Learner* input base and Transform* input base out of Entrypoints… #2748

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Feb 28, 2019
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ public static void Example()
FieldAwareFactorizationMachine(
new FieldAwareFactorizationMachineBinaryClassificationTrainer.Options
{
FeatureColumn = "Features",
LabelColumn = "Sentiment",
FeatureColumnName = "Features",
LabelColumnName = "Sentiment",
LearningRate = 0.1f,
NumberOfIterations = 10
}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ public static void Example()
var advancedPipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features")
.Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent(
new SdcaBinaryTrainer.Options {
LabelColumn = "Sentiment",
FeatureColumn = "Features",
LabelColumnName = "Sentiment",
FeatureColumnName = "Features",
ConvergenceTolerance = 0.01f, // The learning rate for adjusting bias from being regularized
NumThreads = 2, // Degree of lock-free parallelism
}));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ public static void Example()
var pipeline = mlContext.Transforms.Conversion.MapValueToKey("LabelIndex", "Label")
.Append(mlContext.MulticlassClassification.Trainers.LightGbm(new Options
{
LabelColumn = "LabelIndex",
FeatureColumn = "Features",
LabelColumnName = "LabelIndex",
FeatureColumnName = "Features",
Booster = new DartBooster.Options
{
DropRate = 0.15,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public static void Example()
var pipeline = mlContext.Transforms.Concatenate("Features", featureNames)
.Append(mlContext.Regression.Trainers.LightGbm(new Options
{
LabelColumn = labelName,
LabelColumnName = labelName,
NumLeaves = 4,
MinDataPerLeaf = 6,
LearningRate = 0.001,
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.Data/EntryPoints/Cache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Transforms;

[assembly: LoadableClass(typeof(void), typeof(Cache), null, typeof(SignatureEntryPointModule), "Cache")]
namespace Microsoft.ML.EntryPoints
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,9 @@ private EntryPointNode(IHostEnvironment env, IChannel ch, RunContext context,
throw _host.Except($"The following required inputs were not provided: {String.Join(", ", missing)}");

var inputInstance = _inputBuilder.GetInstance();
SetColumnArgument(ch, inputInstance, "LabelColumn", label, "label", typeof(CommonInputs.ITrainerInputWithLabel));
SetColumnArgument(ch, inputInstance, "GroupIdColumn", group, "group Id", typeof(CommonInputs.ITrainerInputWithGroupId));
SetColumnArgument(ch, inputInstance, "WeightColumn", weight, "weight", typeof(CommonInputs.ITrainerInputWithWeight), typeof(CommonInputs.IUnsupervisedTrainerWithWeight));
SetColumnArgument(ch, inputInstance, "LabelColumnName", label, "label", typeof(CommonInputs.ITrainerInputWithLabel));
SetColumnArgument(ch, inputInstance, "RowGroupColumnName", group, "group Id", typeof(CommonInputs.ITrainerInputWithGroupId));
SetColumnArgument(ch, inputInstance, "ExampleWeightColumnName", weight, "weight", typeof(CommonInputs.ITrainerInputWithWeight), typeof(CommonInputs.IUnsupervisedTrainerWithWeight));
SetColumnArgument(ch, inputInstance, "NameColumn", name, "name");

// Validate outputs.
Expand Down
113 changes: 6 additions & 107 deletions src/Microsoft.ML.Data/EntryPoints/InputBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,10 @@
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.Trainers;

namespace Microsoft.ML.EntryPoints
{
/// <summary>
/// The base class for all transform inputs.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.ITransformInput))]
public abstract class TransformInputBase
{
/// <summary>
/// The input dataset. Used only in entry-point methods, since the normal API mechanism for feeding in a dataset to
/// create an <see cref="ITransformer"/> is to use the <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> method.
/// </summary>
[BestFriend]
[Argument(ArgumentType.Required, HelpText = "Input dataset", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, SortOrder = 1)]
internal IDataView Data;
}

[BestFriend]
internal enum CachingOptions
{
Expand All @@ -36,89 +22,12 @@ internal enum CachingOptions
None
}

/// <summary>
/// The base class for all learner inputs.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInput))]
public abstract class LearnerInputBase
{
/// <summary>
/// The data to be used for training. Used only in entry-points, since in the API the expected mechanism is
/// that the user will use the <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> or some other train
/// method.
/// </summary>
[BestFriend]
[Argument(ArgumentType.Required, ShortName = "data", HelpText = "The data to be used for training", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
internal IDataView TrainingData;

/// <summary>
/// Column to use for features.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for features", ShortName = "feat", SortOrder = 2, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string FeatureColumn = DefaultColumnNames.Features;

/// <summary>
/// Normalize option for the feature column. Used only in entry-points, since in the API the user is expected to do this themselves.
/// </summary>
[BestFriend]
[Argument(ArgumentType.AtMostOnce, HelpText = "Normalize option for the feature column", ShortName = "norm", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
internal NormalizeOption NormalizeFeatures = NormalizeOption.Auto;

/// <summary>
/// Whether learner should cache input training data. Used only in entry-points, since the intended API mechanism
/// is that the user will use the <see cref="DataOperationsCatalog.Cache(IDataView, string[])"/> or other method
/// like <see cref="EstimatorChain{TLastTransformer}.AppendCacheCheckpoint(IHostEnvironment)"/>.
/// </summary>
[BestFriend]
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether learner should cache input training data", ShortName = "cache", SortOrder = 6, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
internal CachingOptions Caching = CachingOptions.Auto;
}

/// <summary>
/// The base class for all learner inputs that support a Label column.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithLabel))]
public abstract class LearnerInputBaseWithLabel : LearnerInputBase
{
/// <summary>
/// Column to use for labels.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 3, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string LabelColumn = DefaultColumnNames.Label;
}

// REVIEW: This is a known antipattern, but the solution involves the decorator pattern which can't be used in this case.
/// <summary>
/// The base class for all learner inputs that support a weight column.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithWeight))]
public abstract class LearnerInputBaseWithWeight : LearnerInputBaseWithLabel
{
/// <summary>
/// The name of the example weight column.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string WeightColumn = null;
}

/// <summary>
/// The base class for all unsupervised learner inputs that support a weight column.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.IUnsupervisedTrainerWithWeight))]
public abstract class UnsupervisedLearnerInputBaseWithWeight : LearnerInputBase
{
/// <summary>
/// Column to use for example weight.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string WeightColumn = null;
}

/// <summary>
/// The base class for all evaluators inputs.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.IEvaluatorInput))]
public abstract class EvaluateInputBase
[BestFriend]
internal abstract class EvaluateInputBase
{
[Argument(ArgumentType.Required, ShortName = "data", HelpText = "The data to be used for evaluation.", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public IDataView Data;
Expand All @@ -127,18 +36,8 @@ public abstract class EvaluateInputBase
public string NameColumn = DefaultColumnNames.Name;
}

[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithGroupId))]
public abstract class LearnerInputBaseWithGroupId : LearnerInputBaseWithWeight
{
/// <summary>
/// Column to use for example groupId.
/// </summary>
[Argument(ArgumentType.AtMostOnce, Name = "GroupIdColumn", HelpText = "Column to use for example groupId", ShortName = "groupId", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string GroupIdColumn = null;
}

[BestFriend]
internal static class LearnerEntryPointsUtils
internal static class TrainerEntryPointsUtils
{
public static string FindColumn(IExceptionContext ectx, DataViewSchema schema, Optional<string> value)
{
Expand Down Expand Up @@ -166,13 +65,13 @@ public static TOut Train<TArg, TOut>(IHost host, TArg input,
Func<IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>>> getCustom = null,
ICalibratorTrainerFactory calibrator = null,
int maxCalibrationExamples = 0)
where TArg : LearnerInputBase
where TArg : TrainerInputBase
where TOut : CommonOutputs.TrainerOutput, new()
{
using (var ch = host.Start("Training"))
{
var schema = input.TrainingData.Schema;
var feature = FindColumn(ch, schema, input.FeatureColumn);
var feature = FindColumn(ch, schema, input.FeatureColumnName);
var label = getLabel?.Invoke();
var weight = getWeight?.Invoke();
var group = getGroup?.Invoke();
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
using Microsoft.ML.Model;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Transforms;
using Newtonsoft.Json.Linq;

[assembly: LoadableClass(PlattCalibratorTrainer.Summary, typeof(PlattCalibratorTrainer), null, typeof(SignatureCalibrator),
Expand Down
113 changes: 113 additions & 0 deletions src/Microsoft.ML.Data/Training/TrainerInputBase.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using Microsoft.Data.DataView;
using Microsoft.ML.Calibrators;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.EntryPoints;

namespace Microsoft.ML.Trainers
{
/// <summary>
/// The base class for all trainer inputs.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInput))]
public abstract class TrainerInputBase
Copy link
Contributor

@TomFinley TomFinley Feb 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TrainerInputBase [](start = 26, length = 16)

This class should have a [BestFriend] private protected constructor. We do not want people forming base classes from it willy nilly. Similar with all the other abstract classes defined here. #Resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could potentially be part of another PR. But, it's relatively easy to do if you have to change the PR for some other reason anyway.


In reply to: 260837670 [](ancestors = 260837670)

{
private protected TrainerInputBase() { }

/// <summary>
/// The data to be used for training. Used only in entry-points, since in the API the expected mechanism is
/// that the user will use the <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> or some other train
/// method.
/// </summary>
[BestFriend]
[Argument(ArgumentType.Required, ShortName = "data", HelpText = "The data to be used for training", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
internal IDataView TrainingData;

/// <summary>
/// Column to use for features.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for features", ShortName = "feat", SortOrder = 2, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string FeatureColumnName = DefaultColumnNames.Features;

/// <summary>
/// Normalize option for the feature column. Used only in entry-points, since in the API the user is expected to do this themselves.
/// </summary>
[BestFriend]
[Argument(ArgumentType.AtMostOnce, HelpText = "Normalize option for the feature column", ShortName = "norm", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
internal NormalizeOption NormalizeFeatures = NormalizeOption.Auto;

/// <summary>
/// Whether learner should cache input training data. Used only in entry-points, since the intended API mechanism
/// is that the user will use the <see cref="DataOperationsCatalog.Cache(IDataView, string[])"/> or other method
/// like <see cref="EstimatorChain{TLastTransformer}.AppendCacheCheckpoint(IHostEnvironment)"/>.
/// </summary>
[BestFriend]
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether learner should cache input training data", ShortName = "cache", SortOrder = 6, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
internal CachingOptions Caching = CachingOptions.Auto;
}

/// <summary>
/// The base class for all learner inputs that support a Label column.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithLabel))]
public abstract class TrainerInputBaseWithLabel : TrainerInputBase
{
private protected TrainerInputBaseWithLabel() { }

/// <summary>
/// Column to use for labels.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 3, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string LabelColumnName = DefaultColumnNames.Label;
}

// REVIEW: This is a known antipattern, but the solution involves the decorator pattern which can't be used in this case.
/// <summary>
/// The base class for all learner inputs that support a weight column.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithWeight))]
public abstract class TrainerInputBaseWithWeight : TrainerInputBaseWithLabel
{
private protected TrainerInputBaseWithWeight() { }

/// <summary>
/// Column to use for example weight.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string ExampleWeightColumnName = null;
}

/// <summary>
/// The base class for all unsupervised learner inputs that support a weight column.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.IUnsupervisedTrainerWithWeight))]
public abstract class UnsupervisedTrainerInputBaseWithWeight : TrainerInputBase
{
private protected UnsupervisedTrainerInputBaseWithWeight() { }

/// <summary>
/// Column to use for example weight.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string ExampleWeightColumnName = null;
}

[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithGroupId))]
public abstract class TrainerInputBaseWithGroupId : TrainerInputBaseWithWeight
{
private protected TrainerInputBaseWithGroupId() { }

/// <summary>
/// Column to use for example groupId.
/// </summary>
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example groupId", ShortName = "groupId", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public string RowGroupColumnName = null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model;
using Microsoft.ML.Transforms;

[assembly: LoadableClass(FeatureContributionCalculatingTransformer.Summary, typeof(FeatureContributionCalculatingTransformer), null, typeof(SignatureLoadModel),
FeatureContributionCalculatingTransformer.FriendlyName, FeatureContributionCalculatingTransformer.LoaderSignature)]
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.Data/Transforms/NopTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Model;
using Microsoft.ML.Transforms;

[assembly: LoadableClass(NopTransform.Summary, typeof(NopTransform), null, typeof(SignatureLoadDataTransform),
"", NopTransform.LoaderSignature)]
Expand Down
27 changes: 27 additions & 0 deletions src/Microsoft.ML.Data/Transforms/TransformInputBase.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.Data.DataView;
using Microsoft.ML.CommandLine;
using Microsoft.ML.EntryPoints;

namespace Microsoft.ML.Transforms
{
/// <summary>
/// The base class for all transform inputs.
/// </summary>
[TlcModule.EntryPointKind(typeof(CommonInputs.ITransformInput))]
public abstract class TransformInputBase
Copy link
Contributor

@TomFinley TomFinley Feb 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TransformInputBase [](start = 26, length = 18)

This class should have a [BestFriend] private protected constructor. We do not want people forming base classes willy nilly. #Resolved

{
private protected TransformInputBase() { }

/// <summary>
/// The input dataset. Used only in entry-point methods, since the normal API mechanism for feeding in a dataset to
/// create an <see cref="ITransformer"/> is to use the <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> method.
/// </summary>
[BestFriend]
[Argument(ArgumentType.Required, HelpText = "Input dataset", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, SortOrder = 1)]
internal IDataView Data;
}
}
Loading