Skip to content

Commit ac91db8

Browse files
author
Ivan Matantsev
committed
tests, somewhat pigsty
1 parent d68e59b commit ac91db8

File tree

8 files changed

+1507
-28
lines changed

8 files changed

+1507
-28
lines changed

src/Microsoft.ML.Data/Transforms/TermTransform.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ public ColumnInfo(string input, string output, int maxNumTerms = TermEstimator.D
175175
public readonly string[] Term;
176176
public readonly bool TextKeyValues;
177177

178-
protected string Terms { get; set; }
178+
protected internal string Terms { get; set; }
179179
}
180180

181181
public const string Summary = "Converts input values (words, numbers, etc.) to index in a dictionary.";

src/Microsoft.ML.Transforms/CategoricalStaticExtensions.cs

+1,212
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
<#@ template debug="false" hostspecific="false" language="C#" #>
2+
<#@ assembly name="System.Core" #>
3+
<#@ import namespace="System.Linq" #>
4+
<#@ import namespace="System.Text" #>
5+
<#@ import namespace="System.Collections.Generic" #>
6+
<#@ output extension=".cs" #>
7+
// Licensed to the .NET Foundation under one or more agreements.
8+
// The .NET Foundation licenses this file to you under the MIT license.
9+
// See the LICENSE file in the project root for more information.
10+
11+
using System;
12+
using Microsoft.ML.Data.StaticPipe.Runtime;
13+
14+
namespace Microsoft.ML.Runtime.Data
15+
{
16+
public static partial class CategoricalStaticExtensions
17+
{
18+
// Do not edit this file directly. Rather, it is generated out of TermStaticExtensions.tt.
19+
<#
20+
// Let's skip the time-based types for now.
21+
foreach (string typeName in new string[] { "string", "float", "double", "sbyte", "short", "int", "long", "byte", "ushort", "uint", "ulong", "bool" }) {
22+
bool startRegionBlock = true;
23+
#>
24+
25+
#region For <#=typeName#> inputs.
26+
<#
27+
foreach (bool inputIsKey in new bool[] { false, true }) {
28+
foreach (string arityName in new string[] { "Scalar", "Vector", "VarVector" }) {
29+
string onFitType = typeName == "string" ? "ReadOnlyMemory<char>" : typeName;
30+
bool omitInputArity = arityName == "Scalar" && inputIsKey;
31+
bool isNumeric = typeName != "string" && typeName != "bool";
32+
bool isScalar = arityName == "Scalar";
33+
bool isVarVector = arityName == "VarVector";
34+
35+
if (!startRegionBlock) { // Put lines between the declarations to make them look pretty, but not after the region tag.
36+
#>
37+
38+
<#}
39+
startRegionBlock = false;
40+
#>
41+
/// <summary>
42+
/// Map values to a key-value representation, where the key type's values are those values observed in the input
43+
/// during fitting. During transformation, any values unobserved during fitting will map to the missing key.
44+
<#
45+
if (typeName == "string") { #>
46+
/// Because the empty string is never entered into the dictionary, it will always map to the missing key.
47+
<# }
48+
if (typeName == "float" || typeName == "double") { #>
49+
/// Because <c>NaN</c> floating point values are never entered into the dictionary, and they will always map to the missing key.
50+
<# }
51+
if (!isScalar && isNumeric && !inputIsKey) { #>
52+
/// Zero is considered a valid value and so will be entered into the dictionary if observed. The potential perf
53+
/// implication in that case is that sparse input numeric vectors will map to dense output key vectors.
54+
<# }
55+
if (inputIsKey) { #>
56+
/// We are inputting a key type with values, and in that case the dictionary is considered to be built over the
57+
/// values of the keys, rather than the keys themselves. This also mean the key-values learned for the output
58+
/// will be a subset of the key-values in the input.
59+
<# }
60+
61+
#>
62+
/// </summary>
63+
/// <param name="input">The input column.</param>
64+
/// <param name="outputKind"></param>
65+
/// <param name="order">The ordering policy for what order values will appear in the enumerated set.</param>
66+
/// <param name="maxItems">The maximum number of items.</param>
67+
68+
/// <returns>The key-valued column.</returns>
69+
public static <#=isVarVector?"Var":""#>Vector<float> OneHotEncoding<#=inputIsKey?"<T>":""#>(this <#=omitInputArity?"":arityName+"<"#><#=inputIsKey?"Key<T, ":""#><#=typeName#>><#=inputIsKey&&!isScalar?">":""#> input,
70+
OneHotOutputKind outputKind = DefOut, KeyValueOrder order = DefSort, int maxItems = DefMax)
71+
=> new Impl<#=arityName#><<#=typeName#>>(Contracts.CheckRef(input, nameof(input)), new Config(order, maxItems, outputKind));
72+
<#
73+
} }
74+
#>
75+
#endregion
76+
<#
77+
}
78+
#>
79+
}
80+
}

src/Microsoft.ML.Transforms/CategoricalTransform.cs

+108-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
using Microsoft.ML.Runtime.Internal.Utilities;
1616
using Microsoft.ML.Runtime.Internal.Internallearn;
1717
using Microsoft.ML.Core.Data;
18+
using Microsoft.ML.Data.StaticPipe.Runtime;
1819

1920
[assembly: LoadableClass(CategoricalTransform.Summary, typeof(IDataTransform), typeof(CategoricalTransform), typeof(CategoricalTransform.Arguments), typeof(SignatureDataTransform),
2021
CategoricalTransform.UserName, "CategoricalTransform", "CatTransform", "Categorical", "Cat")]
@@ -119,7 +120,7 @@ public Arguments()
119120

120121
public const string UserName = "Categorical Transform";
121122

122-
public static IDataView Create(IHostEnvironment env, Arguments args, IDataView input)
123+
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
123124
{
124125
Contracts.CheckValue(env, nameof(env));
125126
var h = env.Register("Categorical");
@@ -140,7 +141,7 @@ public static IDataView Create(IHostEnvironment env, Arguments args, IDataView i
140141
col.SetTerms(column.Terms);
141142
columns.Add(col);
142143
}
143-
return Create(env, input, columns.ToArray());
144+
return Create(env, input, columns.ToArray()) as IDataTransform;
144145
}
145146

146147
public static IDataView Create(IHostEnvironment env, IDataView input, params CategoricalEstimator.ColumnInfo[] columns)
@@ -306,4 +307,109 @@ public static CommonOutputs.TransformOutput KeyToText(IHostEnvironment env, KeyT
306307
return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf };
307308
}
308309
}
310+
311+
public enum OneHotOutputKind : byte
312+
{
313+
/// <summary>
314+
/// Output is a bag (multi-set) vector
315+
/// </summary>
316+
Bag = 1,
317+
318+
/// <summary>
319+
/// Output is an indicator vector
320+
/// </summary>
321+
Ind = 2,
322+
323+
/// <summary>
324+
/// Output is a key value
325+
/// </summary>
326+
Key = 3,
327+
328+
/// <summary>
329+
/// Output is binary encoded
330+
/// </summary>
331+
Bin = 4,
332+
}
333+
334+
public static partial class CategoricalStaticExtensions
335+
{
336+
// I am not certain I see a good way to cover the distinct types beyond complete enumeration.
337+
// Raw generics would allow illegal possible inputs, e.g., Scalar<Bitmap>. So, this is a partial
338+
// class, and all the public facing extension methods for each possible type are in a T4 generated result.
339+
340+
private const KeyValueOrder DefSort = (KeyValueOrder)TermEstimator.Defaults.Sort;
341+
private const int DefMax = TermEstimator.Defaults.MaxNumTerms;
342+
private const OneHotOutputKind DefOut = (OneHotOutputKind)CategoricalEstimator.Defaults.OutKind;
343+
344+
private struct Config
345+
{
346+
public readonly KeyValueOrder Order;
347+
public readonly int Max;
348+
public readonly OneHotOutputKind OutputKind;
349+
350+
public Config(KeyValueOrder order, int max, OneHotOutputKind outputKind)
351+
{
352+
Order = order;
353+
Max = max;
354+
OutputKind = outputKind;
355+
}
356+
}
357+
358+
private interface IOneHotCol
359+
{
360+
PipelineColumn Input { get; }
361+
Config Config { get; }
362+
}
363+
364+
private sealed class ImplScalar<T> :Vector<float>, IOneHotCol
365+
{
366+
public PipelineColumn Input { get; }
367+
public Config Config { get; }
368+
public ImplScalar(PipelineColumn input, Config config) : base(Rec.Inst, input)
369+
{
370+
Input = input;
371+
Config = config;
372+
}
373+
}
374+
375+
private sealed class ImplVector<T> : Vector<float>, IOneHotCol
376+
{
377+
public PipelineColumn Input { get; }
378+
public Config Config { get; }
379+
public ImplVector(PipelineColumn input, Config config) : base(Rec.Inst, input)
380+
{
381+
Input = input;
382+
Config = config;
383+
}
384+
}
385+
386+
private sealed class ImplVarVector<T> : VarVector<float>, IOneHotCol
387+
{
388+
public PipelineColumn Input { get; }
389+
public Config Config { get; }
390+
public ImplVarVector(PipelineColumn input, Config config) : base(Rec.Inst, input)
391+
{
392+
Input = input;
393+
Config = config;
394+
}
395+
}
396+
397+
private sealed class Rec : EstimatorReconciler
398+
{
399+
public static readonly Rec Inst = new Rec();
400+
401+
public override IEstimator<ITransformer> Reconcile(IHostEnvironment env, PipelineColumn[] toOutput,
402+
IReadOnlyDictionary<PipelineColumn, string> inputNames, IReadOnlyDictionary<PipelineColumn, string> outputNames, IReadOnlyCollection<string> usedNames)
403+
{
404+
var infos = new CategoricalEstimator.ColumnInfo[toOutput.Length];
405+
for (int i = 0; i < toOutput.Length; ++i)
406+
{
407+
var tcol = (IOneHotCol)toOutput[i];
408+
infos[i] = new CategoricalEstimator.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]], (CategoricalTransform.OutputKind)tcol.Config.OutputKind,
409+
tcol.Config.Max, (TermTransform.SortOrder)tcol.Config.Order);
410+
}
411+
return new CategoricalEstimator(env, infos);
412+
}
413+
}
414+
}
309415
}

