Skip to content

Convert categorical hash to estimator #1033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
367 changes: 261 additions & 106 deletions src/Microsoft.ML.Transforms/CategoricalHashTransform.cs

Large diffs are not rendered by default.

70 changes: 36 additions & 34 deletions src/Microsoft.ML.Transforms/CategoricalTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV

private readonly TransformerChain<ITransformer> _transformer;

public CategoricalTransform(TermEstimator term, IEstimator<ITransformer> keyToVector, IDataView input)
public CategoricalTransform(TermEstimator term, IEstimator<ITransformer> toVector, IDataView input)
{
var chain = term.Append(keyToVector);
var chain = term.Append(toVector);
_transformer = chain.Fit(input);
}

Expand All @@ -171,16 +171,14 @@ public CategoricalTransform(TermEstimator term, IEstimator<ITransformer> keyToVe

public bool IsRowToRowMapper => _transformer.IsRowToRowMapper;

public IRowToRowMapper GetRowToRowMapper(ISchema inputSchema)
{
Contracts.CheckValue(inputSchema, nameof(inputSchema));
Copy link
Member

@sfilipi sfilipi Sep 28, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Contrac [](start = 11, length = 8)

did you intend to omit the check? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @Zruty0 pointed out we run check inside _transformer.GetRowToRowMapper


In reply to: 221369314 [](ancestors = 221369314)

return _transformer.GetRowToRowMapper(inputSchema);
}
public IRowToRowMapper GetRowToRowMapper(ISchema inputSchema) => _transformer.GetRowToRowMapper(inputSchema);
}

/// <summary>
/// Estimator which takes set of columns and produce for each column indicator array.
/// </summary>
public sealed class CategoricalEstimator : IEstimator<CategoricalTransform>
{
public static class Defaults
internal static class Defaults
{
public const CategoricalTransform.OutputKind OutKind = CategoricalTransform.OutputKind.Ind;
}
Expand All @@ -204,7 +202,7 @@ internal void SetTerms(string terms)
}

private readonly IHost _host;
private readonly IEstimator<ITransformer> _keyToSomething;
private readonly IEstimator<ITransformer> _toSomething;
private TermEstimator _term;

/// A helper method to create <see cref="CategoricalEstimator"/> for public facing API.
Expand All @@ -223,13 +221,11 @@ public CategoricalEstimator(IHostEnvironment env, params ColumnInfo[] columns)
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(TermEstimator));
_term = new TermEstimator(_host, columns);

var binaryCols = new List<(string input, string output)>();
var cols = new List<(string input, string output, bool bag)>();
bool binaryEncoding = false;
for (int i = 0; i < columns.Length; i++)
{
var column = columns[i];
bool bag;
CategoricalTransform.OutputKind kind = columns[i].OutputKind;
switch (kind)
{
Expand All @@ -238,31 +234,37 @@ public CategoricalEstimator(IHostEnvironment env, params ColumnInfo[] columns)
case CategoricalTransform.OutputKind.Key:
continue;
case CategoricalTransform.OutputKind.Bin:
binaryEncoding = true;
bag = false;
binaryCols.Add((column.Output, column.Output));
break;
case CategoricalTransform.OutputKind.Ind:
bag = false;
cols.Add((column.Output, column.Output, false));
break;
case CategoricalTransform.OutputKind.Bag:
bag = true;
cols.Add((column.Output, column.Output, true));
break;
}
cols.Add((column.Output, column.Output, bag));
if (binaryEncoding)
{
_keyToSomething = new KeyToBinaryVectorEstimator(_host, cols.Select(x => new KeyToBinaryVectorTransform.ColumnInfo(x.input, x.output)).ToArray());
}
}
IEstimator<ITransformer> toBinVector = null;
IEstimator<ITransformer> toVector = null;
if (binaryCols.Count > 0)
toBinVector = new KeyToBinaryVectorEstimator(_host, binaryCols.Select(x => new KeyToBinaryVectorTransform.ColumnInfo(x.input, x.output)).ToArray());
if (cols.Count > 0)
toVector = new KeyToVectorEstimator(_host, cols.Select(x => new KeyToVectorTransform.ColumnInfo(x.input, x.output, x.bag)).ToArray());

