Skip to content

PcaTrainer as estimator #996

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 25, 2018
4 changes: 1 addition & 3 deletions src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,6 @@ private static ColumnType GetPredColType(ColumnType scoreType, ISchemaBoundRowMa
}

private static bool OutputTypeMatches(ColumnType scoreType)
{
return scoreType == NumberType.Float;
}
=> scoreType == NumberType.Float;
}
Copy link
Contributor

@TomFinley TomFinley Sep 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bit of a nit, but while you're at it, maybe indent properly. #Closed

}
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ public abstract class SingleFeaturePredictionTransformerBase<TModel> : Predictio
public ColumnType FeatureColumnType { get; }

public SingleFeaturePredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn)
:base(host, model, trainSchema)
: base(host, model, trainSchema)
{
FeatureColumn = featureColumn;

Expand All @@ -148,7 +148,7 @@ public SingleFeaturePredictionTransformerBase(IHost host, TModel model, ISchema
}

internal SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx)
:base(host, ctx)
: base(host, ctx)
{
FeatureColumn = ctx.LoadStringOrNull();

Expand All @@ -166,7 +166,7 @@ public override ISchema GetOutputSchema(ISchema inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));

if(FeatureColumn != null)
if (FeatureColumn != null)
{
if (!inputSchema.TryGetColumnIndex(FeatureColumn, out int col))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), null);
Expand Down
86 changes: 74 additions & 12 deletions src/Microsoft.ML.PCA/PcaTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
using Microsoft.ML.Runtime.PCA;
using Microsoft.ML.Runtime.Training;
using Microsoft.ML.Runtime.Internal.Internallearn;
using Microsoft.ML.Core.Data;

[assembly: LoadableClass(RandomizedPcaTrainer.Summary, typeof(RandomizedPcaTrainer), typeof(RandomizedPcaTrainer.Arguments),
new[] { typeof(SignatureAnomalyDetectorTrainer), typeof(SignatureTrainer) },
Expand All @@ -41,7 +42,7 @@ namespace Microsoft.ML.Runtime.PCA
/// <remarks>
/// This PCA can be made into Kernel PCA by using Random Fourier Features transform
/// </remarks>
public sealed class RandomizedPcaTrainer : TrainerBase<PcaPredictor>
public sealed class RandomizedPcaTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<PcaPredictor>, PcaPredictor>
Copy link
Contributor

@TomFinley TomFinley Sep 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BinaryPredictionTransformer [](start = 68, length = 27)

So, why is this a binary prediction transformer? Prior to your change it was not a binary predictor, it was part of the anomaly detection task. #Closed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, thanks.


In reply to: 220048006 [](ancestors = 220048006)

{
public const string LoadNameValue = "pcaAnomaly";
internal const string UserNameValue = "PCA Anomaly Detector";
Expand Down Expand Up @@ -73,28 +74,56 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight
private readonly int _oversampling;
private readonly bool _center;
private readonly int _seed;
private readonly string _featureColumn;

public override PredictionKind PredictionKind => PredictionKind.AnomalyDetection;

// The training performs two passes, only. Probably not worth caching.
private static readonly TrainerInfo _info = new TrainerInfo(caching: false);
public override TrainerInfo Info => _info;

public RandomizedPcaTrainer(IHostEnvironment env, Arguments args)
: base(env, LoadNameValue)
public RandomizedPcaTrainer(IHostEnvironment env, string featureColumn, string weightColumn = null,
int rank = 20, int oversampling = 20, bool center = true, int? seed = null)
: this(env, null, featureColumn, weightColumn, rank, oversampling, center, seed)
{
Host.CheckValue(args, nameof(args));
Host.CheckUserArg(args.Rank > 0, nameof(args.Rank), "Rank must be positive");
Host.CheckUserArg(args.Oversampling >= 0, nameof(args.Oversampling), "Oversampling must be non-negative");

_rank = args.Rank;
_center = args.Center;
_oversampling = args.Oversampling;
_seed = args.Seed ?? Host.Rand.Next();

}

internal RandomizedPcaTrainer(IHostEnvironment env, Arguments args)
:this(env, args, args.FeatureColumn, args.WeightColumn)
{

}

private RandomizedPcaTrainer(IHostEnvironment env, Arguments args, string featureColumn, string weightColumn,
int rank = 20, int oversampling = 20, bool center = true, int? seed = null)
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), MakeFeatureColumn(featureColumn), null, MakeWeightColumn(weightColumn))
{
// if the args are not null, we got here from maml, and the internal ctor.
if (args != null)
{
_rank = args.Rank;
_center = args.Center;
_oversampling = args.Oversampling;
_seed = args.Seed ?? Host.Rand.Next();
}
else
{
_rank = rank;
_center = center;
_oversampling = oversampling;
_seed = seed ?? Host.Rand.Next();
}

_featureColumn = featureColumn;

Host.CheckUserArg(_rank > 0, nameof(_rank), "Rank must be positive");
Host.CheckUserArg(_oversampling >= 0, nameof(_oversampling), "Oversampling must be non-negative");

}

//Note: the notations used here are the same as in https://web.stanford.edu/group/mmds/slides2010/Martinsson.pdf (pg. 9)
public override PcaPredictor Train(TrainContext context)
protected override PcaPredictor TrainModelCore(TrainContext context)
{
Host.CheckValue(context, nameof(context));

Expand All @@ -108,6 +137,18 @@ public override PcaPredictor Train(TrainContext context)
}
}

