diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FieldAwareFactorizationMachinewWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FieldAwareFactorizationMachinewWithOptions.cs
index f4ec502d1f..cae0684039 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FieldAwareFactorizationMachinewWithOptions.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FieldAwareFactorizationMachinewWithOptions.cs
@@ -32,8 +32,8 @@ public static void Example()
FieldAwareFactorizationMachine(
new FieldAwareFactorizationMachineBinaryClassificationTrainer.Options
{
- FeatureColumn = "Features",
- LabelColumn = "Sentiment",
+ FeatureColumnName = "Features",
+ LabelColumnName = "Sentiment",
LearningRate = 0.1f,
NumberOfIterations = 10
}));
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/SDCALogisticRegression.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/SDCALogisticRegression.cs
index f9b21cc9f9..e5a01cd4f7 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/SDCALogisticRegression.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/SDCALogisticRegression.cs
@@ -62,8 +62,8 @@ public static void Example()
var advancedPipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features")
.Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent(
new SdcaBinaryTrainer.Options {
- LabelColumn = "Sentiment",
- FeatureColumn = "Features",
+ LabelColumnName = "Sentiment",
+ FeatureColumnName = "Features",
ConvergenceTolerance = 0.01f, // The learning rate for adjusting bias from being regularized
NumThreads = 2, // Degree of lock-free parallelism
}));
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbmWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbmWithOptions.cs
index 79368c9ebf..a149aa9199 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbmWithOptions.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbmWithOptions.cs
@@ -35,8 +35,8 @@ public static void Example()
var pipeline = mlContext.Transforms.Conversion.MapValueToKey("LabelIndex", "Label")
.Append(mlContext.MulticlassClassification.Trainers.LightGbm(new Options
{
- LabelColumn = "LabelIndex",
- FeatureColumn = "Features",
+ LabelColumnName = "LabelIndex",
+ FeatureColumnName = "Features",
Booster = new DartBooster.Options
{
DropRate = 0.15,
diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/LightGbmWithOptions.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/LightGbmWithOptions.cs
index c1c82a9735..e93eeb3f96 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/LightGbmWithOptions.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/LightGbmWithOptions.cs
@@ -38,7 +38,7 @@ public static void Example()
var pipeline = mlContext.Transforms.Concatenate("Features", featureNames)
.Append(mlContext.Regression.Trainers.LightGbm(new Options
{
- LabelColumn = labelName,
+ LabelColumnName = labelName,
NumLeaves = 4,
MinDataPerLeaf = 6,
LearningRate = 0.001,
diff --git a/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs b/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs
index f294bb3452..fac0d77753 100644
--- a/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs
@@ -509,9 +509,9 @@ private EntryPointNode(IHostEnvironment env, IChannel ch, RunContext context,
throw _host.Except($"The following required inputs were not provided: {String.Join(", ", missing)}");
var inputInstance = _inputBuilder.GetInstance();
- SetColumnArgument(ch, inputInstance, "LabelColumn", label, "label", typeof(CommonInputs.ITrainerInputWithLabel));
- SetColumnArgument(ch, inputInstance, "GroupIdColumn", group, "group Id", typeof(CommonInputs.ITrainerInputWithGroupId));
- SetColumnArgument(ch, inputInstance, "WeightColumn", weight, "weight", typeof(CommonInputs.ITrainerInputWithWeight), typeof(CommonInputs.IUnsupervisedTrainerWithWeight));
+ SetColumnArgument(ch, inputInstance, "LabelColumnName", label, "label", typeof(CommonInputs.ITrainerInputWithLabel));
+ SetColumnArgument(ch, inputInstance, "RowGroupColumnName", group, "group Id", typeof(CommonInputs.ITrainerInputWithGroupId));
+ SetColumnArgument(ch, inputInstance, "ExampleWeightColumnName", weight, "weight", typeof(CommonInputs.ITrainerInputWithWeight), typeof(CommonInputs.IUnsupervisedTrainerWithWeight));
SetColumnArgument(ch, inputInstance, "NameColumn", name, "name");
// Validate outputs.
diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs
index 36393f9aca..f28caf6be9 100644
--- a/src/Microsoft.ML.Data/EntryPoints/InputBase.cs
+++ b/src/Microsoft.ML.Data/EntryPoints/InputBase.cs
@@ -9,24 +9,10 @@
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
+using Microsoft.ML.Trainers;
namespace Microsoft.ML.EntryPoints
{
- ///
- /// The base class for all transform inputs.
- ///
- [TlcModule.EntryPointKind(typeof(CommonInputs.ITransformInput))]
- public abstract class TransformInputBase
- {
- ///
- /// The input dataset. Used only in entry-point methods, since the normal API mechanism for feeding in a dataset to
- /// create an is to use the method.
- ///
- [BestFriend]
- [Argument(ArgumentType.Required, HelpText = "Input dataset", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, SortOrder = 1)]
- internal IDataView Data;
- }
-
[BestFriend]
internal enum CachingOptions
{
@@ -35,89 +21,12 @@ internal enum CachingOptions
None
}
- ///
- /// The base class for all learner inputs.
- ///
- [TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInput))]
- public abstract class LearnerInputBase
- {
- ///
- /// The data to be used for training. Used only in entry-points, since in the API the expected mechanism is
- /// that the user will use the or some other train
- /// method.
- ///
- [BestFriend]
- [Argument(ArgumentType.Required, ShortName = "data", HelpText = "The data to be used for training", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
- internal IDataView TrainingData;
-
- ///
- /// Column to use for features.
- ///
- [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for features", ShortName = "feat", SortOrder = 2, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
- public string FeatureColumn = DefaultColumnNames.Features;
-
- ///
- /// Normalize option for the feature column. Used only in entry-points, since in the API the user is expected to do this themselves.
- ///
- [BestFriend]
- [Argument(ArgumentType.AtMostOnce, HelpText = "Normalize option for the feature column", ShortName = "norm", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
- internal NormalizeOption NormalizeFeatures = NormalizeOption.Auto;
-
- ///
- /// Whether learner should cache input training data. Used only in entry-points, since the intended API mechanism
- /// is that the user will use the or other method
- /// like .
- ///
- [BestFriend]
- [Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether learner should cache input training data", ShortName = "cache", SortOrder = 6, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
- internal CachingOptions Caching = CachingOptions.Auto;
- }
-
- ///
- /// The base class for all learner inputs that support a Label column.
- ///
- [TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithLabel))]
- public abstract class LearnerInputBaseWithLabel : LearnerInputBase
- {
- ///
- /// Column to use for labels.
- ///
- [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 3, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
- public string LabelColumn = DefaultColumnNames.Label;
- }
-
- // REVIEW: This is a known antipattern, but the solution involves the decorator pattern which can't be used in this case.
- ///
- /// The base class for all learner inputs that support a weight column.
- ///
- [TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithWeight))]
- public abstract class LearnerInputBaseWithWeight : LearnerInputBaseWithLabel
- {
- ///
- /// The name of the example weight column.
- ///
- [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
- public string WeightColumn = null;
- }
-
- ///
- /// The base class for all unsupervised learner inputs that support a weight column.
- ///
- [TlcModule.EntryPointKind(typeof(CommonInputs.IUnsupervisedTrainerWithWeight))]
- public abstract class UnsupervisedLearnerInputBaseWithWeight : LearnerInputBase
- {
- ///
- /// Column to use for example weight.
- ///
- [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
- public string WeightColumn = null;
- }
-
///
/// The base class for all evaluators inputs.
///
[TlcModule.EntryPointKind(typeof(CommonInputs.IEvaluatorInput))]
- public abstract class EvaluateInputBase
+ [BestFriend]
+ internal abstract class EvaluateInputBase
{
[Argument(ArgumentType.Required, ShortName = "data", HelpText = "The data to be used for evaluation.", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
public IDataView Data;
@@ -126,18 +35,8 @@ public abstract class EvaluateInputBase
public string NameColumn = DefaultColumnNames.Name;
}
- [TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithGroupId))]
- public abstract class LearnerInputBaseWithGroupId : LearnerInputBaseWithWeight
- {
- ///
- /// Column to use for example groupId.
- ///
- [Argument(ArgumentType.AtMostOnce, Name = "GroupIdColumn", HelpText = "Column to use for example groupId", ShortName = "groupId", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
- public string GroupIdColumn = null;
- }
-
[BestFriend]
- internal static class LearnerEntryPointsUtils
+ internal static class TrainerEntryPointsUtils
{
public static string FindColumn(IExceptionContext ectx, DataViewSchema schema, Optional value)
{
@@ -165,13 +64,13 @@ public static TOut Train(IHost host, TArg input,
Func>> getCustom = null,
ICalibratorTrainerFactory calibrator = null,
int maxCalibrationExamples = 0)
- where TArg : LearnerInputBase
+ where TArg : TrainerInputBase
where TOut : CommonOutputs.TrainerOutput, new()
{
using (var ch = host.Start("Training"))
{
var schema = input.TrainingData.Schema;
- var feature = FindColumn(ch, schema, input.FeatureColumn);
+ var feature = FindColumn(ch, schema, input.FeatureColumnName);
var label = getLabel?.Invoke();
var weight = getWeight?.Invoke();
var group = getGroup?.Invoke();
diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
index 028b38531a..46e8063df0 100644
--- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs
+++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
@@ -19,6 +19,7 @@
using Microsoft.ML.Model;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Model.Pfa;
+using Microsoft.ML.Transforms;
using Newtonsoft.Json.Linq;
[assembly: LoadableClass(PlattCalibratorTrainer.Summary, typeof(PlattCalibratorTrainer), null, typeof(SignatureCalibrator),
diff --git a/src/Microsoft.ML.Data/Training/TrainerInputBase.cs b/src/Microsoft.ML.Data/Training/TrainerInputBase.cs
new file mode 100644
index 0000000000..1387c27220
--- /dev/null
+++ b/src/Microsoft.ML.Data/Training/TrainerInputBase.cs
@@ -0,0 +1,113 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using Microsoft.Data.DataView;
+using Microsoft.ML.Calibrators;
+using Microsoft.ML.CommandLine;
+using Microsoft.ML.Data;
+using Microsoft.ML.Data.IO;
+using Microsoft.ML.EntryPoints;
+
+namespace Microsoft.ML.Trainers
+{
+ ///
+ /// The base class for all trainer inputs.
+ ///
+ [TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInput))]
+ public abstract class TrainerInputBase
+ {
+ private protected TrainerInputBase() { }
+
+ ///
+ /// The data to be used for training. Used only in entry-points, since in the API the expected mechanism is
+ /// that the user will use the or some other train
+ /// method.
+ ///
+ [BestFriend]
+ [Argument(ArgumentType.Required, ShortName = "data", HelpText = "The data to be used for training", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
+ internal IDataView TrainingData;
+
+ ///
+ /// Column to use for features.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for features", ShortName = "feat", SortOrder = 2, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
+ public string FeatureColumnName = DefaultColumnNames.Features;
+
+ ///
+ /// Normalize option for the feature column. Used only in entry-points, since in the API the user is expected to do this themselves.
+ ///
+ [BestFriend]
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Normalize option for the feature column", ShortName = "norm", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
+ internal NormalizeOption NormalizeFeatures = NormalizeOption.Auto;
+
+ ///
+ /// Whether learner should cache input training data. Used only in entry-points, since the intended API mechanism
+ /// is that the user will use the or other method
+ /// like .
+ ///
+ [BestFriend]
+ [Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether learner should cache input training data", ShortName = "cache", SortOrder = 6, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
+ internal CachingOptions Caching = CachingOptions.Auto;
+ }
+
+ ///
+ /// The base class for all learner inputs that support a Label column.
+ ///
+ [TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithLabel))]
+ public abstract class TrainerInputBaseWithLabel : TrainerInputBase
+ {
+ private protected TrainerInputBaseWithLabel() { }
+
+ ///
+ /// Column to use for labels.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 3, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
+ public string LabelColumnName = DefaultColumnNames.Label;
+ }
+
+ // REVIEW: This is a known antipattern, but the solution involves the decorator pattern which can't be used in this case.
+ ///
+ /// The base class for all learner inputs that support a weight column.
+ ///
+ [TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithWeight))]
+ public abstract class TrainerInputBaseWithWeight : TrainerInputBaseWithLabel
+ {
+ private protected TrainerInputBaseWithWeight() { }
+
+ ///
+ /// Column to use for example weight.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
+ public string ExampleWeightColumnName = null;
+ }
+
+ ///
+ /// The base class for all unsupervised learner inputs that support a weight column.
+ ///
+ [TlcModule.EntryPointKind(typeof(CommonInputs.IUnsupervisedTrainerWithWeight))]
+ public abstract class UnsupervisedTrainerInputBaseWithWeight : TrainerInputBase
+ {
+ private protected UnsupervisedTrainerInputBaseWithWeight() { }
+
+ ///
+ /// Column to use for example weight.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
+ public string ExampleWeightColumnName = null;
+ }
+
+ [TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithGroupId))]
+ public abstract class TrainerInputBaseWithGroupId : TrainerInputBaseWithWeight
+ {
+ private protected TrainerInputBaseWithGroupId() { }
+
+ ///
+ /// Column to use for example groupId.
+ ///
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example groupId", ShortName = "groupId", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
+ public string RowGroupColumnName = null;
+ }
+}
diff --git a/src/Microsoft.ML.Data/Transforms/NopTransform.cs b/src/Microsoft.ML.Data/Transforms/NopTransform.cs
index 526a827b4b..43068d9e21 100644
--- a/src/Microsoft.ML.Data/Transforms/NopTransform.cs
+++ b/src/Microsoft.ML.Data/Transforms/NopTransform.cs
@@ -9,6 +9,7 @@
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Model;
+using Microsoft.ML.Transforms;
[assembly: LoadableClass(NopTransform.Summary, typeof(NopTransform), null, typeof(SignatureLoadDataTransform),
"", NopTransform.LoaderSignature)]
diff --git a/src/Microsoft.ML.Data/Transforms/TransformInputBase.cs b/src/Microsoft.ML.Data/Transforms/TransformInputBase.cs
new file mode 100644
index 0000000000..e123379824
--- /dev/null
+++ b/src/Microsoft.ML.Data/Transforms/TransformInputBase.cs
@@ -0,0 +1,27 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.Data.DataView;
+using Microsoft.ML.CommandLine;
+using Microsoft.ML.EntryPoints;
+
+namespace Microsoft.ML.Transforms
+{
+ ///
+ /// The base class for all transform inputs.
+ ///
+ [TlcModule.EntryPointKind(typeof(CommonInputs.ITransformInput))]
+ public abstract class TransformInputBase
+ {
+ private protected TransformInputBase() { }
+
+ ///
+ /// The input dataset. Used only in entry-point methods, since the normal API mechanism for feeding in a dataset to
+ /// create an is to use the method.
+ ///
+ [BestFriend]
+ [Argument(ArgumentType.Required, HelpText = "Input dataset", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, SortOrder = 1)]
+ internal IDataView Data;
+ }
+}
diff --git a/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs b/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs
index afec3af471..fab5ecfe2c 100644
--- a/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs
+++ b/src/Microsoft.ML.Ensemble/EntryPoints/Ensemble.cs
@@ -20,9 +20,9 @@ public static CommonOutputs.BinaryClassificationOutput CreateBinaryEnsemble(IHos
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
- return LearnerEntryPointsUtils.Train(host, input,
+ return TrainerEntryPointsUtils.Train(host, input,
() => new EnsembleTrainer(host, input),
- () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
+ () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName));
}
[TlcModule.EntryPoint(Name = "Trainers.EnsembleClassification", Desc = "Train multiclass ensemble.", UserName = EnsembleTrainer.UserNameValue)]
@@ -33,9 +33,9 @@ public static CommonOutputs.MulticlassClassificationOutput CreateMultiClassEnsem
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
- return LearnerEntryPointsUtils.Train(host, input,
+ return TrainerEntryPointsUtils.Train(host, input,
() => new MulticlassDataPartitionEnsembleTrainer(host, input),
- () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
+ () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName));
}
[TlcModule.EntryPoint(Name = "Trainers.EnsembleRegression", Desc = "Train regression ensemble.", UserName = EnsembleTrainer.UserNameValue)]
@@ -46,9 +46,9 @@ public static CommonOutputs.RegressionOutput CreateRegressionEnsemble(IHostEnvir
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);
- return LearnerEntryPointsUtils.Train(host, input,
+ return TrainerEntryPointsUtils.Train(host, input,
() => new RegressionEnsembleTrainer(host, input),
- () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
+ () => TrainerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumnName));
}
}
}
diff --git a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs
index 13e435e2e4..f650bb296e 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/EnsembleTrainerBase.cs
@@ -22,7 +22,7 @@ internal abstract class EnsembleTrainerBase
where TCombiner : class, IOutputCombiner
{
- public abstract class ArgumentsBase : LearnerInputBaseWithLabel
+ public abstract class ArgumentsBase : TrainerInputBaseWithLabel
{
#pragma warning disable CS0649 // These are set via reflection.
[Argument(ArgumentType.AtMostOnce,
diff --git a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
index a5f0c93604..3d44c3d30f 100644
--- a/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
+++ b/src/Microsoft.ML.Ensemble/Trainer/Multiclass/MulticlassDataPartitionEnsembleTrainer.cs
@@ -63,7 +63,7 @@ public Arguments()
// non-default column names. Unfortuantely no method of resolving this temporary strikes me as being any
// less laborious than the proper fix, which is that this "meta" component should itself be a trainer
// estimator, as opposed to a regular trainer.
- var trainerEstimator = new MulticlassLogisticRegression(env, LabelColumn, FeatureColumn);
+ var trainerEstimator = new MulticlassLogisticRegression(env, LabelColumnName, FeatureColumnName);
return TrainerUtils.MapTrainerEstimatorToTrainer(env, trainerEstimator);
})
diff --git a/src/Microsoft.ML.EntryPoints/ModelOperations.cs b/src/Microsoft.ML.EntryPoints/ModelOperations.cs
index ebd2a33fa7..41ba5d93ca 100644
--- a/src/Microsoft.ML.EntryPoints/ModelOperations.cs
+++ b/src/Microsoft.ML.EntryPoints/ModelOperations.cs
@@ -9,6 +9,7 @@
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Trainers;
+using Microsoft.ML.Transforms;
[assembly: LoadableClass(typeof(void), typeof(ModelOperations), null, typeof(SignatureEntryPointModule), "ModelOperations")]
@@ -52,7 +53,7 @@ public sealed class PredictorModelOutput
public PredictorModel PredictorModel;
}
- public sealed class CombineOvaPredictorModelsInput : LearnerInputBaseWithWeight
+ public sealed class CombineOvaPredictorModelsInput : TrainerInputBaseWithWeight
{
[Argument(ArgumentType.Multiple, HelpText = "Input models", SortOrder = 1)]
public PredictorModel[] ModelArray;
@@ -142,13 +143,13 @@ public static PredictorModelOutput CombineOvaModels(IHostEnvironment env, Combin
using (var ch = host.Start("CombineOvaModels"))
{
var schema = normalizedView.Schema;
- var label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.LabelColumn),
- input.LabelColumn,
+ var label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.LabelColumnName),
+ input.LabelColumnName,
DefaultColumnNames.Label);
- var feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.FeatureColumn),
- input.FeatureColumn, DefaultColumnNames.Features);
- var weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.WeightColumn),
- input.WeightColumn, DefaultColumnNames.Weight);
+ var feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.FeatureColumnName),
+ input.FeatureColumnName, DefaultColumnNames.Features);
+ var weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.ExampleWeightColumnName),
+ input.ExampleWeightColumnName, DefaultColumnNames.Weight);
var data = new RoleMappedData(normalizedView, label, feature, null, weight);
return new PredictorModelOutput
diff --git a/src/Microsoft.ML.EntryPoints/OneVersusAllMacro.cs b/src/Microsoft.ML.EntryPoints/OneVersusAllMacro.cs
index 716cc0b9fa..4fb062cced 100644
--- a/src/Microsoft.ML.EntryPoints/OneVersusAllMacro.cs
+++ b/src/Microsoft.ML.EntryPoints/OneVersusAllMacro.cs
@@ -28,7 +28,7 @@ public sealed class SubGraphOutput
public Var Model;
}
- public sealed class Arguments : LearnerInputBaseWithWeight
+ public sealed class Arguments : TrainerInputBaseWithWeight
{
// This is the subgraph that describes how to train a model for submodel. It should
// accept one IDataView input and output one IPredictorModel output.
@@ -119,13 +119,13 @@ private static int GetNumberOfClasses(IHostEnvironment env, Arguments input, out
{
// RoleMappedData creation
var schema = input.TrainingData.Schema;
- label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumn),
- input.LabelColumn,
+ label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.LabelColumnName),
+ input.LabelColumnName,
DefaultColumnNames.Label);
- var feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumn),
- input.FeatureColumn, DefaultColumnNames.Features);
- var weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.WeightColumn),
- input.WeightColumn, DefaultColumnNames.Weight);
+ var feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.FeatureColumnName),
+ input.FeatureColumnName, DefaultColumnNames.Features);
+ var weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(Arguments.ExampleWeightColumnName),
+ input.ExampleWeightColumnName, DefaultColumnNames.Weight);
// Get number of classes
var data = new RoleMappedData(input.TrainingData, label, feature, null, weight);
@@ -164,8 +164,8 @@ public static CommonOutputs.MacroOutput