Skip to content

Commit 97cc104

Browse files
authored
Concat estimator with pigsty extensions for ConcatWith, AsVector (#881)
1 parent c01c46f commit 97cc104

File tree

7 files changed

+590
-1
lines changed

7 files changed

+590
-1
lines changed

src/Microsoft.ML.Core/Data/MetadataUtils.cs

+16
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,22 @@ public static bool IsNormalized(this SchemaShape.Column col)
366366
&& metaCol.ItemType == BoolType.Instance;
367367
}
368368

369+
/// <summary>
370+
/// Returns whether a column has the <see cref="Kinds.SlotNames"/> metadata indicated by
371+
/// the schema shape.
372+
/// </summary>
373+
/// <param name="col">The schema shape column to query</param>
374+
/// <returns>True if and only if the column is a definite sized vector type, has the
375+
/// <see cref="Kinds.SlotNames"/> metadata of definite sized vectors of text.</returns>
376+
public static bool HasSlotNames(this SchemaShape.Column col)
377+
{
378+
Contracts.CheckValue(col, nameof(col));
379+
return col.Kind == SchemaShape.Column.VectorKind.Vector
380+
&& col.Metadata.TryFindColumn(Kinds.SlotNames, out var metaCol)
381+
&& metaCol.Kind == SchemaShape.Column.VectorKind.Vector && !metaCol.IsKey
382+
&& metaCol.ItemType == TextType.Instance;
383+
}
384+
369385
/// <summary>
370386
/// Tries to get the metadata kind of the specified type for a column.
371387
/// </summary>

src/Microsoft.ML.Data/Transforms/ConcatEstimator.cs

+418
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#@ TextLoader{
2+
#@ header+
3+
#@ sep=tab
4+
#@ col=f1:R4:0-0
5+
#@ col=f2:R4:1-2
6+
#@ col=f3:R4:3-7
7+
#@ col=f4:R4:8-**
8+
#@ }
9+
float1 float1 float1 float4.age float4.fnlwgt float4.education-num float4.capital-gain float1
10+
25 25 25 25 226802 7 0 25 25 226802 7 0 0 40 0 25
11+
38 38 38 38 89814 9 0 38 38 89814 9 0 0 50 0 38
12+
28 28 28 28 336951 12 0 28 28 336951 12 0 0 40 1 28
13+
44 44 44 44 160323 10 7688 44 44 160323 10 7688 0 40 1 44
14+
18 18 18 18 103497 10 0 18 18 103497 10 0 0 30 0 18
15+
34 34 34 34 198693 6 0 34 34 198693 6 0 0 30 0 34
16+
29 29 29 29 227026 9 0 29 29 227026 9 0 0 40 0 29
17+
63 63 63 63 104626 15 3103 63 63 104626 15 3103 0 32 1 63
18+
24 24 24 24 369667 10 0 24 24 369667 10 0 0 40 0 24
19+
55 55 55 55 104996 4 0 55 55 104996 4 0 0 10 0 55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#@ TextLoader{
2+
#@ header+
3+
#@ sep=tab
4+
#@ col=f1:R4:0-0
5+
#@ col=f2:R4:1-2
6+
#@ col=f3:R4:3-7
7+
#@ col=f4:R4:8-**
8+
#@ }
9+
float1 float1 float1 float4.age float4.fnlwgt float4.education-num float4.capital-gain float1
10+
25 25 25 25 226802 7 0 25 25 226802 7 0 0 40 0 25
11+
38 38 38 38 89814 9 0 38 38 89814 9 0 0 50 0 38
12+
28 28 28 28 336951 12 0 28 28 336951 12 0 0 40 1 28
13+
44 44 44 44 160323 10 7688 44 44 160323 10 7688 0 40 1 44
14+
18 18 18 18 103497 10 0 18 18 103497 10 0 0 30 0 18
15+
34 34 34 34 198693 6 0 34 34 198693 6 0 0 30 0 34
16+
29 29 29 29 227026 9 0 29 29 227026 9 0 0 40 0 29
17+
63 63 63 63 104626 15 3103 63 63 104626 15 3103 0 32 1 63
18+
24 24 24 24 369667 10 0 24 24 369667 10 0 0 40 0 24
19+
55 55 55 55 104996 4 0 55 55 104996 4 0 0 10 0 55

test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs

