Skip to content

Commit 555b8e2

Browse files
zeahmedeerhardt
authored andcommitted
Added convenience constructor for set of transforms. (dotnet#405)
* Added convenience constructor for set of transforms (dotnet#371). * Removed useless validation from Concate transform. * Added more parameters to some transforms. * Addressed reviewers' comments. * XML Comments added to constructors/helper methods. * Created private static class for managing default values. * Addressed reviewers' comments. * Resolved some formatting issues.
1 parent 3343e94 commit 555b8e2

10 files changed

+421
-23
lines changed

src/Microsoft.ML.Data/Transforms/ConcatTransform.cs

+25
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,19 @@ public bool TryUnparse(StringBuilder sb)
9090

9191
public sealed class Arguments : TransformInputBase
9292
{
93+
public Arguments()
94+
{
95+
}
96+
97+
public Arguments(string name, params string[] source)
98+
{
99+
Column = new[] { new Column()
100+
{
101+
Name = name,
102+
Source = source
103+
}};
104+
}
105+
93106
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:srcs)", ShortName = "col", SortOrder = 1)]
94107
public Column[] Column;
95108
}
@@ -527,6 +540,18 @@ private static VersionInfo GetVersionInfo()
527540

528541
public override ISchema Schema => _bindings;
529542

543+
/// <summary>
544+
/// Convenience constructor for public facing API.
545+
/// </summary>
546+
/// <param name="env">Host Environment.</param>
547+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
548+
/// <param name="name">Name of the output column.</param>
549+
/// <param name="source">Input columns to concatenate.</param>
550+
public ConcatTransform(IHostEnvironment env, IDataView input, string name, params string[] source)
551+
: this(env, new Arguments(name, source), input)
552+
{
553+
}
554+
530555
/// <summary>
531556
/// Public constructor corresponding to SignatureDataTransform.
532557
/// </summary>

src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs

+12
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,18 @@ private static VersionInfo GetVersionInfo()
6464

6565
private const string RegistrationName = "CopyColumns";
6666

67+
/// <summary>
68+
/// Convenience constructor for public facing API.
69+
/// </summary>
70+
/// <param name="env">Host Environment.</param>
71+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
72+
/// <param name="name">Name of the output column.</param>
73+
/// <param name="source">Name of the column to be copied.</param>
74+
public CopyColumnsTransform(IHostEnvironment env, IDataView input, string name, string source)
75+
: this(env, new Arguments(){ Column = new[] { new Column() { Source = source, Name = name }}}, input)
76+
{
77+
}
78+
6779
public CopyColumnsTransform(IHostEnvironment env, Arguments args, IDataView input)
6880
: base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, null)
6981
{

src/Microsoft.ML.Data/Transforms/DropColumnsTransform.cs

+24
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,17 @@ private static VersionInfo GetVersionInfo()
237237
private const string DropRegistrationName = "DropColumns";
238238
private const string KeepRegistrationName = "KeepColumns";
239239

240+
/// <summary>
241+
/// Convenience constructor for public facing API.
242+
/// </summary>
243+
/// <param name="env">Host Environment.</param>
244+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
245+
/// <param name="columnsToDrop">Name of the columns to be dropped.</param>
246+
public DropColumnsTransform(IHostEnvironment env, IDataView input, params string[] columnsToDrop)
247+
:this(env, new Arguments() { Column = columnsToDrop }, input)
248+
{
249+
}
250+
240251
/// <summary>
241252
/// Public constructor corresponding to SignatureDataTransform.
242253
/// </summary>
@@ -383,4 +394,17 @@ public ValueGetter<TValue> GetGetter<TValue>(int col)
383394
}
384395
}
385396
}
397+
398+
public class KeepColumnsTransform
399+
{
400+
/// <summary>
401+
/// A helper method to create <see cref="KeepColumnsTransform"/> for public facing API.
402+
/// </summary>
403+
/// <param name="env">Host Environment.</param>
404+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
405+
/// <param name="columnsToKeep">Name of the columns to be kept. All other columns will be removed.</param>
406+
/// <returns></returns>
407+
public static IDataTransform Create(IHostEnvironment env, IDataView input, params string[] columnsToKeep)
408+
=> new DropColumnsTransform(env, new DropColumnsTransform.KeepArguments() { Column = columnsToKeep }, input);
409+
}
386410
}

