Skip to content

Commit 6eb3551

Browse files
ganikeerhardt
authored andcommitted
GetSummaryDataView/Row implementation for Pca and Linear Predictors (dotnet#185)
* Implement `ICanGetSummaryAsIDataView` on `PcaPredictor` class * Implement `ICanGetSummaryAsIRow` on `LinearPredictor` class
1 parent 6ff387b commit 6eb3551

32 files changed

+15142
-58
lines changed

src/Microsoft.ML.PCA/PcaTrainer.cs

+21
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ public static CommonOutputs.AnomalyDetectionOutput TrainPcaAnomaly(IHostEnvironm
307307
// REVIEW: move the predictor to a different file and fold EigenUtils.cs to this file.
308308
public sealed class PcaPredictor : PredictorBase<Float>,
309309
IValueMapper,
310+
ICanGetSummaryAsIDataView,
310311
ICanSaveInTextFormat, ICanSaveModel, ICanSaveSummary
311312
{
312313
public const string LoaderSignature = "pcaAnomExec";
@@ -468,6 +469,26 @@ public void SaveAsText(TextWriter writer, RoleMappedSchema schema)
468469
}
469470
}
470471

472+
public IDataView GetSummaryDataView(RoleMappedSchema schema)
473+
{
474+
var bldr = new ArrayDataViewBuilder(Host);
475+
476+
var cols = new VBuffer<Float>[_rank + 1];
477+
var names = new string[_rank + 1];
478+
for (var i = 0; i < _rank; ++i)
479+
{
480+
names[i] = "EigenVector" + i;
481+
cols[i] = _eigenVectors[i];
482+
}
483+
names[_rank] = "MeanVector";
484+
cols[_rank] = _mean;
485+
486+
bldr.AddColumn("VectorName", names);
487+
bldr.AddColumn("VectorData", NumberType.R4, cols);
488+
489+
return bldr.GetDataView();
490+
}
491+
471492
public ColumnType InputType
472493
{
473494
get { return _inputType; }

src/Microsoft.ML.StandardLearners/Standard/LinearPredictor.cs

+28-48
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ public abstract class LinearPredictor : PredictorBase<Float>,
4444
ICanSaveInTextFormat,
4545
ICanSaveInSourceCode,
4646
ICanSaveModel,
47+
ICanGetSummaryAsIRow,
4748
ICanSaveSummary,
4849
IPredictorWithFeatureWeights<Float>,
4950
IWhatTheFeatureValueMapper,
@@ -343,6 +344,30 @@ public void SaveAsCode(TextWriter writer, RoleMappedSchema schema)
343344

344345
public abstract void SaveSummary(TextWriter writer, RoleMappedSchema schema);
345346

347+
public virtual IRow GetSummaryIRowOrNull(RoleMappedSchema schema)
348+
{
349+
var cols = new List<IColumn>();
350+
351+
var names = default(VBuffer<DvText>);
352+
MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, Weight.Length, ref names);
353+
var slotNamesCol = RowColumnUtils.GetColumn(MetadataUtils.Kinds.SlotNames,
354+
new VectorType(TextType.Instance, Weight.Length), ref names);
355+
var slotNamesRow = RowColumnUtils.GetRow(null, slotNamesCol);
356+
var colType = new VectorType(NumberType.R4, Weight.Length);
357+
358+
// Add the bias and the weight columns.
359+
var bias = Bias;
360+
cols.Add(RowColumnUtils.GetColumn("Bias", NumberType.R4, ref bias));
361+
var weights = Weight;
362+
cols.Add(RowColumnUtils.GetColumn("Weights", colType, ref weights, slotNamesRow));
363+
return RowColumnUtils.GetRow(null, cols.ToArray());
364+
}
365+
366+
public virtual IRow GetStatsIRowOrNull(RoleMappedSchema schema)
367+
{
368+
return null;
369+
}
370+
346371
public abstract void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator = null);
347372

348373
public virtual void GetFeatureWeights(ref VBuffer<Float> weights)
@@ -366,8 +391,7 @@ public ValueMapper<TSrc, VBuffer<Float>> GetWhatTheFeatureMapper<TSrc, TDstContr
366391

