-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Categorical estimator #899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
d68e59b
ac91db8
25d46e7
2570eb9
8174f3b
114fec9
9d51927
4711a8d
120d3e8
cbbf5d0
47a00a9
7bfedf5
d2d2a88
c645cbd
486d65a
49c5514
71c1d45
a397200
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,16 +6,31 @@ | |
using Microsoft.ML.Data.StaticPipe.Runtime; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Collections.Immutable; | ||
using System.Linq; | ||
|
||
namespace Microsoft.ML.Runtime.Data | ||
{ | ||
public sealed class TermEstimator : IEstimator<TermTransform> | ||
{ | ||
public static class Defaults | ||
{ | ||
public const int MaxNumTerms = 1000000; | ||
public const TermTransform.SortOrder Sort = TermTransform.SortOrder.Occurrence; | ||
} | ||
|
||
private readonly IHost _host; | ||
private readonly TermTransform.ColumnInfo[] _columns; | ||
public TermEstimator(IHostEnvironment env, string name, string source = null, int maxNumTerms = TermTransform.Defaults.MaxNumTerms, TermTransform.SortOrder sort = TermTransform.Defaults.Sort) : | ||
|
||
/// <summary> | ||
/// Convenience constructor for public facing API. | ||
/// </summary> | ||
/// <param name="env">Host Environment.</param> | ||
/// <param name="name">Name of the output column.</param> | ||
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param> | ||
/// <param name="maxNumTerms">Maximum number of terms to keep per column when auto-training.</param> | ||
/// <param name="sort">How items should be ordered when vectorized. By default, they will be in the order encountered. | ||
/// If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a').</param> | ||
public TermEstimator(IHostEnvironment env, string name, string source = null, int maxNumTerms = Defaults.MaxNumTerms, TermTransform.SortOrder sort = Defaults.Sort) : | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
xml doc #Resolved |
||
this(env, new TermTransform.ColumnInfo(name, source ?? name, maxNumTerms, sort)) | ||
{ | ||
} | ||
|
@@ -47,7 +62,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) | |
if (!col.IsKey || !col.Metadata.TryFindColumn(MetadataUtils.Kinds.KeyValues, out var kv) || kv.Kind != SchemaShape.Column.VectorKind.Vector) | ||
{ | ||
kv = new SchemaShape.Column(MetadataUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, | ||
col.ItemType, col.IsKey); | ||
colInfo.TextKeyValues ? TextType.Instance : col.ItemType, col.IsKey); | ||
} | ||
Contracts.AssertValue(kv); | ||
|
||
|
@@ -90,7 +105,7 @@ public sealed class ToKeyFitResult<T> | |
// At the moment this is empty. Once PR #863 clears, we can change this class to hold the output | ||
// key-values metadata. | ||
|
||
internal ToKeyFitResult(TermTransform.TermMap map) | ||
public ToKeyFitResult(TermTransform.TermMap map) | ||
{ | ||
} | ||
} | ||
|
@@ -101,8 +116,8 @@ public static partial class TermStaticExtensions | |
// Raw generics would allow illegal possible inputs, e.g., Scalar<Bitmap>. So, this is a partial | ||
// class, and all the public facing extension methods for each possible type are in a T4 generated result. | ||
|
||
private const KeyValueOrder DefSort = (KeyValueOrder)TermTransform.Defaults.Sort; | ||
private const int DefMax = TermTransform.Defaults.MaxNumTerms; | ||
private const KeyValueOrder DefSort = (KeyValueOrder)TermEstimator.Defaults.Sort; | ||
private const int DefMax = TermEstimator.Defaults.MaxNumTerms; | ||
|
||
private struct Config | ||
{ | ||
|
@@ -176,7 +191,7 @@ public override IEstimator<ITransformer> Reconcile(IHostEnvironment env, Pipelin | |
{ | ||
var infos = new TermTransform.ColumnInfo[toOutput.Length]; | ||
Action<TermTransform> onFit = null; | ||
for (int i=0; i<toOutput.Length; ++i) | ||
for (int i = 0; i < toOutput.Length; ++i) | ||
{ | ||
var tcol = (ITermCol)toOutput[i]; | ||
infos[i] = new TermTransform.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]], | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,17 +2,14 @@ | |
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Float = System.Single; | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Text; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.CommandLine; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.EntryPoints; | ||
using Microsoft.ML.Runtime.Internal.Utilities; | ||
using Microsoft.ML.Runtime.Internal.Internallearn; | ||
|
||
[assembly: LoadableClass(CategoricalHashTransform.Summary, typeof(IDataTransform), typeof(CategoricalHashTransform), typeof(CategoricalHashTransform.Arguments), typeof(SignatureDataTransform), | ||
CategoricalHashTransform.UserName, "CategoricalHashTransform", "CatHashTransform", "CategoricalHash", "CatHash")] | ||
|
@@ -62,14 +59,11 @@ protected override bool TryParse(string str) | |
|
||
// We accept N:B:S where N is the new column name, B is the number of bits, | ||
// and S is source column names. | ||
string extra; | ||
if (!base.TryParse(str, out extra)) | ||
if (!TryParse(str, out string extra)) | ||
return false; | ||
if (extra == null) | ||
return true; | ||
|
||
int bits; | ||
if (!int.TryParse(extra, out bits)) | ||
if (!int.TryParse(extra, out int bits)) | ||
return false; | ||
HashBits = bits; | ||
return true; | ||
|
@@ -201,14 +195,81 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV | |
}; | ||
} | ||
|
||
return CategoricalTransform.CreateTransformCore( | ||
return CreateTransformCore( | ||
args.OutputKind, args.Column, | ||
args.Column.Select(col => col.OutputKind).ToList(), | ||
new HashTransform(h, hashArgs, input), | ||
h, | ||
env, | ||
args); | ||
} | ||
} | ||
|
||
private static IDataTransform CreateTransformCore(CategoricalTransform.OutputKind argsOutputKind, OneToOneColumn[] columns, | ||
List<CategoricalTransform.OutputKind?> columnOutputKinds, IDataTransform input, IHost h, Arguments catHashArgs = null) | ||
{ | ||
Contracts.CheckValue(columns, nameof(columns)); | ||
Contracts.CheckValue(columnOutputKinds, nameof(columnOutputKinds)); | ||
Contracts.CheckParam(columns.Length == columnOutputKinds.Count, nameof(columns)); | ||
|
||
using (var ch = h.Start("Create Transform Core")) | ||
{ | ||
// Create the KeyToVectorTransform, if needed. | ||
var cols = new List<KeyToVectorTransform.Column>(); | ||
bool binaryEncoding = argsOutputKind == CategoricalTransform.OutputKind.Bin; | ||
for (int i = 0; i < columns.Length; i++) | ||
{ | ||
var column = columns[i]; | ||
if (!column.TrySanitize()) | ||
throw h.ExceptUserArg(nameof(Column.Name)); | ||
|
||
bool? bag; | ||
CategoricalTransform.OutputKind kind = columnOutputKinds[i] ?? argsOutputKind; | ||
switch (kind) | ||
{ | ||
default: | ||
throw ch.ExceptUserArg(nameof(Column.OutputKind)); | ||
case CategoricalTransform.OutputKind.Key: | ||
continue; | ||
case CategoricalTransform.OutputKind.Bin: | ||
binaryEncoding = true; | ||
bag = false; | ||
break; | ||
case CategoricalTransform.OutputKind.Ind: | ||
bag = false; | ||
break; | ||
case CategoricalTransform.OutputKind.Bag: | ||
bag = true; | ||
break; | ||
} | ||
var col = new KeyToVectorTransform.Column(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
object initializer? #ByDesign There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's gonna be dead after hash transform become estimator. In reply to: 217236161 [](ancestors = 217236161) |
||
col.Name = column.Name; | ||
col.Source = column.Name; | ||
col.Bag = bag; | ||
cols.Add(col); | ||
} | ||
|
||
if (cols.Count == 0) | ||
return input; | ||
|
||
IDataTransform transform; | ||
if (binaryEncoding) | ||
{ | ||
if ((catHashArgs?.InvertHash ?? 0) != 0) | ||
ch.Warning("Invert hashing is being used with binary encoding."); | ||
|
||
var keyToBinaryVecCols = cols.Select(x => new KeyToBinaryVectorTransform.ColumnInfo(x.Source, x.Name)).ToArray(); | ||
transform = KeyToBinaryVectorTransform.Create(h, input, keyToBinaryVecCols); | ||
} | ||
else | ||
{ | ||
var keyToVecCols = cols.Select(x => new KeyToVectorTransform.ColumnInfo(x.Source, x.Name, x.Bag ?? argsOutputKind == CategoricalTransform.OutputKind.Bag)).ToArray(); | ||
|
||
transform = KeyToVectorTransform.Create(h, input, keyToVecCols); | ||
} | ||
|
||
ch.Done(); | ||
return transform; | ||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a right way to do? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good to me
In reply to: 217229486 [](ancestors = 217229486)