Skip to content

Commit 617f7f6

Browse files
authored
Creation of components through MLContext and cleanup (OneHotHash, Hash, CopyCol, KeyToVector) (#2364)
1 parent 140783b commit 617f7f6

File tree

19 files changed

+342
-267
lines changed

19 files changed

+342
-267
lines changed

src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public static CommonOutputs.TransformOutput SelectColumns(IHostEnvironment env,
3838
}
3939

4040
[TlcModule.EntryPoint(Name = "Transforms.ColumnCopier", Desc = "Duplicates columns from the dataset", UserName = ColumnCopyingTransformer.UserName, ShortName = ColumnCopyingTransformer.ShortName)]
41-
public static CommonOutputs.TransformOutput CopyColumns(IHostEnvironment env, ColumnCopyingTransformer.Arguments input)
41+
public static CommonOutputs.TransformOutput CopyColumns(IHostEnvironment env, ColumnCopyingTransformer.Options input)
4242
{
4343
Contracts.CheckValue(env, nameof(env));
4444
var host = env.Register("CopyColumns");

src/Microsoft.ML.Data/Properties/AssemblyInfo.cs

+4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TensorFlow" + PublicKey.Value)]
4040
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries" + PublicKey.Value)]
4141
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Transforms" + PublicKey.Value)]
42+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.DnnImageFeaturizer.AlexNet" + PublicKey.Value)]
43+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.DnnImageFeaturizer.ResNet101" + PublicKey.Value)]
44+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.DnnImageFeaturizer.ResNet18" + PublicKey.Value)]
45+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.DnnImageFeaturizer.ResNet50" + PublicKey.Value)]
4246

4347
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.StaticPipe" + PublicKey.Value)]
4448

src/Microsoft.ML.Data/TrainCatalog.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,11 @@ private void EnsureStratificationColumn(ref IDataView data, ref string stratific
152152
// Generate a new column with the hashed stratification column.
153153
while (data.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
154154
stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc);
155-
HashingTransformer.ColumnInfo columnInfo;
155+
HashingEstimator.ColumnInfo columnInfo;
156156
if (seed.HasValue)
157-
columnInfo = new HashingTransformer.ColumnInfo(stratificationColumn, origStratCol, 30, seed.Value);
157+
columnInfo = new HashingEstimator.ColumnInfo(stratificationColumn, origStratCol, 30, seed.Value);
158158
else
159-
columnInfo = new HashingTransformer.ColumnInfo(stratificationColumn, origStratCol, 30);
159+
columnInfo = new HashingEstimator.ColumnInfo(stratificationColumn, origStratCol, 30);
160160
data = new HashingEstimator(Host, columnInfo).Fit(data).Transform(data);
161161
}
162162
}

src/Microsoft.ML.Data/Transforms/ColumnCopying.cs

+21-9
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
using Microsoft.ML.Transforms;
1919

2020
[assembly: LoadableClass(ColumnCopyingTransformer.Summary, typeof(IDataTransform), typeof(ColumnCopyingTransformer),
21-
typeof(ColumnCopyingTransformer.Arguments), typeof(SignatureDataTransform),
21+
typeof(ColumnCopyingTransformer.Options), typeof(SignatureDataTransform),
2222
ColumnCopyingTransformer.UserName, "CopyColumns", "CopyColumnsTransform", ColumnCopyingTransformer.ShortName,
2323
DocName = "transform/CopyColumnsTransformer.md")]
2424

@@ -33,18 +33,27 @@
3333

3434
namespace Microsoft.ML.Transforms
3535
{
36+
/// <summary>
37+
/// <see cref="ColumnCopyingEstimator"/> copies the input column to another column named as specified in the parameters of the transformation.
38+
/// </summary>
3639
public sealed class ColumnCopyingEstimator : TrivialEstimator<ColumnCopyingTransformer>
3740
{
38-
public ColumnCopyingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName) :
41+
[BestFriend]
42+
internal ColumnCopyingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName) :
3943
this(env, (outputColumnName, inputColumnName))
4044
{
4145
}
4246

43-
public ColumnCopyingEstimator(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns)
47+
[BestFriend]
48+
internal ColumnCopyingEstimator(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns)
4449
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnCopyingEstimator)), new ColumnCopyingTransformer(env, columns))
4550
{
4651
}
4752

