-
Notifications
You must be signed in to change notification settings - Fork 1.9k
PcaTrainer as estimator #996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
662d19c
5220393
c1d54f7
9d5783d
431cfae
6708f7f
83b1daf
2974c38
43f5775
eac7be4
f1a1213
65c1ee3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
using Microsoft.ML.Runtime.PCA; | ||
using Microsoft.ML.Runtime.Training; | ||
using Microsoft.ML.Runtime.Internal.Internallearn; | ||
using Microsoft.ML.Core.Data; | ||
|
||
[assembly: LoadableClass(RandomizedPcaTrainer.Summary, typeof(RandomizedPcaTrainer), typeof(RandomizedPcaTrainer.Arguments), | ||
new[] { typeof(SignatureAnomalyDetectorTrainer), typeof(SignatureTrainer) }, | ||
|
@@ -41,7 +42,7 @@ namespace Microsoft.ML.Runtime.PCA | |
/// <remarks> | ||
/// This PCA can be made into Kernel PCA by using Random Fourier Features transform | ||
/// </remarks> | ||
public sealed class RandomizedPcaTrainer : TrainerBase<PcaPredictor> | ||
public sealed class RandomizedPcaTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<PcaPredictor>, PcaPredictor> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
So, why is this a binary prediction transformer? Prior to your change it was not a binary predictor, it was part of the anomaly detection task. #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
{ | ||
public const string LoadNameValue = "pcaAnomaly"; | ||
internal const string UserNameValue = "PCA Anomaly Detector"; | ||
|
@@ -73,28 +74,56 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight | |
private readonly int _oversampling; | ||
private readonly bool _center; | ||
private readonly int _seed; | ||
private readonly string _featureColumn; | ||
|
||
public override PredictionKind PredictionKind => PredictionKind.AnomalyDetection; | ||
|
||
// The training performs two passes, only. Probably not worth caching. | ||
private static readonly TrainerInfo _info = new TrainerInfo(caching: false); | ||
public override TrainerInfo Info => _info; | ||
|
||
public RandomizedPcaTrainer(IHostEnvironment env, Arguments args) | ||
: base(env, LoadNameValue) | ||
public RandomizedPcaTrainer(IHostEnvironment env, string featureColumn, string weightColumn = null, | ||
int rank = 20, int oversampling = 20, bool center = true, int? seed = null) | ||
: this(env, null, featureColumn, weightColumn, rank, oversampling, center, seed) | ||
{ | ||
Host.CheckValue(args, nameof(args)); | ||
Host.CheckUserArg(args.Rank > 0, nameof(args.Rank), "Rank must be positive"); | ||
Host.CheckUserArg(args.Oversampling >= 0, nameof(args.Oversampling), "Oversampling must be non-negative"); | ||
|
||
_rank = args.Rank; | ||
_center = args.Center; | ||
_oversampling = args.Oversampling; | ||
_seed = args.Seed ?? Host.Rand.Next(); | ||
|
||
} | ||
|
||
internal RandomizedPcaTrainer(IHostEnvironment env, Arguments args) | ||
:this(env, args, args.FeatureColumn, args.WeightColumn) | ||
{ | ||
|
||
} | ||
|
||
private RandomizedPcaTrainer(IHostEnvironment env, Arguments args, string featureColumn, string weightColumn, | ||
int rank = 20, int oversampling = 20, bool center = true, int? seed = null) | ||
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), MakeFeatureColumn(featureColumn), null, MakeWeightColumn(weightColumn)) | ||
{ | ||
// if the args are not null, we got here from maml, and the internal ctor. | ||
if (args != null) | ||
{ | ||
_rank = args.Rank; | ||
_center = args.Center; | ||
_oversampling = args.Oversampling; | ||
_seed = args.Seed ?? Host.Rand.Next(); | ||
} | ||
else | ||
{ | ||
_rank = rank; | ||
_center = center; | ||
_oversampling = oversampling; | ||
_seed = seed ?? Host.Rand.Next(); | ||
} | ||
|
||
_featureColumn = featureColumn; | ||
|
||
Host.CheckUserArg(_rank > 0, nameof(_rank), "Rank must be positive"); | ||
Host.CheckUserArg(_oversampling >= 0, nameof(_oversampling), "Oversampling must be non-negative"); | ||
|
||
} | ||
|
||
//Note: the notations used here are the same as in https://web.stanford.edu/group/mmds/slides2010/Martinsson.pdf (pg. 9) | ||
public override PcaPredictor Train(TrainContext context) | ||
protected override PcaPredictor TrainModelCore(TrainContext context) | ||
{ | ||
Host.CheckValue(context, nameof(context)); | ||
|
||
|
@@ -108,6 +137,18 @@ public override PcaPredictor Train(TrainContext context) | |
} | ||
} | ||
|
||
private static SchemaShape.Column MakeWeightColumn(string weightColumn) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use trainerUtils #Resolved |
||
{ | ||
if (weightColumn == null) | ||
return null; | ||
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false); | ||
} | ||
|
||
private static SchemaShape.Column MakeFeatureColumn(string featureColumn) | ||
{ | ||
return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); | ||
} | ||
|
||
private PcaPredictor TrainCore(IChannel ch, RoleMappedData data, int dimension) | ||
{ | ||
Host.AssertValue(ch); | ||
|
@@ -266,6 +307,27 @@ private static void PostProcess(VBuffer<Float>[] y, Float[] sigma, Float[] z, in | |
} | ||
} | ||
|
||
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema) | ||
{ | ||
return new[] | ||
{ | ||
new SchemaShape.Column(DefaultColumnNames.Score, | ||
SchemaShape.Column.VectorKind.Scalar, | ||
NumberType.R4, | ||
false, | ||
new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())), | ||
|
||
new SchemaShape.Column(DefaultColumnNames.PredictedLabel, | ||
SchemaShape.Column.VectorKind.Scalar, | ||
BoolType.Instance, | ||
false, | ||
new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())) | ||
}; | ||
} | ||
|
||
protected override BinaryPredictionTransformer<PcaPredictor> MakeTransformer(PcaPredictor model, ISchema trainSchema) | ||
=> new BinaryPredictionTransformer<PcaPredictor>(Host, model, trainSchema, _featureColumn); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@zruty do we have a threeshold, to pass? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
[TlcModule.EntryPoint(Name = "Trainers.PcaAnomalyDetector", | ||
Desc = "Train an PCA Anomaly model.", | ||
UserName = UserNameValue, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
using Microsoft.ML.Core.Data; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.Learners; | ||
using Microsoft.ML.Runtime.PCA; | ||
using Microsoft.ML.Runtime.RunTests; | ||
using Xunit; | ||
using Xunit.Abstractions; | ||
|
@@ -17,6 +18,32 @@ public TrainerEstimators(ITestOutputHelper helper) : base(helper) | |
{ | ||
} | ||
|
||
/// <summary> | ||
/// FastTreeBinaryClassification TrainerEstimator test | ||
/// </summary> | ||
[Fact] | ||
public void PCATrainerEstimator() | ||
{ | ||
string featureColumn = "NumericFeatures"; | ||
|
||
var reader = new TextLoader(Env, new TextLoader.Arguments() | ||
{ | ||
HasHeader = true, | ||
Separator = "\t", | ||
Column = new[] | ||
{ | ||
new TextLoader.Column(featureColumn, DataKind.R4, new [] { new TextLoader.Range(1, 784) }) | ||
} | ||
}); | ||
var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.mnistOneClass.trainFilename))); | ||
|
||
|
||
// Pipeline. | ||
var pipeline = new RandomizedPcaTrainer(Env, featureColumn, rank:10); | ||
|
||
TestEstimatorCore(pipeline, data); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
private (IEstimator<ITransformer>, IDataView) GetBinaryClassificationPipeline() | ||
{ | ||
var data = new TextLoader(Env, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bit of a nit, but while you're at it, maybe indent properly. #Closed