private static SchemaShape.Column MakeWeightColumn(string weightColumn)
Copy link
Member Author

@sfilipi sfilipi Sep 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use trainerUtils #Resolved

{
if (weightColumn == null)
return null;
return new SchemaShape.Column(weightColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
}

private static SchemaShape.Column MakeFeatureColumn(string featureColumn)
{
return new SchemaShape.Column(featureColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false);
}

private PcaPredictor TrainCore(IChannel ch, RoleMappedData data, int dimension)
{
Host.AssertValue(ch);
Expand Down Expand Up @@ -266,6 +307,27 @@ private static void PostProcess(VBuffer<Float>[] y, Float[] sigma, Float[] z, in
}
}

protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
{
return new[]
{
new SchemaShape.Column(DefaultColumnNames.Score,
SchemaShape.Column.VectorKind.Scalar,
NumberType.R4,
false,
new SchemaShape(MetadataUtils.GetTrainerOutputMetadata())),

new SchemaShape.Column(DefaultColumnNames.PredictedLabel,
SchemaShape.Column.VectorKind.Scalar,
BoolType.Instance,
false,
new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()))
};
}

protected override BinaryPredictionTransformer<PcaPredictor> MakeTransformer(PcaPredictor model, ISchema trainSchema)
=> new BinaryPredictionTransformer<PcaPredictor>(Host, model, trainSchema, _featureColumn);
Copy link
Member Author

@sfilipi sfilipi Sep 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a, _featu [](start = 84, length = 9)

@zruty do we have a threeshold, to pass? #Resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we use default


In reply to: 219663367 [](ancestors = 219663367)


[TlcModule.EntryPoint(Name = "Trainers.PcaAnomalyDetector",
Desc = "Train an PCA Anomaly model.",
UserName = UserNameValue,
Expand Down
27 changes: 27 additions & 0 deletions test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Runtime.PCA;
using Microsoft.ML.Runtime.RunTests;
using Xunit;
using Xunit.Abstractions;
Expand All @@ -17,6 +18,32 @@ public TrainerEstimators(ITestOutputHelper helper) : base(helper)
{
}

/// <summary>
/// FastTreeBinaryClassification TrainerEstimator test
/// </summary>
[Fact]
public void PCATrainerEstimator()
{
string featureColumn = "NumericFeatures";

var reader = new TextLoader(Env, new TextLoader.Arguments()
{
HasHeader = true,
Separator = "\t",
Column = new[]
{
new TextLoader.Column(featureColumn, DataKind.R4, new [] { new TextLoader.Range(1, 784) })
}
});
var data = reader.Read(new MultiFileSource(GetDataPath(TestDatasets.mnistOneClass.trainFilename)));


// Pipeline.
var pipeline = new RandomizedPcaTrainer(Env, featureColumn, rank:10);

TestEstimatorCore(pipeline, data);
}
Copy link
Contributor

@Zruty0 Zruty0 Sep 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

} [](start = 7, length = 2)

Done() #Resolved


private (IEstimator<ITransformer>, IDataView) GetBinaryClassificationPipeline()
{
var data = new TextLoader(Env,
Expand Down