src/Microsoft.ML.Data/Transforms/NAFilter.cs

+18-1
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,18 @@ namespace Microsoft.ML.Runtime.Data
2828
{
2929
public sealed class NAFilter : FilterBase
3030
{
31+
private static class Defaults
32+
{
33+
public const bool Complement = false;
34+
}
35+
3136
public sealed class Arguments : TransformInputBase
3237
{
3338
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Column", ShortName = "col", SortOrder = 1)]
3439
public string[] Column;
3540

3641
[Argument(ArgumentType.Multiple, HelpText = "If true, keep only rows that contain NA values, and filter the rest.")]
37-
public bool Complement;
42+
public bool Complement = Defaults.Complement;
3843
}
3944

4045
private sealed class ColInfo
@@ -72,6 +77,18 @@ private static VersionInfo GetVersionInfo()
7277
private readonly bool _complement;
7378
private const string RegistrationName = "MissingValueFilter";
7479

80+
/// <summary>
81+
/// Convenience constructor for public facing API.
82+
/// </summary>
83+
/// <param name="env">Host Environment.</param>
84+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
85+
/// <param name="complement">If true, keep only rows that contain NA values, and filter the rest.</param>
86+
/// <param name="columns">Name of the columns. Only these columns will be used to filter rows having 'NA' values.</param>
87+
public NAFilter(IHostEnvironment env, IDataView input, bool complement = Defaults.Complement, params string[] columns)
88+
: this(env, new Arguments() { Column = columns, Complement = complement }, input)
89+
{
90+
}
91+
7592
public NAFilter(IHostEnvironment env, Arguments args, IDataView input)
7693
: base(env, RegistrationName, input)
7794
{

src/Microsoft.ML.Transforms/BootstrapSampleTransform.cs

+29-3
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,27 @@ namespace Microsoft.ML.Runtime.Data
2525
/// </summary>
2626
public sealed class BootstrapSampleTransform : FilterBase
2727
{
28+
private static class Defaults
29+
{
30+
public const bool Complement = false;
31+
public const bool ShuffleInput = true;
32+
public const int PoolSize = 1000;
33+
}
34+
2835
public sealed class Arguments : TransformInputBase
2936
{
3037
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether this is the out-of-bag sample, that is, all those rows that are not selected by the transform.",
3138
ShortName = "comp")]
32-
public bool Complement;
39+
public bool Complement = Defaults.Complement;
3340

3441
[Argument(ArgumentType.AtMostOnce, HelpText = "The random seed. If unspecified random state will be instead derived from the environment.")]
3542
public uint? Seed;
3643

3744
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether we should attempt to shuffle the source data. By default on, but can be turned off for efficiency.", ShortName = "si")]
38-
public bool ShuffleInput = true;
45+
public bool ShuffleInput = Defaults.ShuffleInput;
3946

4047
[Argument(ArgumentType.LastOccurenceWins, HelpText = "When shuffling the output, the number of output rows to keep in that pool. Note that shuffling of output is completely distinct from shuffling of input.", ShortName = "pool")]
41-
public int PoolSize = 1000;
48+
public int PoolSize = Defaults.PoolSize;
4249
}
4350

4451
internal const string Summary = "Approximate bootstrap sampling.";
@@ -76,6 +83,25 @@ public BootstrapSampleTransform(IHostEnvironment env, Arguments args, IDataView
7683
_poolSize = args.PoolSize;
7784
}
7885

86+
/// <summary>
87+
/// Convenience constructor for public facing API.
88+
/// </summary>
89+
/// <param name="env">Host Environment.</param>
90+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
91+
/// <param name="complement">Whether this is the out-of-bag sample, that is, all those rows that are not selected by the transform.</param>
92+
/// <param name="seed">The random seed. If unspecified random state will be instead derived from the environment.</param>
93+
/// <param name="shuffleInput">Whether we should attempt to shuffle the source data. By default on, but can be turned off for efficiency.</param>
94+
/// <param name="poolSize">When shuffling the output, the number of output rows to keep in that pool. Note that shuffling of output is completely distinct from shuffling of input.</param>
95+
public BootstrapSampleTransform(IHostEnvironment env,
96+
IDataView input,
97+
bool complement = Defaults.Complement,
98+
uint? seed = null,
99+
bool shuffleInput = Defaults.ShuffleInput,
100+
int poolSize = Defaults.PoolSize)
101+
: this(env, new Arguments() { Complement = complement, Seed = seed, ShuffleInput = shuffleInput, PoolSize = poolSize }, input)
102+
{
103+
}
104+
79105
private BootstrapSampleTransform(IHost host, ModelLoadContext ctx, IDataView input)
80106
: base(host, input)
81107
{

src/Microsoft.ML.Transforms/CategoricalHashTransform.cs

+46-5
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,15 @@ public bool TryUnparse(StringBuilder sb)
8686
}
8787
}
8888

89+
private static class Defaults
90+
{
91+
public const int HashBits = 16;
92+
public const uint Seed = 314489979;
93+
public const bool Ordered = true;
94+
public const int InvertHash = 0;
95+
public const CategoricalTransform.OutputKind OutputKind = CategoricalTransform.OutputKind.Bag;
96+
}
97+
8998
/// <summary>
9099
/// This class is a merger of <see cref="HashTransform.Arguments"/> and <see cref="KeyToVectorTransform.Arguments"/>
91100
/// with join option removed
@@ -97,29 +106,61 @@ public sealed class Arguments : TransformInputBase
97106

98107
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of bits to hash into. Must be between 1 and 30, inclusive.",
99108
ShortName = "bits", SortOrder = 2)]
100-
public int HashBits = 16;
109+
public int HashBits = Defaults.HashBits;
101110