+38
Original file line numberDiff line numberDiff line change
@@ -365,5 +365,43 @@ public void ToKey()
365365
// Because they're over exactly the same data, they ought to have the same cardinality and everything.
366366
Assert.True(valuesKeyKeyType.Equals(valuesKeyType));
367367
}
368+
369+
[Fact]
370+
public void ConcatWith()
371+
{
372+
var env = new TlcEnvironment(seed: 0);
373+
var dataPath = GetDataPath("iris.data");
374+
var reader = TextLoader.CreateReader(env,
375+
c => (label: c.LoadText(4), values: c.LoadFloat(0, 3), value: c.LoadFloat(2)),
376+
separator: ',');
377+
var dataSource = new MultiFileSource(dataPath);
378+
var data = reader.Read(dataSource);
379+
380+
var est = data.MakeNewEstimator()
381+
.Append(r => (
382+
r.label, r.values, r.value,
383+
c0: r.label.AsVector(), c1: r.label.ConcatWith(r.label),
384+
c2: r.value.ConcatWith(r.values), c3: r.values.ConcatWith(r.value, r.values)));
385+
386+
var tdata = est.Fit(data).Transform(data);
387+
var schema = tdata.AsDynamic.Schema;
388+
389+
int[] idx = new int[4];
390+
for (int i = 0; i < idx.Length; ++i)
391+
Assert.True(schema.TryGetColumnIndex("c" + i, out idx[i]), $"Could not find col c{i}");
392+
var types = new VectorType[idx.Length];
393+
int[] expectedLen = new int[] { 1, 2, 5, 9 };
394+
for (int i = 0; i < idx.Length; ++i)
395+
{
396+
var type = schema.GetColumnType(idx[i]);
397+
Assert.True(type.VectorSize > 0, $"Col c{i} had unexpected type {type}");
398+
types[i] = type.AsVector;
399+
Assert.Equal(expectedLen[i], type.VectorSize);
400+
}
401+
Assert.Equal(TextType.Instance, types[0].ItemType);
402+
Assert.Equal(TextType.Instance, types[1].ItemType);
403+
Assert.Equal(NumberType.Float, types[2].ItemType);
404+
Assert.Equal(NumberType.Float, types[3].ItemType);
405+
}
368406
}
369407
}

test/Microsoft.ML.Tests/TermEstimatorTests.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class TestMetaClass
4848
}
4949

5050
[Fact]
51-
void TestDifferntTypes()
51+
void TestDifferentTypes()
5252
{
5353
string dataPath = GetDataPath("adult.test");
5454

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime.Api;
6+
using Microsoft.ML.Runtime.Data;
7+
using Microsoft.ML.Runtime.Data.IO;
8+
using Microsoft.ML.Runtime.Model;
9+
using Microsoft.ML.Runtime.RunTests;
10+
using Microsoft.ML.Runtime.Tools;
11+
using System.IO;
12+
using Xunit;
13+
using Xunit.Abstractions;
14+
15+
namespace Microsoft.ML.Tests.Transformers
16+
{
17+
public sealed class ConcatTests : TestDataPipeBase
18+
{
19+
public ConcatTests(ITestOutputHelper output) : base(output)
20+
{
21+
}
22+
23+
[Fact]
24+
void TestConcat()
25+
{
26+
string dataPath = GetDataPath("adult.test");
27+
28+
var source = new MultiFileSource(dataPath);
29+
var loader = new TextLoader(Env, new TextLoader.Arguments
30+
{
31+
Column = new[]{
32+
new TextLoader.Column("float1", DataKind.R4, 0),
33+
new TextLoader.Column("float4", DataKind.R4, new[]{new TextLoader.Range(0), new TextLoader.Range(2), new TextLoader.Range(4), new TextLoader.Range(10) }),
34+
new TextLoader.Column("vfloat", DataKind.R4, new[]{new TextLoader.Range(0), new TextLoader.Range(2), new TextLoader.Range(4), new TextLoader.Range(10, null) { AutoEnd = false, VariableEnd = true } })
35+
},
36+
Separator = ",",
37+
HasHeader = true
38+
}, new MultiFileSource(dataPath));
39+
var data = loader.Read(source);
40+
41+
ColumnType GetType(ISchema schema, string name)
42+
{
43+
Assert.True(schema.TryGetColumnIndex(name, out int cIdx), $"Could not find '{name}'");
44+
return schema.GetColumnType(cIdx);
45+
}
46+
var pipe = new ConcatEstimator(Env, "f1", "float1")
47+
.Append(new ConcatEstimator(Env, "f2", "float1", "float1"))
48+
.Append(new ConcatEstimator(Env, "f3", "float4", "float1"))
49+
.Append(new ConcatEstimator(Env, "f4", "vfloat", "float1"));
50+
51+
data = TakeFilter.Create(Env, data, 10);
52+
data = pipe.Fit(data).Transform(data);
53+
54+
ColumnType t;
55+
t = GetType(data.Schema, "f1");
56+
Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 1);
57+
t = GetType(data.Schema, "f2");
58+
Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 2);
59+
t = GetType(data.Schema, "f3");
60+
Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 5);
61+
t = GetType(data.Schema, "f4");
62+
Assert.True(t.IsVector && t.ItemType == NumberType.R4 && t.VectorSize == 0);
63+
64+
data = new ChooseColumnsTransform(Env, data, "f1", "f2", "f3", "f4");
65+
66+
var subdir = Path.Combine("Transform", "Concat");
67+
var outputPath = GetOutputPath(subdir, "Concat1.tsv");
68+
using (var ch = Env.Start("save"))
69+
{
70+
var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true, Dense = true });
71+
using (var fs = File.Create(outputPath))
72+
DataSaverUtils.SaveDataView(ch, saver, data, fs, keepHidden: false);
73+
}
74+
75+
CheckEquality(subdir, "Concat1.tsv");
76+
Done();
77+
}
78+
}
79+
}

0 commit comments

Comments
 (0)