src/Microsoft.ML.Transforms/Microsoft.ML.Transforms.csproj

+16
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@
5151
</ItemGroup>
5252

5353
<ItemGroup>
54+
<Compile Update="CategoricalStaticExtensions.cs">
55+
<DesignTime>True</DesignTime>
56+
<AutoGen>True</AutoGen>
57+
<DependentUpon>CategoricalStaticExtensions.tt</DependentUpon>
58+
</Compile>
5459
<Compile Update="Properties\Resources.Designer.cs">
5560
<DesignTime>True</DesignTime>
5661
<AutoGen>True</AutoGen>
@@ -65,4 +70,15 @@
6570
</EmbeddedResource>
6671
</ItemGroup>
6772

73+
<ItemGroup>
74+
<None Update="CategoricalStaticExtensions.tt">
75+
<Generator>TextTemplatingFileGenerator</Generator>
76+
<LastGenOutput>CategoricalStaticExtensions.cs</LastGenOutput>
77+
</None>
78+
</ItemGroup>
79+
80+
<ItemGroup>
81+
<Service Include="{508349b6-6b84-4df5-91f0-309beebad82d}" />
82+
</ItemGroup>
83+
6884
</Project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#@ TextLoader{
2+
#@ header+
3+
#@ sep=tab
4+
#@ col=A:R4:0-9
5+
#@ col=B:R4:10-49
6+
#@ col=C:R4:50-59
7+
#@ col=D:R4:60-99
8+
#@ }
9+
5 3 6 4 8 1 2 7 10 9 [0].5 [0].1 [0].4 [0].3 [0].6 [0].8 [0].10 [0].2 [0].7 [0].9 [1].5 [1].1 [1].4 [1].3 [1].6 [1].8 [1].10 [1].2 [1].7 [1].9 [2].5 [2].1 [2].4 [2].3 [2].6 [2].8 [2].10 [2].2 [2].7 [2].9 [3].5 [3].1 [3].4 [3].3 [3].6 [3].8 [3].10 [3].2 [3].7 [3].9 5 3 6 4 8 1 2 7 10 9 [0].5 [0].1 [0].4 [0].3 [0].6 [0].8 [0].10 [0].2 [0].7 [0].9 [1].5 [1].1 [1].4 [1].3 [1].6 [1].8 [1].10 [1].2 [1].7 [1].9 [2].5 [2].1 [2].4 [2].3 [2].6 [2].8 [2].10 [2].2 [2].7 [2].9 [3].5 [3].1 [3].4 [3].3 [3].6 [3].8 [3].10 [3].2 [3].7 [3].9
10+
100 0:1 10:1 21:1 31:1 41:1 50:1 60:1 71:1 81:1 91:1
11+
100 0:1 10:1 22:1 32:1 40:1 50:1 60:1 72:1 82:1 90:1
12+
100 1:1 13:1 21:1 31:1 41:1 51:1 63:1 71:1 81:1 91:1
13+
100 2:1 14:1 25:1 35:1 41:1 52:1 64:1 75:1 85:1 91:1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#@ TextLoader{
2+
#@ header+
3+
#@ sep=tab
4+
#@ col=A:R4:0-9
5+
#@ col=B:R4:10-49
6+
#@ col=C:R4:50-59
7+
#@ col=D:R4:60-99
8+
#@ }
9+
5 3 6 4 8 1 2 7 10 9 [0].5 [0].1 [0].4 [0].3 [0].6 [0].8 [0].10 [0].2 [0].7 [0].9 [1].5 [1].1 [1].4 [1].3 [1].6 [1].8 [1].10 [1].2 [1].7 [1].9 [2].5 [2].1 [2].4 [2].3 [2].6 [2].8 [2].10 [2].2 [2].7 [2].9 [3].5 [3].1 [3].4 [3].3 [3].6 [3].8 [3].10 [3].2 [3].7 [3].9 5 3 6 4 8 1 2 7 10 9 [0].5 [0].1 [0].4 [0].3 [0].6 [0].8 [0].10 [0].2 [0].7 [0].9 [1].5 [1].1 [1].4 [1].3 [1].6 [1].8 [1].10 [1].2 [1].7 [1].9 [2].5 [2].1 [2].4 [2].3 [2].6 [2].8 [2].10 [2].2 [2].7 [2].9 [3].5 [3].1 [3].4 [3].3 [3].6 [3].8 [3].10 [3].2 [3].7 [3].9
10+
100 0:1 10:1 21:1 31:1 41:1 50:1 60:1 71:1 81:1 91:1
11+
100 0:1 10:1 22:1 32:1 40:1 50:1 60:1 72:1 82:1 90:1
12+
100 1:1 13:1 21:1 31:1 41:1 51:1 63:1 71:1 81:1 91:1
13+
100 2:1 14:1 25:1 35:1 41:1 52:1 64:1 75:1 85:1 91:1

0 commit comments

Comments
 (0)