Skip to content

Commit 95cedf5

Browse files
authored
Hash estimator (#944)
1 parent 6812cb5 commit 95cedf5

File tree

8 files changed

+582
-342
lines changed

8 files changed

+582
-342
lines changed

src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs

+2-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
using Microsoft.ML.Runtime.Data;
1414
using Microsoft.ML.Runtime.Internal.Calibration;
1515
using Microsoft.ML.Runtime.Internal.Utilities;
16+
using Microsoft.ML.Transforms;
1617

1718
[assembly: LoadableClass(typeof(CrossValidationCommand), typeof(CrossValidationCommand.Arguments), typeof(SignatureCommand),
1819
"Cross Validation", CrossValidationCommand.LoadName)]
@@ -329,10 +330,7 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output
329330
int inc = 0;
330331
while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
331332
stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc);
332-
var hashargs = new HashTransform.Arguments();
333-
hashargs.Column = new[] { new HashTransform.Column { Source = origStratCol, Name = stratificationColumn } };
334-
hashargs.HashBits = 30;
335-
output = new HashTransform(Host, hashargs, input);
333+
output = new HashEstimator(Host, origStratCol, stratificationColumn, 30).Fit(input).Transform(input);
336334
}
337335
}
338336

src/Microsoft.ML.Data/Transforms/HashTransform.cs

+415-304
Large diffs are not rendered by default.

src/Microsoft.ML.Data/Transforms/TermEstimator.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ public static class Defaults
2525
/// Convenience constructor for public facing API.
2626
/// </summary>
2727
/// <param name="env">Host Environment.</param>
28-
/// <param name="name">Name of the output column.</param>
29-
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
28+
/// <param name="inputColumn">Name of the output column.</param>
29+
/// <param name="outputColumn">Name of the column to be transformed. If this is null '<paramref name="inputColumn"/>' will be used.</param>
3030
/// <param name="maxNumTerms">Maximum number of terms to keep per column when auto-training.</param>
3131
/// <param name="sort">How items should be ordered when vectorized. By default, they will be in the order encountered.
3232
/// If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a').</param>
33-
public TermEstimator(IHostEnvironment env, string name, string source = null, int maxNumTerms = Defaults.MaxNumTerms, TermTransform.SortOrder sort = Defaults.Sort) :
34-
this(env, new TermTransform.ColumnInfo(name, source ?? name, maxNumTerms, sort))
33+
public TermEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, int maxNumTerms = Defaults.MaxNumTerms, TermTransform.SortOrder sort = Defaults.Sort) :
34+
this(env, new TermTransform.ColumnInfo(inputColumn, outputColumn ?? inputColumn, maxNumTerms, sort))
3535
{
3636
}
3737

src/Microsoft.ML.Transforms/CategoricalHashTransform.cs

+6-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using Microsoft.ML.Runtime.Data;
1111
using Microsoft.ML.Runtime.EntryPoints;
1212
using Microsoft.ML.Runtime.Internal.Utilities;
13+
using Microsoft.ML.Transforms;
1314

1415
[assembly: LoadableClass(CategoricalHashTransform.Summary, typeof(IDataTransform), typeof(CategoricalHashTransform), typeof(CategoricalHashTransform.Arguments), typeof(SignatureDataTransform),
1516
CategoricalHashTransform.UserName, "CategoricalHashTransform", "CatHashTransform", "CategoricalHash", "CatHash")]
@@ -91,7 +92,7 @@ private static class Defaults
9192
}
9293

9394
/// <summary>
94-
/// This class is a merger of <see cref="HashTransform.Arguments"/> and <see cref="KeyToVectorTransform.Arguments"/>
95+
/// This class is a merger of <see cref="HashTransformer.Arguments"/> and <see cref="KeyToVectorTransform.Arguments"/>
9596
/// with join option removed
9697
/// </summary>
9798
public sealed class Arguments : TransformInputBase
@@ -169,13 +170,13 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
169170
throw h.ExceptUserArg(nameof(args.HashBits), "Number of bits must be between 1 and {0}", NumBitsLim - 1);
170171

