Skip to content

Commit a18d296

Browse files
artidoroTomFinley
authored andcommitted
Conversion of ordinary least square linear regression (OlsLinearRegression) to estimator (#1002)
1 parent b790195 commit a18d296

File tree

2 files changed

+107
-23
lines changed

2 files changed

+107
-23
lines changed

src/Microsoft.ML.HalLearners/OlsLinearRegression.cs

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using Float = System.Single;
6-
75
using System;
86
using System.Collections.Generic;
97
using System.IO;
8+
using Microsoft.ML.Core.Data;
109
using Microsoft.ML.Runtime;
1110
using Microsoft.ML.Runtime.HalLearners;
1211
using Microsoft.ML.Runtime.Internal.Internallearn;
@@ -34,7 +33,7 @@
3433
namespace Microsoft.ML.Runtime.HalLearners
3534
{
3635
/// <include file='doc.xml' path='doc/members/member[@name="OLS"]/*' />
37-
public sealed class OlsLinearRegressionTrainer : TrainerBase<OlsLinearRegressionPredictor>
36+
public sealed class OlsLinearRegressionTrainer : TrainerEstimatorBase<RegressionPredictionTransformer<OlsLinearRegressionPredictor>, OlsLinearRegressionPredictor>
3837
{
3938
public sealed class Arguments : LearnerInputBaseWithWeight
4039
{
@@ -44,7 +43,7 @@ public sealed class Arguments : LearnerInputBaseWithWeight
4443
[Argument(ArgumentType.AtMostOnce, HelpText = "L2 regularization weight", ShortName = "l2", SortOrder = 50)]
4544
[TGUI(SuggestedSweeps = "1e-6,0.1,1")]
4645
[TlcModule.SweepableDiscreteParamAttribute("L2Weight", new object[] { 1e-6f, 0.1f, 1f })]
47-
public Float L2Weight = 1e-6f;
46+
public float L2Weight = 1e-6f;
4847

4948
[Argument(ArgumentType.LastOccurenceWins, HelpText = "Whether to calculate per parameter significance statistics", ShortName = "sig")]
5049
public bool PerParameterSignificance = true;
@@ -56,7 +55,7 @@ public sealed class Arguments : LearnerInputBaseWithWeight
5655
internal const string Summary = "The ordinary least square regression fits the target function as a linear function of the numerical features "
5756
+ "that minimizes the square loss function.";
5857

59-
private readonly Float _l2Weight;
58+
private readonly float _l2Weight;
6059
private readonly bool _perParameterSignificance;
6160

6261
public override PredictionKind PredictionKind => PredictionKind.Regression;
@@ -65,15 +64,59 @@ public sealed class Arguments : LearnerInputBaseWithWeight
6564
private static readonly TrainerInfo _info = new TrainerInfo(caching: false);
6665
public override TrainerInfo Info => _info;
6766

68-
public OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args)
69-
: base(env, LoadNameValue)
67+
/// <summary>
68+
/// Initializes a new instance of <see cref="OlsLinearRegressionTrainer"/>
69+
/// </summary>
70+
/// <param name="env">The environment to use.</param>
71+
/// <param name="labelColumn">The name of the label column.</param>
72+
/// <param name="featureColumn">The name of the feature column.</param>
73+
/// <param name="weightColumn">The name for the example weight column.</param>
74+
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
75+
public OlsLinearRegressionTrainer(IHostEnvironment env, string featureColumn, string labelColumn,
76+
string weightColumn = null, Action<Arguments> advancedSettings = null)
77+
: this(env, ArgsInit(featureColumn, labelColumn, weightColumn, advancedSettings))
78+
{
79+
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
80+
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
81+
}
82+
83+
/// <summary>
84+
/// Initializes a new instance of <see cref="OlsLinearRegressionTrainer"/>
85+
/// </summary>
86+
internal OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args)
87+
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(args.FeatureColumn),
88+
TrainerUtils.MakeR4ScalarLabel(args.LabelColumn), TrainerUtils.MakeR4ScalarWeightColumn(args.WeightColumn))
7089
{
7190
Host.CheckValue(args, nameof(args));
7291
Host.CheckUserArg(args.L2Weight >= 0, nameof(args.L2Weight), "L2 regularization term cannot be negative");
7392
_l2Weight = args.L2Weight;
7493
_perParameterSignificance = args.PerParameterSignificance;
7594
}
7695

96+
private static Arguments ArgsInit(string featureColumn, string labelColumn,
97+
string weightColumn, Action<Arguments> advancedSettings)
98+
{
99+
var args = new Arguments();
100+
101+
// Apply the advanced args, if the user supplied any.
102+
advancedSettings?.Invoke(args);
103+
args.FeatureColumn = featureColumn;
104+
args.LabelColumn = labelColumn;
105+
args.WeightColumn = weightColumn;
106+
return args;
107+
}
108+
109+
protected override RegressionPredictionTransformer<OlsLinearRegressionPredictor> MakeTransformer(OlsLinearRegressionPredictor model, ISchema trainSchema)
110+
=> new RegressionPredictionTransformer<OlsLinearRegressionPredictor>(Host, model, trainSchema, FeatureColumn.Name);
111+
112+
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
113+
{
114+
return new[]
115+
{
116+
new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))
117+
};
118+
}
119+
77120
/// <summary>
78121
/// In several calculations, we calculate probabilities or other quantities that should range
79122
/// from 0 to 1, but because of numerical imprecision may, in entirely innocent circumstances,
@@ -84,7 +127,7 @@ public OlsLinearRegressionTrainer(IHostEnvironment env, Arguments args)
84127
private static Double ProbClamp(Double p)
85128
=> Math.Max(0, Math.Min(p, 1));
86129

