Skip to content

Hash estimator #944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 20, 2018
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.Calibration;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Transforms;

[assembly: LoadableClass(typeof(CrossValidationCommand), typeof(CrossValidationCommand.Arguments), typeof(SignatureCommand),
"Cross Validation", CrossValidationCommand.LoadName)]
Expand Down Expand Up @@ -329,10 +330,7 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output
int inc = 0;
while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc);
var hashargs = new HashTransform.Arguments();
hashargs.Column = new[] { new HashTransform.Column { Source = origStratCol, Name = stratificationColumn } };
hashargs.HashBits = 30;
output = new HashTransform(Host, hashargs, input);
output = new HashConverter(Host, origStratCol, stratificationColumn, 30).Fit(input).Transform(input);
}
}

Expand Down
712 changes: 403 additions & 309 deletions src/Microsoft.ML.Data/Transforms/HashTransform.cs

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions src/Microsoft.ML.Transforms/CategoricalHashTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Transforms;

[assembly: LoadableClass(CategoricalHashTransform.Summary, typeof(IDataTransform), typeof(CategoricalHashTransform), typeof(CategoricalHashTransform.Arguments), typeof(SignatureDataTransform),
CategoricalHashTransform.UserName, "CategoricalHashTransform", "CatHashTransform", "CategoricalHash", "CatHash")]
Expand Down Expand Up @@ -91,7 +92,7 @@ private static class Defaults
}

/// <summary>
/// This class is a merger of <see cref="HashTransform.Arguments"/> and <see cref="KeyToVectorTransform.Arguments"/>
/// This class is a merger of <see cref="HashConverterTransformer.Arguments"/> and <see cref="KeyToVectorTransform.Arguments"/>
/// with join option removed
/// </summary>
public sealed class Arguments : TransformInputBase
Expand Down Expand Up @@ -169,13 +170,13 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
throw h.ExceptUserArg(nameof(args.HashBits), "Number of bits must be between 1 and {0}", NumBitsLim - 1);

// creating the Hash function
var hashArgs = new HashTransform.Arguments
var hashArgs = new HashConverterTransformer.Arguments
{
HashBits = args.HashBits,
Seed = args.Seed,
Ordered = args.Ordered,
InvertHash = args.InvertHash,
Column = new HashTransform.Column[args.Column.Length]
Column = new HashConverterTransformer.Column[args.Column.Length]
};
for (int i = 0; i < args.Column.Length; i++)
{
Expand All @@ -184,7 +185,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
throw h.ExceptUserArg(nameof(Column.Name));
h.Assert(!string.IsNullOrWhiteSpace(column.Name));
h.Assert(!string.IsNullOrWhiteSpace(column.Source));
hashArgs.Column[i] = new HashTransform.Column
hashArgs.Column[i] = new HashConverterTransformer.Column
{
HashBits = column.HashBits,
Seed = column.Seed,
Expand All @@ -198,7 +199,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
return CreateTransformCore(
args.OutputKind, args.Column,
args.Column.Select(col => col.OutputKind).ToList(),
new HashTransform(h, hashArgs, input),
HashConverterTransformer.Create(h, hashArgs, input),
h,
args);
}
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Transforms;

[assembly: LoadableClass(WordBagTransform.Summary, typeof(IDataTransform), typeof(WordBagTransform), typeof(WordBagTransform.Arguments), typeof(SignatureDataTransform),
"Word Bag Transform", "WordBagTransform", "WordBag")]
Expand Down Expand Up @@ -474,7 +475,7 @@ public interface INgramExtractorFactory
{
/// <summary>
/// Whether the extractor transform created by this factory uses the hashing trick
/// (by using <see cref="HashTransform"/> or <see cref="NgramHashTransform"/>, for example).
/// (by using <see cref="HashConverterTransformer"/> or <see cref="NgramHashTransform"/>, for example).
/// </summary>
bool UseHashingTrick { get; }

Expand Down
11 changes: 6 additions & 5 deletions src/Microsoft.ML.Transforms/Text/WordHashBagTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Transforms;

[assembly: LoadableClass(WordHashBagTransform.Summary, typeof(IDataTransform), typeof(WordHashBagTransform), typeof(WordHashBagTransform.Arguments), typeof(SignatureDataTransform),
"Word Hash Bag Transform", "WordHashBagTransform", "WordHashBag")]
Expand Down Expand Up @@ -266,7 +267,7 @@ public bool TryUnparse(StringBuilder sb)
}

/// <summary>
/// This class is a merger of <see cref="HashTransform.Arguments"/> and
/// This class is a merger of <see cref="HashConverterTransformer.Arguments"/> and
/// <see cref="NgramHashTransform.Arguments"/>, with the ordered option,
/// the rehashUnigrams option and the allLength option removed.
/// </summary>
Expand Down Expand Up @@ -340,7 +341,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
List<TermTransform.Column> termCols = null;
if (termLoaderArgs != null)
termCols = new List<TermTransform.Column>();
var hashColumns = new List<HashTransform.Column>();
var hashColumns = new List<HashConverterTransformer.Column>();
var ngramHashColumns = new NgramHashTransform.Column[args.Column.Length];

var colCount = args.Column.Length;
Expand Down Expand Up @@ -371,7 +372,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
}

