Skip to content

Commit 6ec3280

Browse files
authored
Creation of components through MLContext and cleanup (GcnNorm, LpNorm, RandomFourier, CustomStopWords, VectorWhiten, PCA) (#2366)
1 parent 17b07d8 commit 6ec3280

File tree

24 files changed

+666
-545
lines changed

24 files changed

+666
-545
lines changed

src/Microsoft.ML.Data/Transforms/BootstrapSamplingTransformer.cs

+11-11
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
using Microsoft.ML.Model;
1414
using Microsoft.ML.Transforms;
1515

16-
[assembly: LoadableClass(BootstrapSamplingTransformer.Summary, typeof(BootstrapSamplingTransformer), typeof(BootstrapSamplingTransformer.Arguments), typeof(SignatureDataTransform),
16+
[assembly: LoadableClass(BootstrapSamplingTransformer.Summary, typeof(BootstrapSamplingTransformer), typeof(BootstrapSamplingTransformer.Options), typeof(SignatureDataTransform),
1717
BootstrapSamplingTransformer.UserName, "BootstrapSampleTransform", "BootstrapSample")]
1818

1919
[assembly: LoadableClass(BootstrapSamplingTransformer.Summary, typeof(BootstrapSamplingTransformer), null, typeof(SignatureLoadDataTransform),
@@ -36,7 +36,7 @@ internal static class Defaults
3636
public const int PoolSize = 1000;
3737
}
3838

39-
public sealed class Arguments : TransformInputBase
39+
public sealed class Options : TransformInputBase
4040
{
4141
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether this is the out-of-bag sample, that is, all those rows that are not selected by the transform.",
4242
ShortName = "comp")]
@@ -76,16 +76,16 @@ private static VersionInfo GetVersionInfo()
7676
private readonly bool _shuffleInput;
7777
private readonly int _poolSize;
7878

79-
public BootstrapSamplingTransformer(IHostEnvironment env, Arguments args, IDataView input)
79+
public BootstrapSamplingTransformer(IHostEnvironment env, Options options, IDataView input)
8080
: base(env, RegistrationName, input)
8181
{
82-
Host.CheckValue(args, nameof(args));
83-
Host.CheckUserArg(args.PoolSize >= 0, nameof(args.PoolSize), "Cannot be negative");
82+
Host.CheckValue(options, nameof(options));
83+
Host.CheckUserArg(options.PoolSize >= 0, nameof(options.PoolSize), "Cannot be negative");
8484

85-
_complement = args.Complement;
86-
_state = new TauswortheHybrid.State(args.Seed ?? (uint)Host.Rand.Next());
87-
_shuffleInput = args.ShuffleInput;
88-
_poolSize = args.PoolSize;
85+
_complement = options.Complement;
86+
_state = new TauswortheHybrid.State(options.Seed ?? (uint)Host.Rand.Next());
87+
_shuffleInput = options.ShuffleInput;
88+
_poolSize = options.PoolSize;
8989
}
9090

9191
/// <summary>
@@ -103,7 +103,7 @@ public BootstrapSamplingTransformer(IHostEnvironment env,
103103
uint? seed = null,
104104
bool shuffleInput = Defaults.ShuffleInput,
105105
int poolSize = Defaults.PoolSize)
106-
: this(env, new Arguments() { Complement = complement, Seed = seed, ShuffleInput = shuffleInput, PoolSize = poolSize }, input)
106+
: this(env, new Options() { Complement = complement, Seed = seed, ShuffleInput = shuffleInput, PoolSize = poolSize }, input)
107107
{
108108
}
109109

@@ -242,7 +242,7 @@ protected override bool MoveNextCore()
242242
internal static class BootstrapSample
243243
{
244244
[TlcModule.EntryPoint(Name = "Transforms.ApproximateBootstrapSampler", Desc = BootstrapSamplingTransformer.Summary, UserName = BootstrapSamplingTransformer.UserName, ShortName = BootstrapSamplingTransformer.RegistrationName)]
245-
public static CommonOutputs.TransformOutput GetSample(IHostEnvironment env, BootstrapSamplingTransformer.Arguments input)
245+
public static CommonOutputs.TransformOutput GetSample(IHostEnvironment env, BootstrapSamplingTransformer.Options input)
246246
{
247247
Contracts.CheckValue(env, nameof(env));
248248
env.CheckValue(input, nameof(input));

src/Microsoft.ML.Data/Transforms/RowShufflingTransformer.cs

+6-5
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
using Microsoft.ML.Model;
1616
using Microsoft.ML.Transforms;
1717

18-
[assembly: LoadableClass(RowShufflingTransformer.Summary, typeof(RowShufflingTransformer), typeof(RowShufflingTransformer.Arguments), typeof(SignatureDataTransform),
18+
[assembly: LoadableClass(RowShufflingTransformer.Summary, typeof(RowShufflingTransformer), typeof(RowShufflingTransformer.Options), typeof(SignatureDataTransform),
1919
"Shuffle Transform", "ShuffleTransform", "Shuffle", "shuf")]
2020

2121
[assembly: LoadableClass(RowShufflingTransformer.Summary, typeof(RowShufflingTransformer), null, typeof(SignatureLoadDataTransform),
@@ -30,7 +30,8 @@ namespace Microsoft.ML.Transforms
3030
/// rows in the input cursor, and then, successively, the output cursor will yield one
3131
/// of these rows and replace it with another row from the input.
3232
/// </summary>
33-
public sealed class RowShufflingTransformer : RowToRowTransformBase
33+
[BestFriend]
34+
internal sealed class RowShufflingTransformer : RowToRowTransformBase
3435
{
3536
private static class Defaults
3637
{
@@ -39,7 +40,7 @@ private static class Defaults
3940
public const bool ForceShuffle = false;
4041
}
4142

42-
public sealed class Arguments
43+
public sealed class Options
4344
{
4445
// REVIEW: A more intelligent heuristic, based on the expected size of the inputs, perhaps?
4546
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The pool will have this many rows", ShortName = "rows")]
@@ -99,14 +100,14 @@ public RowShufflingTransformer(IHostEnvironment env,
99100
int poolRows = Defaults.PoolRows,
100101
bool poolOnly = Defaults.PoolOnly,
101102
bool forceShuffle = Defaults.ForceShuffle)
102-
: this(env, new Arguments() { PoolRows = poolRows, PoolOnly = poolOnly, ForceShuffle = forceShuffle }, input)
103+
: this(env, new Options() { PoolRows = poolRows, PoolOnly = poolOnly, ForceShuffle = forceShuffle }, input)
103104
{
104105
}
105106

106107
/// <summary>
107108
/// Public constructor corresponding to SignatureDataTransform.
108109
/// </summary>
109-
public RowShufflingTransformer(IHostEnvironment env, Arguments args, IDataView input)
110+
public RowShufflingTransformer(IHostEnvironment env, Options args, IDataView input)
110111
: base(env, RegistrationName, input)
111112
{
112113
Host.CheckValue(args, nameof(args));

src/Microsoft.ML.Ensemble/Selector/SubsetSelector/BootstrapSelector.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public override IEnumerable<Subset> GetSubsets(Batch batch, Random rand)
4747
for (int i = 0; i < Size; i++)
4848
{
4949
// REVIEW: Consider ways to reintroduce "balanced" samples.
50-
var viewTrain = new BootstrapSamplingTransformer(Host, new BootstrapSamplingTransformer.Arguments(), Data.Data);
50+
var viewTrain = new BootstrapSamplingTransformer(Host, new BootstrapSamplingTransformer.Options(), Data.Data);
5151
var dataTrain = new RoleMappedData(viewTrain, Data.Schema.GetColumnRoleNames());
5252
yield return FeatureSelector.SelectFeatures(dataTrain, rand);
5353
}

src/Microsoft.ML.HalLearners.StaticPipe/VectorWhiteningStaticExtensions.cs

+8-8
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
4949
{
5050
Contracts.Assert(toOutput.Length == 1);
5151

52-
var infos = new VectorWhiteningTransformer.ColumnInfo[toOutput.Length];
52+
var infos = new VectorWhiteningEstimator.ColumnInfo[toOutput.Length];
5353
for (int i = 0; i < toOutput.Length; i++)
54-
infos[i] = new VectorWhiteningTransformer.ColumnInfo(outputNames[toOutput[i]], inputNames[((OutPipelineColumn)toOutput[i]).Input], _kind, _eps, _maxRows, _pcaNum);
54+
infos[i] = new VectorWhiteningEstimator.ColumnInfo(outputNames[toOutput[i]], inputNames[((OutPipelineColumn)toOutput[i]).Input], _kind, _eps, _maxRows, _pcaNum);
5555

5656
return new VectorWhiteningEstimator(env, infos);
5757
}
@@ -63,18 +63,18 @@ public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
6363
/// <param name="maxRows">Maximum number of rows used to train the transform.</param>
6464
/// <param name="pcaNum">In case of PCA whitening, indicates the number of components to retain.</param>
6565
public static Vector<float> PcaWhitening(this Vector<float> input,
66-
float eps = VectorWhiteningTransformer.Defaults.Eps,
67-
int maxRows = VectorWhiteningTransformer.Defaults.MaxRows,
68-
int pcaNum = VectorWhiteningTransformer.Defaults.PcaNum)
66+
float eps = VectorWhiteningEstimator.Defaults.Eps,
67+
int maxRows = VectorWhiteningEstimator.Defaults.MaxRows,
68+
int pcaNum = VectorWhiteningEstimator.Defaults.PcaNum)
6969
=> new OutPipelineColumn(input, WhiteningKind.Pca, eps, maxRows, pcaNum);
7070

7171
/// <include file='../Microsoft.ML.HalLearners/doc.xml' path='doc/members/member[@name="Whitening"]/*'/>
7272
/// <param name="input">The column to which the transform will be applied.</param>
7373
/// <param name="eps">Whitening constant, prevents division by zero.</param>
7474
/// <param name="maxRows">Maximum number of rows used to train the transform.</param>
7575
public static Vector<float> ZcaWhitening(this Vector<float> input,
76-
float eps = VectorWhiteningTransformer.Defaults.Eps,
77-
int maxRows = VectorWhiteningTransformer.Defaults.MaxRows)
78-
=> new OutPipelineColumn(input, WhiteningKind.Zca, eps, maxRows, VectorWhiteningTransformer.Defaults.PcaNum);
76+
float eps = VectorWhiteningEstimator.Defaults.Eps,
77+
int maxRows = VectorWhiteningEstimator.Defaults.MaxRows)
78+
=> new OutPipelineColumn(input, WhiteningKind.Zca, eps, maxRows, VectorWhiteningEstimator.Defaults.PcaNum);
7979
}
8080
}

src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs

+5-5
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(
112112
/// </format>
113113
/// </example>
114114
public static VectorWhiteningEstimator VectorWhiten(this TransformsCatalog.ProjectionTransforms catalog, string outputColumnName, string inputColumnName = null,
115-
WhiteningKind kind = VectorWhiteningTransformer.Defaults.Kind,
116-
float eps = VectorWhiteningTransformer.Defaults.Eps,
117-
int maxRows = VectorWhiteningTransformer.Defaults.MaxRows,
118-
int pcaNum = VectorWhiteningTransformer.Defaults.PcaNum)
115+
WhiteningKind kind = VectorWhiteningEstimator.Defaults.Kind,
116+
float eps = VectorWhiteningEstimator.Defaults.Eps,
117+
int maxRows = VectorWhiteningEstimator.Defaults.MaxRows,
118+
int pcaNum = VectorWhiteningEstimator.Defaults.PcaNum)
119119
=> new VectorWhiteningEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, kind, eps, maxRows, pcaNum);
120120

121121
/// <summary>
@@ -124,7 +124,7 @@ public static VectorWhiteningEstimator VectorWhiten(this TransformsCatalog.Proje
124124
/// </summary>
125125
/// <param name="catalog">The transform's catalog.</param>
126126
/// <param name="columns">Describes the parameters of the whitening process for each column pair.</param>
127-
public static VectorWhiteningEstimator VectorWhiten(this TransformsCatalog.ProjectionTransforms catalog, params VectorWhiteningTransformer.ColumnInfo[] columns)
127+
public static VectorWhiteningEstimator VectorWhiten(this TransformsCatalog.ProjectionTransforms catalog, params VectorWhiteningEstimator.ColumnInfo[] columns)
128128
=> new VectorWhiteningEstimator(CatalogUtils.GetEnvironment(catalog), columns);
129129

130130
}

src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ private RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedDa
141141
idvToFeedTrain = idvToShuffle;
142142
else
143143
{
144-
var shuffleArgs = new RowShufflingTransformer.Arguments
144+
var shuffleArgs = new RowShufflingTransformer.Options
145145
{
146146
PoolOnly = false,
147147
ForceShuffle = _options.Shuffle

0 commit comments

Comments
 (0)