Skip to content

Commit ef5dbc5

Browse files
authored
Pigsty extensions for term estimators (#870)
1 parent 0727b95 commit ef5dbc5

File tree

7 files changed

+1451
-22
lines changed

7 files changed

+1451
-22
lines changed

src/Microsoft.ML.Data/Microsoft.ML.Data.csproj

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,18 @@
77
<DefineConstants>CORECLR</DefineConstants>
88
</PropertyGroup>
99

10+
<ItemGroup>
11+
<None Include="Transforms\TermStaticExtensions.cs">
12+
<DesignTime>True</DesignTime>
13+
<AutoGen>True</AutoGen>
14+
<DependentUpon>TermStaticExtensions.tt</DependentUpon>
15+
</None>
16+
</ItemGroup>
17+
1018
<ItemGroup>
1119
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonPackageVersion)" />
1220
<PackageReference Include="System.Collections.Immutable" Version="$(SystemCollectionsImmutableVersion)" />
21+
<PackageReference Include="System.Memory" Version="4.5.1" />
1322
<PackageReference Include="System.Threading.Tasks.Dataflow" Version="$(SystemThreadingTasksDataflowPackageVersion)" />
1423
</ItemGroup>
1524

@@ -18,4 +27,25 @@
1827
<ProjectReference Include="..\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.csproj" />
1928
</ItemGroup>
2029

30+
<ItemGroup>
31+
<None Update="Transforms\TermStaticExtensions.tt">
32+
<Generator>TextTemplatingFileGenerator</Generator>
33+
<LastGenOutput>TermStaticExtensions.cs</LastGenOutput>
34+
</None>
35+
</ItemGroup>
36+
37+
<ItemGroup>
38+
<Service Include="{508349b6-6b84-4df5-91f0-309beebad82d}" />
39+
</ItemGroup>
40+
41+
<ItemGroup>
42+
<Compile Update="Transforms\TermStaticExtensions.cs">
43+
<DesignTime>True</DesignTime>
44+
<AutoGen>True</AutoGen>
45+
<DependentUpon>TermStaticExtensions.tt</DependentUpon>
46+
<Generator>TextTemplatingFileGenerator</Generator>
47+
<LastGenOutput>TermStaticExtensions.cs</LastGenOutput>
48+
</Compile>
49+
</ItemGroup>
50+
2151
</Project>

src/Microsoft.ML.Data/Transforms/TermEstimator.cs

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
// See the LICENSE file in the project root for more information.
44

55
using Microsoft.ML.Core.Data;
6+
using Microsoft.ML.Data.StaticPipe.Runtime;
7+
using System;
8+
using System.Collections.Generic;
9+
using System.Collections.Immutable;
610
using System.Linq;
711