hashColumns.Add(
new HashTransform.Column
new HashConverterTransformer.Column
{
Name = tmpName,
Source = termLoaderArgs == null ? column.Source[isrc] : tmpName,
Expand Down Expand Up @@ -435,7 +436,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV

// Args for the Hash function with multiple columns
var hashArgs =
new HashTransform.Arguments
new HashConverterTransformer.Arguments
{
HashBits = 31,
Seed = args.Seed,
Expand All @@ -444,7 +445,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
InvertHash = args.InvertHash
};

view = new HashTransform(h, hashArgs, view);
view = HashConverterTransformer.Create(h, hashArgs, view);

// creating the NgramHash function
var ngramHashArgs =
Expand Down
19 changes: 10 additions & 9 deletions test/Microsoft.ML.TestFramework/DataPipe/TestDataPipe.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.TextAnalytics;
using Xunit;
using Microsoft.ML.Transforms;
Copy link
Contributor

@Zruty0 Zruty0 Sep 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using [](start = 0, length = 5)

(placeholder for Tom's comment) #Resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Ivan, Pete asked me to tell you to sort your usings. :)


In reply to: 218936325 [](ancestors = 218936325)


namespace Microsoft.ML.Runtime.RunTests
{
Expand Down Expand Up @@ -82,14 +83,14 @@ private void TestHashTransformHelper<T>(T[] data, uint[] results, NumberType typ
builder.AddColumn("F1", type, data);
var srcView = builder.GetDataView();

HashTransform.Column col = new HashTransform.Column();
col.Source = "F1";
var col = new HashConverterTransformer.Column();
col.Name = "F1";
col.HashBits = 5;
col.Seed = 42;
HashTransform.Arguments args = new HashTransform.Arguments();
args.Column = new HashTransform.Column[] { col };
var args = new HashConverterTransformer.Arguments();
args.Column = new HashConverterTransformer.Column[] { col };

var hashTransform = new HashTransform(Env, args, srcView);
var hashTransform = HashConverterTransformer.Create(Env, args, srcView);
using (var cursor = hashTransform.GetRowCursor(c => true))
{
var resultGetter = cursor.GetGetter<uint>(1);
Expand Down Expand Up @@ -120,14 +121,14 @@ private void TestHashTransformVectorHelper<T>(VBuffer<T> data, uint[][] results,
private void TestHashTransformVectorHelper(ArrayDataViewBuilder builder, uint[][] results)
{
var srcView = builder.GetDataView();
HashTransform.Column col = new HashTransform.Column();
var col = new HashConverterTransformer.Column();
col.Source = "F1V";
col.HashBits = 5;
col.Seed = 42;
HashTransform.Arguments args = new HashTransform.Arguments();
args.Column = new HashTransform.Column[] { col };
var args = new HashConverterTransformer.Arguments();
args.Column = new HashConverterTransformer.Column[] { col };

var hashTransform = new HashTransform(Env, args, srcView);
var hashTransform = HashConverterTransformer.Create(Env, args, srcView);
using (var cursor = hashTransform.GetRowCursor(c => true))
{
var resultGetter = cursor.GetGetter<VBuffer<uint>>(1);
Expand Down
131 changes: 131 additions & 0 deletions test/Microsoft.ML.Tests/Transformers/HashTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.RunTests;
using Microsoft.ML.Runtime.Tools;
using Microsoft.ML.Transforms;
using System.IO;
using System.Linq;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.ML.Tests.Transformers
{
public class HashTests : TestDataPipeBase
{
public HashTests(ITestOutputHelper output) : base(output)
{
}

private class TestClass
{
public float A;
public float B;
public float C;
}

private class TestMeta
{
[VectorType(2)]
public float[] A;
public float B;
[VectorType(2)]
public double[] C;
public double D;
}

[Fact]
public void HashWorkout()
{
var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };

var dataView = ComponentCreation.CreateDataView(Env, data);
var pipe = new HashConverter(Env, new[]{
new HashConverterTransformer.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1),
new HashConverterTransformer.ColumnInfo("B", "HashB", hashBits:3, ordered:true),
new HashConverterTransformer.ColumnInfo("C", "HashC", seed:42),
new HashConverterTransformer.ColumnInfo("A", "HashD"),
});

TestEstimatorCore(pipe, dataView);
Done();
}

[Fact]
public void TestMetadata()
{

var data = new[] {
new TestMeta() { A=new float[2] { 3.5f, 2.5f}, B=1, C= new double[2] { 5.1f, 6.1f}, D= 7},
new TestMeta() { A=new float[2] { 3.5f, 2.5f}, B=1, C= new double[2] { 5.1f, 6.1f}, D= 7},
new TestMeta() { A=new float[2] { 3.5f, 2.5f}, B=1, C= new double[2] { 5.1f, 6.1f}, D= 7}};


var dataView = ComponentCreation.CreateDataView(Env, data);
var pipe = new HashConverter(Env, new[] {
new HashConverterTransformer.ColumnInfo("A", "HashA", invertHash:1, hashBits:10),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HashConverterTransformer [](start = 20, length = 24)

hmm, now that I think of it, we should probably have ColumnInfo be part of the estimator as well, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want this to be part of this PR or as separate PR?


In reply to: 218936515 [](ancestors = 218936515)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR, if not too cumbersome


In reply to: 218941758 [](ancestors = 218941758,218936515)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, we have this pattern in all estimators. Wouldn't it be better to move stuff in one PR, rather than in separate? Or you want me to update all estimators as well?


In reply to: 218944749 [](ancestors = 218944749,218941758,218936515)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather prefer to keep one style in code, and if it needed to flip everything, other than have two styles in code.


In reply to: 218945227 [](ancestors = 218945227,218944749,218941758,218936515)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair. Let's do it in a different PR for all at once, where applicable


In reply to: 218954028 [](ancestors = 218954028,218945227,218944749,218941758,218936515)

new HashConverterTransformer.ColumnInfo("A", "HashAUnlim", invertHash:-1, hashBits:10),
new HashConverterTransformer.ColumnInfo("A", "HashAUnlimOrdered", invertHash:-1, hashBits:10, ordered:true)
});
var result = pipe.Fit(dataView).Transform(dataView);
ValidateMetadata(result);
Done();
}

private void ValidateMetadata(IDataView result)
{

Assert.True(result.Schema.TryGetColumnIndex("HashA", out int HashA));
Assert.True(result.Schema.TryGetColumnIndex("HashAUnlim", out int HashAUnlim));
Assert.True(result.Schema.TryGetColumnIndex("HashAUnlimOrdered", out int HashAUnlimOrdered));
VBuffer<DvText> keys = default;
var types = result.Schema.GetMetadataTypes(HashA);
Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.KeyValues });
result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, HashA, ref keys);
Assert.True(keys.Length == 1024);
Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] {"2.5", "3.5" });

types = result.Schema.GetMetadataTypes(HashAUnlim);
Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.KeyValues });
result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, HashA, ref keys);
Assert.True(keys.Length == 1024);
Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] { "2.5", "3.5" });

