Skip to content

Commit 760cef2

Browse files
xaduprecodemzs
authored andcommitted
Fix creation of dataviews inferred with .NET types with sparse vectors (dotnet#587)
`DataViewConstructionUtils`'s methods to create dataviews over .NET types will now have correctly inferred "getters" in the case of sparse vectors.
1 parent c1993d6 commit 760cef2

File tree

3 files changed

+118
-0
lines changed

3 files changed

+118
-0
lines changed

src/Microsoft.ML.Api/DataViewConstructionUtils.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ private Delegate CreateGetter(int index)
198198
Ch.Assert(outputType.GetGenericTypeDefinition() == typeof(VBuffer<>));
199199
Ch.Assert(outputType.GetGenericArguments()[0] == colType.ItemType.RawType);
200200
del = CreateDirectVBufferGetterDelegate<int>;
201+
genericType = colType.ItemType.RawType;
201202
}
202203
else if (colType.IsPrimitive)
203204
{

src/Microsoft.ML.Api/TypedCursor.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ private Action<TRow> GenerateSetter(IRow input, int index, InternalSchemaDefinit
349349
Ch.Assert(fieldType.GetGenericTypeDefinition() == typeof(VBuffer<>));
350350
Ch.Assert(fieldType.GetGenericArguments()[0] == colType.ItemType.RawType);
351351
del = CreateVBufferToVBufferSetter<int>;
352+
genericType = colType.ItemType.RawType;
352353
}
353354
else if (colType.IsPrimitive)
354355
{
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime.Api;
6+
using Microsoft.ML.Runtime.Data;
7+
using Xunit;
8+
using Xunit.Abstractions;
9+
10+
namespace Microsoft.ML.Runtime.RunTests
11+
{
12+
public sealed class TestSparseDataView : TestDataViewBase
13+
{
14+
private const string Cat = "DataView";
15+
16+
public TestSparseDataView(ITestOutputHelper obj) : base(obj)
17+
{
18+
}
19+
20+
private class DenseExample<T>
21+
{
22+
[VectorType(2)]
23+
public T[] X;
24+
}
25+
26+
private class SparseExample<T>
27+
{
28+
[VectorType(5)]
29+
public VBuffer<T> X;
30+
}
31+
32+
[Fact]
33+
[TestCategory(Cat)]
34+
public void SparseDataView()
35+
{
36+
GenericSparseDataView(new[] { 1f, 2f, 3f }, new[] { 1f, 10f, 100f });
37+
GenericSparseDataView(new DvInt4[] { 1, 2, 3 }, new DvInt4[] { 1, 10, 100 });
38+
GenericSparseDataView(new DvBool[] { true, true, true }, new DvBool[] { false, false, false });
39+
GenericSparseDataView(new double[] { 1, 2, 3 }, new double[] { 1, 10, 100 });
40+
GenericSparseDataView(new DvText[] { new DvText("a"), new DvText("b"), new DvText("c") },
41+
new DvText[] { new DvText("aa"), new DvText("bb"), new DvText("cc") });
42+
}
43+
44+
private void GenericSparseDataView<T>(T[] v1, T[] v2)
45+
{
46+
var inputs = new[] {
47+
new SparseExample<T>() { X = new VBuffer<T> (5, 3, v1, new int[] { 0, 2, 4 }) },
48+
new SparseExample<T>() { X = new VBuffer<T> (5, 3, v2, new int[] { 0, 1, 3 }) }
49+
};
50+
using (var host = new TlcEnvironment())
51+
{
52+
var data = host.CreateStreamingDataView(inputs);
53+
var value = new VBuffer<T>();
54+
int n = 0;
55+
using (var cur = data.GetRowCursor(i => true))
56+
{
57+
var getter = cur.GetGetter<VBuffer<T>>(0);
58+
while (cur.MoveNext())
59+
{
60+
getter(ref value);
61+
Assert.True(value.Count == 3);
62+
++n;
63+
}
64+
}
65+
Assert.True(n == 2);
66+
var iter = data.AsEnumerable<SparseExample<T>>(host, false).GetEnumerator();
67+
n = 0;
68+
while (iter.MoveNext())
69+
++n;
70+
Assert.True(n == 2);
71+
}
72+
}
73+
74+
[Fact]
75+
[TestCategory(Cat)]
76+
public void DenseDataView()
77+
{
78+
GenericDenseDataView(new[] { 1f, 2f, 3f }, new[] { 1f, 10f, 100f });
79+
GenericDenseDataView(new DvInt4[] { 1, 2, 3 }, new DvInt4[] { 1, 10, 100 });
80+
GenericDenseDataView(new DvBool[] { true, true, true }, new DvBool[] { false, false, false });
81+
GenericDenseDataView(new double[] { 1, 2, 3 }, new double[] { 1, 10, 100 });
82+
GenericDenseDataView(new DvText[] { new DvText("a"), new DvText("b"), new DvText("c") },
83+
new DvText[] { new DvText("aa"), new DvText("bb"), new DvText("cc") });
84+
}
85+
86+
private void GenericDenseDataView<T>(T[] v1, T[] v2)
87+
{
88+
var inputs = new[] {
89+
new DenseExample<T>() { X = v1 },
90+
new DenseExample<T>() { X = v2 }
91+
};
92+
using (var host = new TlcEnvironment())
93+
{
94+
var data = host.CreateStreamingDataView(inputs);
95+
var value = new VBuffer<T>();
96+
int n = 0;
97+
using (var cur = data.GetRowCursor(i => true))
98+
{
99+
var getter = cur.GetGetter<VBuffer<T>>(0);
100+
while (cur.MoveNext())
101+
{
102+
getter(ref value);
103+
Assert.True(value.Count == 3);
104+
++n;
105+
}
106+
}
107+
Assert.True(n == 2);
108+
var iter = data.AsEnumerable<DenseExample<T>>(host, false).GetEnumerator();
109+
n = 0;
110+
while (iter.MoveNext())
111+
++n;
112+
Assert.True(n == 2);
113+
}
114+
}
115+
}
116+
}

0 commit comments

Comments
 (0)