Skip to content

Commit e972912

Browse files
committed
Make accessor of linear coefficients unique to the public
1. Internalize GetFeatureWeights(ref VBuffer<float> weights) 2. Internalize IHaveFeatureWeights
1 parent 6e9023f commit e972912

File tree

10 files changed

+16
-23
lines changed

10 files changed

+16
-23
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/FeatureContributionCalculationTransform.cs

+2-3
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ public static void Example()
4444
var outData = featureContributionCalculator.Fit(scoredData).Transform(scoredData);
4545

4646
// Let's extract the weights from the linear model to use as a comparison
47-
var weights = new VBuffer<float>();
48-
model.Model.GetFeatureWeights(ref weights);
47+
var weights = model.Model.Weights;
4948

5049
// Let's now walk through the first ten records and see which feature drove the values the most
5150
// Get prediction scores and contributions
@@ -63,7 +62,7 @@ public static void Example()
6362
var value = row.Features[featureOfInterest];
6463
var contribution = row.FeatureContributions[featureOfInterest];
6564
var name = data.Schema[featureOfInterest + 1].Name;
66-
var weight = weights.GetValues()[featureOfInterest];
65+
var weight = weights[featureOfInterest];
6766

6867
Console.WriteLine("{0:0.00}\t{1:0.00}\t{2}\t{3:0.00}\t{4:0.00}\t{5:0.00}",
6968
row.MedianHomeValue,

docs/samples/Microsoft.ML.Samples/Static/SDCARegression.cs

+3-5
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,10 @@ public static void SdcaRegression()
4646
var model = learningPipeline.Fit(trainData);
4747

4848
// Check the weights that the model learned
49-
VBuffer<float> weights = default;
50-
pred.GetFeatureWeights(ref weights);
49+
var weights = pred.Weights;
5150

52-
var weightsValues = weights.GetValues();
53-
Console.WriteLine($"weight 0 - {weightsValues[0]}");
54-
Console.WriteLine($"weight 1 - {weightsValues[1]}");
51+
Console.WriteLine($"weight 0 - {weights[0]}");
52+
Console.WriteLine($"weight 1 - {weights[1]}");
5553

5654
// Evaluate how the model is doing on the test data
5755
var dataWithPredictions = model.Transform(testData);

src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ internal interface ICanSaveInSourceCode
146146
/// <summary>
147147
/// Interface implemented by components that can assign weights to features.
148148
/// </summary>
149-
public interface IHaveFeatureWeights
149+
[BestFriend]
150+
internal interface IHaveFeatureWeights
150151
{
151152
/// <summary>
152153
/// Returns the weights for the features.

src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ private TPredictor TrainCore(IChannel ch, RoleMappedData data, LinearModelParame
671671
float bias = 0.0f;
672672
if (predictor != null)
673673
{
674-
predictor.GetFeatureWeights(ref weights);
674+
((IHaveFeatureWeights)predictor).GetFeatureWeights(ref weights);
675675
VBufferUtils.Densify(ref weights);
676676
bias = predictor.Bias;
677677
}

src/Microsoft.ML.StandardLearners/Standard/LinearModelParameters.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
using Microsoft.ML;
1212
using Microsoft.ML.Calibrators;
1313
using Microsoft.ML.Data;
14-
using Microsoft.ML.Internal.Internallearn;
1514
using Microsoft.ML.Internal.Utilities;
1615
using Microsoft.ML.Model;
1716
using Microsoft.ML.Model.OnnxConverter;
@@ -384,7 +383,7 @@ private protected virtual DataViewRow GetSummaryIRowOrNull(RoleMappedSchema sche
384383

385384
void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator) => SaveAsIni(writer, schema, calibrator);
386385

387-
public void GetFeatureWeights(ref VBuffer<float> weights)
386+
void IHaveFeatureWeights.GetFeatureWeights(ref VBuffer<float> weights)
388387
{
389388
Weight.CopyTo(ref weights);
390389
}

src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ internal DataViewSchema.Annotations MakeStatisticsMetadata(LinearBinaryModelPara
438438
builder.AddPrimitiveValue("BiasPValue", NumberDataViewType.Single, biasPValue);
439439

440440
var weights = default(VBuffer<float>);
441-
parent.GetFeatureWeights(ref weights);
441+
((IHaveFeatureWeights)parent).GetFeatureWeights(ref weights);
442442
var estimate = default(VBuffer<float>);
443443
var stdErr = default(VBuffer<float>);
444444
var zScore = default(VBuffer<float>);

src/Microsoft.ML.StandardLearners/Standard/Online/OnlineLinear.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using Microsoft.ML.EntryPoints;
1111
using Microsoft.ML.Internal.Internallearn;
1212
using Microsoft.ML.Internal.Utilities;
13+
using Microsoft.ML.Model;
1314
using Microsoft.ML.Numeric;
1415

1516
namespace Microsoft.ML.Trainers
@@ -130,7 +131,7 @@ protected TrainStateBase(IChannel ch, int numFeatures, LinearModelParameters pre
130131
// unless we have a lot of features.
131132
if (predictor != null)
132133
{
133-
predictor.GetFeatureWeights(ref Weights);
134+
((IHaveFeatureWeights)parent).GetFeatureWeights(ref Weights);
134135
VBufferUtils.Densify(ref Weights);
135136
Bias = predictor.Bias;
136137
}

src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1937,7 +1937,7 @@ private protected override TModel TrainCore(IChannel ch, RoleMappedData data, Li
19371937
float bias = 0.0f;
19381938
if (predictor != null)
19391939
{
1940-
predictor.GetFeatureWeights(ref weights);
1940+
((IHaveFeatureWeights)predictor).GetFeatureWeights(ref weights);
19411941
VBufferUtils.Densify(ref weights);
19421942
bias = predictor.Bias;
19431943
}

test/Microsoft.ML.StaticPipelineTesting/Training.cs

+2-6
Original file line numberDiff line numberDiff line change
@@ -627,9 +627,7 @@ public void PoissonRegression()
627627
var model = pipe.Fit(dataSource);
628628
Assert.NotNull(pred);
629629
// 11 input features, so we ought to have 11 weights.
630-
VBuffer<float> weights = new VBuffer<float>();
631-
pred.GetFeatureWeights(ref weights);
632-
Assert.Equal(11, weights.Length);
630+
Assert.Equal(11, pred.Weights.Count);
633631

634632
var data = model.Load(dataSource);
635633

@@ -751,9 +749,7 @@ public void OnlineGradientDescent()
751749
var model = pipe.Fit(dataSource);
752750
Assert.NotNull(pred);
753751
// 11 input features, so we ought to have 11 weights.
754-
VBuffer<float> weights = new VBuffer<float>();
755-
pred.GetFeatureWeights(ref weights);
756-
Assert.Equal(11, weights.Length);
752+
Assert.Equal(11, pred.Weights.Count);
757753

758754
var data = model.Load(dataSource);
759755

test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ public void IntrospectiveTraining()
4242
var model = pipeline.Fit(data);
4343

4444
// Get feature weights.
45-
VBuffer<float> weights = default;
46-
model.LastTransformer.Model.GetFeatureWeights(ref weights);
45+
var weights = model.LastTransformer.Model.Weights;
4746
}
4847

4948
[Fact]

0 commit comments

Comments
 (0)