Skip to content

Commit bfac765

Browse files
committed
Added convenience constructor for set of transforms (dotnet#371).
1 parent ecc6857 commit bfac765

10 files changed

+303
-2
lines changed

src/Microsoft.ML.Data/Transforms/ConcatTransform.cs

+22
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,28 @@ private static VersionInfo GetVersionInfo()
527527

528528
public override ISchema Schema => _bindings;
529529

530+
public ConcatTransform(IHostEnvironment env, IDataView input, string outputColumn, params string[] inputColumns)
531+
: base(env, RegistrationName, input)
532+
{
533+
var cols = new Column[1];
534+
cols[0] = new Column()
535+
{
536+
Name = outputColumn,
537+
Source = inputColumns
538+
};
539+
540+
var args = new Arguments()
541+
{
542+
Column = cols
543+
};
544+
Host.CheckValue(args, nameof(args));
545+
Host.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column));
546+
for (int i = 0; i < args.Column.Length; i++)
547+
Host.CheckUserArg(Utils.Size(args.Column[i].Source) > 0, nameof(args.Column));
548+
549+
_bindings = new Bindings(args.Column, null, Source.Schema);
550+
}
551+
530552
/// <summary>
531553
/// Public constructor corresponding to SignatureDataTransform.
532554
/// </summary>

src/Microsoft.ML.Data/Transforms/CopyColumnsTransform.cs

+26
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,32 @@ private static VersionInfo GetVersionInfo()
6464

6565
private const string RegistrationName = "CopyColumns";
6666

