Skip to content

Commit 0ada2fd

Browse files
ganikTomFinley
authored andcommitted
Move Learner* input base and Transform* input base out of Entrypoints… (#2748)
* Move Learner* input base and Transform* input base out of Entrypoints namespace
1 parent 129b47c commit 0ada2fd

File tree

76 files changed

+602
-557
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+602
-557
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/FieldAwareFactorizationMachinewWithOptions.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ public static void Example()
3232
FieldAwareFactorizationMachine(
3333
new FieldAwareFactorizationMachineBinaryClassificationTrainer.Options
3434
{
35-
FeatureColumn = "Features",
36-
LabelColumn = "Sentiment",
35+
FeatureColumnName = "Features",
36+
LabelColumnName = "Sentiment",
3737
LearningRate = 0.1f,
3838
NumberOfIterations = 10
3939
}));

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/BinaryClassification/SDCALogisticRegression.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ public static void Example()
6262
var advancedPipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features")
6363
.Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent(
6464
new SdcaBinaryTrainer.Options {
65-
LabelColumn = "Sentiment",
66-
FeatureColumn = "Features",
65+
LabelColumnName = "Sentiment",
66+
FeatureColumnName = "Features",
6767
ConvergenceTolerance = 0.01f, // The learning rate for adjusting bias from being regularized
6868
NumThreads = 2, // Degree of lock-free parallelism
6969
}));

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/MulticlassClassification/LightGbmWithOptions.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ public static void Example()
3535
var pipeline = mlContext.Transforms.Conversion.MapValueToKey("LabelIndex", "Label")
3636
.Append(mlContext.MulticlassClassification.Trainers.LightGbm(new Options
3737
{
38-
LabelColumn = "LabelIndex",
39-
FeatureColumn = "Features",
38+
LabelColumnName = "LabelIndex",
39+
FeatureColumnName = "Features",
4040
Booster = new DartBooster.Options
4141
{
4242
DropRate = 0.15,

docs/samples/Microsoft.ML.Samples/Dynamic/Trainers/Regression/LightGbmWithOptions.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public static void Example()
3838
var pipeline = mlContext.Transforms.Concatenate("Features", featureNames)
3939
.Append(mlContext.Regression.Trainers.LightGbm(new Options
4040
{
41-
LabelColumn = labelName,
41+
LabelColumnName = labelName,
4242
NumLeaves = 4,
4343
MinDataPerLeaf = 6,
4444
LearningRate = 0.001,

src/Microsoft.ML.Data/EntryPoints/EntryPointNode.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -509,9 +509,9 @@ private EntryPointNode(IHostEnvironment env, IChannel ch, RunContext context,
509509
throw _host.Except($"The following required inputs were not provided: {String.Join(", ", missing)}");
510510

511511
var inputInstance = _inputBuilder.GetInstance();
512-
SetColumnArgument(ch, inputInstance, "LabelColumn", label, "label", typeof(CommonInputs.ITrainerInputWithLabel));
513-
SetColumnArgument(ch, inputInstance, "GroupIdColumn", group, "group Id", typeof(CommonInputs.ITrainerInputWithGroupId));
514-
SetColumnArgument(ch, inputInstance, "WeightColumn", weight, "weight", typeof(CommonInputs.ITrainerInputWithWeight), typeof(CommonInputs.IUnsupervisedTrainerWithWeight));
512+
SetColumnArgument(ch, inputInstance, "LabelColumnName", label, "label", typeof(CommonInputs.ITrainerInputWithLabel));
513+
SetColumnArgument(ch, inputInstance, "RowGroupColumnName", group, "group Id", typeof(CommonInputs.ITrainerInputWithGroupId));
514+
SetColumnArgument(ch, inputInstance, "ExampleWeightColumnName", weight, "weight", typeof(CommonInputs.ITrainerInputWithWeight), typeof(CommonInputs.IUnsupervisedTrainerWithWeight));
515515
SetColumnArgument(ch, inputInstance, "NameColumn", name, "name");
516516

517517
// Validate outputs.

src/Microsoft.ML.Data/EntryPoints/InputBase.cs

+6-107
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,10 @@
99
using Microsoft.ML.CommandLine;
1010
using Microsoft.ML.Data;
1111
using Microsoft.ML.Data.IO;
12+
using Microsoft.ML.Trainers;
1213

1314
namespace Microsoft.ML.EntryPoints
1415
{
15-
/// <summary>
16-
/// The base class for all transform inputs.
17-
/// </summary>
18-
[TlcModule.EntryPointKind(typeof(CommonInputs.ITransformInput))]
19-
public abstract class TransformInputBase
20-
{
21-
/// <summary>
22-
/// The input dataset. Used only in entry-point methods, since the normal API mechanism for feeding in a dataset to
23-
/// create an <see cref="ITransformer"/> is to use the <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> method.
24-
/// </summary>
25-
[BestFriend]
26-
[Argument(ArgumentType.Required, HelpText = "Input dataset", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, SortOrder = 1)]
27-
internal IDataView Data;
28-
}
29-
3016
[BestFriend]
3117
internal enum CachingOptions
3218
{
@@ -35,89 +21,12 @@ internal enum CachingOptions
3521
None
3622
}
3723

38-
/// <summary>
39-
/// The base class for all learner inputs.
40-
/// </summary>
41-
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInput))]
42-
public abstract class LearnerInputBase
43-
{
44-
/// <summary>
45-
/// The data to be used for training. Used only in entry-points, since in the API the expected mechanism is
46-
/// that the user will use the <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> or some other train
47-
/// method.
48-
/// </summary>
49-
[BestFriend]
50-
[Argument(ArgumentType.Required, ShortName = "data", HelpText = "The data to be used for training", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
51-
internal IDataView TrainingData;
52-
53-
/// <summary>
54-
/// Column to use for features.
55-
/// </summary>
56-
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for features", ShortName = "feat", SortOrder = 2, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
57-
public string FeatureColumn = DefaultColumnNames.Features;
58-
59-
/// <summary>
60-
/// Normalize option for the feature column. Used only in entry-points, since in the API the user is expected to do this themselves.
61-
/// </summary>
62-
[BestFriend]
63-
[Argument(ArgumentType.AtMostOnce, HelpText = "Normalize option for the feature column", ShortName = "norm", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
64-
internal NormalizeOption NormalizeFeatures = NormalizeOption.Auto;
65-
66-
/// <summary>
67-
/// Whether learner should cache input training data. Used only in entry-points, since the intended API mechanism
68-
/// is that the user will use the <see cref="DataOperationsCatalog.Cache(IDataView, string[])"/> or other method
69-
/// like <see cref="EstimatorChain{TLastTransformer}.AppendCacheCheckpoint(IHostEnvironment)"/>.
70-
/// </summary>
71-
[BestFriend]
72-
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether learner should cache input training data", ShortName = "cache", SortOrder = 6, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
73-
internal CachingOptions Caching = CachingOptions.Auto;
74-
}
75-
76-
/// <summary>
77-
/// The base class for all learner inputs that support a Label column.
78-
/// </summary>
79-
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithLabel))]
80-
public abstract class LearnerInputBaseWithLabel : LearnerInputBase
81-
{
82-
/// <summary>
83-
/// Column to use for labels.
84-
/// </summary>
85-
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 3, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
86-
public string LabelColumn = DefaultColumnNames.Label;
87-
}
88-
89-
// REVIEW: This is a known antipattern, but the solution involves the decorator pattern which can't be used in this case.
90-
/// <summary>
91-
/// The base class for all learner inputs that support a weight column.
92-
/// </summary>
93-
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithWeight))]
94-
public abstract class LearnerInputBaseWithWeight : LearnerInputBaseWithLabel
95-
{
96-
/// <summary>
97-
/// The name of the example weight column.
98-
/// </summary>
99-
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
100-
public string WeightColumn = null;
101-
}
102-
103-
/// <summary>
104-
/// The base class for all unsupervised learner inputs that support a weight column.
105-
/// </summary>
106-
[TlcModule.EntryPointKind(typeof(CommonInputs.IUnsupervisedTrainerWithWeight))]
107-
public abstract class UnsupervisedLearnerInputBaseWithWeight : LearnerInputBase
108-
{
109-
/// <summary>
110-
/// Column to use for example weight.
111-
/// </summary>
112-
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
113-
public string WeightColumn = null;
114-
}
115-
11624
/// <summary>
11725
/// The base class for all evaluators inputs.
11826
/// </summary>
11927
[TlcModule.EntryPointKind(typeof(CommonInputs.IEvaluatorInput))]
120-
public abstract class EvaluateInputBase
28+
[BestFriend]
29+
internal abstract class EvaluateInputBase
12130
{
12231
[Argument(ArgumentType.Required, ShortName = "data", HelpText = "The data to be used for evaluation.", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
12332
public IDataView Data;
@@ -126,18 +35,8 @@ public abstract class EvaluateInputBase
12635
public string NameColumn = DefaultColumnNames.Name;
12736
}
12837

129-
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithGroupId))]
130-
public abstract class LearnerInputBaseWithGroupId : LearnerInputBaseWithWeight
131-
{
132-
/// <summary>
133-
/// Column to use for example groupId.
134-
/// </summary>
135-
[Argument(ArgumentType.AtMostOnce, Name = "GroupIdColumn", HelpText = "Column to use for example groupId", ShortName = "groupId", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
136-
public string GroupIdColumn = null;
137-
}
138-
13938
[BestFriend]
140-
internal static class LearnerEntryPointsUtils
39+
internal static class TrainerEntryPointsUtils
14140
{
14241
public static string FindColumn(IExceptionContext ectx, DataViewSchema schema, Optional<string> value)
14342
{
@@ -165,13 +64,13 @@ public static TOut Train<TArg, TOut>(IHost host, TArg input,
16564
Func<IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>>> getCustom = null,
16665
ICalibratorTrainerFactory calibrator = null,
16766
int maxCalibrationExamples = 0)
168-
where TArg : LearnerInputBase
67+
where TArg : TrainerInputBase
16968
where TOut : CommonOutputs.TrainerOutput, new()
17069
{
17170
using (var ch = host.Start("Training"))
17271
{
17372
var schema = input.TrainingData.Schema;
174-
var feature = FindColumn(ch, schema, input.FeatureColumn);
73+
var feature = FindColumn(ch, schema, input.FeatureColumnName);
17574
var label = getLabel?.Invoke();
17675
var weight = getWeight?.Invoke();
17776
var group = getGroup?.Invoke();

src/Microsoft.ML.Data/Prediction/Calibrator.cs

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
using Microsoft.ML.Model;
2020
using Microsoft.ML.Model.OnnxConverter;
2121
using Microsoft.ML.Model.Pfa;
22+
using Microsoft.ML.Transforms;
2223
using Newtonsoft.Json.Linq;
2324

2425
[assembly: LoadableClass(PlattCalibratorTrainer.Summary, typeof(PlattCalibratorTrainer), null, typeof(SignatureCalibrator),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using Microsoft.Data.DataView;
8+
using Microsoft.ML.Calibrators;
9+
using Microsoft.ML.CommandLine;
10+
using Microsoft.ML.Data;
11+
using Microsoft.ML.Data.IO;
12+
using Microsoft.ML.EntryPoints;
13+
14+
namespace Microsoft.ML.Trainers
15+
{
16+
/// <summary>
17+
/// The base class for all trainer inputs.
18+
/// </summary>
19+
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInput))]
20+
public abstract class TrainerInputBase
21+
{
22+
private protected TrainerInputBase() { }
23+
24+
/// <summary>
25+
/// The data to be used for training. Used only in entry-points, since in the API the expected mechanism is
26+
/// that the user will use the <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> or some other train
27+
/// method.
28+
/// </summary>
29+
[BestFriend]
30+
[Argument(ArgumentType.Required, ShortName = "data", HelpText = "The data to be used for training", SortOrder = 1, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
31+
internal IDataView TrainingData;
32+
33+
/// <summary>
34+
/// Column to use for features.
35+
/// </summary>
36+
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for features", ShortName = "feat", SortOrder = 2, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
37+
public string FeatureColumnName = DefaultColumnNames.Features;
38+
39+
/// <summary>
40+
/// Normalize option for the feature column. Used only in entry-points, since in the API the user is expected to do this themselves.
41+
/// </summary>
42+
[BestFriend]
43+
[Argument(ArgumentType.AtMostOnce, HelpText = "Normalize option for the feature column", ShortName = "norm", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
44+
internal NormalizeOption NormalizeFeatures = NormalizeOption.Auto;
45+
46+
/// <summary>
47+
/// Whether learner should cache input training data. Used only in entry-points, since the intended API mechanism
48+
/// is that the user will use the <see cref="DataOperationsCatalog.Cache(IDataView, string[])"/> or other method
49+
/// like <see cref="EstimatorChain{TLastTransformer}.AppendCacheCheckpoint(IHostEnvironment)"/>.
50+
/// </summary>
51+
[BestFriend]
52+
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether learner should cache input training data", ShortName = "cache", SortOrder = 6, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
53+
internal CachingOptions Caching = CachingOptions.Auto;
54+
}
55+
56+
/// <summary>
57+
/// The base class for all learner inputs that support a Label column.
58+
/// </summary>
59+
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithLabel))]
60+
public abstract class TrainerInputBaseWithLabel : TrainerInputBase
61+
{
62+
private protected TrainerInputBaseWithLabel() { }
63+
64+
/// <summary>
65+
/// Column to use for labels.
66+
/// </summary>
67+
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for labels", ShortName = "lab", SortOrder = 3, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
68+
public string LabelColumnName = DefaultColumnNames.Label;
69+
}
70+
71+
// REVIEW: This is a known antipattern, but the solution involves the decorator pattern which can't be used in this case.
72+
/// <summary>
73+
/// The base class for all learner inputs that support a weight column.
74+
/// </summary>
75+
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithWeight))]
76+
public abstract class TrainerInputBaseWithWeight : TrainerInputBaseWithLabel
77+
{
78+
private protected TrainerInputBaseWithWeight() { }
79+
80+
/// <summary>
81+
/// Column to use for example weight.
82+
/// </summary>
83+
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
84+
public string ExampleWeightColumnName = null;
85+
}
86+
87+
/// <summary>
88+
/// The base class for all unsupervised learner inputs that support a weight column.
89+
/// </summary>
90+
[TlcModule.EntryPointKind(typeof(CommonInputs.IUnsupervisedTrainerWithWeight))]
91+
public abstract class UnsupervisedTrainerInputBaseWithWeight : TrainerInputBase
92+
{
93+
private protected UnsupervisedTrainerInputBaseWithWeight() { }
94+
95+
/// <summary>
96+
/// Column to use for example weight.
97+
/// </summary>
98+
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example weight", ShortName = "weight", SortOrder = 4, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
99+
public string ExampleWeightColumnName = null;
100+
}
101+
102+
[TlcModule.EntryPointKind(typeof(CommonInputs.ITrainerInputWithGroupId))]
103+
public abstract class TrainerInputBaseWithGroupId : TrainerInputBaseWithWeight
104+
{
105+
private protected TrainerInputBaseWithGroupId() { }
106+
107+
/// <summary>
108+
/// Column to use for example groupId.
109+
/// </summary>
110+
[Argument(ArgumentType.AtMostOnce, HelpText = "Column to use for example groupId", ShortName = "groupId", SortOrder = 5, Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly)]
111+
public string RowGroupColumnName = null;
112+
}
113+
}

src/Microsoft.ML.Data/Transforms/NopTransform.cs

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using Microsoft.ML.Data;
1010
using Microsoft.ML.EntryPoints;
1111
using Microsoft.ML.Model;
12+
using Microsoft.ML.Transforms;
1213

1314
[assembly: LoadableClass(NopTransform.Summary, typeof(NopTransform), null, typeof(SignatureLoadDataTransform),
1415
"", NopTransform.LoaderSignature)]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.Data.DataView;
6+
using Microsoft.ML.CommandLine;
7+
using Microsoft.ML.EntryPoints;
8+
9+
namespace Microsoft.ML.Transforms
10+
{
11+
/// <summary>
12+
/// The base class for all transform inputs.
13+
/// </summary>
14+
[TlcModule.EntryPointKind(typeof(CommonInputs.ITransformInput))]
15+
public abstract class TransformInputBase
16+
{
17+
private protected TransformInputBase() { }
18+
19+
/// <summary>
20+
/// The input dataset. Used only in entry-point methods, since the normal API mechanism for feeding in a dataset to
21+
/// create an <see cref="ITransformer"/> is to use the <see cref="IEstimator{TTransformer}.Fit(IDataView)"/> method.
22+
/// </summary>
23+
[BestFriend]
24+
[Argument(ArgumentType.Required, HelpText = "Input dataset", Visibility = ArgumentAttribute.VisibilityType.EntryPointsOnly, SortOrder = 1)]
25+
internal IDataView Data;
26+
}
27+
}

0 commit comments

Comments
 (0)