Skip to content

Commit b2851e5

Browse files
authored
Creation of components through MLContext and cleanup (Convert, DropSlots, FeatureSelection) (#2365)
1 parent e192a18 commit b2851e5

File tree

17 files changed

+236
-177
lines changed

17 files changed

+236
-177
lines changed

src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs

+3-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ public static class ConversionsExtensionsCatalog
2222
/// </summary>
2323
/// <param name="catalog">The transform's catalog.</param>
2424
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
25-
/// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param> /// <param name="hashBits">Number of bits to hash into. Must be between 1 and 31, inclusive.</param>
25+
/// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
26+
/// <param name="hashBits">Number of bits to hash into. Must be between 1 and 31, inclusive.</param>
2627
/// <param name="invertHash">During hashing we constuct mappings between original values and the produced hash values.
2728
/// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one.
2829
/// <paramref name="invertHash"/> specifies the upper bound of the number of distinct input values mapping to a hash that should be retained.
@@ -55,7 +56,7 @@ public static TypeConvertingEstimator ConvertType(this TransformsCatalog.Convers
5556
/// </summary>
5657
/// <param name="catalog">The transform's catalog.</param>
5758
/// <param name="columns">Description of dataset columns and how to process them.</param>
58-
public static TypeConvertingEstimator ConvertType(this TransformsCatalog.ConversionTransforms catalog, params TypeConvertingTransformer.ColumnInfo[] columns)
59+
public static TypeConvertingEstimator ConvertType(this TransformsCatalog.ConversionTransforms catalog, params TypeConvertingEstimator.ColumnInfo[] columns)
5960
=> new TypeConvertingEstimator(CatalogUtils.GetEnvironment(catalog), columns);
6061

6162
/// <summary>

src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransform.cs renamed to src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransformer.cs

+18-14
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ namespace Microsoft.ML.Data
7676
/// </example>
7777
public sealed class FeatureContributionCalculatingTransformer : OneToOneTransformerBase
7878
{
79-
public sealed class Arguments : TransformInputBase
79+
internal sealed class Options : TransformInputBase
8080
{
8181
[Argument(ArgumentType.Required, HelpText = "The predictor model to apply to data", SortOrder = 1)]
8282
public PredictorModel PredictorModel;
@@ -99,9 +99,9 @@ public sealed class Arguments : TransformInputBase
9999
internal const string FriendlyName = "Feature Contribution Calculation";
100100
internal const string LoaderSignature = "FeatureContribution";
101101

102-
public readonly int Top;
103-
public readonly int Bottom;
104-
public readonly bool Normalize;
102+
internal readonly int Top;
103+
internal readonly int Bottom;
104+
internal readonly bool Normalize;
105105

106106
private readonly IFeatureContributionMapper _predictor;
107107

@@ -128,7 +128,7 @@ private static VersionInfo GetVersionInfo()
128128
/// <param name="numNegativeContributions">The number of negative contributions to report, sorted from highest magnitude to lowest magnitude.
129129
/// Note that if there are fewer features with negative contributions than <paramref name="numNegativeContributions"/>, the rest will be returned as zeros.</param>
130130
/// <param name="normalize">Whether the feature contributions should be normalized to the [-1, 1] interval.</param>
131-
public FeatureContributionCalculatingTransformer(IHostEnvironment env, ICalculateFeatureContribution modelParameters,
131+
internal FeatureContributionCalculatingTransformer(IHostEnvironment env, ICalculateFeatureContribution modelParameters,
132132
string featureColumn = DefaultColumnNames.Features,
133133
int numPositiveContributions = FeatureContributionCalculatingEstimator.Defaults.NumPositiveContributions,
134134
int numNegativeContributions = FeatureContributionCalculatingEstimator.Defaults.NumNegativeContributions,
@@ -281,7 +281,7 @@ public sealed class FeatureContributionCalculatingEstimator : TrivialEstimator<F
281281
private readonly string _featureColumn;
282282
private readonly ICalculateFeatureContribution _predictor;
283283

284-
public static class Defaults
284+
internal static class Defaults
285285
{
286286
public const int NumPositiveContributions = 10;
287287
public const int NumNegativeContributions = 10;
@@ -300,7 +300,7 @@ public static class Defaults
300300
/// <param name="numNegativeContributions">The number of negative contributions to report, sorted from highest magnitude to lowest magnitude.
301301
/// Note that if there are fewer features with negative contributions than <paramref name="numNegativeContributions"/>, the rest will be returned as zeros.</param>
302302
/// <param name="normalize">Whether the feature contributions should be normalized to the [-1, 1] interval.</param>
303-
public FeatureContributionCalculatingEstimator(IHostEnvironment env, ICalculateFeatureContribution modelParameters,
303+
internal FeatureContributionCalculatingEstimator(IHostEnvironment env, ICalculateFeatureContribution modelParameters,
304304
string featureColumn = DefaultColumnNames.Features,
305305
int numPositiveContributions = Defaults.NumPositiveContributions,
306306
int numNegativeContributions = Defaults.NumNegativeContributions,
@@ -312,6 +312,10 @@ public FeatureContributionCalculatingEstimator(IHostEnvironment env, ICalculateF
312312
_predictor = modelParameters;
313313
}
314314

315+
/// <summary>
316+
/// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
317+
/// Used for schema propagation and verification in a pipeline.
318+
/// </summary>
315319
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
316320
{
317321
// Check that the featureColumn is present.
@@ -341,20 +345,20 @@ internal static class FeatureContributionEntryPoint
341345
[TlcModule.EntryPoint(Name = "Transforms.FeatureContributionCalculationTransformer",
342346
Desc = FeatureContributionCalculatingTransformer.Summary,
343347
UserName = FeatureContributionCalculatingTransformer.FriendlyName)]
344-
public static CommonOutputs.TransformOutput FeatureContributionCalculation(IHostEnvironment env, FeatureContributionCalculatingTransformer.Arguments args)
348+
public static CommonOutputs.TransformOutput FeatureContributionCalculation(IHostEnvironment env, FeatureContributionCalculatingTransformer.Options options)
345349
{
346350
Contracts.CheckValue(env, nameof(env));
347351
var host = env.Register(nameof(FeatureContributionCalculatingTransformer));
348-
host.CheckValue(args, nameof(args));
349-
EntryPointUtils.CheckInputArgs(host, args);
350-
host.CheckValue(args.PredictorModel, nameof(args.PredictorModel));
352+
host.CheckValue(options, nameof(options));
353+
EntryPointUtils.CheckInputArgs(host, options);
354+
host.CheckValue(options.PredictorModel, nameof(options.PredictorModel));
351355

352-
var predictor = args.PredictorModel.Predictor as ICalculateFeatureContribution;
356+
var predictor = options.PredictorModel.Predictor as ICalculateFeatureContribution;
353357
if (predictor == null)
354358
throw host.ExceptUserArg(nameof(predictor), "The provided model parameters do not support feature contribution calculation.");
355-
var outData = new FeatureContributionCalculatingTransformer(host, predictor, args.FeatureColumn, args.Top, args.Bottom, args.Normalize).Transform(args.Data);
359+
var outData = new FeatureContributionCalculatingTransformer(host, predictor, options.FeatureColumn, options.Top, options.Bottom, options.Normalize).Transform(options.Data);
356360

357-
return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, outData, args.Data), OutputData = outData};
361+
return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, outData, options.Data), OutputData = outData};
358362
}
359363
}
360364
}

src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs renamed to src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs

+11-9
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
using Microsoft.ML.Model;
1717
using Microsoft.ML.Transforms.FeatureSelection;
1818

19-
[assembly: LoadableClass(SlotsDroppingTransformer.Summary, typeof(IDataTransform), typeof(SlotsDroppingTransformer), typeof(SlotsDroppingTransformer.Arguments), typeof(SignatureDataTransform),
19+
[assembly: LoadableClass(SlotsDroppingTransformer.Summary, typeof(IDataTransform), typeof(SlotsDroppingTransformer), typeof(SlotsDroppingTransformer.Options), typeof(SignatureDataTransform),
2020
SlotsDroppingTransformer.FriendlyName, SlotsDroppingTransformer.LoaderSignature, "DropSlots")]
2121

2222
[assembly: LoadableClass(SlotsDroppingTransformer.Summary, typeof(IDataTransform), typeof(SlotsDroppingTransformer), null, typeof(SignatureLoadDataTransform),
@@ -37,14 +37,15 @@ namespace Microsoft.ML.Transforms.FeatureSelection
3737
/// </summary>
3838
public sealed class SlotsDroppingTransformer : OneToOneTransformerBase
3939
{
40-
public sealed class Arguments
40+
internal sealed class Options
4141
{
4242
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Columns to drop the slots for",
4343
Name = "Column", ShortName = "col", SortOrder = 1)]
4444
public Column[] Columns;
4545
}
4646

47-
public sealed class Column : OneToOneColumn
47+
[BestFriend]
48+
internal sealed class Column : OneToOneColumn
4849
{
4950
[Argument(ArgumentType.Multiple, HelpText = "Source slot index range(s) of the column to drop")]
5051
public Range[] Slots;
@@ -112,7 +113,7 @@ internal bool TryUnparse(StringBuilder sb)
112113
}
113114
}
114115

115-
public sealed class Range
116+
internal sealed class Range
116117
{
117118
[Argument(ArgumentType.Required, HelpText = "First index in the range")]
118119
public int Min;
@@ -191,7 +192,8 @@ public bool IsValid()
191192
/// <summary>
192193
/// Describes how the transformer handles one input-output column pair.
193194
/// </summary>
194-
public sealed class ColumnInfo
195+
[BestFriend]
196+
internal sealed class ColumnInfo
195197
{
196198
public readonly string Name;
197199
public readonly string InputColumnName;
@@ -258,7 +260,7 @@ private static VersionInfo GetVersionInfo()
258260
/// <param name="inputColumnName">Name of column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
259261
/// <param name="min">Specifies the lower bound of the range of slots to be dropped. The lower bound is inclusive. </param>
260262
/// <param name="max">Specifies the upper bound of the range of slots to be dropped. The upper bound is exclusive.</param>
261-
public SlotsDroppingTransformer(IHostEnvironment env, string outputColumnName, string inputColumnName = null, int min = default, int? max = null)
263+
internal SlotsDroppingTransformer(IHostEnvironment env, string outputColumnName, string inputColumnName = null, int min = default, int? max = null)
262264
: this(env, new ColumnInfo(outputColumnName, inputColumnName, (min, max)))
263265
{
264266
}
@@ -268,7 +270,7 @@ public SlotsDroppingTransformer(IHostEnvironment env, string outputColumnName, s
268270
/// </summary>
269271
/// <param name="env">The environment to use.</param>
270272
/// <param name="columns">Specifies the ranges of slots to drop for each column pair.</param>
271-
public SlotsDroppingTransformer(IHostEnvironment env, params ColumnInfo[] columns)
273+
internal SlotsDroppingTransformer(IHostEnvironment env, params ColumnInfo[] columns)
272274
: base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), GetColumnPairs(columns))
273275
{
274276
Host.AssertNonEmpty(ColumnPairs);
@@ -308,9 +310,9 @@ private static SlotsDroppingTransformer Create(IHostEnvironment env, ModelLoadCo
308310
}
309311

310312
// Factory method for SignatureDataTransform.
311-
private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
313+
private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
312314
{
313-
var columns = args.Columns.Select(column => new ColumnInfo(column)).ToArray();
315+
var columns = options.Columns.Select(column => new ColumnInfo(column)).ToArray();
314316
return new SlotsDroppingTransformer(env, columns).MakeDataTransform(input);
315317
}
316318

0 commit comments

Comments
 (0)