102111
[Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")]
103-
public uint Seed = 314489979;
112+
public uint Seed = Defaults.Seed;
104113

105114
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether the position of each term should be included in the hash", ShortName = "ord")]
106-
public bool Ordered = true;
115+
public bool Ordered = Defaults.Ordered;
107116

108117
[Argument(ArgumentType.AtMostOnce,
109118
HelpText = "Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.",
110119
ShortName = "ih")]
111-
public int InvertHash;
120+
public int InvertHash = Defaults.InvertHash;
112121

113122
[Argument(ArgumentType.AtMostOnce, HelpText = "Output kind: Bag (multi-set vector), Ind (indicator vector), or Key (index)",
114123
ShortName = "kind", SortOrder = 102)]
115-
public CategoricalTransform.OutputKind OutputKind = CategoricalTransform.OutputKind.Bag;
124+
public CategoricalTransform.OutputKind OutputKind = Defaults.OutputKind;
116125
}
117126

118127
internal const string Summary = "Converts the categorical value into an indicator array by hashing the value and using the hash as an index in the "
119128
+ "bag. If the input column is a vector, a single indicator bag is returned for it.";
120129

121130
public const string UserName = "Categorical Hash Transform";
122131

132+
/// <summary>
133+
/// A helper method to create <see cref="CategoricalHashTransform"/> for public facing API.
134+
/// </summary>
135+
/// <param name="env">Host Environment.</param>
136+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
137+
/// <param name="name">Name of the output column.</param>
138+
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
139+
/// <param name="hashBits">Number of bits to hash into. Must be between 1 and 30, inclusive.</param>
140+
/// <param name="invertHash">Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.</param>
141+
/// <param name="outputKind">The type of output expected.</param>
142+
public static IDataTransform Create(IHostEnvironment env,
143+
IDataView input,
144+
string name,
145+
string source =null,
146+
int hashBits = Defaults.HashBits,
147+
int invertHash = Defaults.InvertHash,
148+
CategoricalTransform.OutputKind outputKind = Defaults.OutputKind)
149+
{
150+
var args = new Arguments()
151+
{
152+
Column = new[] { new Column(){
153+
Source = source ?? name,
154+
Name = name
155+
}
156+
},
157+
HashBits = hashBits,
158+
InvertHash = invertHash,
159+
OutputKind = outputKind
160+
};
161+
return Create(env, args, input);
162+
}
163+
123164
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
124165
{
125166
Contracts.CheckValue(env, nameof(env));

src/Microsoft.ML.Transforms/CategoricalTransform.cs

+40-1
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,27 @@ public static class CategoricalTransform
3838
{
3939
public enum OutputKind : byte
4040
{
41+
/// <summary>
42+
/// Output is a bag (multi-set) vector
43+
/// </summary>
4144
[TGUI(Label = "Output is a bag (multi-set) vector")]
4245
Bag = 1,
4346

47+
/// <summary>
48+
/// Output is an indicator vector
49+
/// </summary>
4450
[TGUI(Label = "Output is an indicator vector")]
4551
Ind = 2,
4652

53+
/// <summary>
54+
/// Output is a key value
55+
/// </summary>
4756
[TGUI(Label = "Output is a key value")]
4857
Key = 3,
4958

59+
/// <summary>
60+
/// Output is binary encoded
61+
/// </summary>
5062
[TGUI(Label = "Output is binary encoded")]
5163
Bin = 4,
5264
}
@@ -96,14 +108,19 @@ public bool TryUnparse(StringBuilder sb)
96108
}
97109
}
98110

