Skip to content

Commit bd98b16

Browse files
committed
feature selection transforms mainly
1 parent 28ae548 commit bd98b16

File tree

18 files changed

+187
-160
lines changed

18 files changed

+187
-160
lines changed

src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public static TypeConvertingEstimator ConvertType(this TransformsCatalog.Convers
5555
/// </summary>
5656
/// <param name="catalog">The transform's catalog.</param>
5757
/// <param name="columns">Description of dataset columns and how to process them.</param>
58-
public static TypeConvertingEstimator ConvertType(this TransformsCatalog.ConversionTransforms catalog, params TypeConvertingTransformer.ColumnInfo[] columns)
58+
public static TypeConvertingEstimator ConvertType(this TransformsCatalog.ConversionTransforms catalog, params TypeConvertingEstimator.ColumnInfo[] columns)
5959
=> new TypeConvertingEstimator(CatalogUtils.GetEnvironment(catalog), columns);
6060

6161
/// <summary>

src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransform.cs renamed to src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs

+14-10
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ namespace Microsoft.ML.Data
7676
/// </example>
7777
public sealed class FeatureContributionCalculatingTransformer : OneToOneTransformerBase
7878
{
79-
public sealed class Arguments : TransformInputBase
79+
internal sealed class Options : TransformInputBase
8080
{
8181
[Argument(ArgumentType.Required, HelpText = "The predictor model to apply to data", SortOrder = 1)]
8282
public PredictorModel PredictorModel;
@@ -128,7 +128,7 @@ private static VersionInfo GetVersionInfo()
128128
/// <param name="numNegativeContributions">The number of negative contributions to report, sorted from highest magnitude to lowest magnitude.
129129
/// Note that if there are fewer features with negative contributions than <paramref name="numNegativeContributions"/>, the rest will be returned as zeros.</param>
130130
/// <param name="normalize">Whether the feature contributions should be normalized to the [-1, 1] interval.</param>
131-
public FeatureContributionCalculatingTransformer(IHostEnvironment env, ICalculateFeatureContribution modelParameters,
131+
internal FeatureContributionCalculatingTransformer(IHostEnvironment env, ICalculateFeatureContribution modelParameters,
132132
string featureColumn = DefaultColumnNames.Features,
133133
int numPositiveContributions = FeatureContributionCalculatingEstimator.Defaults.NumPositiveContributions,
134134
int numNegativeContributions = FeatureContributionCalculatingEstimator.Defaults.NumNegativeContributions,
@@ -300,7 +300,7 @@ public static class Defaults
300300
/// <param name="numNegativeContributions">The number of negative contributions to report, sorted from highest magnitude to lowest magnitude.
301301
/// Note that if there are fewer features with negative contributions than <paramref name="numNegativeContributions"/>, the rest will be returned as zeros.</param>
302302
/// <param name="normalize">Whether the feature contributions should be normalized to the [-1, 1] interval.</param>
303-
public FeatureContributionCalculatingEstimator(IHostEnvironment env, ICalculateFeatureContribution modelParameters,
303+
internal FeatureContributionCalculatingEstimator(IHostEnvironment env, ICalculateFeatureContribution modelParameters,
304304
string featureColumn = DefaultColumnNames.Features,
305305
int numPositiveContributions = Defaults.NumPositiveContributions,
306306
int numNegativeContributions = Defaults.NumNegativeContributions,
@@ -312,6 +312,10 @@ public FeatureContributionCalculatingEstimator(IHostEnvironment env, ICalculateF
312312
_predictor = modelParameters;
313313
}
314314

315+
/// <summary>
316+
/// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
317+
/// Used for schema propagation and verification in a pipeline.
318+
/// </summary>
315319
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
316320
{
317321
// Check that the featureColumn is present.
@@ -341,20 +345,20 @@ internal static class FeatureContributionEntryPoint
341345
[TlcModule.EntryPoint(Name = "Transforms.FeatureContributionCalculationTransformer",
342346
Desc = FeatureContributionCalculatingTransformer.Summary,
343347
UserName = FeatureContributionCalculatingTransformer.FriendlyName)]
344-
public static CommonOutputs.TransformOutput FeatureContributionCalculation(IHostEnvironment env, FeatureContributionCalculatingTransformer.Arguments args)
348+
public static CommonOutputs.TransformOutput FeatureContributionCalculation(IHostEnvironment env, FeatureContributionCalculatingTransformer.Options options)
345349
{
346350
Contracts.CheckValue(env, nameof(env));
347351
var host = env.Register(nameof(FeatureContributionCalculatingTransformer));
348-
host.CheckValue(args, nameof(args));
349-
EntryPointUtils.CheckInputArgs(host, args);
350-
host.CheckValue(args.PredictorModel, nameof(args.PredictorModel));
352+
host.CheckValue(options, nameof(options));
353+
EntryPointUtils.CheckInputArgs(host, options);
354+
host.CheckValue(options.PredictorModel, nameof(options.PredictorModel));
351355

352-
var predictor = args.PredictorModel.Predictor as ICalculateFeatureContribution;
356+
var predictor = options.PredictorModel.Predictor as ICalculateFeatureContribution;
353357
if (predictor == null)
354358
throw host.ExceptUserArg(nameof(predictor), "The provided model parameters do not support feature contribution calculation.");
355-
var outData = new FeatureContributionCalculatingTransformer(host, predictor, args.FeatureColumn, args.Top, args.Bottom, args.Normalize).Transform(args.Data);
359+
var outData = new FeatureContributionCalculatingTransformer(host, predictor, options.FeatureColumn, options.Top, options.Bottom, options.Normalize).Transform(options.Data);
356360

357-
return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, outData, args.Data), OutputData = outData};
361+
return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, outData, options.Data), OutputData = outData};
358362
}
359363
}
360364
}

src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs renamed to src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs

+11-9
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
using Microsoft.ML.Model;
1717
using Microsoft.ML.Transforms.FeatureSelection;
1818

19-
[assembly: LoadableClass(SlotsDroppingTransformer.Summary, typeof(IDataTransform), typeof(SlotsDroppingTransformer), typeof(SlotsDroppingTransformer.Arguments), typeof(SignatureDataTransform),
19+
[assembly: LoadableClass(SlotsDroppingTransformer.Summary, typeof(IDataTransform), typeof(SlotsDroppingTransformer), typeof(SlotsDroppingTransformer.Options), typeof(SignatureDataTransform),
2020
SlotsDroppingTransformer.FriendlyName, SlotsDroppingTransformer.LoaderSignature, "DropSlots")]
2121

2222
[assembly: LoadableClass(SlotsDroppingTransformer.Summary, typeof(IDataTransform), typeof(SlotsDroppingTransformer), null, typeof(SignatureLoadDataTransform),
@@ -37,14 +37,15 @@ namespace Microsoft.ML.Transforms.FeatureSelection
3737
/// </summary>
3838
public sealed class SlotsDroppingTransformer : OneToOneTransformerBase
3939
{
40-
public sealed class Arguments
40+
internal sealed class Options
4141
{
4242
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Columns to drop the slots for",
4343
Name = "Column", ShortName = "col", SortOrder = 1)]
4444
public Column[] Columns;
4545
}
4646

47-
public sealed class Column : OneToOneColumn
47+
[BestFriend]
48+
internal sealed class Column : OneToOneColumn
4849
{
4950
[Argument(ArgumentType.Multiple, HelpText = "Source slot index range(s) of the column to drop")]
5051
public Range[] Slots;
@@ -112,7 +113,7 @@ internal bool TryUnparse(StringBuilder sb)
112113
}
113114
}
114115

115-
public sealed class Range
116+
internal sealed class Range
116117
{
117118
[Argument(ArgumentType.Required, HelpText = "First index in the range")]
118119
public int Min;
@@ -191,7 +192,8 @@ public bool IsValid()
191192
/// <summary>
192193
/// Describes how the transformer handles one input-output column pair.
193194
/// </summary>
194-
public sealed class ColumnInfo
195+
[BestFriend]
196+
internal sealed class ColumnInfo
195197
{
196198
public readonly string Name;
197199
public readonly string InputColumnName;
@@ -258,7 +260,7 @@ private static VersionInfo GetVersionInfo()
258260
/// <param name="inputColumnName">Name of column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
259261
/// <param name="min">Specifies the lower bound of the range of slots to be dropped. The lower bound is inclusive. </param>
260262
/// <param name="max">Specifies the upper bound of the range of slots to be dropped. The upper bound is exclusive.</param>
261-
public SlotsDroppingTransformer(IHostEnvironment env, string outputColumnName, string inputColumnName = null, int min = default, int? max = null)
263+
internal SlotsDroppingTransformer(IHostEnvironment env, string outputColumnName, string inputColumnName = null, int min = default, int? max = null)
262264
: this(env, new ColumnInfo(outputColumnName, inputColumnName, (min, max)))
263265
{
264266
}
@@ -268,7 +270,7 @@ public SlotsDroppingTransformer(IHostEnvironment env, string outputColumnName, s
268270
/// </summary>
269271
/// <param name="env">The environment to use.</param>
270272
/// <param name="columns">Specifies the ranges of slots to drop for each column pair.</param>
271-
public SlotsDroppingTransformer(IHostEnvironment env, params ColumnInfo[] columns)
273+
internal SlotsDroppingTransformer(IHostEnvironment env, params ColumnInfo[] columns)
272274
: base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns))
273275
{
274276
Host.AssertNonEmpty(ColumnPairs);
@@ -308,9 +310,9 @@ private static SlotsDroppingTransformer Create(IHostEnvironment env, ModelLoadCo
308310
}
309311

310312
// Factory method for SignatureDataTransform.
311-
private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
313+
private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
312314
{
313-
var columns = args.Columns.Select(column => new ColumnInfo(column)).ToArray();
315+
var columns = options.Columns.Select(column => new ColumnInfo(column)).ToArray();
314316
return new SlotsDroppingTransformer(env, columns).MakeDataTransform(input);
315317
}
316318

0 commit comments

Comments
 (0)