Skip to content

Concat estimator with pigsty extensions for ConcatWith, AsVector #881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/Microsoft.ML.Core/Data/MetadataUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,22 @@ public static bool IsNormalized(this SchemaShape.Column col)
&& metaCol.ItemType == BoolType.Instance;
}

/// <summary>
/// Returns whether a column has the <see cref="Kinds.SlotNames"/> metadata indicated by
/// the schema shape.
/// </summary>
/// <param name="col">The schema shape column to query</param>
/// <returns>True if and only if the column is a definite sized vector type, has the
/// <see cref="Kinds.SlotNames"/> metadata of definite sized vectors of text.</returns>
public static bool HasSlotNames(this SchemaShape.Column col)
{
Contracts.CheckValue(col, nameof(col));
return col.Kind == SchemaShape.Column.VectorKind.Vector
&& col.Metadata.TryFindColumn(Kinds.SlotNames, out var metaCol)
&& metaCol.Kind == SchemaShape.Column.VectorKind.Vector && !metaCol.IsKey
&& metaCol.ItemType == TextType.Instance;
}

/// <summary>
/// Tries to get the metadata kind of the specified type for a column.
/// </summary>
Expand Down
418 changes: 418 additions & 0 deletions src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs

Large diffs are not rendered by default.

19 changes: 19 additions & 0 deletions test/BaselineOutput/SingleDebug/Transform/Concat/Concat1.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#@ TextLoader{
#@ header+
#@ sep=tab
#@ col=f1:R4:0-0
#@ col=f2:R4:1-2
#@ col=f3:R4:3-7
#@ col=f4:R4:8-**
#@ }
float1 float1 float1 float4.age float4.fnlwgt float4.education-num float4.capital-gain float1
25 25 25 25 226802 7 0 25 25 226802 7 0 0 40 0 25
38 38 38 38 89814 9 0 38 38 89814 9 0 0 50 0 38
28 28 28 28 336951 12 0 28 28 336951 12 0 0 40 1 28
44 44 44 44 160323 10 7688 44 44 160323 10 7688 0 40 1 44
18 18 18 18 103497 10 0 18 18 103497 10 0 0 30 0 18
34 34 34 34 198693 6 0 34 34 198693 6 0 0 30 0 34
29 29 29 29 227026 9 0 29 29 227026 9 0 0 40 0 29
63 63 63 63 104626 15 3103 63 63 104626 15 3103 0 32 1 63
24 24 24 24 369667 10 0 24 24 369667 10 0 0 40 0 24
55 55 55 55 104996 4 0 55 55 104996 4 0 0 10 0 55
19 changes: 19 additions & 0 deletions test/BaselineOutput/SingleRelease/Transform/Concat/Concat1.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#@ TextLoader{
#@ header+
#@ sep=tab
#@ col=f1:R4:0-0
#@ col=f2:R4:1-2
#@ col=f3:R4:3-7
#@ col=f4:R4:8-**
#@ }
float1 float1 float1 float4.age float4.fnlwgt float4.education-num float4.capital-gain float1
25 25 25 25 226802 7 0 25 25 226802 7 0 0 40 0 25
38 38 38 38 89814 9 0 38 38 89814 9 0 0 50 0 38
28 28 28 28 336951 12 0 28 28 336951 12 0 0 40 1 28
44 44 44 44 160323 10 7688 44 44 160323 10 7688 0 40 1 44
18 18 18 18 103497 10 0 18 18 103497 10 0 0 30 0 18
34 34 34 34 198693 6 0 34 34 198693 6 0 0 30 0 34
29 29 29 29 227026 9 0 29 29 227026 9 0 0 40 0 29
63 63 63 63 104626 15 3103 63 63 104626 15 3103 0 32 1 63
24 24 24 24 369667 10 0 24 24 369667 10 0 0 40 0 24
55 55 55 55 104996 4 0 55 55 104996 4 0 0 10 0 55
38 changes: 38 additions & 0 deletions test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -365,5 +365,43 @@ public void ToKey()
// Because they're over exactly the same data, they ought to have the same cardinality and everything.
Assert.True(valuesKeyKeyType.Equals(valuesKeyType));
}