types = result.Schema.GetMetadataTypes(HashAUnlimOrdered);
Assert.Equal(types.Select(x => x.Key), new string[1] { MetadataUtils.Kinds.KeyValues });
result.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, HashA, ref keys);
Assert.True(keys.Length == 1024);
Assert.Equal(keys.Items().Select(x => x.Value.ToString()), new string[2] { "2.5", "3.5" });
}
[Fact]
public void TestCommandLine()
{
Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:R4:0} xf=Hash{col=B:A} in=f:\2.txt" }), (int)0);
}

[Fact]
public void TestOldSavingAndLoading()
{
var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } };
var dataView = ComponentCreation.CreateDataView(Env, data);
var pipe = new HashConverter(Env, new[]{
new HashConverterTransformer.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1),
new HashConverterTransformer.ColumnInfo("B", "HashB", hashBits:3, ordered:true),
new HashConverterTransformer.ColumnInfo("C", "HashC", seed:42),
new HashConverterTransformer.ColumnInfo("A", "HashD"),
});
var result = pipe.Fit(dataView).Transform(dataView);
var resultRoles = new RoleMappedData(result);
using (var ms = new MemoryStream())
{
TrainUtils.SaveModel(Env, Env.Start("saving"), ms, null, resultRoles);
ms.Position = 0;
var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms);
}
}
}
}