-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Modify API for advanced settings (RandomizedPcaTrainer) #2390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
77746d6
03e98b8
e5b1a74
8c47da1
398475a
8adb8b1
62f5db8
02b8651
b82c326
b1526d1
d90ded5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using Microsoft.Data.DataView; | ||
|
||
namespace Microsoft.ML.Data.Evaluators.Metrics | ||
{ | ||
public sealed class AnomalyDetectionMetrics | ||
{ | ||
public double Auc { get; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Summaries, Remarks, and links to relevant documentation. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added basic summaries for now. wanted to also add the remarks from TLC website., but the explanations there were not clear esp. for the detection rate metrics. In reply to: 255583277 [](ancestors = 255583277) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For these summaries, check in with @shmoradims ; he's building a set of generic docs for things like AUC, F1, RMSE, etc. In reply to: 255703503 [](ancestors = 255703503,255583277) |
||
public double DrAtK { get; } | ||
public double DrAtPFpr { get; } | ||
public double DrAtNumPos { get; } | ||
public double NumAnomalies { get; } | ||
public double ThreshAtK { get; } | ||
public double ThreshAtP { get; } | ||
public double ThreshAtNumPos { get; } | ||
|
||
internal AnomalyDetectionMetrics(IExceptionContext ectx, Row overallResult) | ||
{ | ||
long FetchInt(string name) => RowCursorUtils.Fetch<long>(ectx, overallResult, name); | ||
float FetchFloat(string name) => RowCursorUtils.Fetch<float>(ectx, overallResult, name); | ||
double FetchDouble(string name) => RowCursorUtils.Fetch<double>(ectx, overallResult, name); | ||
|
||
Auc = FetchDouble(BinaryClassifierEvaluator.Auc); | ||
DrAtK = FetchDouble(AnomalyDetectionEvaluator.OverallMetrics.DrAtK); | ||
DrAtPFpr = FetchDouble(AnomalyDetectionEvaluator.OverallMetrics.DrAtPFpr); | ||
DrAtNumPos = FetchDouble(AnomalyDetectionEvaluator.OverallMetrics.DrAtNumPos); | ||
NumAnomalies = FetchInt(AnomalyDetectionEvaluator.OverallMetrics.NumAnomalies); | ||
ThreshAtK = FetchFloat(AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK); | ||
ThreshAtP = FetchFloat(AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP); | ||
ThreshAtNumPos = FetchFloat(AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
using Microsoft.Data.DataView; | ||
using Microsoft.ML.Core.Data; | ||
using Microsoft.ML.Data; | ||
using Microsoft.ML.Data.Evaluators.Metrics; | ||
using Microsoft.ML.Transforms; | ||
using Microsoft.ML.Transforms.Conversions; | ||
|
||
|
@@ -564,4 +565,51 @@ public RankerMetrics Evaluate(IDataView data, string label, string groupId, stri | |
return eval.Evaluate(data, label, groupId, score); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// The central catalog for anomaly detection tasks and trainers. | ||
/// </summary> | ||
public sealed class AnomalyDetectionCatalog : TrainCatalogBase | ||
{ | ||
/// <summary> | ||
/// The list of trainers for anomaly detection. | ||
/// </summary> | ||
public AnomalyDetectionTrainers Trainers { get; } | ||
|
||
internal AnomalyDetectionCatalog(IHostEnvironment env) | ||
: base(env, nameof(AnomalyDetectionCatalog)) | ||
{ | ||
Trainers = new AnomalyDetectionTrainers(this); | ||
} | ||
|
||
public sealed class AnomalyDetectionTrainers : CatalogInstantiatorBase | ||
{ | ||
internal AnomalyDetectionTrainers(AnomalyDetectionCatalog catalog) | ||
: base(catalog) | ||
{ | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Evaluates scored anomaly detection data. | ||
/// </summary> | ||
/// <param name="data">The scored data.</param> | ||
/// <param name="label">The name of the label column in <paramref name="data"/>.</param> | ||
/// <param name="score">The name of the score column in <paramref name="data"/>.</param> | ||
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param> | ||
/// <returns>The evaluation results for these calibrated outputs.</returns> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
What is calibrated here? Could you explain in a few more words? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
public AnomalyDetectionMetrics Evaluate(IDataView data, string label = DefaultColumnNames.Label, string score = DefaultColumnNames.Score, | ||
string predictedLabel = DefaultColumnNames.PredictedLabel) | ||
{ | ||
Host.CheckValue(data, nameof(data)); | ||
Host.CheckNonEmpty(label, nameof(label)); | ||
Host.CheckNonEmpty(score, nameof(score)); | ||
Host.CheckNonEmpty(predictedLabel, nameof(predictedLabel)); | ||
|
||
var args = new AnomalyDetectionEvaluator.Arguments(); | ||
|
||
var eval = new AnomalyDetectionEvaluator(Host, args); | ||
return eval.Evaluate(data, label, score, predictedLabel); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,13 +3,14 @@ | |
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Data; | ||
using Microsoft.ML.Trainers.PCA; | ||
using Microsoft.ML.Transforms.Projections; | ||
using static Microsoft.ML.Trainers.PCA.RandomizedPcaTrainer; | ||
|
||
namespace Microsoft.ML | ||
{ | ||
public static class PcaCatalog | ||
{ | ||
|
||
/// <summary>Initializes a new instance of <see cref="PrincipalComponentAnalysisEstimator"/>.</summary> | ||
/// <param name="catalog">The transform's catalog.</param> | ||
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param> | ||
|
@@ -35,5 +36,25 @@ public static PrincipalComponentAnalysisEstimator ProjectToPrincipalComponents(t | |
/// <param name="columns">Input columns to apply PrincipalComponentAnalysis on.</param> | ||
public static PrincipalComponentAnalysisEstimator ProjectToPrincipalComponents(this TransformsCatalog.ProjectionTransforms catalog, params PrincipalComponentAnalysisEstimator.ColumnInfo[] columns) | ||
=> new PrincipalComponentAnalysisEstimator(CatalogUtils.GetEnvironment(catalog), columns); | ||
|
||
public static RandomizedPcaTrainer RandomizedPca(this AnomalyDetectionCatalog.AnomalyDetectionTrainers catalog, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs xml docs with remarks and links to a sample. Here or add to #1209 . #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
string featureColumn = DefaultColumnNames.Features, | ||
string weights = null, | ||
int rank = 20, | ||
int oversampling = 20, | ||
bool center = true, | ||
int? seed = null) | ||
{ | ||
Contracts.CheckValue(catalog, nameof(catalog)); | ||
var env = CatalogUtils.GetEnvironment(catalog); | ||
return new RandomizedPcaTrainer(env, featureColumn, weights, rank, oversampling, center, seed); | ||
} | ||
|
||
public static RandomizedPcaTrainer RandomizedPca(this AnomalyDetectionCatalog.AnomalyDetectionTrainers catalog, Options options) | ||
{ | ||
Contracts.CheckValue(catalog, nameof(catalog)); | ||
var env = CatalogUtils.GetEnvironment(catalog); | ||
return new RandomizedPcaTrainer(env, options); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,7 @@ | |
using Microsoft.ML.Trainers.PCA; | ||
using Microsoft.ML.Training; | ||
|
||
[assembly: LoadableClass(RandomizedPcaTrainer.Summary, typeof(RandomizedPcaTrainer), typeof(RandomizedPcaTrainer.Arguments), | ||
[assembly: LoadableClass(RandomizedPcaTrainer.Summary, typeof(RandomizedPcaTrainer), typeof(RandomizedPcaTrainer.Options), | ||
new[] { typeof(SignatureAnomalyDetectorTrainer), typeof(SignatureTrainer) }, | ||
RandomizedPcaTrainer.UserNameValue, | ||
RandomizedPcaTrainer.LoadNameValue, | ||
|
@@ -49,7 +49,7 @@ public sealed class RandomizedPcaTrainer : TrainerEstimatorBase<AnomalyPredictio | |
internal const string Summary = "This algorithm trains an approximate PCA using Randomized SVD algorithm. " | ||
+ "This PCA can be made into Kernel PCA by using Random Fourier Features transform."; | ||
|
||
public class Arguments : UnsupervisedLearnerInputBaseWithWeight | ||
public class Options : UnsupervisedLearnerInputBaseWithWeight | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
xml docs are coming later? #Pending There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
{ | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "The number of components in the PCA", ShortName = "k", SortOrder = 50)] | ||
[TGUI(SuggestedSweeps = "10,20,40,80")] | ||
|
@@ -91,7 +91,7 @@ public class Arguments : UnsupervisedLearnerInputBaseWithWeight | |
/// <param name="oversampling">Oversampling parameter for randomized PCA training.</param> | ||
/// <param name="center">If enabled, data is centered to be zero mean.</param> | ||
/// <param name="seed">The seed for random number generation.</param> | ||
public RandomizedPcaTrainer(IHostEnvironment env, | ||
internal RandomizedPcaTrainer(IHostEnvironment env, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we just make the class There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not really. we would want to expose these through mlcontext. not via constructors In reply to: 254391018 [](ancestors = 254391018) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Resolved. I hadn't seen the pattern for trainable transforms where the class is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i believe in ML.NET terms, this is "trainer estimator" (for anomaly detection tasks) most other "trainer estimator"s follow the same pattern e.g. KMeansPlusPlusTrainer In reply to: 255587702 [](ancestors = 255587702) |
||
string features, | ||
string weights = null, | ||
int rank = 20, | ||
|
@@ -103,23 +103,23 @@ public RandomizedPcaTrainer(IHostEnvironment env, | |
|
||
} | ||
|
||
internal RandomizedPcaTrainer(IHostEnvironment env, Arguments args) | ||
:this(env, args, args.FeatureColumn, args.WeightColumn) | ||
internal RandomizedPcaTrainer(IHostEnvironment env, Options options) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It's strange... I noticed that renaming Arguments to Options did not modify anything in the mlContext catalog. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I looked it up, and I don't think there is an entry for this trainer in mlContext. Can you add it? In reply to: 253319239 [](ancestors = 253319239) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah. i noticed couple more components which do not have mlcontext extension. will add In reply to: 253319255 [](ancestors = 253319255,253319239) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i added mlcontext extension for this. Also added a test for it that exercises the Fit() and Transform() APIs. Evaluate() API currently missing from Anomaly Detection. i will create a separate issue for that. In reply to: 253584603 [](ancestors = 253584603,253319255,253319239) |
||
:this(env, options, options.FeatureColumn, options.WeightColumn) | ||
{ | ||
|
||
} | ||
|
||
private RandomizedPcaTrainer(IHostEnvironment env, Arguments args, string featureColumn, string weightColumn, | ||
private RandomizedPcaTrainer(IHostEnvironment env, Options options, string featureColumn, string weightColumn, | ||
int rank = 20, int oversampling = 20, bool center = true, int? seed = null) | ||
: base(Contracts.CheckRef(env, nameof(env)).Register(LoadNameValue), TrainerUtils.MakeR4VecFeature(featureColumn), default, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn)) | ||
{ | ||
// if the args are not null, we got here from maml, and the internal ctor. | ||
if (args != null) | ||
if (options != null) | ||
{ | ||
_rank = args.Rank; | ||
_center = args.Center; | ||
_oversampling = args.Oversampling; | ||
_seed = args.Seed ?? Host.Rand.Next(); | ||
_rank = options.Rank; | ||
_center = options.Center; | ||
_oversampling = options.Oversampling; | ||
_seed = options.Seed ?? Host.Rand.Next(); | ||
} | ||
else | ||
{ | ||
|
@@ -347,14 +347,14 @@ protected override AnomalyPredictionTransformer<PcaModelParameters> MakeTransfor | |
Desc = "Train an PCA Anomaly model.", | ||
UserName = UserNameValue, | ||
ShortName = ShortName)] | ||
internal static CommonOutputs.AnomalyDetectionOutput TrainPcaAnomaly(IHostEnvironment env, Arguments input) | ||
internal static CommonOutputs.AnomalyDetectionOutput TrainPcaAnomaly(IHostEnvironment env, Options input) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above; these can be kept There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The In reply to: 254391534 [](ancestors = 254391534) |
||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
var host = env.Register("TrainPCAAnomaly"); | ||
host.CheckValue(input, nameof(input)); | ||
EntryPointUtils.CheckInputArgs(host, input); | ||
|
||
return LearnerEntryPointsUtils.Train<Arguments, CommonOutputs.AnomalyDetectionOutput>(host, input, | ||
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.AnomalyDetectionOutput>(host, input, | ||
() => new RandomizedPcaTrainer(host, input), | ||
getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn)); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using System.Drawing; | ||
using System.Drawing.Imaging; | ||
using System.IO; | ||
using System.Linq; | ||
using Microsoft.Data.DataView; | ||
using Microsoft.ML.Data; | ||
using Microsoft.ML.ImageAnalytics; | ||
using Microsoft.ML.Model; | ||
using Microsoft.ML.RunTests; | ||
using Xunit; | ||
using Xunit.Abstractions; | ||
|
||
namespace Microsoft.ML.Tests | ||
{ | ||
public class AnomalyDetectionTests : TestDataPipeBase | ||
{ | ||
public AnomalyDetectionTests(ITestOutputHelper output) : base(output) | ||
{ | ||
} | ||
|
||
/// <summary> | ||
/// RandomizedPcaTrainer test | ||
/// </summary> | ||
[Fact] | ||
public void RandomizedPcaTrainer() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
{ | ||
var mlContext = new MLContext(seed: 1, conc: 1); | ||
string featureColumn = "NumericFeatures"; | ||
|
||
var reader = new TextLoader(Env, new TextLoader.Arguments() | ||
{ | ||
HasHeader = true, | ||
Separator = "\t", | ||
Columns = new[] | ||
{ | ||
new TextLoader.Column("Label", DataKind.R4, 0), | ||
new TextLoader.Column(featureColumn, DataKind.R4, new [] { new TextLoader.Range(1, 784) }) | ||
} | ||
}); | ||
|
||
var trainData = reader.Read(GetDataPath(TestDatasets.mnistOneClass.trainFilename)); | ||
var testData = reader.Read(GetDataPath(TestDatasets.mnistOneClass.testFilename)); | ||
|
||
var pipeline = ML.AnomalyDetection.Trainers.RandomizedPca(featureColumn); | ||
|
||
var transformer = pipeline.Fit(trainData); | ||
var transformedData = transformer.Transform(testData); | ||
|
||
// Evaluate | ||
var metrics = ML.AnomalyDetection.Evaluate(transformedData); | ||
|
||
Assert.Equal(0.99, metrics.Auc, 2); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do we know that these numbers are correct? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i tried out the same dataset in TLC with the same trainer, the numbers are close. Not exact hough in general, in this PR i am only exposing the trainer / evaluators as they exist currently in the codebase. the PR does not have any algorithmic changes or changes in evaluation metrics themselves. In reply to: 255589211 [](ancestors = 255589211) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess the big question is, what do we want to test here?
If we do want a baseline test, can we mark it as such, and check to further decimal places? As an aside, are there correctness tests on these metrics that we can migrate from the internal repo? If so, can you file it as an issue to be done later?) In reply to: 255680814 [](ancestors = 255680814,255589211) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point. as per the classification above, these seem like "baseline" tests, I have increased the precision to 5 places of decimal. as for test migration from internal repo, it seems we ported over the In reply to: 255796608 [](ancestors = 255796608,255680814,255589211) |
||
Assert.Equal(0.90, metrics.DrAtK, 2); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This one @ 5 places too, please :) #Resolved |
||
Assert.Equal(0.90, metrics.DrAtPFpr, 2); | ||
Assert.Equal(0.90, metrics.DrAtNumPos, 2); | ||
Assert.Equal(10, metrics.NumAnomalies); | ||
Assert.Equal(0.57, metrics.ThreshAtK, 2); | ||
Assert.Equal(0.63, metrics.ThreshAtP, 2); | ||
Assert.Equal(0.65, metrics.ThreshAtNumPos, 2); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Defaults for
label
,score
, andpredictedLabel
? #Resolved