2
2
// The .NET Foundation licenses this file to you under the MIT license.
3
3
// See the LICENSE file in the project root for more information.
4
4
5
- using Float = System . Single ;
6
-
7
5
using System ;
8
6
using System . Collections . Generic ;
9
7
using System . IO ;
8
+ using Microsoft . ML . Core . Data ;
10
9
using Microsoft . ML . Runtime ;
11
10
using Microsoft . ML . Runtime . HalLearners ;
12
11
using Microsoft . ML . Runtime . Internal . Internallearn ;
34
33
namespace Microsoft . ML . Runtime . HalLearners
35
34
{
36
35
/// <include file='doc.xml' path='doc/members/member[@name="OLS"]/*' />
37
- public sealed class OlsLinearRegressionTrainer : TrainerBase < OlsLinearRegressionPredictor >
36
+ public sealed class OlsLinearRegressionTrainer : TrainerEstimatorBase < RegressionPredictionTransformer < OlsLinearRegressionPredictor > , OlsLinearRegressionPredictor >
38
37
{
39
38
public sealed class Arguments : LearnerInputBaseWithWeight
40
39
{
@@ -44,7 +43,7 @@ public sealed class Arguments : LearnerInputBaseWithWeight
44
43
[ Argument ( ArgumentType . AtMostOnce , HelpText = "L2 regularization weight" , ShortName = "l2" , SortOrder = 50 ) ]
45
44
[ TGUI ( SuggestedSweeps = "1e-6,0.1,1" ) ]
46
45
[ TlcModule . SweepableDiscreteParamAttribute ( "L2Weight" , new object [ ] { 1e-6f , 0.1f , 1f } ) ]
47
- public Float L2Weight = 1e-6f ;
46
+ public float L2Weight = 1e-6f ;
48
47
49
48
[ Argument ( ArgumentType . LastOccurenceWins , HelpText = "Whether to calculate per parameter significance statistics" , ShortName = "sig" ) ]
50
49
public bool PerParameterSignificance = true ;
@@ -56,7 +55,7 @@ public sealed class Arguments : LearnerInputBaseWithWeight
56
55
internal const string Summary = "The ordinary least square regression fits the target function as a linear function of the numerical features "
57
56
+ "that minimizes the square loss function." ;
58
57
59
- private readonly Float _l2Weight ;
58
+ private readonly float _l2Weight ;
60
59
private readonly bool _perParameterSignificance ;
61
60
62
61
public override PredictionKind PredictionKind => PredictionKind . Regression ;
@@ -65,15 +64,59 @@ public sealed class Arguments : LearnerInputBaseWithWeight
65
64
private static readonly TrainerInfo _info = new TrainerInfo ( caching : false ) ;
66
65
public override TrainerInfo Info => _info ;
67
66
68
- public OlsLinearRegressionTrainer ( IHostEnvironment env , Arguments args )
69
- : base ( env , LoadNameValue )
67
+ /// <summary>
68
+ /// Initializes a new instance of <see cref="OlsLinearRegressionTrainer"/>
69
+ /// </summary>
70
+ /// <param name="env">The environment to use.</param>
71
+ /// <param name="labelColumn">The name of the label column.</param>
72
+ /// <param name="featureColumn">The name of the feature column.</param>
73
+ /// <param name="weightColumn">The name for the example weight column.</param>
74
+ /// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
75
+ public OlsLinearRegressionTrainer ( IHostEnvironment env , string featureColumn , string labelColumn ,
76
+ string weightColumn = null , Action < Arguments > advancedSettings = null )
77
+ : this ( env , ArgsInit ( featureColumn , labelColumn , weightColumn , advancedSettings ) )
78
+ {
79
+ Host . CheckNonEmpty ( featureColumn , nameof ( featureColumn ) ) ;
80
+ Host . CheckNonEmpty ( labelColumn , nameof ( labelColumn ) ) ;
81
+ }
82
+
83
+ /// <summary>
84
+ /// Initializes a new instance of <see cref="OlsLinearRegressionTrainer"/>
85
+ /// </summary>
86
+ internal OlsLinearRegressionTrainer ( IHostEnvironment env , Arguments args )
87
+ : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( LoadNameValue ) , TrainerUtils . MakeR4VecFeature ( args . FeatureColumn ) ,
88
+ TrainerUtils . MakeR4ScalarLabel ( args . LabelColumn ) , TrainerUtils . MakeR4ScalarWeightColumn ( args . WeightColumn ) )
70
89
{
71
90
Host . CheckValue ( args , nameof ( args ) ) ;
72
91
Host . CheckUserArg ( args . L2Weight >= 0 , nameof ( args . L2Weight ) , "L2 regularization term cannot be negative" ) ;
73
92
_l2Weight = args . L2Weight ;
74
93
_perParameterSignificance = args . PerParameterSignificance ;
75
94
}
76
95
96
+ private static Arguments ArgsInit ( string featureColumn , string labelColumn ,
97
+ string weightColumn , Action < Arguments > advancedSettings )
98
+ {
99
+ var args = new Arguments ( ) ;
100
+
101
+ // Apply the advanced args, if the user supplied any.
102
+ advancedSettings ? . Invoke ( args ) ;
103
+ args . FeatureColumn = featureColumn ;
104
+ args . LabelColumn = labelColumn ;
105
+ args . WeightColumn = weightColumn ;
106
+ return args ;
107
+ }
108
+
109
+ protected override RegressionPredictionTransformer < OlsLinearRegressionPredictor > MakeTransformer ( OlsLinearRegressionPredictor model , ISchema trainSchema )
110
+ => new RegressionPredictionTransformer < OlsLinearRegressionPredictor > ( Host , model , trainSchema , FeatureColumn . Name ) ;
111
+
112
+ protected override SchemaShape . Column [ ] GetOutputColumnsCore ( SchemaShape inputSchema )
113
+ {
114
+ return new [ ]
115
+ {
116
+ new SchemaShape . Column ( DefaultColumnNames . Score , SchemaShape . Column . VectorKind . Scalar , NumberType . R4 , false , new SchemaShape ( MetadataUtils . GetTrainerOutputMetadata ( ) ) )
117
+ } ;
118
+ }
119
+
77
120
/// <summary>
78
121
/// In several calculations, we calculate probabilities or other quantities that should range
79
122
/// from 0 to 1, but because of numerical imprecision may, in entirely innocent circumstances,
@@ -84,7 +127,7 @@ public OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args)
84
127
private static Double ProbClamp ( Double p )
85
128
=> Math . Max ( 0 , Math . Min ( p , 1 ) ) ;
86
129
87
- public override OlsLinearRegressionPredictor Train ( TrainContext context )
130
+ protected override OlsLinearRegressionPredictor TrainModelCore ( TrainContext context )
88
131
{
89
132
using ( var ch = Host . Start ( "Training" ) )
90
133
{
@@ -234,24 +277,24 @@ private OlsLinearRegressionPredictor TrainCore(IChannel ch, FloatLabelCursor.Fac
234
277
for ( int i = 0 ; i < beta . Length ; ++ i )
235
278
ch . Check ( FloatUtils . IsFinite ( beta [ i ] ) , "Non-finite values detected in OLS solution" ) ;
236
279
237
- var weights = VBufferUtils . CreateDense < Float > ( beta . Length - 1 ) ;
280
+ var weights = VBufferUtils . CreateDense < float > ( beta . Length - 1 ) ;
238
281
for ( int i = 1 ; i < beta . Length ; ++ i )
239
- weights . Values [ i - 1 ] = ( Float ) beta [ i ] ;
240
- var bias = ( Float ) beta [ 0 ] ;
282
+ weights . Values [ i - 1 ] = ( float ) beta [ i ] ;
283
+ var bias = ( float ) beta [ 0 ] ;
241
284
if ( ! ( _l2Weight > 0 ) && m == n )
242
285
{
243
286
// We would expect the solution to the problem to be exact in this case.
244
287
ch . Info ( "Number of examples equals number of parameters, solution is exact but no statistics can be derived" ) ;
245
- return new OlsLinearRegressionPredictor ( Host , ref weights , bias , null , null , null , 1 , Float . NaN ) ;
288
+ return new OlsLinearRegressionPredictor ( Host , ref weights , bias , null , null , null , 1 , float . NaN ) ;
246
289
}
247
290
248
291
Double rss = 0 ; // residual sum of squares
249
292
Double tss = 0 ; // total sum of squares
250
293
using ( var cursor = cursorFactory . Create ( ) )
251
294
{
252
295
var lrPredictor = new LinearRegressionPredictor ( Host , ref weights , bias ) ;
253
- var lrMap = lrPredictor . GetMapper < VBuffer < Float > , Float > ( ) ;
254
- Float yh = default ;
296
+ var lrMap = lrPredictor . GetMapper < VBuffer < float > , float > ( ) ;
297
+ float yh = default ;
255
298
while ( cursor . MoveNext ( ) )
256
299
{
257
300
var features = cursor . Features ;
@@ -298,7 +341,7 @@ private OlsLinearRegressionPredictor TrainCore(IChannel ch, FloatLabelCursor.Fac
298
341
{
299
342
// Iterate through all entries of inverse Hessian to make adjustment to variance.
300
343
int ioffset = 1 ;
301
- Float reg = _l2Weight * _l2Weight * n ;
344
+ float reg = _l2Weight * _l2Weight * n ;
302
345
for ( int iRow = 1 ; iRow < m ; iRow ++ )
303
346
{
304
347
for ( int iCol = 0 ; iCol <= iRow ; iCol ++ )
@@ -321,7 +364,7 @@ private OlsLinearRegressionPredictor TrainCore(IChannel ch, FloatLabelCursor.Fac
321
364
standardErrors [ i ] = Math . Sqrt ( s2 * standardErrors [ i ] ) ;
322
365
ch . Check ( FloatUtils . IsFinite ( standardErrors [ i ] ) , "Non-finite standard error detected from OLS solution" ) ;
323
366
tValues [ i ] = beta [ i ] / standardErrors [ i ] ;
324
- pValues [ i ] = ( Float ) MathUtils . TStatisticToPValue ( tValues [ i ] , n - m ) ;
367
+ pValues [ i ] = ( float ) MathUtils . TStatisticToPValue ( tValues [ i ] , n - m ) ;
325
368
ch . Check ( 0 <= pValues [ i ] && pValues [ i ] <= 1 , "p-Value calculated outside expected [0,1] range" ) ;
326
369
}
327
370
@@ -558,7 +601,7 @@ public IReadOnlyCollection<Double> TValues
558
601
public IReadOnlyCollection < Double > PValues
559
602
{ get { return _pValues . AsReadOnly ( ) ; } }
560
603
561
- internal OlsLinearRegressionPredictor ( IHostEnvironment env , ref VBuffer < Float > weights , Float bias ,
604
+ internal OlsLinearRegressionPredictor ( IHostEnvironment env , ref VBuffer < float > weights , float bias ,
562
605
Double [ ] standardErrors , Double [ ] tValues , Double [ ] pValues , Double rSquared , Double rSquaredAdjusted )
563
606
: base ( env , RegistrationName , ref weights , bias )
564
607
{
@@ -726,7 +769,7 @@ public override void SaveSummary(TextWriter writer, RoleMappedSchema schema)
726
769
}
727
770
}
728
771
729
- public override void GetFeatureWeights ( ref VBuffer < Float > weights )
772
+ public override void GetFeatureWeights ( ref VBuffer < float > weights )
730
773
{
731
774
if ( _pValues == null )
732
775
{
@@ -737,15 +780,15 @@ public override void GetFeatureWeights(ref VBuffer<Float> weights)
737
780
var values = weights . Values ;
738
781
var size = _pValues . Length - 1 ;
739
782
if ( Utils . Size ( values ) < size )
740
- values = new Float [ size ] ;
783
+ values = new float [ size ] ;
741
784
for ( int i = 0 ; i < size ; i ++ )
742
785
{
743
- var score = - ( Float ) Math . Log ( _pValues [ i + 1 ] ) ;
744
- if ( score > Float . MaxValue )
745
- score = Float . MaxValue ;
786
+ var score = - ( float ) Math . Log ( _pValues [ i + 1 ] ) ;
787
+ if ( score > float . MaxValue )
788
+ score = float . MaxValue ;
746
789
values [ i ] = score ;
747
790
}
748
- weights = new VBuffer < Float > ( size , values , weights . Indices ) ;
791
+ weights = new VBuffer < float > ( size , values , weights . Indices ) ;
749
792
}
750
793
}
751
794
}
0 commit comments