4
4
5
5
using System ;
6
6
using Microsoft . ML . Data ;
7
+ using Microsoft . ML . EntryPoints ;
7
8
using Microsoft . ML . Trainers . HalLearners ;
8
9
using Microsoft . ML . Trainers . SymSgd ;
9
10
using Microsoft . ML . Transforms . Projections ;
@@ -19,36 +20,78 @@ public static class HalLearnersCatalog
19
20
/// Predict a target using a linear regression model trained with the <see cref="OlsLinearRegressionTrainer"/>.
20
21
/// </summary>
21
22
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
22
- /// <param name="labelColumn">The labelColumn column.</param>
23
+ /// <param name="labelColumn">The label column.</param>
23
24
/// <param name="featureColumn">The features column.</param>
24
25
/// <param name="weights">The weights column.</param>
25
- /// <param name="advancedSettings">Algorithm advanced settings.</param>
26
26
public static OlsLinearRegressionTrainer OrdinaryLeastSquares ( this RegressionContext . RegressionTrainers ctx ,
27
27
string labelColumn = DefaultColumnNames . Label ,
28
28
string featureColumn = DefaultColumnNames . Features ,
29
- string weights = null ,
30
- Action < OlsLinearRegressionTrainer . Arguments > advancedSettings = null )
29
+ string weights = null )
31
30
{
32
31
Contracts . CheckValue ( ctx , nameof ( ctx ) ) ;
33
32
var env = CatalogUtils . GetEnvironment ( ctx ) ;
34
- return new OlsLinearRegressionTrainer ( env , labelColumn , featureColumn , weights , advancedSettings ) ;
33
+ var options = new OlsLinearRegressionTrainer . Options
34
+ {
35
+ LabelColumn = labelColumn ,
36
+ FeatureColumn = featureColumn ,
37
+ WeightColumn = weights != null ? Optional < string > . Explicit ( weights ) : Optional < string > . Implicit ( DefaultColumnNames . Weight )
38
+ } ;
39
+
40
+ return new OlsLinearRegressionTrainer ( env , options ) ;
41
+ }
42
+
43
+ /// <summary>
44
+ /// Predict a target using a linear regression model trained with the <see cref="OlsLinearRegressionTrainer"/>.
45
+ /// </summary>
46
+ /// <param name="ctx">The <see cref="RegressionContext"/>.</param>
47
+ /// <param name="options">Algorithm advanced options. See <see cref="OlsLinearRegressionTrainer.Options"/>.</param>
48
+ public static OlsLinearRegressionTrainer OrdinaryLeastSquares (
49
+ this RegressionContext . RegressionTrainers ctx ,
50
+ OlsLinearRegressionTrainer . Options options )
51
+ {
52
+ Contracts . CheckValue ( ctx , nameof ( ctx ) ) ;
53
+ Contracts . CheckValue ( options , nameof ( options ) ) ;
54
+
55
+ var env = CatalogUtils . GetEnvironment ( ctx ) ;
56
+ return new OlsLinearRegressionTrainer ( env , options ) ;
35
57
}
36
58
37
59
/// <summary>
38
60
/// Predict a target using a linear binary classification model trained with the <see cref="SymSgdClassificationTrainer"/>.
39
61
/// </summary>
40
62
/// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param>
41
- /// <param name="labelColumn">The labelColumn column.</param>
63
+ /// <param name="labelColumn">The label column.</param>
42
64
/// <param name="featureColumn">The features column.</param>
43
- /// <param name="advancedSettings">Algorithm advanced settings.</param>
44
- public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent ( this BinaryClassificationContext . BinaryClassificationTrainers ctx ,
65
+ public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent (
66
+ this BinaryClassificationContext . BinaryClassificationTrainers ctx ,
45
67
string labelColumn = DefaultColumnNames . Label ,
46
- string featureColumn = DefaultColumnNames . Features ,
47
- Action < SymSgdClassificationTrainer . Arguments > advancedSettings = null )
68
+ string featureColumn = DefaultColumnNames . Features )
48
69
{
49
70
Contracts . CheckValue ( ctx , nameof ( ctx ) ) ;
50
71
var env = CatalogUtils . GetEnvironment ( ctx ) ;
51
- return new SymSgdClassificationTrainer ( env , labelColumn , featureColumn , advancedSettings ) ;
72
+
73
+ var options = new SymSgdClassificationTrainer . Options
74
+ {
75
+ LabelColumn = labelColumn ,
76
+ FeatureColumn = featureColumn ,
77
+ } ;
78
+
79
+ return new SymSgdClassificationTrainer ( env , options ) ;
80
+ }
81
+
82
+ /// <summary>
83
+ /// Predict a target using a linear binary classification model trained with the <see cref="SymSgdClassificationTrainer"/>.
84
+ /// </summary>
85
+ /// <param name="ctx">The <see cref="BinaryClassificationContext"/>.</param>
86
+ /// <param name="options">Algorithm advanced options. See <see cref="SymSgdClassificationTrainer.Options"/>.</param>
87
+ public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent (
88
+ this BinaryClassificationContext . BinaryClassificationTrainers ctx ,
89
+ SymSgdClassificationTrainer . Options options )
90
+ {
91
+ Contracts . CheckValue ( ctx , nameof ( ctx ) ) ;
92
+ Contracts . CheckValue ( options , nameof ( options ) ) ;
93
+ var env = CatalogUtils . GetEnvironment ( ctx ) ;
94
+ return new SymSgdClassificationTrainer ( env , options ) ;
52
95
}
53
96
54
97
/// <summary>
@@ -57,7 +100,8 @@ public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this
57
100
/// </summary>
58
101
/// <param name="catalog">The transform's catalog.</param>
59
102
/// <param name="inputColumn">Name of the input column.</param>
60
- /// <param name="outputColumn">Name of the column resulting from the transformation of <paramref name="inputColumn"/>. Null means <paramref name="inputColumn"/> is replaced. </param>
103
+ /// <param name="outputColumn">Name of the column resulting from the transformation of <paramref name="inputColumn"/>.
104
+ /// Null means <paramref name="inputColumn"/> is replaced. </param>
61
105
/// <param name="kind">Whitening kind (PCA/ZCA).</param>
62
106
/// <param name="eps">Whitening constant, prevents division by zero.</param>
63
107
/// <param name="maxRows">Maximum number of rows used to train the transform.</param>
@@ -69,16 +113,17 @@ public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this
69
113
/// ]]>
70
114
/// </format>
71
115
/// </example>
72
- public static VectorWhiteningEstimator VectorWhiten ( this TransformsCatalog . ProjectionTransforms catalog , string inputColumn , string outputColumn = null ,
116
+ public static VectorWhiteningEstimator VectorWhiten ( this TransformsCatalog . ProjectionTransforms catalog ,
117
+ string inputColumn , string outputColumn = null ,
73
118
WhiteningKind kind = VectorWhiteningTransformer . Defaults . Kind ,
74
119
float eps = VectorWhiteningTransformer . Defaults . Eps ,
75
120
int maxRows = VectorWhiteningTransformer . Defaults . MaxRows ,
76
121
int pcaNum = VectorWhiteningTransformer . Defaults . PcaNum )
77
122
=> new VectorWhiteningEstimator ( CatalogUtils . GetEnvironment ( catalog ) , inputColumn , outputColumn , kind , eps , maxRows , pcaNum ) ;
78
123
79
124
/// <summary>
80
- /// Takes columns filled with a vector of random variables with a known covariance matrix into a set of new variables whose covariance is the identity matrix,
81
- /// meaning that they are uncorrelated and each have variance 1.
125
+ /// Takes columns filled with a vector of random variables with a known covariance matrix into a set of new variables whose
126
+ /// covariance is the identity matrix, meaning that they are uncorrelated and each have variance 1.
82
127
/// </summary>
83
128
/// <param name="catalog">The transform's catalog.</param>
84
129
/// <param name="columns">Describes the parameters of the whitening process for each column pair.</param>
0 commit comments