171172
// creating the Hash function
172-
var hashArgs = new HashTransform.Arguments
173+
var hashArgs = new HashTransformer.Arguments
173174
{
174175
HashBits = args.HashBits,
175176
Seed = args.Seed,
176177
Ordered = args.Ordered,
177178
InvertHash = args.InvertHash,
178-
Column = new HashTransform.Column[args.Column.Length]
179+
Column = new HashTransformer.Column[args.Column.Length]
179180
};
180181
for (int i = 0; i < args.Column.Length; i++)
181182
{
@@ -184,7 +185,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
184185
throw h.ExceptUserArg(nameof(Column.Name));
185186
h.Assert(!string.IsNullOrWhiteSpace(column.Name));
186187
h.Assert(!string.IsNullOrWhiteSpace(column.Source));
187-
hashArgs.Column[i] = new HashTransform.Column
188+
hashArgs.Column[i] = new HashTransformer.Column
188189
{
189190
HashBits = column.HashBits,
190191
Seed = column.Seed,
@@ -198,7 +199,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
198199
return CreateTransformCore(
199200
args.OutputKind, args.Column,
200201
args.Column.Select(col => col.OutputKind).ToList(),
201-
new HashTransform(h, hashArgs, input),
202+
HashTransformer.Create(h, hashArgs, input),
202203
h,
203204
args);
204205
}

src/Microsoft.ML.Transforms/Text/WordBagTransform.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
using Microsoft.ML.Runtime.Data;
1414
using Microsoft.ML.Runtime.EntryPoints;
1515
using Microsoft.ML.Runtime.Internal.Utilities;
16+
using Microsoft.ML.Transforms;
1617

1718
[assembly: LoadableClass(WordBagTransform.Summary, typeof(IDataTransform), typeof(WordBagTransform), typeof(WordBagTransform.Arguments), typeof(SignatureDataTransform),
1819
"Word Bag Transform", "WordBagTransform", "WordBag")]
@@ -474,7 +475,7 @@ public interface INgramExtractorFactory
474475
{
475476
/// <summary>
476477
/// Whether the extractor transform created by this factory uses the hashing trick
477-
/// (by using <see cref="HashTransform"/> or <see cref="NgramHashTransform"/>, for example).
478+
/// (by using <see cref="HashTransformer"/> or <see cref="NgramHashTransform"/>, for example).
478479
/// </summary>
479480
bool UseHashingTrick { get; }
480481

src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs

+6-5
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using Microsoft.ML.Runtime.Data;
1212
using Microsoft.ML.Runtime.EntryPoints;
1313
using Microsoft.ML.Runtime.Internal.Utilities;
14+
using Microsoft.ML.Transforms;
1415

1516
[assembly: LoadableClass(WordHashBagTransform.Summary, typeof(IDataTransform), typeof(WordHashBagTransform), typeof(WordHashBagTransform.Arguments), typeof(SignatureDataTransform),
1617
"Word Hash Bag Transform", "WordHashBagTransform", "WordHashBag")]
@@ -266,7 +267,7 @@ public bool TryUnparse(StringBuilder sb)
266267
}
267268

268269
/// <summary>
269-
/// This class is a merger of <see cref="HashTransform.Arguments"/> and
270+
/// This class is a merger of <see cref="HashTransformer.Arguments"/> and
270271
/// <see cref="NgramHashTransform.Arguments"/>, with the ordered option,
271272
/// the rehashUnigrams option and the allLength option removed.
272273
/// </summary>
@@ -340,7 +341,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
340341
List<TermTransform.Column> termCols = null;
341342
if (termLoaderArgs != null)
342343
termCols = new List<TermTransform.Column>();
343-
var hashColumns = new List<HashTransform.Column>();
344+
var hashColumns = new List<HashTransformer.Column>();
344345
var ngramHashColumns = new NgramHashTransform.Column[args.Column.Length];
345346

346347
var colCount = args.Column.Length;
@@ -371,7 +372,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
371372
}
372373