367392
public sealed partial class LinearBinaryPredictor : LinearPredictor,
368393
ICanGetSummaryInKeyValuePairs,
369-
IParameterMixer<Float>,
370-
ICanGetSummaryAsIRow
394+
IParameterMixer<Float>
371395
{
372396
public const string LoaderSignature = "Linear2CExec";
373397
public const string RegistrationName = "LinearBinaryPredictor";
@@ -503,26 +527,7 @@ public IList<KeyValuePair<string, object>> GetSummaryInKeyValuePairs(RoleMappedS
503527
return results;
504528
}
505529

506-
public IRow GetSummaryIRowOrNull(RoleMappedSchema schema)
507-
{
508-
var cols = new List<IColumn>();
509-
510-
var names = default(VBuffer<DvText>);
511-
MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, Weight.Length, ref names);
512-
var slotNamesCol = RowColumnUtils.GetColumn(MetadataUtils.Kinds.SlotNames,
513-
new VectorType(TextType.Instance, Weight.Length), ref names);
514-
var slotNamesRow = RowColumnUtils.GetRow(null, slotNamesCol);
515-
var colType = new VectorType(NumberType.R4, Weight.Length);
516-
517-
// Add the bias and the weight columns.
518-
var bias = Bias;
519-
cols.Add(RowColumnUtils.GetColumn("Bias", NumberType.R4, ref bias));
520-
var weights = Weight;
521-
cols.Add(RowColumnUtils.GetColumn("Weights", colType, ref weights, slotNamesRow));
522-
return RowColumnUtils.GetRow(null, cols.ToArray());
523-
}
524-
525-
public IRow GetStatsIRowOrNull(RoleMappedSchema schema)
530+
public override IRow GetStatsIRowOrNull(RoleMappedSchema schema)
526531
{
527532
if (_stats == null)
528533
return null;
@@ -582,8 +587,7 @@ public override void SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICali
582587

583588
public sealed class LinearRegressionPredictor : RegressionPredictor,
584589
IParameterMixer<Float>,
585-
ICanGetSummaryInKeyValuePairs,
586-
ICanGetSummaryAsIRow
590+
ICanGetSummaryInKeyValuePairs
587591
{
588592
public const string LoaderSignature = "LinearRegressionExec";
589593
public const string RegistrationName = "LinearRegressionPredictor";
@@ -663,30 +667,6 @@ public IList<KeyValuePair<string, object>> GetSummaryInKeyValuePairs(RoleMappedS
663667

664668
return results;
665669
}
666-
667-
public IRow GetSummaryIRowOrNull(RoleMappedSchema schema)
668-
{
669-
var cols = new List<IColumn>();
670-
671-
var names = default(VBuffer<DvText>);
672-
MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, Weight.Length, ref names);
673-
var slotNamesCol = RowColumnUtils.GetColumn(MetadataUtils.Kinds.SlotNames,
674-
new VectorType(TextType.Instance, Weight.Length), ref names);
675-
var slotNamesRow = RowColumnUtils.GetRow(null, slotNamesCol);
676-
var colType = new VectorType(NumberType.R4, Weight.Length);
677-
678-
// Add the bias and the weight columns.
679-
var bias = Bias;
680-
cols.Add(RowColumnUtils.GetColumn("Bias", NumberType.R4, ref bias));
681-
var weights = Weight;
682-
cols.Add(RowColumnUtils.GetColumn("Weights", colType, ref weights, slotNamesRow));
683-
return RowColumnUtils.GetRow(null, cols.ToArray());
684-
}
685-
686-
public IRow GetStatsIRowOrNull(RoleMappedSchema schema)
687-
{
688-
return null;
689-
}
690670
}
691671

692672
public sealed class PoissonRegressionPredictor : RegressionPredictor, IParameterMixer<Float>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#@ TextLoader{
2+
#@ header+
3+
#@ sep=tab
4+
#@ col=Bias:R4:0
5+
#@ col=Weights:R4:1-9
6+
#@ }
7+
Bias thickness uniform_size uniform_shape adhesion epit_size bare_nuclei bland_chromatin normal_nucleoli mitoses
8+
-6.186806 2.65800762 1.68089855 1.944068 1.42514718 0.8536965 2.9325006 1.74816787 1.58165014 0.595681
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#@ TextLoader{
2+
#@ header+
3+
#@ sep=tab
4+
#@ col={name={Count of training examples} type=I8 src=0}
5+
#@ col={name={Residual Deviance} type=R4 src=1}
6+
#@ col={name={Null Deviance} type=R4 src=2}
7+
#@ col=AIC:R4:3
8+
#@ }
9+
Count of training examples Residual Deviance Null Deviance AIC
10+
683 119.098892 884.3502 159.098892
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#@ TextLoader{
2+
#@ header+
3+
#@ sep=tab
4+
#@ col=Bias:R4:0
5+
#@ col=Weights:R4:1-9
6+
#@ col=ClassNames:TX:10
7+
#@ }
8+
Bias thickness uniform_size uniform_shape adhesion epit_size bare_nuclei bland_chromatin normal_nucleoli mitoses ClassNames
9+
3.36404228 -1.579712 -0.8266232 -1.051891 -0.79305464 -0.386733949 -1.59106934 -1.01550019 -0.8356989 -0.332574666 Class_0
10+
-3.36404562 1.57971311 0.826623559 1.051891 0.7930542 0.386735022 1.59107041 1.015499 0.8356983 0.332574 Class_1

0 commit comments

Comments
 (0)