[Fact]
public void ConcatWith()
{
var env = new TlcEnvironment(seed: 0);
var dataPath = GetDataPath("iris.data");
var reader = TextLoader.CreateReader(env,
c => (label: c.LoadText(4), values: c.LoadFloat(0, 3), value: c.LoadFloat(2)),
separator: ',');
var dataSource = new MultiFileSource(dataPath);
var data = reader.Read(dataSource);

var est = data.MakeNewEstimator()
.Append(r => (
r.label, r.values, r.value,
c0: r.label.AsVector(), c1: r.label.ConcatWith(r.label),
c2: r.value.ConcatWith(r.values), c3: r.values.ConcatWith(r.value, r.values)));

var tdata = est.Fit(data).Transform(data);
var schema = tdata.AsDynamic.Schema;

int[] idx = new int[4];
for (int i = 0; i < idx.Length; ++i)
Assert.True(schema.TryGetColumnIndex("c" + i, out idx[i]), $"Could not find col c{i}");
var types = new VectorType[idx.Length];
int[] expectedLen = new int[] { 1, 2, 5, 9 };
for (int i = 0; i < idx.Length; ++i)
{
var type = schema.GetColumnType(idx[i]);
Assert.True(type.VectorSize > 0, $"Col c{i} had unexpected type {type}");
types[i] = type.AsVector;
Assert.Equal(expectedLen[i], type.VectorSize);
}
Assert.Equal(TextType.Instance, types[0].ItemType);
Assert.Equal(TextType.Instance, types[1].ItemType);
Assert.Equal(NumberType.Float, types[2].ItemType);
Assert.Equal(NumberType.Float, types[3].ItemType);
}
}
}
2 changes: 1 addition & 1 deletion test/Microsoft.ML.Tests/TermEstimatorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class TestMetaClass
}

[Fact]
void TestDifferntTypes()
void TestDifferentTypes()
{
string dataPath = GetDataPath("adult.test");

Expand Down
79 changes: 79 additions & 0 deletions test/Microsoft.ML.Tests/Transformers/ConcatTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.IO;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.RunTests;
using Microsoft.ML.Runtime.Tools;
using System.IO;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.ML.Tests.Transformers
{
public sealed class ConcatTests : TestDataPipeBase
{
public ConcatTests(ITestOutputHelper output) : base(output)
{
}

[Fact]
void TestConcat()
{
string dataPath = GetDataPath("adult.test");

var source = new MultiFileSource(dataPath);
var loader = new TextLoader(Env, new TextLoader.Arguments
{
Column = new[]{
new TextLoader.Column("float1", DataKind.R4, 0),
new TextLoader.Column("float4", DataKind.R4, new[]{new TextLoader.Range(0), new TextLoader.Range(2), new TextLoader.Range(4), new TextLoader.Range(10) }),
new TextLoader.Column("vfloat", DataKind.R4, new[]{new TextLoader.Range(0), new TextLoader.Range(2), new TextLoader.Range(4), new TextLoader.Range(10, null) { AutoEnd = false, VariableEnd = true } })
},
Separator = ",",
HasHeader = true
}, new MultiFileSource(dataPath));
var data = loader.Read(source);

ColumnType GetType(ISchema schema, string name)
{
Assert.True(schema.TryGetColumnIndex(name, out int cIdx), $"Could not find '{name}'");
return schema.GetColumnType(cIdx);
}
var pipe = new ConcatEstimator(Env, "f1", "float1")
.Append(new ConcatEstimator(Env, "f2", "float1", "float1"))
.Append(new ConcatEstimator(Env, "f3", "float4", "float1"))
.Append(new ConcatEstimator(Env, "f4", "vfloat", "float1"));

data = TakeFilter.Create(Env, data, 10);
data = pipe.Fit(data).Transform(data);

ColumnType t;
t = GetType(data.Schema, "f1");
Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 1);
t = GetType(data.Schema, "f2");
Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 2);
t = GetType(data.Schema, "f3");
Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 5);
t = GetType(data.Schema, "f4");
Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 0);

data = new ChooseColumnsTransform(Env, data, "f1", "f2", "f3", "f4");

var subdir = Path.Combine("Transform", "Concat");
var outputPath = GetOutputPath(subdir, "Concat1.tsv");
using (var ch = Env.Start("save"))
{
var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true, Dense = true });
using (var fs = File.Create(outputPath))
DataSaverUtils.SaveDataView(ch, saver, data, fs, keepHidden: false);
}

CheckEquality(subdir, "Concat1.tsv");
Done();
}
}
}