373374
hashColumns.Add(
374-
new HashTransform.Column
375+
new HashTransformer.Column
375376
{
376377
Name = tmpName,
377378
Source = termLoaderArgs == null ? column.Source[isrc] : tmpName,
@@ -435,7 +436,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
435436

436437
// Args for the Hash function with multiple columns
437438
var hashArgs =
438-
new HashTransform.Arguments
439+
new HashTransformer.Arguments
439440
{
440441
HashBits = 31,
441442
Seed = args.Seed,
@@ -444,7 +445,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
444445
InvertHash = args.InvertHash
445446
};
446447

447-
view = new HashTransform(h, hashArgs, view);
448+
view = HashTransformer.Create(h, hashArgs, view);
448449

449450
// creating the NgramHash function
450451
var ngramHashArgs =

test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs

+13-19
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,12 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using Float = System.Single;
6-
7-
using System;
8-
using System.Collections.Generic;
9-
using System.IO;
10-
using Microsoft.ML.Runtime.CommandLine;
115
using Microsoft.ML.Runtime.Data;
12-
using Microsoft.ML.Runtime.Data.IO;
13-
using Microsoft.ML.Runtime.Internal.Utilities;
14-
using Microsoft.ML.Runtime.Model;
156
using Microsoft.ML.Runtime.TextAnalytics;
7+
using Microsoft.ML.Transforms;
8+
using System;
169
using Xunit;
10+
using Float = System.Single;
1711

1812
namespace Microsoft.ML.Runtime.RunTests
1913
{
@@ -82,14 +76,14 @@ private void TestHashTransformHelper<T>(T[] data, uint[] results, NumberType typ
8276
builder.AddColumn("F1", type, data);
8377
var srcView = builder.GetDataView();
8478

85-
HashTransform.Column col = new HashTransform.Column();
86-
col.Source = "F1";
79+
var col = new HashTransformer.Column();
80+
col.Name = "F1";
8781
col.HashBits = 5;
8882
col.Seed = 42;
89-
HashTransform.Arguments args = new HashTransform.Arguments();
90-
args.Column = new HashTransform.Column[] { col };
83+
var args = new HashTransformer.Arguments();
84+
args.Column = new HashTransformer.Column[] { col };
9185

92-
var hashTransform = new HashTransform(Env, args, srcView);
86+
var hashTransform = HashTransformer.Create(Env, args, srcView);
9387
using (var cursor = hashTransform.GetRowCursor(c => true))
9488
{
9589
var resultGetter = cursor.GetGetter<uint>(1);
@@ -120,14 +114,14 @@ private void TestHashTransformVectorHelper<T>(VBuffer<T> data, uint[][] results,
120114
private void TestHashTransformVectorHelper(ArrayDataViewBuilder builder, uint[][] results)
121115
{
122116
var srcView = builder.GetDataView();
123-
HashTransform.Column col = new HashTransform.Column();
124-
col.Source = "F1V";
117+
var col = new HashTransformer.Column();
118+
col.Name = "F1V";
125119
col.HashBits = 5;
126120
col.Seed = 42;
127-
HashTransform.Arguments args = new HashTransform.Arguments();
128-
args.Column = new HashTransform.Column[] { col };
121+
var args = new HashTransformer.Arguments();
122+
args.Column = new HashTransformer.Column[] { col };
129123

130-
var hashTransform = new HashTransform(Env, args, srcView);
124+
var hashTransform = HashTransformer.Create(Env, args, srcView);
131125
using (var cursor = hashTransform.GetRowCursor(c => true))
132126
{
133127
var resultGetter = cursor.GetGetter<VBuffer<uint>>(1);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime.Api;
6+
using Microsoft.ML.Runtime.Data;
7+
using Microsoft.ML.Runtime.Model;
8+
using Microsoft.ML.Runtime.RunTests;
9+
using Microsoft.ML.Runtime.Tools;
10+
using Microsoft.ML.Transforms;
11+
using System;
12+
using System.IO;
13+
using System.Linq;
14+
using Xunit;
15+
using Xunit.Abstractions;
16+
17+
namespace Microsoft.ML.Tests.Transformers
18+
{
19+
public class HashTests : TestDataPipeBase
20+
{
21+
public HashTests(ITestOutputHelper output) : base(output)
22+
{
23+
}
24+
25+
private class TestClass
26+
{
27+
public float A;
28+
public float B;
29+
public float C;
30+
}
31+
32+
private class TestMeta
33+
{
34+
[VectorType(2)]
35+
public float[] A;
36+
public float B;
37+
[VectorType(2)]
38+
public double[] C;
39+
public double D;
40+
}
41+
42+
[Fact]
43+
public void HashWorkout()
44+
{
45+
var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
46+
47+
var dataView = ComponentCreation.CreateDataView(Env, data);
48+
var pipe = new HashEstimator(Env, new[]{
49+
new HashTransformer.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1),
50+
new HashTransformer.ColumnInfo("B", "HashB", hashBits:3, ordered:true),
51+
new HashTransformer.ColumnInfo("C", "HashC", seed:42),
52+
new HashTransformer.ColumnInfo("A", "HashD"),
53+
});
54+
55+
TestEstimatorCore(pipe, dataView);
56+
Done();
57+
}
58+
59+
[Fact]
60+
public void TestMetadata()
61+
{
62+
63+
var data = new[] {
64+
new TestMeta() { A=new float[2] { 3.5f, 2.5f}, B=1, C= new double[2] { 5.1f, 6.1f}, D= 7},
65+
new TestMeta() { A=new float[2] { 3.5f, 2.5f}, B=1, C= new double[2] { 5.1f, 6.1f}, D= 7},
66+
new TestMeta() { A=new float[2] { 3.5f, 2.5f}, B=1, C= new double[2] { 5.1f, 6.1f}, D= 7}};
67+
68+
69+
var dataView = ComponentCreation.CreateDataView(Env, data);
70+
var pipe = new HashEstimator(Env, new[] {
71+
new HashTransformer.ColumnInfo("A", "HashA", invertHash:1, hashBits:10),
72+
new HashTransformer.ColumnInfo("A", "HashAUnlim", invertHash:-1, hashBits:10),
73+
new HashTransformer.ColumnInfo("A", "HashAUnlimOrdered", invertHash:-1, hashBits:10, ordered:true)
74+
});
75+
var result = pipe.Fit(dataView).Transform(dataView);
76+
ValidateMetadata(result);
77+
Done();
78+
}
79+
80+
private void ValidateMetadata(IDataView result)
81+
{
82+
83+
Assert.True(result.Schema.TryGetColumnIndex("HashA", out int HashA));
84+
Assert.True(result.Schema.TryGetColumnIndex("HashAUnlim", out int HashAUnlim));
85+
Assert.True(result.Schema.TryGetColumnIndex("HashAUnlimOrdered", out int HashAUnlimOrdered));
86+
VBuffer<ReadOnlyMemory<char>> keys = default;
87+
var types = result.Schema.GetMetadataTypes(HashA);
88+
Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.KeyValues });
89+
result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, HashA, ref keys);
90+
Assert.True(keys.Length == 1024);
91+
//REVIEW: This is weird. I specified invertHash to 1 so I expect only one value to be in key values, but i got two.
92+
Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] {"2.5", "3.5" });
93+
94+
types = result.Schema.GetMetadataTypes(HashAUnlim);
95+
Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.KeyValues });
96+
result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, HashA, ref keys);
97+
Assert.True(keys.Length == 1024);
98+
Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] { "2.5", "3.5" });
99+
100+
types = result.Schema.GetMetadataTypes(HashAUnlimOrdered);
101+
Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.KeyValues });
102+
result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, HashA, ref keys);
103+
Assert.True(keys.Length == 1024);
104+
Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] { "2.5", "3.5" });
105+
}
106+
107+
[Fact]
108+
public void TestCommandLine()
109+
{
110+
Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Hash{col=B:A} in=f:\2.txt" }), (int)0);
111+
}
112+
113+
[Fact]
114+
public void TestOldSavingAndLoading()
115+
{
116+
var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
117+
var dataView = ComponentCreation.CreateDataView(Env, data);
118+
var pipe = new HashEstimator(Env, new[]{
119+
new HashTransformer.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1),
120+
new HashTransformer.ColumnInfo("B", "HashB", hashBits:3, ordered:true),
121+
new HashTransformer.ColumnInfo("C", "HashC", seed:42),
122+
new HashTransformer.ColumnInfo("A", "HashD"),
123+
});
124+
var result = pipe.Fit(dataView).Transform(dataView);
125+
var resultRoles = new RoleMappedData(result);
126+
using (var ms = new MemoryStream())
127+
{
128+
TrainUtils.SaveModel(Env, Env.Start("saving"), ms, null, resultRoles);
129+
ms.Position = 0;
130+
var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms);
131+
}
132+
}
133+
}
134+
}

0 commit comments

Comments
 (0)