if (toBinVector != null && toVector != null)
_toSomething = toVector.Append(toBinVector);
else
{
if (toBinVector != null)
_toSomething = toBinVector;
else
{
_keyToSomething = new KeyToVectorEstimator(_host, cols.Select(x => new KeyToVectorTransform.ColumnInfo(x.input, x.output, x.bag)).ToArray());
}
_toSomething = toVector;
}
}

public SchemaShape GetOutputSchema(SchemaShape inputSchema) => _term.Append(_keyToSomething).GetOutputSchema(inputSchema);
public SchemaShape GetOutputSchema(SchemaShape inputSchema) => _term.Append(_toSomething).GetOutputSchema(inputSchema);

public CategoricalTransform Fit(IDataView input) => new CategoricalTransform(_term, _keyToSomething, input);
public CategoricalTransform Fit(IDataView input) => new CategoricalTransform(_term, _toSomething, input);

internal void WrapTermWithDelegate(Action<TermTransform> onFit)
{
Expand Down Expand Up @@ -455,13 +457,13 @@ public override IEstimator<ITransformer> Reconcile(IHostEnvironment env, Pipelin
}

/// <summary>
/// Converts the categorical value into an indicator array by building a dictionary of categories based on the data and using the id in the dictionary as the index in the array
/// Converts the categorical value into an indicator array by building a dictionary of categories based on the data and using the id in the dictionary as the index in the array.
/// </summary>
/// <param name="input">Incoming data.</param>
/// <param name="outputKind">Specify output type of indicator array: array or binary encoded data.</param>
/// <param name="order">How Id for each value would be assigined: by occurrence or by value.</param>
/// <param name="outputKind">Specify the output type of indicator array: array or binary encoded data.</param>
/// <param name="order">How the Id for each value would be assigined: by occurrence or by value.</param>
/// <param name="maxItems">Maximum number of ids to keep during data scanning.</param>
/// /// <param name="onFit">Called upon fitting with the learnt enumeration on the dataset.</param>
/// <param name="onFit">Called upon fitting with the learnt enumeration on the dataset.</param>
public static Vector<float> OneHotEncoding(this Scalar<string> input, OneHotScalarOutputKind outputKind = (OneHotScalarOutputKind)DefOut, KeyValueOrder order = DefSort,
int maxItems = DefMax, ToKeyFitResult<ReadOnlyMemory<char>>.OnFit onFit = null)
{
Expand All @@ -470,11 +472,11 @@ public static Vector<float> OneHotEncoding(this Scalar<string> input, OneHotScal
}

/// <summary>
/// Converts the categorical value into an indicator array by building a dictionary of categories based on the data and using the id in the dictionary as the index in the array
/// Converts the categorical value into an indicator array by building a dictionary of categories based on the data and using the id in the dictionary as the index in the array.
/// </summary>
/// <param name="input">Incoming data.</param>
/// <param name="outputKind">Specify output type of indicator array: Multiarray, array or binary encoded data.</param>
/// <param name="order">How Id for each value would be assigined: by occurrence or by value.</param>
/// <param name="outputKind">Specify the output type of indicator array: Multiarray, array or binary encoded data.</param>
/// <param name="order">How the Id for each value would be assigined: by occurrence or by value.</param>
/// <param name="maxItems">Maximum number of ids to keep during data scanning.</param>
/// <param name="onFit">Called upon fitting with the learnt enumeration on the dataset.</param>
public static Vector<float> OneHotEncoding(this Vector<string> input, OneHotVectorOutputKind outputKind = DefOut, KeyValueOrder order = DefSort, int maxItems = DefMax,
Expand Down
20 changes: 10 additions & 10 deletions test/BaselineOutput/SingleDebug/Categorical/featurized.tsv
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#@ TextLoader{
#@ header+
#@ sep=tab
#@ col=A:R4:0-4
#@ col=B:R4:5-24
#@ col=C:R4:25-44
#@ col=D:R4:45-49
#@ col=E:R4:50-69
#@ col=A:R4:0-9
#@ col=B:R4:10-49
#@ col=C:R4:50-59
#@ col=D:R4:60-64
#@ col=E:R4:65-84
#@ }
Bit4 Bit3 Bit2 Bit1 Bit0 [0].Bit4 [0].Bit3 [0].Bit2 [0].Bit1 [0].Bit0 [1].Bit4 [1].Bit3 [1].Bit2 [1].Bit1 [1].Bit0 [2].Bit4 [2].Bit3 [2].Bit2 [2].Bit1 [2].Bit0 [3].Bit4 [3].Bit3 [3].Bit2 [3].Bit1 [3].Bit0 [0].Bit4 [0].Bit3 [0].Bit2 [0].Bit1 [0].Bit0 [1].Bit4 [1].Bit3 [1].Bit2 [1].Bit1 [1].Bit0 [2].Bit4 [2].Bit3 [2].Bit2 [2].Bit1 [2].Bit0 [3].Bit4 [3].Bit3 [3].Bit2 [3].Bit1 [3].Bit0 Bit4 Bit3 Bit2 Bit1 Bit0 [0].Bit4 [0].Bit3 [0].Bit2 [0].Bit1 [0].Bit0 [1].Bit4 [1].Bit3 [1].Bit2 [1].Bit1 [1].Bit0 [2].Bit4 [2].Bit3 [2].Bit2 [2].Bit1 [2].Bit0 [3].Bit4 [3].Bit3 [3].Bit2 [3].Bit1 [3].Bit0
70 14:1 19:1 24:1 34:1 39:1 44:1 59:1 64:1 69:1
70 13:1 18:1 33:1 38:1 58:1 63:1
70 4:1 8:1 9:1 14:1 19:1 24:1 28:1 29:1 34:1 39:1 44:1 49:1 53:1 54:1 59:1 64:1 69:1
70 3:1 7:1 12:1 14:1 17:1 19:1 24:1 27:1 32:1 34:1 37:1 39:1 44:1 48:1 52:1 57:1 59:1 62:1 64:1 69:1
5 3 6 4 8 1 2 7 10 9 [0].5 [0].1 [0].4 [0].3 [0].6 [0].8 [0].10 [0].2 [0].7 [0].9 [1].5 [1].1 [1].4 [1].3 [1].6 [1].8 [1].10 [1].2 [1].7 [1].9 [2].5 [2].1 [2].4 [2].3 [2].6 [2].8 [2].10 [2].2 [2].7 [2].9 [3].5 [3].1 [3].4 [3].3 [3].6 [3].8 [3].10 [3].2 [3].7 [3].9 5 1 4 3 6 8 10 2 7 9 Bit4 Bit3 Bit2 Bit1 Bit0 [0].Bit4 [0].Bit3 [0].Bit2 [0].Bit1 [0].Bit0 [1].Bit4 [1].Bit3 [1].Bit2 [1].Bit1 [1].Bit0 [2].Bit4 [2].Bit3 [2].Bit2 [2].Bit1 [2].Bit0 [3].Bit4 [3].Bit3 [3].Bit2 [3].Bit1 [3].Bit0
85 0:1 10:1 21:1 31:1 41:1 50:1 51:3 74:1 79:1 84:1
85 0:1 10:1 22:1 32:1 40:1 50:2 52:2 73:1 78:1
85 1:1 13:1 21:1 31:1 41:1 51:3 53:1 64:1 68:1 69:1 74:1 79:1 84:1
85 2:1 14:1 25:1 35:1 41:1 51:1 54:1 55:2 63:1 67:1 72:1 74:1 77:1 79:1 84:1
12 changes: 12 additions & 0 deletions test/BaselineOutput/SingleDebug/CategoricalHash/featurized.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#@ TextLoader{
#@ sep=tab
#@ col=A:R4:0-65535
#@ col=B:R4:65536-327679
#@ col=C:R4:327680-393215
#@ col=D:R4:393216-393233
#@ col=E:R4:393234-393305
#@ }
393306 11529:1 77065:1 165873:1 196777:1 326564:1 327849:1 339209:1 362481:1 392100:1 393220:1 393222:1 393223:1 393225:1 393230:1 393233:1 393238:1 393240:1 393241:1 393243:1 393248:1 393251:1 393254:1 393259:1 393260:1 393261:1 393262:1 393263:1 393264:1 393265:1 393269:1 393280:1 393282:1 393284:1 393287:1 393290:1 393291:1 393292:1 393293:1 393294:1 393296:1 393297:1 393298:1 393300:1 393303:1
393306 11529:1 77065:1 192621:1 236060:1 323071:1 339209:1 367132:1 388607:1 389229:1 393220:1 393222:1 393223:1 393225:1 393230:1 393233:1 393238:1 393240:1 393241:1 393243:1 393248:1 393251:1 393254:1 393255:1 393256:1 393257:1 393263:1 393264:1 393266:1 393267:1 393269:1 393272:1 393275:1 393276:1 393278:1 393283:1 393284:1 393285:1 393290:1 393291:1 393292:1 393294:1 393295:1 393297:1 393298:1 393299:1 393300:1 393301:1 393302:1 393303:1 393304:1 393305:1
393306 47483:1 113019:1 165873:1 196777:1 326564:1 327849:1 362481:1 375163:1 392100:1 393218:1 393220:1 393221:1 393222:1 393225:1 393227:1 393228:1 393229:1 393230:1 393232:1 393233:1 393236:1 393238:1 393239:1 393240:1 393243:1 393245:1 393246:1 393247:1 393248:1 393250:1 393251:1 393254:1 393259:1 393260:1 393261:1 393262:1 393263:1 393264:1 393265:1 393269:1 393280:1 393282:1 393284:1 393287:1 393290:1 393291:1 393292:1 393293:1 393294:1 393296:1 393297:1 393298:1 393300:1 393303:1
393306 42588:1 108124:1 173921:1 212446:1 326564:1 343518:1 370268:1 370529:1 392100:1 393218:1 393220:1 393223:1 393224:1 393227:1 393229:1 393230:1 393231:1 393236:1 393238:1 393241:1 393242:1 393245:1 393247:1 393248:1 393249:1 393254:1 393256:1 393259:1 393260:1 393261:1 393263:1 393264:1 393269:1 393274:1 393275:1 393276:1 393277:1 393279:1 393280:1 393281:1 393283:1 393284:1 393285:1 393286:1 393290:1 393291:1 393292:1 393293:1 393294:1 393296:1 393297:1 393298:1 393300:1 393303:1
20 changes: 10 additions & 10 deletions test/BaselineOutput/SingleRelease/Categorical/featurized.tsv
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#@ TextLoader{
#@ header+
#@ sep=tab
#@ col=A:R4:0-4
#@ col=B:R4:5-24
#@ col=C:R4:25-44
#@ col=D:R4:45-49
#@ col=E:R4:50-69
#@ col=A:R4:0-9
#@ col=B:R4:10-49
#@ col=C:R4:50-59
#@ col=D:R4:60-64
#@ col=E:R4:65-84
#@ }
Bit4 Bit3 Bit2 Bit1 Bit0 [0].Bit4 [0].Bit3 [0].Bit2 [0].Bit1 [0].Bit0 [1].Bit4 [1].Bit3 [1].Bit2 [1].Bit1 [1].Bit0 [2].Bit4 [2].Bit3 [2].Bit2 [2].Bit1 [2].Bit0 [3].Bit4 [3].Bit3 [3].Bit2 [3].Bit1 [3].Bit0 [0].Bit4 [0].Bit3 [0].Bit2 [0].Bit1 [0].Bit0 [1].Bit4 [1].Bit3 [1].Bit2 [1].Bit1 [1].Bit0 [2].Bit4 [2].Bit3 [2].Bit2 [2].Bit1 [2].Bit0 [3].Bit4 [3].Bit3 [3].Bit2 [3].Bit1 [3].Bit0 Bit4 Bit3 Bit2 Bit1 Bit0 [0].Bit4 [0].Bit3 [0].Bit2 [0].Bit1 [0].Bit0 [1].Bit4 [1].Bit3 [1].Bit2 [1].Bit1 [1].Bit0 [2].Bit4 [2].Bit3 [2].Bit2 [2].Bit1 [2].Bit0 [3].Bit4 [3].Bit3 [3].Bit2 [3].Bit1 [3].Bit0
70 14:1 19:1 24:1 34:1 39:1 44:1 59:1 64:1 69:1
70 13:1 18:1 33:1 38:1 58:1 63:1
70 4:1 8:1 9:1 14:1 19:1 24:1 28:1 29:1 34:1 39:1 44:1 49:1 53:1 54:1 59:1 64:1 69:1
70 3:1 7:1 12:1 14:1 17:1 19:1 24:1 27:1 32:1 34:1 37:1 39:1 44:1 48:1 52:1 57:1 59:1 62:1 64:1 69:1
5 3 6 4 8 1 2 7 10 9 [0].5 [0].1 [0].4 [0].3 [0].6 [0].8 [0].10 [0].2 [0].7 [0].9 [1].5 [1].1 [1].4 [1].3 [1].6 [1].8 [1].10 [1].2 [1].7 [1].9 [2].5 [2].1 [2].4 [2].3 [2].6 [2].8 [2].10 [2].2 [2].7 [2].9 [3].5 [3].1 [3].4 [3].3 [3].6 [3].8 [3].10 [3].2 [3].7 [3].9 5 1 4 3 6 8 10 2 7 9 Bit4 Bit3 Bit2 Bit1 Bit0 [0].Bit4 [0].Bit3 [0].Bit2 [0].Bit1 [0].Bit0 [1].Bit4 [1].Bit3 [1].Bit2 [1].Bit1 [1].Bit0 [2].Bit4 [2].Bit3 [2].Bit2 [2].Bit1 [2].Bit0 [3].Bit4 [3].Bit3 [3].Bit2 [3].Bit1 [3].Bit0
85 0:1 10:1 21:1 31:1 41:1 50:1 51:3 74:1 79:1 84:1
85 0:1 10:1 22:1 32:1 40:1 50:2 52:2 73:1 78:1
85 1:1 13:1 21:1 31:1 41:1 51:3 53:1 64:1 68:1 69:1 74:1 79:1 84:1
85 2:1 14:1 25:1 35:1 41:1 51:1 54:1 55:2 63:1 67:1 72:1 74:1 77:1 79:1 84:1
12 changes: 12 additions & 0 deletions test/BaselineOutput/SingleRelease/CategoricalHash/featurized.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#@ TextLoader{
#@ sep=tab
#@ col=A:R4:0-65535
#@ col=B:R4:65536-327679
#@ col=C:R4:327680-393215
#@ col=D:R4:393216-393233
#@ col=E:R4:393234-393305
#@ }
393306 11529:1 77065:1 165873:1 196777:1 326564:1 327849:1 339209:1 362481:1 392100:1 393220:1 393222:1 393223:1 393225:1 393230:1 393233:1 393238:1 393240:1 393241:1 393243:1 393248:1 393251:1 393254:1 393259:1 393260:1 393261:1 393262:1 393263:1 393264:1 393265:1 393269:1 393280:1 393282:1 393284:1 393287:1 393290:1 393291:1 393292:1 393293:1 393294:1 393296:1 393297:1 393298:1 393300:1 393303:1
393306 11529:1 77065:1 192621:1 236060:1 323071:1 339209:1 367132:1 388607:1 389229:1 393220:1 393222:1 393223:1 393225:1 393230:1 393233:1 393238:1 393240:1 393241:1 393243:1 393248:1 393251:1 393254:1 393255:1 393256:1 393257:1 393263:1 393264:1 393266:1 393267:1 393269:1 393272:1 393275:1 393276:1 393278:1 393283:1 393284:1 393285:1 393290:1 393291:1 393292:1 393294:1 393295:1 393297:1 393298:1 393299:1 393300:1 393301:1 393302:1 393303:1 393304:1 393305:1
393306 47483:1 113019:1 165873:1 196777:1 326564:1 327849:1 362481:1 375163:1 392100:1 393218:1 393220:1 393221:1 393222:1 393225:1 393227:1 393228:1 393229:1 393230:1 393232:1 393233:1 393236:1 393238:1 393239:1 393240:1 393243:1 393245:1 393246:1 393247:1 393248:1 393250:1 393251:1 393254:1 393259:1 393260:1 393261:1 393262:1 393263:1 393264:1 393265:1 393269:1 393280:1 393282:1 393284:1 393287:1 393290:1 393291:1 393292:1 393293:1 393294:1 393296:1 393297:1 393298:1 393300:1 393303:1
393306 42588:1 108124:1 173921:1 212446:1 326564:1 343518:1 370268:1 370529:1 392100:1 393218:1 393220:1 393223:1 393224:1 393227:1 393229:1 393230:1 393231:1 393236:1 393238:1 393241:1 393242:1 393245:1 393247:1 393248:1 393249:1 393254:1 393256:1 393259:1 393260:1 393261:1 393263:1 393264:1 393269:1 393274:1 393275:1 393276:1 393277:1 393279:1 393280:1 393281:1 393283:1 393284:1 393285:1 393286:1 393290:1 393291:1 393292:1 393293:1 393294:1 393296:1 393297:1 393298:1 393300:1 393303:1
Loading