111+
private static class Defaults
112+
{
113+
public const OutputKind OutKind = OutputKind.Ind;
114+
}
115+
99116
public sealed class Arguments : TermTransform.ArgumentsBase
100117
{
101118
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
102119
public Column[] Column;
103120

104121
[Argument(ArgumentType.AtMostOnce, HelpText = "Output kind: Bag (multi-set vector), Ind (indicator vector), or Key (index)",
105122
ShortName = "kind", SortOrder = 102)]
106-
public OutputKind OutputKind = OutputKind.Ind;
123+
public OutputKind OutputKind = Defaults.OutKind;
107124

108125
public Arguments()
109126
{
@@ -118,6 +135,28 @@ public Arguments()
118135

119136
public const string UserName = "Categorical Transform";
120137

138+
/// <summary>
139+
/// A helper method to create <see cref="CategoricalTransform"/> for public facing API.
140+
/// </summary>
141+
/// <param name="env">Host Environment.</param>
142+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
143+
/// <param name="name">Name of the output column.</param>
144+
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
145+
/// <param name="outputKind">The type of output expected.</param>
146+
public static IDataTransform Create(IHostEnvironment env, IDataView input, string name, string source = null, OutputKind outputKind = Defaults.OutKind)
147+
{
148+
var args = new Arguments()
149+
{
150+
Column = new[] { new Column(){
151+
Source = source ?? name,
152+
Name = name
153+
}
154+
},
155+
OutputKind = outputKind
156+
};
157+
return Create(env, args, input);
158+
}
159+
121160
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
122161
{
123162
Contracts.CheckValue(env, nameof(env));

src/Microsoft.ML.Transforms/CountFeatureSelection.cs

+24-1
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,40 @@ public static class CountFeatureSelectionTransform
2828
public const string Summary = "Selects the slots for which the count of non-default values is greater than or equal to a threshold.";
2929
public const string UserName = "Count Feature Selection Transform";
3030

31+
private static class Defaults
32+
{
33+
public const long Count = 1;
34+
}
35+
3136
public sealed class Arguments : TransformInputBase
3237
{
3338
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Columns to use for feature selection", ShortName = "col", SortOrder = 1)]
3439
public string[] Column;
3540

3641
[Argument(ArgumentType.Required, HelpText = "If the count of non-default values for a slot is greater than or equal to this threshold, the slot is preserved", ShortName = "c", SortOrder = 1)]
37-
public long Count = 1;
42+
public long Count = Defaults.Count;
3843
}
3944

4045
internal static string RegistrationName = "CountFeatureSelectionTransform";
4146

47+
/// <summary>
48+
/// A helper method to create CountFeatureSelection transform for public facing API.
49+
/// </summary>
50+
/// <param name="env">Host Environment.</param>
51+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
52+
/// <param name="count">If the count of non-default values for a slot is greater than or equal to this threshold, the slot is preserved.</param>
53+
/// <param name="columns">Columns to use for feature selection.</param>
54+
/// <returns></returns>
55+
public static IDataTransform Create(IHostEnvironment env, IDataView input, long count = Defaults.Count, params string[] columns)
56+
{
57+
var args = new Arguments()
58+
{
59+
Column = columns,
60+
Count = count
61+
};
62+
return Create(env, args, input);
63+
}
64+
4265
/// <summary>
4366
/// Create method corresponding to SignatureDataTransform.
4467
/// </summary>

0 commit comments

Comments
 (0)