diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index 341e3a72af..e940ea9d4d 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -198,6 +198,7 @@ private Delegate CreateGetter(int index) Ch.Assert(outputType.GetGenericTypeDefinition() == typeof(VBuffer<>)); Ch.Assert(outputType.GetGenericArguments()[0] == colType.ItemType.RawType); del = CreateDirectVBufferGetterDelegate; + genericType = colType.ItemType.RawType; } else if (colType.IsPrimitive) { diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index f6ebaf687f..cd8198e14d 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -349,6 +349,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit Ch.Assert(fieldType.GetGenericTypeDefinition() == typeof(VBuffer<>)); Ch.Assert(fieldType.GetGenericArguments()[0] == colType.ItemType.RawType); del = CreateVBufferToVBufferSetter; + genericType = colType.ItemType.RawType; } else if (colType.IsPrimitive) { diff --git a/test/Microsoft.ML.TestFramework/TestSparseDataView.cs b/test/Microsoft.ML.TestFramework/TestSparseDataView.cs new file mode 100644 index 0000000000..08c9e17a28 --- /dev/null +++ b/test/Microsoft.ML.TestFramework/TestSparseDataView.cs @@ -0,0 +1,116 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Runtime.RunTests +{ + public sealed class TestSparseDataView : TestDataViewBase + { + private const string Cat = "DataView"; + + public TestSparseDataView(ITestOutputHelper obj) : base(obj) + { + } + + private class DenseExample + { + [VectorType(2)] + public T[] X; + } + + private class SparseExample + { + [VectorType(5)] + public VBuffer X; + } + + [Fact] + [TestCategory(Cat)] + public void SparseDataView() + { + GenericSparseDataView(new[] { 1f, 2f, 3f }, new[] { 1f, 10f, 100f }); + GenericSparseDataView(new DvInt4[] { 1, 2, 3 }, new DvInt4[] { 1, 10, 100 }); + GenericSparseDataView(new DvBool[] { true, true, true }, new DvBool[] { false, false, false }); + GenericSparseDataView(new double[] { 1, 2, 3 }, new double[] { 1, 10, 100 }); + GenericSparseDataView(new DvText[] { new DvText("a"), new DvText("b"), new DvText("c") }, + new DvText[] { new DvText("aa"), new DvText("bb"), new DvText("cc") }); + } + + private void GenericSparseDataView(T[] v1, T[] v2) + { + var inputs = new[] { + new SparseExample() { X = new VBuffer (5, 3, v1, new int[] { 0, 2, 4 }) }, + new SparseExample() { X = new VBuffer (5, 3, v2, new int[] { 0, 1, 3 }) } + }; + using (var host = new TlcEnvironment()) + { + var data = host.CreateStreamingDataView(inputs); + var value = new VBuffer(); + int n = 0; + using (var cur = data.GetRowCursor(i => true)) + { + var getter = cur.GetGetter>(0); + while (cur.MoveNext()) + { + getter(ref value); + Assert.True(value.Count == 3); + ++n; + } + } + Assert.True(n == 2); + var iter = data.AsEnumerable>(host, false).GetEnumerator(); + n = 0; + while (iter.MoveNext()) + ++n; + Assert.True(n == 2); + } + } + + [Fact] + [TestCategory(Cat)] + public void DenseDataView() + { + GenericDenseDataView(new[] { 1f, 2f, 3f }, new[] { 1f, 10f, 100f }); + GenericDenseDataView(new DvInt4[] { 1, 2, 3 }, new DvInt4[] { 1, 10, 100 }); + GenericDenseDataView(new DvBool[] { true, true, true }, new DvBool[] { false, false, false }); + GenericDenseDataView(new double[] { 1, 2, 3 }, new double[] { 1, 10, 100 }); + GenericDenseDataView(new DvText[] { new DvText("a"), new DvText("b"), new DvText("c") }, + new DvText[] { new DvText("aa"), new DvText("bb"), new DvText("cc") }); + } + + private void GenericDenseDataView(T[] v1, T[] v2) + { + var inputs = new[] { + new DenseExample() { X = v1 }, + new DenseExample() { X = v2 } + }; + using (var host = new TlcEnvironment()) + { + var data = host.CreateStreamingDataView(inputs); + var value = new VBuffer(); + int n = 0; + using (var cur = data.GetRowCursor(i => true)) + { + var getter = cur.GetGetter>(0); + while (cur.MoveNext()) + { + getter(ref value); + Assert.True(value.Count == 3); + ++n; + } + } + Assert.True(n == 2); + var iter = data.AsEnumerable>(host, false).GetEnumerator(); + n = 0; + while (iter.MoveNext()) + ++n; + Assert.True(n == 2); + } + } + } +}