87-
public override OlsLinearRegressionPredictor Train(TrainContext context)
130+
protected override OlsLinearRegressionPredictor TrainModelCore(TrainContext context)
88131
{
89132
using (var ch = Host.Start("Training"))
90133
{
@@ -234,24 +277,24 @@ private OlsLinearRegressionPredictor TrainCore(IChannel ch, FloatLabelCursor.Fac
234277
for (int i = 0; i < beta.Length; ++i)
235278
ch.Check(FloatUtils.IsFinite(beta[i]), "Non-finite values detected in OLS solution");
236279

237-
var weights = VBufferUtils.CreateDense<Float>(beta.Length - 1);
280+
var weights = VBufferUtils.CreateDense<float>(beta.Length - 1);
238281
for (int i = 1; i < beta.Length; ++i)
239-
weights.Values[i - 1] = (Float)beta[i];
240-
var bias = (Float)beta[0];
282+
weights.Values[i - 1] = (float)beta[i];
283+
var bias = (float)beta[0];
241284
if (!(_l2Weight > 0) && m == n)
242285
{
243286
// We would expect the solution to the problem to be exact in this case.
244287
ch.Info("Number of examples equals number of parameters, solution is exact but no statistics can be derived");
245-
return new OlsLinearRegressionPredictor(Host, ref weights, bias, null, null, null, 1, Float.NaN);
288+
return new OlsLinearRegressionPredictor(Host, ref weights, bias, null, null, null, 1, float.NaN);
246289
}
247290

248291
Double rss = 0; // residual sum of squares
249292
Double tss = 0; // total sum of squares
250293
using (var cursor = cursorFactory.Create())
251294
{
252295
var lrPredictor = new LinearRegressionPredictor(Host, ref weights, bias);
253-
var lrMap = lrPredictor.GetMapper<VBuffer<Float>, Float>();
254-
Float yh = default;
296+
var lrMap = lrPredictor.GetMapper<VBuffer<float>, float>();
297+
float yh = default;
255298
while (cursor.MoveNext())
256299
{
257300
var features = cursor.Features;
@@ -298,7 +341,7 @@ private OlsLinearRegressionPredictor TrainCore(IChannel ch, FloatLabelCursor.Fac
298341
{
299342
// Iterate through all entries of inverse Hessian to make adjustment to variance.
300343
int ioffset = 1;
301-
Float reg = _l2Weight * _l2Weight * n;
344+
float reg = _l2Weight * _l2Weight * n;
302345
for (int iRow = 1; iRow < m; iRow++)
303346
{
304347
for (int iCol = 0; iCol <= iRow; iCol++)
@@ -321,7 +364,7 @@ private OlsLinearRegressionPredictor TrainCore(IChannel ch, FloatLabelCursor.Fac
321364
standardErrors[i] = Math.Sqrt(s2 * standardErrors[i]);
322365
ch.Check(FloatUtils.IsFinite(standardErrors[i]), "Non-finite standard error detected from OLS solution");
323366
tValues[i] = beta[i] / standardErrors[i];
324-
pValues[i] = (Float)MathUtils.TStatisticToPValue(tValues[i], n - m);
367+
pValues[i] = (float)MathUtils.TStatisticToPValue(tValues[i], n - m);
325368
ch.Check(0 <= pValues[i] && pValues[i] <= 1, "p-Value calculated outside expected [0,1] range");
326369
}
327370

@@ -558,7 +601,7 @@ public IReadOnlyCollection<Double> TValues
558601
public IReadOnlyCollection<Double> PValues
559602
{ get { return _pValues.AsReadOnly(); } }
560603

561-
internal OlsLinearRegressionPredictor(IHostEnvironment env, ref VBuffer<Float> weights, Float bias,
604+
internal OlsLinearRegressionPredictor(IHostEnvironment env, ref VBuffer<float> weights, float bias,
562605
Double[] standardErrors, Double[] tValues, Double[] pValues, Double rSquared, Double rSquaredAdjusted)
563606
: base(env, RegistrationName, ref weights, bias)
564607
{
@@ -726,7 +769,7 @@ public override void SaveSummary(TextWriter writer, RoleMappedSchema schema)
726769
}
727770
}
728771

729-
public override void GetFeatureWeights(ref VBuffer<Float> weights)
772+
public override void GetFeatureWeights(ref VBuffer<float> weights)
730773
{
731774
if (_pValues == null)
732775
{
@@ -737,15 +780,15 @@ public override void GetFeatureWeights(ref VBuffer<Float> weights)
737780
var values = weights.Values;
738781
var size = _pValues.Length - 1;
739782
if (Utils.Size(values) < size)
740-
values = new Float[size];
783+
values = new float[size];
741784
for (int i = 0; i < size; i++)
742785
{
743-
var score = -(Float)Math.Log(_pValues[i + 1]);
744-
if (score > Float.MaxValue)
745-
score = Float.MaxValue;
786+
var score = -(float)Math.Log(_pValues[i + 1]);
787+
if (score > float.MaxValue)
788+
score = float.MaxValue;
746789
values[i] = score;
747790
}
748-
weights = new VBuffer<Float>(size, values, weights.Indices);
791+
weights = new VBuffer<float>(size, values, weights.Indices);
749792
}
750793
}
751794
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime;
6+
using Microsoft.ML.Runtime.Data;
7+
using Microsoft.ML.Runtime.EntryPoints;
8+
using Microsoft.ML.Runtime.HalLearners;
9+
using Microsoft.ML.Runtime.Learners;
10+
using Microsoft.ML.Runtime.RunTests;
11+
using Xunit;
12+
13+
namespace Microsoft.ML.Tests.TrainerEstimators
14+
{
15+
public partial class TrainerEstimators
16+
{
17+
private IDataView GetGeneratedRegressionDataview()
18+
{
19+
return new TextLoader(Env,
20+
new TextLoader.Arguments()
21+
{
22+
Separator = ";",
23+
HasHeader = true,
24+
Column = new[]
25+
{
26+
new TextLoader.Column("Label", DataKind.R4, 11),
27+
new TextLoader.Column("Features", DataKind.R4, new [] { new TextLoader.Range(0, 10) } )
28+
}
29+
}).Read(new MultiFileSource(GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename)));
30+
}
31+
32+
[Fact]
33+
public void TestEstimatorOlsLinearRegression()
34+
{
35+
var dataView = GetGeneratedRegressionDataview();
36+
var pipe = new OlsLinearRegressionTrainer(Env, "Features", "Label");
37+
TestEstimatorCore(pipe, dataView);
38+
Done();
39+
}
40+
}
41+
}

0 commit comments

Comments
 (0)