53+
/// <summary>
54+
/// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
55+
/// Used for schema propagation and verification in a pipeline.
56+
/// </summary>
4857
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
4958
{
5059
Host.CheckValue(inputSchema, nameof(inputSchema));
@@ -69,6 +78,9 @@ public sealed class ColumnCopyingTransformer : OneToOneTransformerBase
6978
internal const string UserName = "Copy Columns Transform";
7079
internal const string ShortName = "Copy";
7180

81+
/// <summary>
82+
/// Names of output and input column pairs on which the transformation is applied.
83+
/// </summary>
7284
public IReadOnlyCollection<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly();
7385

7486
private static VersionInfo GetVersionInfo()
@@ -82,12 +94,12 @@ private static VersionInfo GetVersionInfo()
8294
loaderAssemblyName: typeof(ColumnCopyingTransformer).Assembly.FullName);
8395
}
8496

85-
public ColumnCopyingTransformer(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns)
97+
internal ColumnCopyingTransformer(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns)
8698
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnCopyingTransformer)), columns)
8799
{
88100
}
89101

90-
public sealed class Column : OneToOneColumn
102+
internal sealed class Column : OneToOneColumn
91103
{
92104
internal static Column Parse(string str)
93105
{
@@ -106,20 +118,20 @@ internal bool TryUnparse(StringBuilder sb)
106118
}
107119
}
108120

109-
public sealed class Arguments : TransformInputBase
121+
internal sealed class Options : TransformInputBase
110122
{
111123
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:src)",
112124
Name = "Column", ShortName = "col", SortOrder = 1)]
113125
public Column[] Columns;
114126
}
115127

116128
// Factory method corresponding to SignatureDataTransform.
117-
internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
129+
internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
118130
{
119131
Contracts.CheckValue(env, nameof(env));
120-
env.CheckValue(args, nameof(args));
132+
env.CheckValue(options, nameof(options));
121133

122-
var transformer = new ColumnCopyingTransformer(env, args.Columns.Select(x => (x.Name, x.Source)).ToArray());
134+
var transformer = new ColumnCopyingTransformer(env, options.Columns.Select(x => (x.Name, x.Source)).ToArray());
123135
return transformer.MakeDataTransform(input);
124136
}
125137

src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public static HashingEstimator Hash(this TransformsCatalog.ConversionTransforms
3636
/// </summary>
3737
/// <param name="catalog">The transform's catalog.</param>
3838
/// <param name="columns">Description of dataset columns and how to process them.</param>
39-
public static HashingEstimator Hash(this TransformsCatalog.ConversionTransforms catalog, params HashingTransformer.ColumnInfo[] columns)
39+
public static HashingEstimator Hash(this TransformsCatalog.ConversionTransforms catalog, params HashingEstimator.ColumnInfo[] columns)
4040
=> new HashingEstimator(CatalogUtils.GetEnvironment(catalog), columns);
4141

4242
/// <summary>
@@ -93,7 +93,7 @@ public static KeyToValueMappingEstimator MapKeyToValue(this TransformsCatalog.Co
9393
/// <param name="catalog">The categorical transform's catalog.</param>
9494
/// <param name="columns">The input column to map back to vectors.</param>
9595
public static KeyToVectorMappingEstimator MapKeyToVector(this TransformsCatalog.ConversionTransforms catalog,
96-
params KeyToVectorMappingTransformer.ColumnInfo[] columns)
96+
params KeyToVectorMappingEstimator.ColumnInfo[] columns)
9797
=> new KeyToVectorMappingEstimator(CatalogUtils.GetEnvironment(catalog), columns);
9898

9999
/// <summary>

0 commit comments

Comments
 (0)