812
namespace Microsoft.ML.Runtime.Data
@@ -57,4 +61,137 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
5761
return new SchemaShape(result.Values);
5862
}
5963
}
64+
65+
public enum KeyValueOrder : byte
66+
{
67+
/// <summary>
68+
/// Terms will be assigned ID in the order in which they appear.
69+
/// </summary>
70+
Occurence = TermTransform.SortOrder.Occurrence,
71+
72+
/// <summary>
73+
/// Terms will be assigned ID according to their sort via an ordinal comparison for the type.
74+
/// </summary>
75+
Value = TermTransform.SortOrder.Value
76+
}
77+
78+
/// <summary>
79+
/// Information on the result of fitting a to-key transform.
80+
/// </summary>
81+
/// <typeparam name="T">The type of the values.</typeparam>
82+
public sealed class ToKeyFitResult<T>
83+
{
84+
/// <summary>
85+
/// For user defined delegates that accept instances of the containing type.
86+
/// </summary>
87+
/// <param name="result"></param>
88+
public delegate void OnFit(ToKeyFitResult<T> result);
89+
90+
// At the moment this is empty. Once PR #863 clears, we can change this class to hold the output
91+
// key-values metadata.
92+
93+
internal ToKeyFitResult(TermTransform.TermMap map)
94+
{
95+
}
96+
}
97+
98+
public static partial class TermStaticExtensions
99+
{
100+
// I am not certain I see a good way to cover the distinct types beyond complete enumeration.
101+
// Raw generics would allow illegal possible inputs, e.g., Scalar<Bitmap>. So, this is a partial
102+
// class, and all the public facing extension methods for each possible type are in a T4 generated result.
103+
104+
private const KeyValueOrder DefSort = (KeyValueOrder)TermTransform.Defaults.Sort;
105+
private const int DefMax = TermTransform.Defaults.MaxNumTerms;
106+
107+
private struct Config
108+
{
109+
public readonly KeyValueOrder Order;
110+
public readonly int Max;
111+
public readonly Action<TermTransform.TermMap> OnFit;
112+
113+
public Config(KeyValueOrder order, int max, Action<TermTransform.TermMap> onFit)
114+
{
115+
Order = order;
116+
Max = max;
117+
OnFit = onFit;
118+
}
119+
}
120+
121+
private static Action<TermTransform.TermMap> Wrap<T>(ToKeyFitResult<T>.OnFit onFit)
122+
{
123+
if (onFit == null)
124+
return null;
125+
// The type T asociated with the delegate will be the actual value type once #863 goes in.
126+
// However, until such time as #863 goes in, it would be too awkward to attempt to extract the metadata.
127+
// For now construct the useless object then pass it into the delegate.
128+
return map => onFit(new ToKeyFitResult<T>(map));
129+
}
130+
131+
private interface ITermCol
132+
{
133+
PipelineColumn Input { get; }
134+
Config Config { get; }
135+
}
136+
137+
private sealed class ImplScalar<T> : Key<uint, T>, ITermCol
138+
{
139+
public PipelineColumn Input { get; }
140+
public Config Config { get; }
141+
public ImplScalar(PipelineColumn input, Config config) : base(Rec.Inst, input)
142+
{
143+
Input = input;
144+
Config = config;
145+
}
146+
}
147+
148+
private sealed class ImplVector<T> : Vector<Key<uint, T>>, ITermCol
149+
{
150+
public PipelineColumn Input { get; }
151+
public Config Config { get; }
152+
public ImplVector(PipelineColumn input, Config config) : base(Rec.Inst, input)
153+
{
154+
Input = input;
155+
Config = config;
156+
}
157+
}
158+
159+
private sealed class ImplVarVector<T> : VarVector<Key<uint, T>>, ITermCol
160+
{
161+
public PipelineColumn Input { get; }
162+
public Config Config { get; }
163+
public ImplVarVector(PipelineColumn input, Config config) : base(Rec.Inst, input)
164+
{
165+
Input = input;
166+
Config = config;
167+
}
168+
}
169+
170+
private sealed class Rec : EstimatorReconciler
171+
{
172+
public static readonly Rec Inst = new Rec();
173+
174+
public override IEstimator<ITransformer> Reconcile(IHostEnvironment env, PipelineColumn[] toOutput,
175+
IReadOnlyDictionary<PipelineColumn, string> inputNames, IReadOnlyDictionary<PipelineColumn, string> outputNames, IReadOnlyCollection<string> usedNames)
176+
{
177+
var infos = new TermTransform.ColumnInfo[toOutput.Length];
178+
Action<TermTransform> onFit = null;
179+
for (int i=0; i<toOutput.Length; ++i)
180+
{
181+
var tcol = (ITermCol)toOutput[i];
182+
infos[i] = new TermTransform.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]],
183+
tcol.Config.Max, (TermTransform.SortOrder)tcol.Config.Order);
184+
if (tcol.Config.OnFit != null)
185+
{
186+
int ii = i; // Necessary because if we capture i that will change to toOutput.Length on call.
187+
onFit += tt => tcol.Config.OnFit(tt.GetTermMap(ii));
188+
}
189+
}
190+
var est = new TermEstimator(env, infos);
191+
if (onFit == null)
192+
return est;
193+
return est.WithOnFitDelegate(onFit);
194+
}
195+
}
196+
}
60197
}

0 commit comments

Comments
 (0)