Skip to content

Commit b1cc8eb

Browse files
authored
Towards #1798 . (#2170)
* Towards #1798 . This PR addresses the estimators inside HalLearners: Two public extension methods, one for simple arguments and the other for advanced options Delete unecessary constructors Pass Options objects as arguments instead of Action delegate Rename Arguments to Options Rename Options objects as options (instead of args or advancedSettings used so far)
1 parent bb92c06 commit b1cc8eb

File tree

12 files changed

+179
-146
lines changed

12 files changed

+179
-146
lines changed

src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs

+7-7
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ public static FastTreeRegressionTrainer FastTree(this RegressionContext.Regressi
5757
/// Predict a target using a decision tree binary classification model trained with the <see cref="FastTreeBinaryClassificationTrainer"/>.
5858
/// </summary>
5959
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param>
60-
/// <param name="labelColumn">The labelColumn column.</param>
60+
/// <param name="labelColumn">The label column.</param>
6161
/// <param name="featureColumn">The featureColumn column.</param>
6262
/// <param name="weights">The optional weights column.</param>
6363
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
@@ -97,7 +97,7 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica
9797
/// Ranks a series of inputs based on their relevance, training a decision tree ranking model through the <see cref="FastTreeRankingTrainer"/>.
9898
/// </summary>
9999
/// <param name="ctx">The <see cref="RankingContext"/>.</param>
100-
/// <param name="labelColumn">The labelColumn column.</param>
100+
/// <param name="labelColumn">The label column.</param>
101101
/// <param name="featureColumn">The featureColumn column.</param>
102102
/// <param name="groupId">The groupId column.</param>
103103
/// <param name="weights">The optional weights column.</param>
@@ -139,7 +139,7 @@ public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainer
139139
/// Predict a target using generalized additive models trained with the <see cref="BinaryClassificationGamTrainer"/>.
140140
/// </summary>
141141
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param>
142-
/// <param name="labelColumn">The labelColumn column.</param>
142+
/// <param name="labelColumn">The label column.</param>
143143
/// <param name="featureColumn">The featureColumn column.</param>
144144
/// <param name="weights">The optional weights column.</param>
145145
/// <param name="numIterations">The number of iterations to use in learning the features.</param>
@@ -164,7 +164,7 @@ public static BinaryClassificationGamTrainer GeneralizedAdditiveModels(this Bina
164164
/// Predict a target using generalized additive models trained with the <see cref="RegressionGamTrainer"/>.
165165
/// </summary>
166166
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
167-
/// <param name="labelColumn">The labelColumn column.</param>
167+
/// <param name="labelColumn">The label column.</param>
168168
/// <param name="featureColumn">The featureColumn column.</param>
169169
/// <param name="weights">The optional weights column.</param>
170170
/// <param name="numIterations">The number of iterations to use in learning the features.</param>
@@ -189,7 +189,7 @@ public static RegressionGamTrainer GeneralizedAdditiveModels(this RegressionCont
189189
/// Predict a target using a decision tree regression model trained with the <see cref="FastTreeTweedieTrainer"/>.
190190
/// </summary>
191191
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
192-
/// <param name="labelColumn">The labelColumn column.</param>
192+
/// <param name="labelColumn">The label column.</param>
193193
/// <param name="featureColumn">The featureColumn column.</param>
194194
/// <param name="weights">The optional weights column.</param>
195195
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
@@ -229,7 +229,7 @@ public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.Regr
229229
/// Predict a target using a decision tree regression model trained with the <see cref="FastForestRegression"/>.
230230
/// </summary>
231231
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
232-
/// <param name="labelColumn">The labelColumn column.</param>
232+
/// <param name="labelColumn">The label column.</param>
233233
/// <param name="featureColumn">The featureColumn column.</param>
234234
/// <param name="weights">The optional weights column.</param>
235235
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
@@ -269,7 +269,7 @@ public static FastForestRegression FastForest(this RegressionContext.RegressionT
269269
/// Predict a target using a decision tree regression model trained with the <see cref="FastForestClassification"/>.
270270
/// </summary>
271271
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param>
272-
/// <param name="labelColumn">The labelColumn column.</param>
272+
/// <param name="labelColumn">The label column.</param>
273273
/// <param name="featureColumn">The featureColumn column.</param>
274274
/// <param name="weights">The optional weights column.</param>
275275
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>

src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs

+60-15
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using Microsoft.ML.Data;
7+
using Microsoft.ML.EntryPoints;
78
using Microsoft.ML.Trainers.HalLearners;
89
using Microsoft.ML.Trainers.SymSgd;
910
using Microsoft.ML.Transforms.Projections;
@@ -19,36 +20,78 @@ public static class HalLearnersCatalog
1920
/// Predict a target using a linear regression model trained with the <see cref="OlsLinearRegressionTrainer"/>.
2021
/// </summary>
2122
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
22-
/// <param name="labelColumn">The labelColumn column.</param>
23+
/// <param name="labelColumn">The label column.</param>
2324
/// <param name="featureColumn">The features column.</param>
2425
/// <param name="weights">The weights column.</param>
25-
/// <param name="advancedSettings">Algorithm advanced settings.</param>
2626
public static OlsLinearRegressionTrainer OrdinaryLeastSquares(this RegressionContext.RegressionTrainers ctx,
2727
string labelColumn = DefaultColumnNames.Label,
2828
string featureColumn = DefaultColumnNames.Features,
29-
string weights = null,
30-
Action<OlsLinearRegressionTrainer.Arguments> advancedSettings = null)
29+
string weights = null)
3130
{
3231
Contracts.CheckValue(ctx, nameof(ctx));
3332
var env = CatalogUtils.GetEnvironment(ctx);
34-
return new OlsLinearRegressionTrainer(env, labelColumn, featureColumn, weights, advancedSettings);
33+
var options = new OlsLinearRegressionTrainer.Options
34+
{
35+
LabelColumn = labelColumn,
36+
FeatureColumn = featureColumn,
37+
WeightColumn = weights != null ? Optional<string>.Explicit(weights) : Optional<string>.Implicit(DefaultColumnNames.Weight)
38+
};
39+
40+
return new OlsLinearRegressionTrainer(env, options);
41+
}
42+
43+
/// <summary>
44+
/// Predict a target using a linear regression model trained with the <see cref="OlsLinearRegressionTrainer"/>.
45+
/// </summary>
46+
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
47+
/// <param name="options">Algorithm advanced options. See <see cref="OlsLinearRegressionTrainer.Options"/>.</param>
48+
public static OlsLinearRegressionTrainer OrdinaryLeastSquares(
49+
this RegressionContext.RegressionTrainers ctx,
50+
OlsLinearRegressionTrainer.Options options)
51+
{
52+
Contracts.CheckValue(ctx, nameof(ctx));
53+
Contracts.CheckValue(options, nameof(options));
54+
55+
var env = CatalogUtils.GetEnvironment(ctx);
56+
return new OlsLinearRegressionTrainer(env, options);
3557
}
3658

3759
/// <summary>
3860
/// Predict a target using a linear binary classification model trained with the <see cref="SymSgdClassificationTrainer"/>.
3961
/// </summary>
4062
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param>
41-
/// <param name="labelColumn">The labelColumn column.</param>
63+
/// <param name="labelColumn">The label column.</param>
4264
/// <param name="featureColumn">The features column.</param>
43-
/// <param name="advancedSettings">Algorithm advanced settings.</param>
44-
public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
65+
public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(
66+
this BinaryClassificationContext.BinaryClassificationTrainers ctx,
4567
string labelColumn = DefaultColumnNames.Label,
46-
string featureColumn = DefaultColumnNames.Features,
47-
Action<SymSgdClassificationTrainer.Arguments> advancedSettings = null)
68+
string featureColumn = DefaultColumnNames.Features)
4869
{
4970
Contracts.CheckValue(ctx, nameof(ctx));
5071
var env = CatalogUtils.GetEnvironment(ctx);
51-
return new SymSgdClassificationTrainer(env, labelColumn, featureColumn, advancedSettings);
72+
73+
var options = new SymSgdClassificationTrainer.Options
74+
{
75+
LabelColumn = labelColumn,
76+
FeatureColumn = featureColumn,
77+
};
78+
79+
return new SymSgdClassificationTrainer(env, options);
80+
}
81+
82+
/// <summary>
83+
/// Predict a target using a linear binary classification model trained with the <see cref="SymSgdClassificationTrainer"/>.
84+
/// </summary>
85+
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param>
86+
/// <param name="options">Algorithm advanced options. See <see cref="SymSgdClassificationTrainer.Options"/>.</param>
87+
public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(
88+
this BinaryClassificationContext.BinaryClassificationTrainers ctx,
89+
SymSgdClassificationTrainer.Options options)
90+
{
91+
Contracts.CheckValue(ctx, nameof(ctx));
92+
Contracts.CheckValue(options, nameof(options));
93+
var env = CatalogUtils.GetEnvironment(ctx);
94+
return new SymSgdClassificationTrainer(env, options);
5295
}
5396

5497
/// <summary>
@@ -57,7 +100,8 @@ public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this
57100
/// </summary>
58101
/// <param name="catalog">The transform's catalog.</param>
59102
/// <param name="inputColumn">Name of the input column.</param>
60-
/// <param name="outputColumn">Name of the column resulting from the transformation of <paramref name="inputColumn"/>. Null means <paramref name="inputColumn"/> is replaced. </param>
103+
/// <param name="outputColumn">Name of the column resulting from the transformation of <paramref name="inputColumn"/>.
104+
/// Null means <paramref name="inputColumn"/> is replaced. </param>
61105
/// <param name="kind">Whitening kind (PCA/ZCA).</param>
62106
/// <param name="eps">Whitening constant, prevents division by zero.</param>
63107
/// <param name="maxRows">Maximum number of rows used to train the transform.</param>
@@ -69,16 +113,17 @@ public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this
69113
/// ]]>
70114
/// </format>
71115
/// </example>
72-
public static VectorWhiteningEstimator VectorWhiten(this TransformsCatalog.ProjectionTransforms catalog, string inputColumn, string outputColumn = null,
116+
public static VectorWhiteningEstimator VectorWhiten(this TransformsCatalog.ProjectionTransforms catalog,
117+
string inputColumn, string outputColumn = null,
73118
WhiteningKind kind = VectorWhiteningTransformer.Defaults.Kind,
74119
float eps = VectorWhiteningTransformer.Defaults.Eps,
75120
int maxRows = VectorWhiteningTransformer.Defaults.MaxRows,
76121
int pcaNum = VectorWhiteningTransformer.Defaults.PcaNum)
77122
=> new VectorWhiteningEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, kind, eps, maxRows, pcaNum);
78123

79124
/// <summary>
80-
/// Takes columns filled with a vector of random variables with a known covariance matrix into a set of new variables whose covariance is the identity matrix,
81-
/// meaning that they are uncorrelated and each have variance 1.
125+
/// Takes columns filled with a vector of random variables with a known covariance matrix into a set of new variables whose
126+
/// covariance is the identity matrix, meaning that they are uncorrelated and each have variance 1.
82127
/// </summary>
83128
/// <param name="catalog">The transform's catalog.</param>
84129
/// <param name="columns">Describes the parameters of the whitening process for each column pair.</param>

0 commit comments

Comments
 (0)