67+
public static CopyColumnsTransform Create(IHostEnvironment env, IDataView input, params string[] inputColumns)
68+
{
69+
var inputOutputColumns = new(string inputColumn, string outputColumn)[inputColumns.Length];
70+
for (int i = 0; i < inputColumns.Length; i++)
71+
{
72+
inputOutputColumns[i].inputColumn = inputOutputColumns[i].outputColumn = inputColumns[i];
73+
}
74+
return Create(env, input, inputOutputColumns);
75+
}
76+
77+
public static CopyColumnsTransform Create(IHostEnvironment env, IDataView input, params (string inputColumn, string outputColumn)[] inputOutputColumns)
78+
{
79+
Column[] cols = new Column[inputOutputColumns.Length];
80+
for (int i = 0; i < inputOutputColumns.Length; i++)
81+
{
82+
cols[i] = new Column();
83+
cols[i].Source = inputOutputColumns[i].inputColumn;
84+
cols[i].Name = inputOutputColumns[i].outputColumn;
85+
}
86+
var args = new Arguments()
87+
{
88+
Column = cols
89+
};
90+
return new CopyColumnsTransform(env,args,input);
91+
}
92+
6793
public CopyColumnsTransform(IHostEnvironment env, Arguments args, IDataView input)
6894
: base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column, input, null)
6995
{

src/Microsoft.ML.Data/Transforms/DropColumnsTransform.cs

+10
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,16 @@ private static VersionInfo GetVersionInfo()
237237
private const string DropRegistrationName = "DropColumns";
238238
private const string KeepRegistrationName = "KeepColumns";
239239

240+
public DropColumnsTransform CreateColumnDroper(IHostEnvironment env, IDataView input, params string[] columnsToDrop)
241+
{
242+
return new DropColumnsTransform(env, new Arguments() { Column = columnsToDrop }, input);
243+
}
244+
245+
public static DropColumnsTransform CreateColumnSelector(IHostEnvironment env, IDataView input, params string[] columnsToKeep)
246+
{
247+
return new DropColumnsTransform(env, new KeepArguments() { Column = columnsToKeep }, input);
248+
}
249+
240250
/// <summary>
241251
/// Public constructor corresponding to SignatureDataTransform.
242252
/// </summary>

src/Microsoft.ML.Data/Transforms/NAFilter.cs

+6
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ private static VersionInfo GetVersionInfo()
7272
private readonly bool _complement;
7373
private const string RegistrationName = "MissingValueFilter";
7474

75+
public NAFilter(IHostEnvironment env, IDataView input, params string[] inputColumns)
76+
: this(env, new Arguments() { Column = inputColumns}, input)
77+
{
78+
79+
}
80+
7581
public NAFilter(IHostEnvironment env, Arguments args, IDataView input)
7682
: base(env, RegistrationName, input)
7783
{

src/Microsoft.ML.Transforms/BootstrapSampleTransform.cs

+6
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ public BootstrapSampleTransform(IHostEnvironment env, Arguments args, IDataView
7676
_poolSize = args.PoolSize;
7777
}
7878

79+
public BootstrapSampleTransform(IHostEnvironment env, IDataView input, bool complement = false, uint? seed = null, bool shuffleInput = true, int poolSize = 1000)
80+
: this(env, new Arguments() { Complement = complement, Seed = seed, ShuffleInput = shuffleInput, PoolSize = poolSize }, input)
81+
{
82+
83+
}
84+
7985
private BootstrapSampleTransform(IHost host, ModelLoadContext ctx, IDataView input)
8086
: base(host, input)
8187
{

src/Microsoft.ML.Transforms/CategoricalHashTransform.cs

+26
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,32 @@ public sealed class Arguments : TransformInputBase
120120

121121
public const string UserName = "Categorical Hash Transform";
122122

123+
public static IDataTransform Create(IHostEnvironment env, IDataView input, params string[] inputColumns)
124+
{
125+
var inputOutputColumns = new(string inputColumn, string outputColumn)[inputColumns.Length];
126+
for (int i = 0; i < inputColumns.Length; i++)
127+
{
128+
inputOutputColumns[i].inputColumn = inputOutputColumns[i].outputColumn = inputColumns[i];
129+
}
130+
return Create(env, input, inputOutputColumns);
131+
}
132+
133+
public static IDataTransform Create(IHostEnvironment env, IDataView input, params (string inputColumn, string outputColumn)[] inputOutputColumns)
134+
{
135+
Column[] cols = new Column[inputOutputColumns.Length];
136+
for (int i = 0; i < inputOutputColumns.Length; i++)
137+
{
138+
cols[i] = new Column();
139+
cols[i].Source = inputOutputColumns[i].inputColumn;
140+
cols[i].Name = inputOutputColumns[i].outputColumn;
141+
}
142+
var args = new Arguments()
143+
{
144+
Column = cols
145+
};
146+
return Create(env, args, input);
147+
}
148+
123149
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
124150
{
125151
Contracts.CheckValue(env, nameof(env));

src/Microsoft.ML.Transforms/CategoricalTransform.cs

+26
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,32 @@ public Arguments()
118118

119119
public const string UserName = "Categorical Transform";
120120

121+
public static IDataTransform Create(IHostEnvironment env, IDataView input, params string[] inputColumns)
122+
{
123+
var inputOutputColumns = new (string inputColumn, string outputColumn)[inputColumns.Length];
124+
for (int i = 0; i < inputColumns.Length; i++)
125+
{
126+
inputOutputColumns[i].inputColumn = inputOutputColumns[i].outputColumn = inputColumns[i];
127+
}
128+
return Create(env, input, inputOutputColumns);
129+
}
130+
131+
public static IDataTransform Create(IHostEnvironment env, IDataView input, params (string inputColumn, string outputColumn)[] inputOutputColumns)
132+
{
133+
Column[] cols = new Column[inputOutputColumns.Length];
134+
for (int i = 0; i < inputOutputColumns.Length; i++)
135+
{
136+
cols[i] = new Column();
137+
cols[i].Source = inputOutputColumns[i].inputColumn;
138+
cols[i].Name = inputOutputColumns[i].outputColumn;
139+
}
140+
var args = new Arguments()
141+
{
142+
Column = cols
143+
};
144+
return Create(env, args, input);
145+
}
146+
121147
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
122148
{
123149
Contracts.CheckValue(env, nameof(env));

src/Microsoft.ML.Transforms/CountFeatureSelection.cs

+10
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@ public sealed class Arguments : TransformInputBase
3939

4040
internal static string RegistrationName = "CountFeatureSelectionTransform";
4141

42+
public static IDataTransform Create(IHostEnvironment env, IDataView input, long count = 1, params string[] columns)
43+
{
44+
var args = new Arguments()
45+
{
46+
Column = columns,
47+
Count = count
48+
};
49+
return Create(env, args, input);
50+
}
51+
4252
/// <summary>
4353
/// Create method corresponding to SignatureDataTransform.
4454
/// </summary>

src/Microsoft.ML.Transforms/GcnTransform.cs

+64-2
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,37 @@ private static VersionInfo GetVersionInfo()
237237

238238
private readonly ColInfoEx[] _exes;
239239

240+
public static IDataTransform CreateGlobalContrastNormalizer(IHostEnvironment env, IDataView input, params string[] inputColumns)
241+
{
242+
var inputOutputColumns = new(string inputColumn, string outputColumn)[inputColumns.Length];
243+
for (int i = 0; i < inputColumns.Length; i++)
244+
{
245+
inputOutputColumns[i].inputColumn = inputOutputColumns[i].outputColumn = inputColumns[i];
246+
}
247+
return CreateGlobalContrastNormalizer(env, input, inputOutputColumns);
248+
}
249+
250+
public static IDataTransform CreateGlobalContrastNormalizer(IHostEnvironment env, IDataView input, params (string inputColumn, string outputColumn)[] inputOutputColumns)
251+
{
252+
GcnColumn[] cols = new GcnColumn[inputOutputColumns.Length];
253+
for (int i = 0; i < inputOutputColumns.Length; i++)
254+
{
255+
cols[i] = new GcnColumn();
256+
cols[i].Source = inputOutputColumns[i].inputColumn;
257+
cols[i].Name = inputOutputColumns[i].outputColumn;
258+
}
259+
var args = new GcnArguments()
260+
{
261+
Column = cols
262+
};
263+
return new LpNormNormalizerTransform(env, args, input);
264+
}
265+
266+
public static IDataTransform CreateGlobalContrastNormalizer(IHostEnvironment env, IDataView input, GcnArguments args)
267+
{
268+
return new LpNormNormalizerTransform(env, args, input);
269+
}
270+
240271
/// <summary>
241272
/// Public constructor corresponding to SignatureDataTransform.
242273
/// </summary>
@@ -263,9 +294,40 @@ public LpNormNormalizerTransform(IHostEnvironment env, GcnArguments args, IDataV
263294
SetMetadata();
264295
}
265296

297+
public static IDataTransform CreateLpNormNormalizer(IHostEnvironment env, IDataView input, params string[] inputColumns)
298+
{
299+
var inputOutputColumns = new(string inputColumn, string outputColumn)[inputColumns.Length];
300+
for (int i = 0; i < inputColumns.Length; i++)
301+
{
302+
inputOutputColumns[i].inputColumn = inputOutputColumns[i].outputColumn = inputColumns[i];
303+
}
304+
return CreateLpNormNormalizer(env, input, inputOutputColumns);
305+
}
306+
307+
public static IDataTransform CreateLpNormNormalizer(IHostEnvironment env, IDataView input, params (string inputColumn, string outputColumn)[] inputOutputColumns)
308+
{
309+
Column[] cols = new Column[inputOutputColumns.Length];
310+
for (int i = 0; i < inputOutputColumns.Length; i++)
311+
{
312+
cols[i] = new Column();
313+
cols[i].Source = inputOutputColumns[i].inputColumn;
314+
cols[i].Name = inputOutputColumns[i].outputColumn;
315+
}
316+
var args = new Arguments()
317+
{
318+
Column = cols
319+
};
320+
return new LpNormNormalizerTransform(env, args, input);
321+
}
322+
323+
public static IDataTransform CreateLpNormNormalizer(IHostEnvironment env, IDataView input, Arguments args)
324+
{
325+
return new LpNormNormalizerTransform(env, args, input);
326+
}
327+
266328
public LpNormNormalizerTransform(IHostEnvironment env, Arguments args, IDataView input)
267-
: base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column,
268-
input, TestIsFloatVector)
329+
: base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column,
330+
input, TestIsFloatVector)
269331
{
270332
Host.AssertNonEmpty(Infos);
271333
Host.Assert(Infos.Length == Utils.Size(args.Column));

0 commit comments

Comments
 (0)